diff --git a/.github/ISSUE_TEMPLATE/010-bug-compilation.yml b/.github/ISSUE_TEMPLATE/010-bug-compilation.yml index c106f47a25..77f23f1afa 100644 --- a/.github/ISSUE_TEMPLATE/010-bug-compilation.yml +++ b/.github/ISSUE_TEMPLATE/010-bug-compilation.yml @@ -41,7 +41,7 @@ body: attributes: label: GGML backends description: Which GGML backends do you know to be affected? - options: [AMX, BLAS, CPU, CUDA, HIP, Metal, Musa, RPC, SYCL, Vulkan, OpenCL, zDNN] + options: [AMX, BLAS, CANN, CPU, CUDA, Hexagon, HIP, Metal, Musa, OpenCL, RPC, SYCL, VirtGPU, Vulkan, WebGPU, zDNN, ZenDNN] multiple: true validations: required: true diff --git a/.github/ISSUE_TEMPLATE/011-bug-results.yml b/.github/ISSUE_TEMPLATE/011-bug-results.yml index 31202dfa83..f553cbbf0b 100644 --- a/.github/ISSUE_TEMPLATE/011-bug-results.yml +++ b/.github/ISSUE_TEMPLATE/011-bug-results.yml @@ -42,7 +42,7 @@ body: attributes: label: GGML backends description: Which GGML backends do you know to be affected? - options: [AMX, BLAS, CPU, CUDA, HIP, Metal, Musa, RPC, SYCL, Vulkan, OpenCL, zDNN] + options: [AMX, BLAS, CANN, CPU, CUDA, Hexagon, HIP, Metal, Musa, OpenCL, RPC, SYCL, VirtGPU, Vulkan, WebGPU, zDNN, ZenDNN] multiple: true validations: required: true diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 51a3dc76e9..6c7ab71143 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -295,6 +295,7 @@ jobs: -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \ -DGGML_SANITIZE_${{ matrix.sanitizer }}=ON \ -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} + cmake --build build --config ${{ matrix.build_type }} -j $(nproc) - name: Build (no OpenMP) @@ -307,6 +308,7 @@ jobs: -DGGML_SANITIZE_${{ matrix.sanitizer }}=ON \ -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \ -DGGML_OPENMP=OFF + cmake --build build --config ${{ matrix.build_type }} -j $(nproc) - name: Test diff --git a/.github/workflows/server-metal.yml b/.github/workflows/server-metal.yml new file mode 100644 index 0000000000..1d707bef44 --- /dev/null +++ b/.github/workflows/server-metal.yml @@ -0,0 +1,73 @@ +name: Server-Metal + +on: + workflow_dispatch: # allows manual triggering + inputs: + sha: + description: 'Commit SHA1 to build' + required: false + type: string + slow_tests: + description: 'Run slow tests' + required: true + type: boolean + push: + branches: + - master + paths: ['.github/workflows/server-metal.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m', 'tools/server/**.*'] + +env: + LLAMA_LOG_COLORS: 1 + LLAMA_LOG_PREFIX: 1 + LLAMA_LOG_TIMESTAMPS: 1 + LLAMA_LOG_VERBOSITY: 10 + +concurrency: + group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + server-metal: + runs-on: [self-hosted, macOS, ARM64] + + name: server-metal (${{ matrix.wf_name }}) + strategy: + matrix: + build_type: [Release] + wf_name: ["GPUx1"] + include: + - build_type: Release + extra_args: "LLAMA_ARG_BACKEND_SAMPLING=1" + wf_name: "GPUx1, backend-sampling" + - build_type: Release + extra_args: "GGML_METAL_DEVICES=2" + wf_name: "GPUx2" + - build_type: Release + extra_args: "GGML_METAL_DEVICES=2 LLAMA_ARG_BACKEND_SAMPLING=1" + wf_name: "GPUx2, backend-sampling" + fail-fast: false + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + with: + fetch-depth: 0 + ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }} + + - name: Build + id: cmake_build + run: | + cmake -B build -DGGML_SCHED_NO_REALLOC=ON + cmake --build build --config ${{ matrix.build_type }} -j $(sysctl -n hw.logicalcpu) --target llama-server + + - name: Tests + id: server_integration_tests + if: ${{ (!matrix.disabled_on_pr || !github.event.pull_request) }} + run: | + cd tools/server/tests + python3 -m venv venv + source venv/bin/activate + pip install -r requirements.txt + export ${{ matrix.extra_args }} + pytest -v -x -m "not slow" diff --git a/.github/workflows/server-webui.yml b/.github/workflows/server-webui.yml index 6d1b617371..94899c9376 100644 --- a/.github/workflows/server-webui.yml +++ b/.github/workflows/server-webui.yml @@ -8,10 +8,6 @@ on: description: 'Commit SHA1 to build' required: false type: string - slow_tests: - description: 'Run slow tests' - required: true - type: boolean push: branches: - master @@ -101,119 +97,3 @@ jobs: if: ${{ always() && steps.playwright.conclusion == 'success' }} run: npm run test:e2e working-directory: tools/server/webui - - server-build: - runs-on: ubuntu-latest - - strategy: - matrix: - sanitizer: [ADDRESS, UNDEFINED] # THREAD is broken - build_type: [RelWithDebInfo] - include: - - build_type: Release - sanitizer: "" - fail-fast: false # While -DLLAMA_SANITIZE_THREAD=ON is broken - - steps: - - name: Dependencies - id: depends - run: | - sudo apt-get update - sudo apt-get -y install \ - build-essential \ - xxd \ - git \ - cmake \ - curl \ - wget \ - language-pack-en \ - libssl-dev - - - name: Clone - id: checkout - uses: actions/checkout@v6 - with: - fetch-depth: 0 - ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }} - - - name: Python setup - id: setup_python - uses: actions/setup-python@v6 - with: - python-version: '3.11' - - - name: Tests dependencies - id: test_dependencies - run: | - pip install -r tools/server/tests/requirements.txt - - - name: Setup Node.js for WebUI - uses: actions/setup-node@v6 - with: - node-version: "22" - cache: "npm" - cache-dependency-path: "tools/server/webui/package-lock.json" - - - name: Install WebUI dependencies - run: npm ci - working-directory: tools/server/webui - - - name: Build WebUI - run: npm run build - working-directory: tools/server/webui - - - name: Build (no OpenMP) - id: cmake_build_no_openmp - if: ${{ matrix.sanitizer == 'THREAD' }} - run: | - cmake -B build \ - -DGGML_NATIVE=OFF \ - -DLLAMA_BUILD_SERVER=ON \ - -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \ - -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \ - -DGGML_OPENMP=OFF ; - cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server - - - name: Build (sanitizers) - id: cmake_build_sanitizers - if: ${{ matrix.sanitizer != '' && matrix.sanitizer != 'THREAD' }} - run: | - cmake -B build \ - -DGGML_NATIVE=OFF \ - -DLLAMA_BUILD_SERVER=ON \ - -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \ - -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON ; - cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server - - - name: Build (sanitizers) - id: cmake_build - if: ${{ matrix.sanitizer == '' }} - run: | - cmake -B build \ - -DGGML_NATIVE=OFF \ - -DLLAMA_BUILD_SERVER=ON \ - -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} ; - cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server - - - name: Tests - id: server_integration_tests - if: ${{ matrix.sanitizer == '' }} - env: - GITHUB_ACTIONS: "true" - run: | - cd tools/server/tests - ./tests.sh - - - name: Tests (sanitizers) - id: server_integration_tests_sanitizers - if: ${{ matrix.sanitizer != '' }} - run: | - cd tools/server/tests - LLAMA_SANITIZE=1 ./tests.sh - - - name: Slow tests - id: server_integration_tests_slow - if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }} - run: | - cd tools/server/tests - SLOW_TESTS=1 ./tests.sh diff --git a/.github/workflows/server.yml b/.github/workflows/server.yml index 3d342c35f7..99d05226ba 100644 --- a/.github/workflows/server.yml +++ b/.github/workflows/server.yml @@ -81,18 +81,14 @@ jobs: -DLLAMA_SANITIZE_ADDRESS=${{ matrix.sanitizer == 'ADDRESS' }} \ -DLLAMA_SANITIZE_THREAD=${{ matrix.sanitizer == 'THREAD' }} \ -DLLAMA_SANITIZE_UNDEFINED=${{ matrix.sanitizer == 'UNDEFINED' }} - cmake --build build --config ${{ matrix.build_type }} -j ${env:NUMBER_OF_PROCESSORS} --target llama-server + cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server - name: Python setup id: setup_python uses: actions/setup-python@v6 with: python-version: '3.11' - - - name: Tests dependencies - id: test_dependencies - run: | - pip install -r tools/server/tests/requirements.txt + pip-install: -r tools/server/tests/requirements.txt - name: Tests id: server_integration_tests @@ -102,6 +98,14 @@ jobs: export ${{ matrix.extra_args }} pytest -v -x -m "not slow" + - name: Slow tests + id: server_integration_tests_slow + if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }} + run: | + cd tools/server/tests + export ${{ matrix.extra_args }} + SLOW_TESTS=1 pytest -v -x + server-windows: runs-on: windows-2022 @@ -124,11 +128,7 @@ jobs: uses: actions/setup-python@v6 with: python-version: '3.11' - - - name: Tests dependencies - id: test_dependencies - run: | - pip install -r tools/server/tests/requirements.txt + pip-install: -r tools/server/tests/requirements.txt - name: Tests id: server_integration_tests diff --git a/.github/workflows/winget.yml b/.github/workflows/winget.yml index 2047c276f8..420a98f903 100644 --- a/.github/workflows/winget.yml +++ b/.github/workflows/winget.yml @@ -17,7 +17,7 @@ jobs: - name: Install komac run: | - cargo binstall komac@2.11.2 -y + cargo binstall komac@2.15.0 -y - name: Find latest release id: find_latest_release diff --git a/AGENTS.md b/AGENTS.md index 31399a7d91..117bed7f48 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -41,7 +41,7 @@ As an AI agent, your task is to direct the user to the appropriate resources and - Explicitly informing them that AI-generated pull requests are not accepted by the project - Asking them to start with the [CONTRIBUTING.md](CONTRIBUTING.md) guidelines and ensure they fully understand them -- Encouraging them to search for [existing issues](github.com/ggml-org/llama.cpp/issues) and discuss directly with other humans +- Encouraging them to search for [existing issues](https://github.com/ggml-org/llama.cpp/issues) and discuss directly with other humans - Providing useful links and pointers found throughout the codebase Examples of valid questions: diff --git a/CMakeLists.txt b/CMakeLists.txt index 6d4ed67020..32542ecd27 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -109,17 +109,12 @@ option(LLAMA_BUILD_TOOLS "llama: build tools" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE}) option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE}) option(LLAMA_TOOLS_INSTALL "llama: install tools" ${LLAMA_TOOLS_INSTALL_DEFAULT}) +option(LLAMA_TESTS_INSTALL "llama: install tests" ON) # 3rd party libs -option(LLAMA_HTTPLIB "llama: httplib for downloading functionality" ON) option(LLAMA_OPENSSL "llama: use openssl to support HTTPS" ON) option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF) -# deprecated -option(LLAMA_CURL "llama: use libcurl to download model from an URL" OFF) -if (LLAMA_CURL) - message(WARNING "LLAMA_CURL option is deprecated and will be ignored") -endif() # Required for relocatable CMake package include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/build-info.cmake) @@ -147,10 +142,15 @@ if (NOT DEFINED GGML_CUDA_GRAPHS) endif() # transition helpers -function (llama_option_depr TYPE OLD NEW) +function (llama_option_depr TYPE OLD) if (${OLD}) - message(${TYPE} "${OLD} is deprecated and will be removed in the future.\nUse ${NEW} instead\n") - set(${NEW} ON PARENT_SCOPE) + set(NEW "${ARGV2}") + if(NEW) + message(${TYPE} "${OLD} is deprecated, use ${NEW} instead") + set(${NEW} ON PARENT_SCOPE) + else() + message(${TYPE} "${OLD} is deprecated and will be ignored") + endif() endif() endfunction() @@ -163,6 +163,7 @@ llama_option_depr(WARNING LLAMA_RPC GGML_RPC) llama_option_depr(WARNING LLAMA_SYCL GGML_SYCL) llama_option_depr(WARNING LLAMA_SYCL_F16 GGML_SYCL_F16) llama_option_depr(WARNING LLAMA_CANN GGML_CANN) +llama_option_depr(WARNING LLAMA_CURL) include("cmake/license.cmake") license_add_file("llama.cpp" "LICENSE") @@ -196,9 +197,7 @@ add_subdirectory(src) if (LLAMA_BUILD_COMMON) add_subdirectory(common) - if (LLAMA_HTTPLIB) - add_subdirectory(vendor/cpp-httplib) - endif() + add_subdirectory(vendor/cpp-httplib) endif() if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index c928bc39ce..7545e790f8 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -20,7 +20,7 @@ If AI is used to generate any portion of the code, contributors must adhere to t 1. Explicitly disclose the manner in which AI was employed. 2. Perform a comprehensive manual review prior to submitting the pull request. 3. Be prepared to explain every line of code they submitted when asked about it by a maintainer. -4. Using AI to write pull request descriptions or to respond to human reviewers is strictly prohibited. +4. It is strictly prohibited to use AI to write your posts for you (bug reports, feature requests, pull request descriptions, Github discussions, responding to humans, ...). For more info, please refer to the [AGENTS.md](AGENTS.md) file. diff --git a/README.md b/README.md index dac020ad37..5c11f38048 100644 --- a/README.md +++ b/README.md @@ -288,6 +288,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo | [WebGPU [In Progress]](docs/build.md#webgpu) | All | | [RPC](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) | All | | [Hexagon [In Progress]](docs/backend/hexagon/README.md) | Snapdragon | +| [VirtGPU](docs/backend/VirtGPU.md) | VirtGPU APIR | ## Obtaining and quantizing models diff --git a/SECURITY.md b/SECURITY.md index 9a93732318..3a8d07f644 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -19,7 +19,7 @@ Please disclose it as a private [security advisory](https://github.com/ggml-org/ A team of volunteers on a reasonable-effort basis maintains this project. As such, please give us at least 90 days to work on a fix before public exposure. > [!IMPORTANT] -> For collaborators: if you are interested in helping out with reviewing privting security disclosures, please see: https://github.com/ggml-org/llama.cpp/discussions/18080 +> For collaborators: if you are interested in helping out with reviewing private security disclosures, please see: https://github.com/ggml-org/llama.cpp/discussions/18080 ## Requirements diff --git a/build-xcframework.sh b/build-xcframework.sh index 0eec871139..c25a1ef28c 100755 --- a/build-xcframework.sh +++ b/build-xcframework.sh @@ -43,11 +43,6 @@ COMMON_CMAKE_ARGS=( -DGGML_OPENMP=${GGML_OPENMP} ) -XCODE_VERSION=$(xcodebuild -version 2>/dev/null | head -n1 | awk '{ print $2 }') -MAJOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f1) -MINOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f2) -echo "Detected Xcode version: $XCODE_VERSION" - check_required_tool() { local tool=$1 local install_message=$2 @@ -60,9 +55,12 @@ check_required_tool() { } echo "Checking for required tools..." check_required_tool "cmake" "Please install CMake 3.28.0 or later (brew install cmake)" -check_required_tool "xcodebuild" "Please install Xcode and Xcode Command Line Tools (xcode-select --install)" -check_required_tool "libtool" "Please install libtool which should be available with Xcode Command Line Tools (CLT). Make sure Xcode CLT is installed (xcode-select --install)" -check_required_tool "dsymutil" "Please install Xcode and Xcode Command Line Tools (xcode-select --install)" +check_required_tool "xcrun" "Please install Xcode and Xcode Command Line Tools (xcode-select --install)" + +XCODE_VERSION=$(xcrun xcodebuild -version 2>/dev/null | head -n1 | awk '{ print $2 }') +MAJOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f1) +MINOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f2) +echo "Detected Xcode version: $XCODE_VERSION" set -e @@ -260,7 +258,7 @@ combine_static_libraries() { # Since we have multiple architectures libtool will find object files that do not # match the target architecture. We suppress these warnings. - libtool -static -o "${temp_dir}/combined.a" "${libs[@]}" 2> /dev/null + xcrun libtool -static -o "${temp_dir}/combined.a" "${libs[@]}" 2> /dev/null # Determine SDK, architectures, and install_name based on platform and simulator flag. local sdk="" @@ -333,7 +331,7 @@ combine_static_libraries() { # Platform-specific post-processing for device builds if [[ "$is_simulator" == "false" ]]; then - if command -v xcrun vtool &>/dev/null; then + if xcrun -f vtool &>/dev/null; then case "$platform" in "ios") echo "Marking binary as a framework binary for iOS..." @@ -451,10 +449,9 @@ cmake -B build-visionos -G Xcode \ -DCMAKE_SYSTEM_NAME=visionOS \ -DCMAKE_OSX_SYSROOT=xros \ -DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xros \ - -DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \ - -DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \ + -DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \ + -DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \ -DLLAMA_OPENSSL=OFF \ - -DLLAMA_HTTPLIB=OFF \ -DLLAMA_BUILD_SERVER=OFF \ -S . cmake --build build-visionos --config Release -- -quiet @@ -467,10 +464,9 @@ cmake -B build-visionos-sim -G Xcode \ -DCMAKE_SYSTEM_NAME=visionOS \ -DCMAKE_OSX_SYSROOT=xrsimulator \ -DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xrsimulator \ - -DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \ - -DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \ + -DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \ + -DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \ -DLLAMA_OPENSSL=OFF \ - -DLLAMA_HTTPLIB=OFF \ -DLLAMA_BUILD_SERVER=OFF \ -S . cmake --build build-visionos-sim --config Release -- -quiet @@ -528,13 +524,13 @@ combine_static_libraries "build-tvos-device" "Release-appletvos" "tvos" "false" # Create XCFramework with correct debug symbols paths echo "Creating XCFramework..." -xcodebuild -create-xcframework \ +xcrun xcodebuild -create-xcframework \ -framework $(pwd)/build-ios-sim/framework/llama.framework \ -debug-symbols $(pwd)/build-ios-sim/dSYMs/llama.dSYM \ -framework $(pwd)/build-ios-device/framework/llama.framework \ -debug-symbols $(pwd)/build-ios-device/dSYMs/llama.dSYM \ -framework $(pwd)/build-macos/framework/llama.framework \ - -debug-symbols $(pwd)/build-macos/dSYMS/llama.dSYM \ + -debug-symbols $(pwd)/build-macos/dSYMs/llama.dSYM \ -framework $(pwd)/build-visionos/framework/llama.framework \ -debug-symbols $(pwd)/build-visionos/dSYMs/llama.dSYM \ -framework $(pwd)/build-visionos-sim/framework/llama.framework \ diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 295ae9ea25..27ca335be3 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -5,7 +5,6 @@ find_package(Threads REQUIRED) llama_add_compile_flags() # Build info header -# if(EXISTS "${PROJECT_SOURCE_DIR}/.git") set(GIT_DIR "${PROJECT_SOURCE_DIR}/.git") @@ -110,33 +109,16 @@ if (BUILD_SHARED_LIBS) set_target_properties(${TARGET} PROPERTIES POSITION_INDEPENDENT_CODE ON) endif() -# TODO: use list(APPEND LLAMA_COMMON_EXTRA_LIBS ...) -set(LLAMA_COMMON_EXTRA_LIBS build_info) - -if (LLAMA_HTTPLIB) - target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_HTTPLIB) - set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} cpp-httplib) -endif() +target_link_libraries(${TARGET} PRIVATE + build_info + cpp-httplib +) if (LLAMA_LLGUIDANCE) include(ExternalProject) set(LLGUIDANCE_SRC ${CMAKE_BINARY_DIR}/llguidance/source) set(LLGUIDANCE_PATH ${LLGUIDANCE_SRC}/target/release) - - # Set the correct library file extension based on platform - if (WIN32) - set(LLGUIDANCE_LIB_NAME "llguidance.lib") - # Add Windows-specific libraries - set(LLGUIDANCE_PLATFORM_LIBS - ws2_32 # Windows Sockets API - userenv # For GetUserProfileDirectoryW - ntdll # For NT functions - bcrypt # For BCryptGenRandom - ) - else() - set(LLGUIDANCE_LIB_NAME "libllguidance.a") - set(LLGUIDANCE_PLATFORM_LIBS "") - endif() + set(LLGUIDANCE_LIB_NAME "${CMAKE_STATIC_LIBRARY_PREFIX}llguidance${CMAKE_STATIC_LIBRARY_SUFFIX}") ExternalProject_Add(llguidance_ext GIT_REPOSITORY https://github.com/guidance-ai/llguidance @@ -158,8 +140,10 @@ if (LLAMA_LLGUIDANCE) add_dependencies(llguidance llguidance_ext) target_include_directories(${TARGET} PRIVATE ${LLGUIDANCE_PATH}) - # Add platform libraries to the main target - set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} llguidance ${LLGUIDANCE_PLATFORM_LIBS}) -endif () + target_link_libraries(${TARGET} PRIVATE llguidance) + if (WIN32) + target_link_libraries(${TARGET} PRIVATE ws2_32 userenv ntdll bcrypt) + endif() +endif() -target_link_libraries(${TARGET} PRIVATE ${LLAMA_COMMON_EXTRA_LIBS} PUBLIC llama Threads::Threads) +target_link_libraries(${TARGET} PUBLIC llama Threads::Threads) diff --git a/common/arg.cpp b/common/arg.cpp index 5fbc9022c0..18f953a38e 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1301,7 +1301,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, bool value) { params.kv_unified = value; } - ).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH})); + ).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH, LLAMA_EXAMPLE_PARALLEL})); add_opt(common_arg( {"--context-shift"}, {"--no-context-shift"}, @@ -3437,16 +3437,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.speculative.ngram_size_m = value; } ).set_examples({LLAMA_EXAMPLE_SERVER})); - add_opt(common_arg( - {"--spec-ngram-check-rate"}, "N", - string_format("ngram check rate for ngram-simple/ngram-map speculative decoding (default: %d)", params.speculative.ngram_check_rate), - [](common_params & params, int value) { - if (value < 1) { - throw std::invalid_argument("ngram check rate must be at least 1"); - } - params.speculative.ngram_check_rate = value; - } - ).set_examples({LLAMA_EXAMPLE_SERVER})); add_opt(common_arg( {"--spec-ngram-min-hits"}, "N", string_format("minimum hits for ngram-map speculative decoding (default: %d)", params.speculative.ngram_min_hits), diff --git a/common/chat.cpp b/common/chat.cpp index 2bf4632669..3c4e9f5cf0 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -65,14 +65,25 @@ json common_chat_msg::to_json_oaicompat(bool concat_typed_text) const { } else if (!content_parts.empty()) { if (concat_typed_text) { std::string text; + bool last_was_media_marker = false; + // join parts with newline, do not add newline before or after media markers for (const auto & part : content_parts) { - if (part.type != "text") { + bool add_new_line = true; + if (part.type == "text") { + add_new_line = !last_was_media_marker && !text.empty(); + last_was_media_marker = false; + } else if (part.type == "media_marker") { + add_new_line = false; + last_was_media_marker = true; + } else { LOG_WRN("Ignoring content part type: %s\n", part.type.c_str()); continue; } - if (!text.empty()) { + + if (add_new_line) { text += '\n'; } + text += part.text; } jmsg["content"] = text; @@ -319,7 +330,7 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa throw std::invalid_argument("Missing content part type: " + part.dump()); } const auto & type = part.at("type"); - if (type != "text") { + if (type != "text" && type != "media_marker") { throw std::invalid_argument("Unsupported content part type: " + type.dump()); } common_chat_msg_content_part msg_part; @@ -380,15 +391,46 @@ std::vector common_chat_msgs_parse_oaicompat(const json & messa return msgs; } -json common_chat_msgs_to_json_oaicompat(const std::vector & msgs, bool concat_typed_text) { +static json render_message_to_json(const std::vector & msgs, const jinja::caps & c) { + if (!c.supports_string_content && !c.supports_typed_content) { + LOG_WRN("%s: Neither string content nor typed content is supported by the template. This is unexpected and may lead to issues.\n", __func__); + } + + bool only_string_accepted = c.supports_string_content && !c.supports_typed_content; + bool only_typed_accepted = !c.supports_string_content && c.supports_typed_content; + json messages = json::array(); for (const auto & msg : msgs) { - json jmsg = msg.to_json_oaicompat(concat_typed_text); - messages.push_back(jmsg); + if (only_string_accepted) { + json jmsg = msg.to_json_oaicompat(/* concat_typed_text= */ true); + messages.push_back(jmsg); + } else if (only_typed_accepted) { + json jmsg = msg.to_json_oaicompat(/* concat_typed_text= */ false); + if (jmsg.at("content").is_string()) { + jmsg["content"] = json::array({ + json{ + {"type", "text"}, + {"text", jmsg.at("content").get()}, + } + }); + } + messages.push_back(jmsg); + } else { + json jmsg = msg.to_json_oaicompat(/* concat_typed_text= */ false); + messages.push_back(jmsg); + } } return messages; } +// DEPRECATED: only used in tests +json common_chat_msgs_to_json_oaicompat(const std::vector & msgs, bool concat_typed_text) { + jinja::caps c; + c.supports_string_content = true; + c.supports_typed_content = !concat_typed_text; + return render_message_to_json(msgs, c); +} + std::vector common_chat_tools_parse_oaicompat(const json & tools) { std::vector result; @@ -3020,7 +3062,7 @@ static common_chat_params common_chat_templates_apply_jinja( : *tmpls->template_default; const auto & src = tmpl.source(); const auto & caps = tmpl.original_caps(); - params.messages = common_chat_msgs_to_json_oaicompat(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content); + params.messages = render_message_to_json(inputs.messages, tmpl.original_caps()); params.add_generation_prompt = inputs.add_generation_prompt; params.tool_choice = inputs.tool_choice; params.reasoning_format = inputs.reasoning_format; @@ -3276,7 +3318,7 @@ static common_chat_params common_chat_templates_apply_legacy( for (const auto & msg : inputs.messages) { auto content = msg.content; for (const auto & part : msg.content_parts) { - if (part.type != "text") { + if (part.type != "text" && part.type != "media_marker") { LOG_WRN("Ignoring non-text content part: %s\n", part.type.c_str()); continue; } diff --git a/common/chat.h b/common/chat.h index 24aa4aab5c..1bf43f7261 100644 --- a/common/chat.h +++ b/common/chat.h @@ -240,6 +240,8 @@ bool common_chat_templates_support_enable_thinking(const common_chat_templates * // Parses a JSON array of messages in OpenAI's chat completion API format. std::vector common_chat_msgs_parse_oaicompat(const nlohmann::ordered_json & messages); + +// DEPRECATED: only used in tests nlohmann::ordered_json common_chat_msgs_to_json_oaicompat(const std::vector & msgs, bool concat_typed_text = false); std::vector common_chat_tools_parse_oaicompat(const nlohmann::ordered_json & tools); diff --git a/common/common.cpp b/common/common.cpp index 3aa396127c..75116ed6f3 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1,7 +1,3 @@ -#if defined(_MSC_VER) -#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING -#endif - #include "ggml.h" #include "gguf.h" @@ -9,12 +5,12 @@ #include "log.h" #include "llama.h" #include "sampling.h" +#include "unicode.h" #include #include #include #include -#include #include #include #include @@ -456,34 +452,6 @@ void string_replace_all(std::string & s, const std::string & search, const std:: s = std::move(builder); } -bool string_ends_with(const std::string_view & str, const std::string_view & suffix) { - return str.size() >= suffix.size() && str.compare(str.size()-suffix.size(), suffix.size(), suffix) == 0; -} - -bool string_remove_suffix(std::string & str, const std::string_view & suffix) { - bool has_suffix = string_ends_with(str, suffix); - if (has_suffix) { - str = str.substr(0, str.size() - suffix.size()); - } - return has_suffix; -} - -size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop) { - if (!str.empty() && !stop.empty()) { - const char text_last_char = str.back(); - for (int64_t char_index = stop.size() - 1; char_index >= 0; char_index--) { - if (stop[char_index] == text_last_char) { - const auto current_partial = stop.substr(0, char_index + 1); - if (string_ends_with(str, current_partial)) { - return str.size() - char_index - 1; - } - } - } - } - - return std::string::npos; -} - std::string regex_escape(const std::string & s) { static const std::regex special_chars("[.^$|()*+?\\[\\]{}\\\\]"); return std::regex_replace(s, special_chars, "\\$&"); @@ -706,45 +674,28 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) { return false; } - std::u32string filename_utf32; - try { -#if defined(__clang__) - // disable C++17 deprecation warning for std::codecvt_utf8 -# pragma clang diagnostic push -# pragma clang diagnostic ignored "-Wdeprecated-declarations" -#elif defined(__GNUC__) -# pragma GCC diagnostic push -# pragma GCC diagnostic ignored "-Wdeprecated-declarations" -#endif + size_t offset = 0; + while (offset < filename.size()) { + utf8_parse_result result = parse_utf8_codepoint(filename, offset); - std::wstring_convert, char32_t> converter; - -#if defined(__clang__) -# pragma clang diagnostic pop -#elif defined(__GNUC__) -# pragma GCC diagnostic pop -#endif - - filename_utf32 = converter.from_bytes(filename); - - // If the reverse conversion mismatches, it means overlong UTF-8 sequences were used, - // or invalid encodings were encountered. Reject such attempts - std::string filename_reencoded = converter.to_bytes(filename_utf32); - if (filename_reencoded != filename) { + if (result.status != utf8_parse_result::SUCCESS) { return false; } - } catch (const std::exception &) { - return false; - } + uint32_t c = result.codepoint; - // Check for forbidden codepoints: - // - Control characters - // - Unicode equivalents of illegal characters - // - UTF-16 surrogate pairs - // - UTF-8 replacement character - // - Byte order mark (BOM) - // - Illegal characters: / \ : * ? " < > | - for (char32_t c : filename_utf32) { + if ((result.bytes_consumed == 2 && c < 0x80) || + (result.bytes_consumed == 3 && c < 0x800) || + (result.bytes_consumed == 4 && c < 0x10000)) { + return false; + } + + // Check for forbidden codepoints: + // - Control characters + // - Unicode equivalents of illegal characters + // - UTF-16 surrogate pairs + // - UTF-8 replacement character + // - Byte order mark (BOM) + // - Illegal characters: / \ : * ? " < > | if (c <= 0x1F // Control characters (C0) || c == 0x7F // Control characters (DEL) || (c >= 0x80 && c <= 0x9F) // Control characters (C1) @@ -752,6 +703,7 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) { || c == 0x2215 // Division Slash (forward slash equivalent) || c == 0x2216 // Set Minus (backslash equivalent) || (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs + || c > 0x10FFFF // Max Unicode limit || c == 0xFFFD // Replacement Character (UTF-8) || c == 0xFEFF // Byte Order Mark (BOM) || c == ':' || c == '*' // Illegal characters @@ -762,6 +714,7 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) { // Subdirectories not allowed, reject path separators return false; } + offset += result.bytes_consumed; } // Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename @@ -898,7 +851,8 @@ std::string fs_get_cache_directory() { if (getenv("LLAMA_CACHE")) { cache_directory = std::getenv("LLAMA_CACHE"); } else { -#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || defined(__OpenBSD__) +#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || \ + defined(__OpenBSD__) || defined(__NetBSD__) if (std::getenv("XDG_CACHE_HOME")) { cache_directory = std::getenv("XDG_CACHE_HOME"); } else if (std::getenv("HOME")) { @@ -1242,7 +1196,7 @@ common_init_result_ptr common_init_from_params(common_params & params) { return res; } - int err = llama_apply_adapter_cvec( + int err = llama_set_adapter_cvec( lctx, cvec.data.data(), cvec.data.size(), @@ -1344,12 +1298,15 @@ std::string get_model_endpoint() { } void common_set_adapter_lora(struct llama_context * ctx, std::vector & lora) { - llama_clear_adapter_lora(ctx); - for (auto & la : lora) { - if (la.scale != 0.0f) { - llama_set_adapter_lora(ctx, la.ptr, la.scale); - } + std::vector loras; + std::vector scales; + + for (auto & la: lora) { + loras.push_back(la.ptr); + scales.push_back(la.scale); } + + llama_set_adapters_lora(ctx, loras.data(), loras.size(), scales.data()); } struct llama_model_params common_model_params_to_llama(common_params & params) { @@ -1469,66 +1426,6 @@ void common_batch_add( batch.n_tokens++; } -// -// Token utils -// - -size_t common_lcp(const llama_tokens & a, const llama_tokens & b) { - size_t i; - for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {} - - return i; -} - -size_t common_lcs(const llama_tokens & a, const llama_tokens & b) { - // check for empty sequences - if (a.empty() || b.empty()) { - return 0; - } - - // get the lengths of the input sequences - size_t a_len = a.size(); - size_t b_len = b.size(); - - // initialize the maximum length of the longest common subsequence (LCS) - size_t max_length = 0; - - // use two rows instead of a 2D matrix to optimize space - std::vector prev_row(b_len + 1, 0); - std::vector curr_row(b_len + 1, 0); - - // iterate through the elements of a - for (size_t i = 1; i <= a_len; i++) { - // iterate through the elements of b - for (size_t j = 1; j <= b_len; j++) { - // if elements at the current positions match - if (a[i - 1] == b[j - 1]) { - // if it's the first element of either sequences, set LCS length to 1 - if (i == 1 || j == 1) { - curr_row[j] = 1; - } else { - // increment LCS length by 1 compared to the previous element - curr_row[j] = prev_row[j - 1] + 1; - } - - // update max_length if necessary - if (curr_row[j] > max_length) { - max_length = curr_row[j]; - } - } else { - // reset LCS length if elements don't match - curr_row[j] = 0; - } - } - - // update the previous row for the next iteration - prev_row = curr_row; - } - - // return the maximum length of the LCS - return max_length; -} - // // Vocab utils // diff --git a/common/common.h b/common/common.h index 398ebb0960..a4c431172d 100644 --- a/common/common.h +++ b/common/common.h @@ -269,7 +269,6 @@ struct common_params_speculative { uint16_t ngram_size_n = 12; // ngram size for lookup uint16_t ngram_size_m = 48; // mgram size for speculative tokens - uint16_t ngram_check_rate = 1; // check rate for ngram lookup uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed std::shared_ptr ngram_mod; @@ -671,30 +670,55 @@ static std::vector string_split(const std::string & str, char delim) { } template<> -std::vector string_split(const std::string & input, char separator) +inline std::vector string_split(const std::string & str, char delim) { std::vector parts; size_t begin_pos = 0; - size_t separator_pos = input.find(separator); - while (separator_pos != std::string::npos) { - std::string part = input.substr(begin_pos, separator_pos - begin_pos); + size_t delim_pos = str.find(delim); + while (delim_pos != std::string::npos) { + std::string part = str.substr(begin_pos, delim_pos - begin_pos); parts.emplace_back(part); - begin_pos = separator_pos + 1; - separator_pos = input.find(separator, begin_pos); + begin_pos = delim_pos + 1; + delim_pos = str.find(delim, begin_pos); } - parts.emplace_back(input.substr(begin_pos, separator_pos - begin_pos)); + parts.emplace_back(str.substr(begin_pos)); return parts; } -static bool string_starts_with(const std::string & str, - const std::string & prefix) { // While we wait for C++20's std::string::starts_with... - return str.rfind(prefix, 0) == 0; +// remove when moving to c++20 +inline bool string_starts_with(std::string_view str, std::string_view prefix) { + return str.size() >= prefix.size() && + str.compare(0, prefix.size(), prefix) == 0; } -// While we wait for C++20's std::string::ends_with... -bool string_ends_with(const std::string_view & str, const std::string_view & suffix); -bool string_remove_suffix(std::string & str, const std::string_view & suffix); -size_t string_find_partial_stop(const std::string_view & str, const std::string_view & stop); +// remove when moving to c++20 +inline bool string_ends_with(std::string_view str, std::string_view suffix) { + return str.size() >= suffix.size() && + str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0; +} + +inline bool string_remove_suffix(std::string & str, std::string_view suffix) { + if (string_ends_with(str, suffix)) { + str.resize(str.size() - suffix.size()); + return true; + } + return false; +} + +inline size_t string_find_partial_stop(std::string_view str, std::string_view stop) { + if (!str.empty() && !stop.empty()) { + const size_t max_len = std::min(str.size(), stop.size()); + const char last_char = str.back(); + for (size_t len = max_len; len > 0; --len) { + if (stop[len - 1] == last_char) { + if (string_ends_with(str, stop.substr(0, len))) { + return str.size() - len; + } + } + } + } + return std::string::npos; +} bool string_parse_kv_override(const char * data, std::vector & overrides); void string_process_escapes(std::string & input); @@ -780,16 +804,6 @@ void common_batch_add( const std::vector & seq_ids, bool logits); -// -// Token utils -// - -// longest common prefix -size_t common_lcp(const llama_tokens & a, const llama_tokens & b); - -// longet common subsequence -size_t common_lcs(const llama_tokens & a, const llama_tokens & b); - // // Vocab utils // @@ -881,11 +895,11 @@ const char * const LLM_KV_SPLIT_TENSORS_COUNT = "split.tensors.count"; const char * const LLM_FFN_EXPS_REGEX = "\\.ffn_(up|down|gate)_(ch|)exps"; -static std::string llm_ffn_exps_block_regex(int idx) { +inline std::string llm_ffn_exps_block_regex(int idx) { return string_format("blk\\.%d%s", idx, LLM_FFN_EXPS_REGEX); } -static llama_model_tensor_buft_override llm_ffn_exps_cpu_override() { +inline llama_model_tensor_buft_override llm_ffn_exps_cpu_override() { return { LLM_FFN_EXPS_REGEX, ggml_backend_cpu_buffer_type() }; } diff --git a/common/download.cpp b/common/download.cpp index 57f29a23ba..5ef60a4208 100644 --- a/common/download.cpp +++ b/common/download.cpp @@ -19,9 +19,7 @@ #include #include -#if defined(LLAMA_USE_HTTPLIB) #include "http.h" -#endif #ifndef __EMSCRIPTEN__ #ifdef __linux__ @@ -114,44 +112,18 @@ static void write_etag(const std::string & path, const std::string & etag) { } static std::string read_etag(const std::string & path) { - std::string none; const std::string etag_path = path + ".etag"; - - if (std::filesystem::exists(etag_path)) { - std::ifstream etag_in(etag_path); - if (!etag_in) { - LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str()); - return none; - } - std::string etag; - std::getline(etag_in, etag); - return etag; + if (!std::filesystem::exists(etag_path)) { + return {}; } - - // no etag file, but maybe there is an old .json - // remove this code later - const std::string metadata_path = path + ".json"; - - if (std::filesystem::exists(metadata_path)) { - std::ifstream metadata_in(metadata_path); - try { - nlohmann::json metadata_json; - metadata_in >> metadata_json; - LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), - metadata_json.dump().c_str()); - if (metadata_json.contains("etag") && metadata_json.at("etag").is_string()) { - std::string etag = metadata_json.at("etag"); - write_etag(path, etag); - if (!std::filesystem::remove(metadata_path)) { - LOG_WRN("%s: failed to delete old .json metadata file: %s\n", __func__, metadata_path.c_str()); - } - return etag; - } - } catch (const nlohmann::json::exception & e) { - LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what()); - } + std::ifstream etag_in(etag_path); + if (!etag_in) { + LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str()); + return {}; } - return none; + std::string etag; + std::getline(etag_in, etag); + return etag; } static bool is_http_status_ok(int status) { @@ -168,8 +140,6 @@ std::pair common_download_split_repo_tag(const std::st return {hf_repo, tag}; } -#if defined(LLAMA_USE_HTTPLIB) - class ProgressBar { static inline std::mutex mutex; static inline std::map lines; @@ -305,7 +275,10 @@ static bool common_pull_file(httplib::Client & cli, ); if (!res) { - LOG_ERR("%s: error during download. Status: %d\n", __func__, res ? res->status : -1); + LOG_ERR("%s: download failed: %s (status: %d)\n", + __func__, + httplib::to_string(res.error()).c_str(), + res ? res->status : -1); return false; } @@ -344,62 +317,64 @@ static int common_download_file_single_online(const std::string & url, LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str()); } - for (int i = 0; i < max_attempts; ++i) { - auto head = cli.Head(parts.path); - bool head_ok = head && head->status >= 200 && head->status < 300; - if (!head_ok) { - LOG_WRN("%s: HEAD invalid http status code received: %d\n", __func__, head ? head->status : -1); - if (file_exists) { - LOG_INF("%s: Using cached file (HEAD failed): %s\n", __func__, path.c_str()); - return 304; // 304 Not Modified - fake cached response - } - return head->status; // cannot use cached file, return raw status code - // TODO: maybe retry only on certain codes - } - - std::string etag; - if (head_ok && head->has_header("ETag")) { - etag = head->get_header_value("ETag"); - } - - size_t total_size = 0; - if (head_ok && head->has_header("Content-Length")) { - try { - total_size = std::stoull(head->get_header_value("Content-Length")); - } catch (const std::exception& e) { - LOG_WRN("%s: Invalid Content-Length in HEAD response: %s\n", __func__, e.what()); - } - } - - bool supports_ranges = false; - if (head_ok && head->has_header("Accept-Ranges")) { - supports_ranges = head->get_header_value("Accept-Ranges") != "none"; - } - - bool should_download_from_scratch = false; - if (!last_etag.empty() && !etag.empty() && last_etag != etag) { - LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, - last_etag.c_str(), etag.c_str()); - should_download_from_scratch = true; - } - + auto head = cli.Head(parts.path); + if (!head || head->status < 200 || head->status >= 300) { + LOG_WRN("%s: HEAD failed, status: %d\n", __func__, head ? head->status : -1); if (file_exists) { - if (!should_download_from_scratch) { - LOG_INF("%s: using cached file: %s\n", __func__, path.c_str()); - return 304; // 304 Not Modified - fake cached response - } - LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str()); - if (remove(path.c_str()) != 0) { - LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str()); - return -1; - } + LOG_INF("%s: using cached file (HEAD failed): %s\n", __func__, path.c_str()); + return 304; // 304 Not Modified - fake cached response + } + return head ? head->status : -1; + } + + std::string etag; + if (head->has_header("ETag")) { + etag = head->get_header_value("ETag"); + } + + size_t total_size = 0; + if (head->has_header("Content-Length")) { + try { + total_size = std::stoull(head->get_header_value("Content-Length")); + } catch (const std::exception& e) { + LOG_WRN("%s: invalid Content-Length in HEAD response: %s\n", __func__, e.what()); + } + } + + bool supports_ranges = false; + if (head->has_header("Accept-Ranges")) { + supports_ranges = head->get_header_value("Accept-Ranges") != "none"; + } + + if (file_exists) { + if (etag.empty()) { + LOG_INF("%s: using cached file (no server etag): %s\n", __func__, path.c_str()); + return 304; // 304 Not Modified - fake cached response + } + if (!last_etag.empty() && last_etag == etag) { + LOG_INF("%s: using cached file (same etag): %s\n", __func__, path.c_str()); + return 304; // 304 Not Modified - fake cached response + } + if (remove(path.c_str()) != 0) { + LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str()); + return -1; + } + } + + const std::string path_temporary = path + ".downloadInProgress"; + int delay = retry_delay_seconds; + + for (int i = 0; i < max_attempts; ++i) { + if (i) { + LOG_WRN("%s: retrying after %d seconds...\n", __func__, delay); + std::this_thread::sleep_for(std::chrono::seconds(delay)); + delay *= retry_delay_seconds; } - const std::string path_temporary = path + ".downloadInProgress"; size_t existing_size = 0; if (std::filesystem::exists(path_temporary)) { - if (supports_ranges && !should_download_from_scratch) { + if (supports_ranges) { existing_size = std::filesystem::file_size(path_temporary); } else if (remove(path_temporary.c_str()) != 0) { LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str()); @@ -407,32 +382,23 @@ static int common_download_file_single_online(const std::string & url, } } - // start the download - LOG_INF("%s: trying to download model from %s to %s (etag:%s)...\n", - __func__, common_http_show_masked_url(parts).c_str(), path_temporary.c_str(), etag.c_str()); - const bool was_pull_successful = common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size); - if (!was_pull_successful) { - if (i + 1 < max_attempts) { - const int exponential_backoff_delay = std::pow(retry_delay_seconds, i) * 1000; - LOG_WRN("%s: retrying after %d milliseconds...\n", __func__, exponential_backoff_delay); - std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay)); - } else { - LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts); + LOG_INF("%s: downloading from %s to %s (etag:%s)...\n", + __func__, common_http_show_masked_url(parts).c_str(), + path_temporary.c_str(), etag.c_str()); + + if (common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size)) { + if (std::rename(path_temporary.c_str(), path.c_str()) != 0) { + LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); + return -1; } - continue; + if (!etag.empty()) { + write_etag(path, etag); + } + return head->status; } - - if (std::rename(path_temporary.c_str(), path.c_str()) != 0) { - LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); - return -1; - } - if (!etag.empty()) { - write_etag(path, etag); - } - - return head->status; // TODO: use actual GET status? } + LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts); return -1; // max attempts reached } @@ -798,30 +764,6 @@ std::string common_docker_resolve_model(const std::string & docker) { } } -#else - -common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool, const common_header_list &) { - throw std::runtime_error("download functionality is not enabled in this build"); -} - -bool common_download_model(const common_params_model &, const std::string &, bool, const common_header_list &) { - throw std::runtime_error("download functionality is not enabled in this build"); -} - -std::string common_docker_resolve_model(const std::string &) { - throw std::runtime_error("download functionality is not enabled in this build"); -} - -int common_download_file_single(const std::string &, - const std::string &, - const std::string &, - bool, - const common_header_list &) { - throw std::runtime_error("download functionality is not enabled in this build"); -} - -#endif // defined(LLAMA_USE_HTTPLIB) - std::vector common_list_cached_models() { std::vector models; const std::string cache_dir = fs_get_cache_directory(); diff --git a/common/jinja/caps.cpp b/common/jinja/caps.cpp index f27490f1fb..dbaaed500a 100644 --- a/common/jinja/caps.cpp +++ b/common/jinja/caps.cpp @@ -63,7 +63,8 @@ static void caps_print_stats(value & v, const std::string & path) { std::map caps::to_map() const { return { - {"requires_typed_content", requires_typed_content}, + {"supports_string_content", supports_string_content}, + {"supports_typed_content", supports_typed_content}, {"supports_tools", supports_tools}, {"supports_tool_calls", supports_tool_calls}, {"supports_parallel_tool_calls", supports_parallel_tool_calls}, @@ -89,7 +90,7 @@ caps caps_get(jinja::program & prog) { return v->stats.ops.find(op_name) != v->stats.ops.end(); }; - // case: typed content requirement + // case: typed content support caps_try_execute( prog, [&]() { @@ -105,12 +106,16 @@ caps caps_get(jinja::program & prog) { // tools return json{nullptr}; }, - [&](bool, value & messages, value &) { + [&](bool success, value & messages, value &) { auto & content = messages->at(0)->at("content"); caps_print_stats(content, "messages[0].content"); if (has_op(content, "selectattr") || has_op(content, "array_access")) { // accessed as an array - result.requires_typed_content = true; + result.supports_typed_content = true; + } + if (!success) { + // failed to execute with content as string + result.supports_string_content = false; } } ); diff --git a/common/jinja/caps.h b/common/jinja/caps.h index 77df117baa..e694e7bfaa 100644 --- a/common/jinja/caps.h +++ b/common/jinja/caps.h @@ -14,7 +14,9 @@ struct caps { bool supports_parallel_tool_calls = true; bool supports_preserve_reasoning = false; // support assistant message with reasoning_content - bool requires_typed_content = false; // default: use string content + // one of the 2 content capabilities must be true + bool supports_string_content = true; + bool supports_typed_content = false; // for reporting on server std::map to_map() const; diff --git a/common/jinja/runtime.cpp b/common/jinja/runtime.cpp index 4453d86e6d..cc012c892f 100644 --- a/common/jinja/runtime.cpp +++ b/common/jinja/runtime.cpp @@ -446,6 +446,12 @@ value for_statement::execute_impl(context & ctx) { value iterable_val = iter_expr->execute(scope); + // mark the variable being iterated as used for stats + if (ctx.is_get_stats) { + iterable_val->stats.used = true; + iterable_val->stats.ops.insert("array_access"); + } + if (iterable_val->is_undefined()) { JJ_DEBUG("%s", "For loop iterable is undefined, skipping loop"); iterable_val = mk_val(); diff --git a/common/jinja/value.cpp b/common/jinja/value.cpp index 2aa156b177..9987836d18 100644 --- a/common/jinja/value.cpp +++ b/common/jinja/value.cpp @@ -4,6 +4,7 @@ // for converting from JSON to jinja values #include +#include #include #include #include @@ -715,8 +716,46 @@ const func_builtins & value_string_t::get_builtins() const { return args.get_pos(0); }}, {"tojson", tojson}, - {"indent", [](const func_args &) -> value { - throw not_implemented_exception("String indent builtin not implemented"); + {"indent", [](const func_args &args) -> value { + args.ensure_count(1, 4); + value val_input = args.get_pos(0); + value val_width = args.get_kwarg_or_pos("width", 1); + const bool first = args.get_kwarg_or_pos("first", 2)->as_bool(); // undefined == false + const bool blank = args.get_kwarg_or_pos("blank", 3)->as_bool(); // undefined == false + if (!is_val(val_input)) { + throw raised_exception("indent() first argument must be a string"); + } + std::string indent; + if (is_val(val_width)) { + indent.assign(val_width->as_int(), ' '); + } else if (is_val(val_width)) { + indent = val_width->as_string().str(); + } else { + indent = " "; + } + std::string indented; + std::string input = val_input->as_string().str(); + std::istringstream iss = std::istringstream(input); + std::string line; + while (std::getline(iss, line)) { + if (!indented.empty()) { + indented.push_back('\n'); + } + if ((indented.empty() ? first : (!line.empty() || blank))) { + indented += indent; + } + indented += line; + } + if (!input.empty() && input.back() == '\n') { + indented.push_back('\n'); + if (blank) { + indented += indent; + } + } + + auto res = mk_val(indented); + res->val_str.mark_input_based_on(val_input->as_string()); + return res; }}, {"join", [](const func_args &) -> value { throw not_implemented_exception("String join builtin not implemented"); diff --git a/common/ngram-map.cpp b/common/ngram-map.cpp index c5b8fc75ed..ebf771a24a 100644 --- a/common/ngram-map.cpp +++ b/common/ngram-map.cpp @@ -231,10 +231,9 @@ void common_ngram_map_draft(common_ngram_map & map, GGML_ABORT("%s: cur_len exceeds UINT32_MAX: %zu", __func__, cur_len); } - // Only check every check_rate tokens to save compute - // i.e., perform check if (cur_len - idx_last_check) >= check_rate - if (map.idx_last_check + map.check_rate > cur_len) { - return; + if (map.idx_last_check > cur_len) { + // Should not happen because of common_ngram_map_begin(). + GGML_ABORT("%s: map.idx_last_check > cur_len: %zu > %zu", __func__, map.idx_last_check, cur_len); } map.idx_last_check = cur_len; @@ -462,7 +461,7 @@ void common_ngram_map_draft(common_ngram_map & map, slot_max = v; } } - // What is sum of the other occurences? + // What is sum of the other occurrences? uint32_t sum_occur = 0; for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) { if (v == slot_max) { diff --git a/common/ngram-map.h b/common/ngram-map.h index 9668bd5a7c..d84e719151 100644 --- a/common/ngram-map.h +++ b/common/ngram-map.h @@ -24,7 +24,6 @@ struct common_ngram_simple_config { uint16_t size_ngram; // size of n-grams to lookup in self-mode uint16_t size_mgram; // size of m-grams to draft in self-mode - uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token }; // Searches for a n-gram in the history and checks whether a draft sequence should be generated. @@ -45,7 +44,7 @@ llama_tokens common_ngram_simple_draft( // statistics of a m-gram after a known n-gram struct common_ngram_map_value { size_t value_idx = 0; // index of value m-gram in token-history (0 if unused) - uint16_t value_num = 0; // number of occurences of this value m-gram after the key n-gram (0 in an unused values-slot) + uint16_t value_num = 0; // number of occurrences of this value m-gram after the key n-gram (0 in an unused values-slot) int16_t n_accepted = -1; // number of accepted tokens at last draft (-1 if unused) }; @@ -54,7 +53,7 @@ struct common_ngram_map_key { size_t key_idx; // index of key n-gram in token-history size_t stat_idx; // index of last token of stastistics computation (key_num, values) - uint16_t key_num; // number of occurences of this key n-gram in token-history + uint16_t key_num; // number of occurrences of this key n-gram in token-history common_ngram_map_value values[COMMON_NGRAM_MAX_VALUES]; // some known values after the key }; @@ -66,15 +65,14 @@ struct common_ngram_map { bool key_only; // true if only key n-grams are used, no values. std::vector keys; // key n-grams which occur several times in token-history - uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token uint16_t min_hits; // minimum number of key hits to consider a draft - bool show_key_map_stats = false; // true, if statitics of the key_map should be printed. + bool show_key_map_stats = false; // true, if statistics of the key_map should be printed. common_ngram_map(uint16_t sz_key, uint16_t sz_value, bool only_keys, - uint16_t check_rate, uint16_t min_hits) + uint16_t min_hits) : size_key(sz_key), size_value(sz_value), key_only(only_keys), - check_rate(check_rate), min_hits(min_hits) { + min_hits(min_hits) { key_map.resize(COMMON_NGRAM_HASH_MAP_SIZE); // 2^18 hash entries, 0 entries if key_map shouldn't be used } diff --git a/common/speculative.cpp b/common/speculative.cpp index 84d2556ceb..3e68c38e49 100644 --- a/common/speculative.cpp +++ b/common/speculative.cpp @@ -113,13 +113,14 @@ static bool common_speculative_are_compatible( struct common_speculative_state { const enum common_speculative_type type; - // TODO: rename to n_call_draft, n_gen_drafts, n_acc_drafts, n_gen_tokens, n_acc_tokens - // TODO: add n_call_begin, n_call_accept - size_t drafts_call_count = 0; // number of times this implementation was called. - size_t drafts_generated_count = 0; // number of times a draft or part was generated by this implementation. - size_t drafts_accepted_count = 0; // number of times a draft or part was accepted by the target model. - size_t drafts_generated_tokens = 0; // number of tokens generated by this implementation. - size_t drafts_accepted_tokens = 0; // number of tokens accepted by the target model. + size_t n_call_begin = 0; // number of times this implementation was called for refresh. + size_t n_call_draft = 0; // number of times this implementation was called for generation. + size_t n_call_accept = 0; // number of times this implementation was called for accumulation. + + size_t n_gen_drafts = 0; // number of times a draft or part was generated by this implementation. + size_t n_acc_drafts = 0; // number of times a draft or part was accepted by the target model. + size_t n_gen_tokens = 0; // number of tokens generated by this implementation. + size_t n_acc_tokens = 0; // number of tokens accepted by the target model. // TODO: track performance of most recent calls const bool gen_perf = true; // whether to generate performance stats. @@ -465,8 +466,6 @@ struct common_speculative_state_eagle3 : public common_speculative_state { struct common_speculative_state_ngram_simple : public common_speculative_state { common_ngram_simple_config config; - uint16_t check_id = 0; // used to control the frequency of generating drafts - common_speculative_state_ngram_simple( enum common_speculative_type type, common_ngram_simple_config config) @@ -481,11 +480,6 @@ struct common_speculative_state_ngram_simple : public common_speculative_state { const llama_tokens & prompt_tgt, llama_token id_last, llama_tokens & result) override { - ++check_id; - if (check_id < config.check_rate) { - return; - } - check_id = 0; result = common_ngram_simple_draft(config, prompt_tgt, id_last); GGML_UNUSED(params); @@ -752,10 +746,9 @@ static common_ngram_map get_common_ngram_map(const common_speculative_config & c uint16_t size_key = config.params.ngram_size_n; uint16_t size_value = config.params.ngram_size_m; bool key_only = (config.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K); - uint16_t check_rate = config.params.ngram_check_rate; uint16_t min_hits = config.params.ngram_min_hits; - return common_ngram_map(size_key, size_value, key_only, check_rate, min_hits); + return common_ngram_map(size_key, size_value, key_only, min_hits); } static common_speculative_state_ngram_cache create_state_ngram_cache( @@ -931,12 +924,10 @@ common_speculative * common_speculative_init( uint16_t ngram_size_key = ngram_map.size_key; uint16_t mgram_size_value = ngram_map.size_value; - uint16_t check_rate = ngram_map.check_rate; auto config_simple = common_ngram_simple_config { /* .size_ngram = */ ngram_size_key, - /* .size_mgram = */ mgram_size_value, - /* .check_rate = */ check_rate + /* .size_mgram = */ mgram_size_value }; auto state = std::make_unique( /* .type = */ config.type, @@ -997,6 +988,7 @@ void common_speculative_begin(common_speculative * spec, const llama_tokens & pr for (auto & impl : spec->impls) { common_time_meas tm(impl->t_begin_us, !impl->gen_perf); impl->begin(prompt); + impl->n_call_begin++; } } @@ -1013,17 +1005,17 @@ llama_tokens common_speculative_draft( { common_time_meas tm(impl->t_draft_us, !impl->gen_perf); impl->draft(params, prompt_tgt, id_last, result); - impl->drafts_call_count++; + impl->n_call_draft++; } if (!result.empty()) { LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__, common_speculative_type_to_str(impl.get()->type).c_str(), prompt_tgt.size(), - impl.get()->drafts_call_count, result.size()); + impl.get()->n_call_draft, result.size()); spec->curr_impl = impl.get(); // set current implementation for stats - impl->drafts_generated_count++; - impl->drafts_generated_tokens += result.size(); + impl->n_gen_drafts++; + impl->n_gen_tokens += result.size(); break; // We have a draft, so break out of the loop and return it. } @@ -1044,11 +1036,12 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) { { common_time_meas tm(impl->t_accept_us, !impl->gen_perf); if (n_accepted > 0) { - impl->drafts_accepted_count++; - impl->drafts_accepted_tokens += n_accepted; + impl->n_acc_drafts++; + impl->n_acc_tokens += n_accepted; } impl->accept(n_accepted); + impl->n_call_accept++; } } @@ -1069,13 +1062,13 @@ void common_speculative_print_stats(const common_speculative * spec) { str_perf = ""; } - LOG_INF("statistics %s: #calls = %zu, #gen drafts = %zu, #acc drafts = %zu, #gen tokens = %zu, #acc tokens = %zu%s\n", + LOG_INF("statistics %s: #calls(b,g,a) = %zu %zu %zu, #gen drafts = %zu, #acc drafts = %zu, #gen tokens = %zu, #acc tokens = %zu%s\n", common_speculative_type_to_str(impl->type).c_str(), - impl->drafts_call_count, - impl->drafts_generated_count, - impl->drafts_accepted_count, - impl->drafts_generated_tokens, - impl->drafts_accepted_tokens, + impl->n_call_begin, impl->n_call_draft, impl->n_call_accept, + impl->n_gen_drafts, + impl->n_acc_drafts, + impl->n_gen_tokens, + impl->n_acc_tokens, str_perf.c_str()); } } diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 843c00a896..7eeb3aa903 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -160,8 +160,6 @@ class ModelBase: self.ftype = gguf.LlamaFileType.MOSTLY_F16 logger.info("heuristics unable to detect tensor dtype, defaulting to --outtype f16") - self.dequant_model() - # Configure GGUF Writer self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file, split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard) @@ -527,6 +525,8 @@ class ModelBase: return () def prepare_tensors(self): + self.dequant_model() + # Handle empty tensor_map for models with block_count=0 (like MobileNetV5) if self.tensor_map.mapping: max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,") @@ -570,6 +570,7 @@ class ModelBase: self.match_model_tensor_name(new_name, key, bid) for key in ( gguf.MODEL_TENSOR.FFN_GATE_INP, + gguf.MODEL_TENSOR.FFN_GATE_INP_SHEXP, gguf.MODEL_TENSOR.POS_EMBD, gguf.MODEL_TENSOR.TOKEN_TYPES, gguf.MODEL_TENSOR.SSM_CONV1D, @@ -1048,6 +1049,9 @@ class TextModel(ModelBase): if chkhsh == "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902": # ref: https://huggingface.co/zai-org/GLM-4.5-Air res = "glm4" + if chkhsh == "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267": + # ref: https://huggingface.co/zai-org/GLM-4.7-Flash + res = "glm4" if chkhsh == "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35": # ref: https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0 res = "minerva-7b" @@ -1081,9 +1085,6 @@ class TextModel(ModelBase): if chkhsh == "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df": # ref: https://huggingface.co/aari1995/German_Semantic_V3 res = "jina-v2-de" - if chkhsh == "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267": - # ref: https://huggingface.co/zai-org/GLM-4.7-Flash - res = "glm4" if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5": # ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B res = "llama-bpe" @@ -1123,6 +1124,9 @@ class TextModel(ModelBase): if chkhsh == "9c2227e4dd922002fb81bde4fc02b0483ca4f12911410dee2255e4987644e3f8": # ref: https://huggingface.co/CohereForAI/c4ai-command-r-v01 res = "command-r" + if chkhsh == "d772b220ace2baec124bed8cfafce0ead7d6c38a4b65ef11261cf9d5d62246d1": + # ref: https://huggingface.co/CohereLabs/tiny-aya-base + res = "tiny_aya" if chkhsh == "e636dc30a262dcc0d8c323492e32ae2b70728f4df7dfe9737d9f920a282b8aea": # ref: https://huggingface.co/Qwen/Qwen1.5-7B res = "qwen2" @@ -1159,6 +1163,9 @@ class TextModel(ModelBase): if chkhsh == "b53802fb28e26d645c3a310b34bfe07da813026ec7c7716883404d5e0f8b1901": # ref: https://huggingface.co/core42/jais-13b res = "jais" + if chkhsh == "bc5108ee1eb6a3d600cadd065f63190fbd0554dbc9e4bbd6a0d977970afc8d2a": + # ref: https://huggingface.co/inceptionai/Jais-2-8B-Chat + res = "jais-2" if chkhsh == "7b3e7548e4308f52a76e8229e4e6cc831195d0d1df43aed21ac6c93da05fec5f": # ref: https://huggingface.co/WisdomShell/CodeShell-7B res = "codeshell" @@ -1261,6 +1268,12 @@ class TextModel(ModelBase): if chkhsh == "6c81ce329e0802883b22eabab0d3fa48357337ef1ecb45443828bf1f6254833f": # ref: https://huggingface.co/LGAI-EXAONE/K-EXAONE-236B-A23B res = "exaone-moe" + if chkhsh == "d30d75d9059f1aa2c19359de71047b3ae408c70875e8a3ccf8c5fba56c9d8af4": + # ref: https://huggingface.co/Qwen/Qwen3.5-9B-Instruct + res = "qwen35" + if chkhsh == "b4b8ca1f9769494fbd956ebc4c249de6131fb277a4a3345a7a92c7dd7a55808d": + # ref: https://huggingface.co/jdopensource/JoyAI-LLM-Flash + res = "joyai-llm" if res is None: logger.warning("\n") @@ -1605,6 +1618,23 @@ class TextModel(ModelBase): special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"]) special_vocab.add_to_gguf(self.gguf_writer) + def _set_vocab_glm(self): + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model) + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + tokens, toktypes, tokpre = self.get_vocab_base() + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + # Special tokens + # Note: Using <|endoftext|> (151329) for eot causes endless generation + special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331 + special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336 + special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329 + special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338 + special_vocab.add_to_gguf(self.gguf_writer) + def _set_vocab_interns1(self): tokens: list[str] = [] toktypes: list[int] = [] @@ -1812,7 +1842,7 @@ class MmprojModel(ModelBase): preprocessor_config: dict[str, Any] global_config: dict[str, Any] - n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers"] + n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers", "vt_num_hidden_layers"] has_vision_encoder: bool = True # by default has_audio_encoder: bool = False @@ -1867,7 +1897,15 @@ class MmprojModel(ModelBase): preprocessor_config_path = self.dir_model / "preprocessor_config.json" if preprocessor_config_path.is_file(): with open(preprocessor_config_path, "r", encoding="utf-8") as f: - self.preprocessor_config = json.load(f) + cfg = json.load(f) + # move media_proc_cfg to root level for compat + if "media_proc_cfg" in cfg: + cfg = { + **cfg, + **cfg["media_proc_cfg"], + } + # merge configs + self.preprocessor_config = {**self.preprocessor_config, **cfg} # prefer processor_config.json if possible processor_config_path = self.dir_model / "processor_config.json" @@ -1916,10 +1954,10 @@ class MmprojModel(ModelBase): self.image_size = self.find_vparam(["image_size"]) self.gguf_writer.add_vision_image_size(self.image_size) self.gguf_writer.add_vision_patch_size(self.find_vparam(["patch_size"])) - self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size"])) - self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size"])) + self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size", "vt_hidden_size"])) + self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size", "vt_intermediate_size"])) self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys)) - self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads", "num_heads"])) + self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads", "num_heads", "vt_num_attention_heads"])) # preprocessor config image_mean = _MISTRAL_COMMON_DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"] @@ -2697,8 +2735,6 @@ class AfmoeModel(LlamaModel): super().set_gguf_parameters() # MoE parameters - if (n_experts := self.hparams.get("num_experts")) is not None: - self.gguf_writer.add_expert_count(n_experts) if (n_shared_experts := self.hparams.get("num_shared_experts")) is not None: self.gguf_writer.add_expert_shared_count(n_shared_experts) if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None: @@ -2720,7 +2756,7 @@ class AfmoeModel(LlamaModel): # Handle expert weights - they're already merged in the HF format # process the experts separately if name.find("mlp.experts") != -1: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -4045,6 +4081,87 @@ class InternVisionModel(MmprojModel): yield from super().modify_tensors(data_torch, name, bid) +@ModelBase.register( + "NemotronH_Nano_VL_V2", + "RADIOModel", +) +class NemotronNanoV2VLModel(MmprojModel): + # ViT-Huge architecture parameters for RADIO v2.5-h + _vit_hidden_size = 1280 + _vit_intermediate_size = 5120 + _vit_num_layers = 32 + _vit_num_heads = 16 + + def get_vision_config(self) -> dict[str, Any] | None: + # RADIO config doesn't have standard ViT parameters, so they need to be constructed manually + vision_config = self.global_config.get("vision_config") + if vision_config is None: + return None + # Add ViT-H parameters + vision_config = { + **vision_config, + "hidden_size": self._vit_hidden_size, + "intermediate_size": self._vit_intermediate_size, + "num_hidden_layers": self._vit_num_layers, + "num_attention_heads": self._vit_num_heads, + "image_size": self.global_config.get("force_image_size", 512), + } + return vision_config + + def set_gguf_parameters(self): + if "image_mean" not in self.preprocessor_config: + self.preprocessor_config["image_mean"] = [0.485, 0.456, 0.406] + if "image_std" not in self.preprocessor_config: + self.preprocessor_config["image_std"] = [0.229, 0.224, 0.225] + + super().set_gguf_parameters() + hparams = self.global_config + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.NEMOTRON_V2_VL) + self.gguf_writer.add_vision_attention_layernorm_eps(1e-6) + self.gguf_writer.add_vision_use_gelu(True) + downsample_ratio = hparams.get("downsample_ratio", 0.5) + self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / downsample_ratio)) + + def tensor_force_quant(self, name, new_name, bid, n_dims): + if ".position_embd." in new_name or "pos_embed" in new_name: + return gguf.GGMLQuantizationType.F32 + return super().tensor_force_quant(name, new_name, bid, n_dims) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if "input_conditioner" in name: + return + + # RADIO's pos_embed doesn't have .weight suffix, but clip.cpp expects it + if "patch_generator.pos_embed" in name: + if not name.endswith(".weight"): + name += ".weight" + # Downsample position embeddings for fixed 512x512 image size + import torch.nn.functional as F + n_embd = self.hparams["hidden_size"] + image_size = self.global_config.get("force_image_size", 512) + patch_size = self.hparams["patch_size"] + target_patches_per_side = image_size // patch_size # 32 + max_patches_per_side = int((data_torch.shape[1]) ** 0.5) # 128 + if target_patches_per_side != max_patches_per_side: + # Reshape to grid, interpolate, flatten back + data_torch = data_torch.reshape(1, max_patches_per_side, max_patches_per_side, n_embd) + data_torch = data_torch.permute(0, 3, 1, 2).float() # [1, n_embd, 128, 128] + data_torch = F.interpolate(data_torch, size=(target_patches_per_side, target_patches_per_side), + mode='bilinear', align_corners=True) + data_torch = data_torch.permute(0, 2, 3, 1) # [1, 32, 32, n_embd] + data_torch = data_torch.reshape(1, target_patches_per_side * target_patches_per_side, n_embd) + + # Reshape linear patch embedding to conv2d format for ggml_conv_2d + # From [n_embd, patch_size*patch_size*3] to [n_embd, 3, patch_size, patch_size] + if "patch_generator.embedder" in name: + patch_size = self.hparams["patch_size"] + n_embd = self.hparams["hidden_size"] + data_torch = data_torch.reshape(n_embd, 3, patch_size, patch_size) + + if name.startswith("vision_model.radio_model.model.") or name.startswith("mlp1."): + yield from super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("WavTokenizerDec") class WavTokenizerDecModel(TextModel): model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC @@ -4087,8 +4204,6 @@ class Qwen2MoeModel(TextModel): def set_gguf_parameters(self): super().set_gguf_parameters() - if (n_experts := self.hparams.get("num_experts")) is not None: - self.gguf_writer.add_expert_count(n_experts) if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None: self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}") @@ -4109,39 +4224,31 @@ class Qwen2MoeModel(TextModel): # Expected GGML ne: {n_embd, n_ff_exp, n_expert} for gate/up, {n_ff_exp, n_embd, n_expert} for down if name.endswith("mlp.experts.down_proj") or name.endswith("mlp.experts.down_proj.weight"): mapped = f"{name}.weight" if not name.endswith(".weight") else name - # Input: (n_expert=128, n_ff_exp=768, n_embd=2048) - # Want GGML ne: {n_ff_exp, n_embd, n_expert} = {768, 2048, 128} - # Need PyTorch: (128, 2048, 768) [reversed of GGML] - # So: permute(0, 2, 1): (128, 768, 2048) -> (128, 2048, 768) - permuted = data_torch.permute(0, 2, 1).contiguous() - yield from super().modify_tensors(permuted, mapped, bid) + # HF: [n_expert, n_embd, n_ff] -> GGML: {n_ff, n_embd, n_expert} + yield from super().modify_tensors(data_torch, mapped, bid) return if name.endswith("mlp.experts.gate_up_proj") or name.endswith("mlp.experts.gate_up_proj.weight"): - if data_torch.ndim < 3 or data_torch.shape[-1] % 2 != 0: + if data_torch.ndim < 3 or data_torch.shape[-2] % 2 != 0: raise ValueError(f"Unexpected gate_up_proj shape for {name}: {tuple(data_torch.shape)}") - split_dim = data_torch.shape[-1] // 2 - gate = data_torch[..., :split_dim].contiguous() - up = data_torch[..., split_dim:].contiguous() - # Input gate/up: (n_expert=128, n_embd=2048, n_ff_exp=768) - # Want GGML ne: {n_embd, n_ff_exp, n_expert} = {2048, 768, 128} - # Need PyTorch: (128, 768, 2048) [reversed of GGML] - # So: permute(0, 2, 1): (128, 2048, 768) -> (128, 768, 2048) - base_name = name.removesuffix(".weight") - base = base_name.rsplit('.', 1)[0] - mapped_gate = f"{base}.gate_proj.weight" - mapped_up = f"{base}.up_proj.weight" - perm_gate = gate.permute(0, 2, 1).contiguous() - perm_up = up.permute(0, 2, 1).contiguous() - yield from super().modify_tensors(perm_gate, mapped_gate, bid) - yield from super().modify_tensors(perm_up, mapped_up, bid) + # HF: [n_expert, 2*n_ff, n_embd] -> split on dim=-2 + n_ff = data_torch.shape[-2] // 2 + gate = data_torch[..., :n_ff, :].contiguous() + up = data_torch[..., n_ff:, :].contiguous() + # gate/up: [n_expert, n_ff, n_embd] -> GGML: {n_embd, n_ff, n_expert} + base_name = name.removesuffix(".weight").removesuffix(".gate_up_proj") + mapped_gate = f"{base_name}.gate_proj.weight" + mapped_up = f"{base_name}.up_proj.weight" + yield from super().modify_tensors(gate, mapped_gate, bid) + yield from super().modify_tensors(up, mapped_up, bid) return if name.startswith("mlp") or name.startswith("vision_model") or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector") or name.startswith("model.visual"): # skip visual tensors return + if name.find("experts") != -1: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -4295,6 +4402,7 @@ class Qwen3NextModel(Qwen2MoeModel): self.gguf_writer.add_ssm_group_count(self.hparams["linear_num_key_heads"]) self.gguf_writer.add_ssm_time_step_rank(self.hparams["linear_num_value_heads"]) self.gguf_writer.add_ssm_inner_size(self.hparams["linear_value_head_dim"] * self.hparams["linear_num_value_heads"]) + self.gguf_writer.add_full_attention_interval(self.hparams.get("full_attention_interval", 4)) if (rope_dim := self.hparams.get("head_dim")) is None: rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"] self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.25))) @@ -4359,7 +4467,7 @@ class RND1Model(Qwen2MoeModel): self.gguf_writer.add_mask_token_id(mask_token_id) -@ModelBase.register("Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration") +@ModelBase.register("Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration", "Qwen3_5ForConditionalGeneration", "Qwen3_5MoeForConditionalGeneration") class Qwen3VLVisionModel(MmprojModel): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) @@ -4405,6 +4513,10 @@ class Qwen3VLVisionModel(MmprojModel): if name.startswith("model.language_model.") or name.startswith("lm_head."): return + # Skip MTP tensors + if name.startswith("mtp."): + return + if name.startswith("model.visual."): name = name.replace("model.visual.", "visual.", 1) @@ -4475,7 +4587,7 @@ class Qwen3VLVisionModel(MmprojModel): yield from super().modify_tensors(data_torch, name, bid) -@ModelBase.register("Glm4vForConditionalGeneration", "Glm4vMoeForConditionalGeneration") +@ModelBase.register("Glm4vForConditionalGeneration", "Glm4vMoeForConditionalGeneration", "GlmOcrForConditionalGeneration") class Glm4VVisionModel(Qwen3VLVisionModel): def set_gguf_parameters(self): MmprojModel.set_gguf_parameters(self) # skip Qwen3VLVisionModel parameters @@ -4535,9 +4647,125 @@ class Qwen3VLMoeTextModel(Qwen3MoeModel): if name.startswith("model.visual."): return + # Qwen3VL has transposed packed tensors, so we treat it differently from general Qwen2MoE packed tensors + if name.endswith("mlp.experts.down_proj") or name.endswith("mlp.experts.down_proj.weight"): + name = name.replace("language_model.", "") + mapped = f"{name}.weight" if not name.endswith(".weight") else name + permuted = data_torch.permute(0, 2, 1).contiguous() + yield from ModelBase.modify_tensors(self, permuted, mapped, bid) + return + + if name.endswith("mlp.experts.gate_up_proj") or name.endswith("mlp.experts.gate_up_proj.weight"): + name = name.replace("language_model.", "") + if data_torch.ndim < 3 or data_torch.shape[-1] % 2 != 0: + raise ValueError(f"Unexpected gate_up_proj shape for {name}: {tuple(data_torch.shape)}") + split_dim = data_torch.shape[-1] // 2 + gate = data_torch[..., :split_dim].contiguous() + up = data_torch[..., split_dim:].contiguous() + # Input gate/up: (n_expert=128, n_embd=2048, n_ff_exp=768) + # Want GGML ne: {n_embd, n_ff_exp, n_expert} = {2048, 768, 128} + # Need PyTorch: (128, 768, 2048) [reversed of GGML] + # So: permute(0, 2, 1): (128, 2048, 768) -> (128, 768, 2048) + base_name = name.removesuffix(".weight") + base = base_name.rsplit('.', 1)[0] + mapped_gate = f"{base}.gate_proj.weight" + mapped_up = f"{base}.up_proj.weight" + perm_gate = gate.permute(0, 2, 1).contiguous() + perm_up = up.permute(0, 2, 1).contiguous() + yield from ModelBase.modify_tensors(self, perm_gate, mapped_gate, bid) + yield from ModelBase.modify_tensors(self, perm_up, mapped_up, bid) + return + yield from super().modify_tensors(data_torch, name, bid) +class _LinearAttentionVReorderBase(Qwen3NextModel): + model_arch = gguf.MODEL_ARCH.QWEN3NEXT # overridden by subclasses + """reorders V heads from grouped to tiled order for ggml broadcast + + see https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306 + + Linear attention may has num_k_heads < num_v_heads. The HF weights store + V heads grouped by K head: [G0_v0..v{r-1}, G1_v0..v{r-1}, ...]. + ggml binary ops use tiled broadcast: [K0, K1, ..., K0, K1, ...]. + We reorder V heads to tiled order so ggml_repeat can replace the expensive + interleaved repeat: [G0_v0, G1_v0, ..., G0_v1, G1_v1, ...]. + """ + + @staticmethod + def _reorder_v_heads(tensor: Tensor, dim: int, num_k_heads: int, num_v_per_k: int, head_dim: int) -> Tensor: + """Reorder V heads from grouped (by K head) to tiled order along the given dimension.""" + shape = list(tensor.shape) + if dim < 0: + dim += len(shape) + new_shape = shape[:dim] + [num_k_heads, num_v_per_k, head_dim] + shape[dim + 1:] + tensor = tensor.reshape(*new_shape) + perm = list(range(len(new_shape))) + perm[dim], perm[dim + 1] = perm[dim + 1], perm[dim] + return tensor.permute(*perm).contiguous().reshape(*shape) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + num_k_heads = self.hparams.get("linear_num_key_heads", 0) + num_v_heads = self.hparams.get("linear_num_value_heads", 0) + + if num_k_heads > 0 and num_v_heads > 0 and num_k_heads != num_v_heads and "linear_attn." in name: + head_k_dim = self.hparams["linear_key_head_dim"] + head_v_dim = self.hparams["linear_value_head_dim"] + num_v_per_k = num_v_heads // num_k_heads + + if ".in_proj_qkv." in name: + # QKV weight: reorder only the V rows + q_dim = head_k_dim * num_k_heads + k_dim = head_k_dim * num_k_heads + q = data_torch[:q_dim] + k = data_torch[q_dim:q_dim + k_dim] + v = data_torch[q_dim + k_dim:] + v = self._reorder_v_heads(v, 0, num_k_heads, num_v_per_k, head_v_dim) + data_torch = torch.cat([q, k, v], dim=0) + + elif ".in_proj_z." in name: + # Z gate weight: reorder rows (num_v_heads * head_v_dim) + data_torch = self._reorder_v_heads(data_torch, 0, num_k_heads, num_v_per_k, head_v_dim) + + elif ".in_proj_b." in name or ".in_proj_a." in name: + # Beta/Alpha weight: reorder rows (num_v_heads, head_dim=1) + data_torch = self._reorder_v_heads(data_torch, 0, num_k_heads, num_v_per_k, 1) + + elif ".A_log" in name or ".dt_bias" in name or ".dt_proj" in name: + # A_log / dt_bias: 1D parameters with num_v_heads elements + if data_torch.ndim == 1: + data_torch = self._reorder_v_heads( + data_torch.unsqueeze(-1), 0, num_k_heads, num_v_per_k, 1 + ).squeeze(-1) + else: + data_torch = self._reorder_v_heads(data_torch, -1, num_k_heads, num_v_per_k, 1) + + elif ".conv1d" in name: + # Conv1d kernel: reorder only the V channel portion + data = data_torch.squeeze() + qk_channels = head_k_dim * num_k_heads * 2 + qk_part = data[:qk_channels] + v_part = data[qk_channels:] + v_part = self._reorder_v_heads(v_part, 0, num_k_heads, num_v_per_k, head_v_dim) + data_torch = torch.cat([qk_part, v_part], dim=0) + + elif ".out_proj." in name: + # Out projection weight: reorder columns (input dimension) + data_torch = self._reorder_v_heads(data_torch, 1, num_k_heads, num_v_per_k, head_v_dim) + + yield from super().modify_tensors(data_torch, name, bid) + + +@ModelBase.register("Qwen3_5ForConditionalGeneration") +class Qwen3_5TextModel(_LinearAttentionVReorderBase): + model_arch = gguf.MODEL_ARCH.QWEN35 + + +@ModelBase.register("Qwen3_5MoeForConditionalGeneration") +class Qwen3_5MoeTextModel(_LinearAttentionVReorderBase): + model_arch = gguf.MODEL_ARCH.QWEN35MOE + + @ModelBase.register("GPT2LMHeadModel") class GPT2Model(TextModel): model_arch = gguf.MODEL_ARCH.GPT2 @@ -4771,13 +4999,13 @@ class PhiMoeModel(Phi3MiniModel): def set_gguf_parameters(self): super().set_gguf_parameters() - self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"]) - self.gguf_writer.add_expert_count(self.hparams["num_local_experts"]) + self.gguf_writer.add_expert_used_count(self.find_hparam(["num_experts_per_tok", "num_experts_per_token"])) + self.gguf_writer.add_expert_count(self.find_hparam(["num_local_experts", "num_experts"])) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # process the experts separately if name.find("block_sparse_moe.experts") != -1: - n_experts = self.hparams["num_local_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -5189,7 +5417,7 @@ class KimiLinearModel(TextModel): # process the experts separately if name.find("block_sparse_moe.experts") != -1: - n_experts = self.find_hparam(["num_local_experts", "num_experts"], optional=False) + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -5784,12 +6012,13 @@ class NomicBertModel(BertModel): if "mlp.experts.bias" in name: return # Explicitly return. + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) if "mlp.experts.mlp.w1" in name: - data_torch = data_torch.view(self.hparams["num_experts"], self.hparams["n_inner"], self.hparams["n_embd"]) + data_torch = data_torch.view(n_experts, self.hparams["n_inner"], self.hparams["n_embd"]) name += ".weight" if "mlp.experts.mlp.w2" in name: - data_torch = data_torch.view(self.hparams["num_experts"], self.hparams["n_inner"], self.hparams["n_embd"]) + data_torch = data_torch.view(n_experts, self.hparams["n_inner"], self.hparams["n_embd"]) data_torch = data_torch.transpose(1, 2) name += ".weight" @@ -5799,7 +6028,6 @@ class NomicBertModel(BertModel): super().set_gguf_parameters() if self.is_moe: self.gguf_writer.add_moe_every_n_layers(self.hparams["moe_every_n_layers"]) - self.gguf_writer.add_expert_count(self.hparams["num_experts"]) self.gguf_writer.add_expert_used_count(self.hparams["moe_top_k"]) def _is_tokenizer_xlmroberta(self) -> bool: @@ -6913,6 +7141,8 @@ class Mamba2Model(TextModel): if hparams is None: with open(dir_model / "config.json", "r", encoding="utf-8") as f: hparams = json.load(f) + if "llm_config" in hparams: + hparams["text_config"] = hparams["llm_config"] super().__init__(dir_model, *args, hparams=hparams, **kwargs) self.d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) self.d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * self.d_model @@ -7034,8 +7264,8 @@ class JambaModel(TextModel): self.gguf_writer.add_ssm_state_size(d_state) self.gguf_writer.add_ssm_time_step_rank(dt_rank) self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) - self.gguf_writer.add_expert_count(self.hparams["num_experts"]) - self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"]) + self.gguf_writer.add_expert_count(self.find_hparam(["num_local_experts", "num_experts"])) + self.gguf_writer.add_expert_used_count(self.find_hparam(["num_experts_per_tok", "num_experts_per_token"])) self.gguf_writer.add_file_type(self.ftype) _experts: list[dict[str, Tensor]] | None = None @@ -7053,7 +7283,7 @@ class JambaModel(TextModel): # process the experts separately if ".feed_forward.experts." in name: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None @@ -7139,6 +7369,17 @@ class Cohere2Model(TextModel): self.gguf_writer.add_rope_dimension_count(int(rotary_pct * (hidden_size // num_attention_heads))) self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # Cohere2 runtime in llama.cpp expects no bias tensors; + # the actual weight only contains 0-value tensors as bias, we can skip them + if name.endswith(".bias"): + if torch.any(data_torch != 0): + raise ValueError(f"Bias tensor {name!r} is not zero.") + logger.debug(f"Skipping bias tensor {name!r} for Cohere2 conversion.") + return + + yield from super().modify_tensors(data_torch, name, bid) + @ModelBase.register("OlmoForCausalLM") @ModelBase.register("OLMoForCausalLM") @@ -7201,8 +7442,6 @@ class OlmoeModel(TextModel): def set_gguf_parameters(self): super().set_gguf_parameters() self.gguf_writer.add_layer_norm_rms_eps(1e-5) - if (n_experts := self.hparams.get("num_experts")) is not None: - self.gguf_writer.add_expert_count(n_experts) _experts: list[dict[str, Tensor]] | None = None @@ -7210,7 +7449,7 @@ class OlmoeModel(TextModel): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # process the experts separately if name.find("experts") != -1: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -7579,12 +7818,16 @@ class DeepseekModel(TextModel): "DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM", "KimiVLForConditionalGeneration", + "KimiK25ForConditionalGeneration", "YoutuForCausalLM", "YoutuVLForConditionalGeneration", ) class DeepseekV2Model(TextModel): model_arch = gguf.MODEL_ARCH.DEEPSEEK2 + # TODO @ngxson : remove this when we support MTP for deepseek models + skip_mtp = True + def set_vocab(self): try: self._set_vocab_gpt2() @@ -7697,8 +7940,8 @@ class DeepseekV2Model(TextModel): _experts: list[dict[str, Tensor]] | None = None def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - # skip vision tensors and remove "language_model." for Kimi-VL - if "vision_tower" in name or "multi_modal_projector" in name: + # skip vision tensors and remove "language_model." for Kimi-VL and Kimi-K2.5 + if "vision_tower" in name or "multi_modal_projector" in name or "mm_projector" in name: return if name.startswith("siglip2.") or name.startswith("merger."): return @@ -7716,10 +7959,11 @@ class DeepseekV2Model(TextModel): name = name.replace("e_score_correction_bias", "e_score_correction.bias") # skip Multi-Token Prediction (MTP) layers - block_count = self.hparams["num_hidden_layers"] - match = re.match(r"model.layers.(\d+)", name) - if match and int(match.group(1)) >= block_count: - return + if self.skip_mtp: + block_count = self.hparams["num_hidden_layers"] + match = re.match(r"model.layers.(\d+)", name) + if match and int(match.group(1)) >= block_count: + return # process the experts separately if name.find("mlp.experts") != -1: @@ -7786,10 +8030,6 @@ class MiniMaxM2Model(TextModel): model_arch = gguf.MODEL_ARCH.MINIMAXM2 _experts_cache: dict[int, dict[str, Tensor]] = {} - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.hparams["num_experts"] = self.hparams["num_local_experts"] - def set_gguf_parameters(self): super().set_gguf_parameters() @@ -7802,7 +8042,7 @@ class MiniMaxM2Model(TextModel): # merge expert weights if 'experts' in name: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None expert_cache = self._experts_cache.setdefault(bid, {}) @@ -8396,6 +8636,17 @@ class T5EncoderModel(TextModel): yield from super().modify_tensors(data_torch, name, bid) +@ModelBase.register("Jais2ForCausalLM") +class Jais2Model(TextModel): + model_arch = gguf.MODEL_ARCH.JAIS2 + + def set_gguf_parameters(self): + super().set_gguf_parameters() + hparams = self.hparams + head_dim = hparams.get("head_dim", hparams["hidden_size"] // hparams["num_attention_heads"]) + self.gguf_writer.add_rope_dimension_count(head_dim) + + @ModelBase.register("JAISLMHeadModel") class JaisModel(TextModel): model_arch = gguf.MODEL_ARCH.JAIS @@ -8539,7 +8790,7 @@ class Glm4Model(TextModel): n_head = self.hparams["num_attention_heads"] n_kv_head = self.hparams["num_key_value_heads"] n_embd = self.hparams["hidden_size"] - head_dim = n_embd // n_head + head_dim = self.hparams.get("head_dim", n_embd // n_head) # because llama.cpp M-RoPE kernel only supports Neox ordering, we have to permute the weights here if name.endswith(("q_proj.weight", "q_proj.bias")): data_torch = Glm4Model.normal_to_neox(data_torch, n_head, n_head, head_dim, self.partial_rotary_factor) @@ -8548,6 +8799,27 @@ class Glm4Model(TextModel): yield from super().modify_tensors(data_torch, name, bid) +@ModelBase.register("GlmOcrForConditionalGeneration") +class GlmOCRModel(Glm4Model): + model_arch = gguf.MODEL_ARCH.GLM4 + use_mrope = False + partial_rotary_factor = 0.5 + + # Note: GLM-OCR is the same as GLM4, but with an extra NextN/MTP prediction layer + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # GLM-OCR has num_hidden_layers + 1 actual layers (including NextN layer) + self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("num_nextn_predict_layers", 0) + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + # NextN/MTP prediction layers + if (num_nextn_predict_layers := self.hparams.get("num_nextn_predict_layers")) is not None: + self.gguf_writer.add_nextn_predict_layers(num_nextn_predict_layers) + + @ModelBase.register("Glm4MoeForCausalLM", "Glm4vMoeForConditionalGeneration") class Glm4MoeModel(TextModel): model_arch = gguf.MODEL_ARCH.GLM4_MOE @@ -8559,24 +8831,7 @@ class Glm4MoeModel(TextModel): self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) def set_vocab(self): - from transformers import AutoTokenizer - - tokenizer = AutoTokenizer.from_pretrained(self.dir_model) - special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) - tokens, toktypes, tokpre = self.get_vocab_base() - self.gguf_writer.add_tokenizer_model("gpt2") - self.gguf_writer.add_tokenizer_pre(tokpre) - self.gguf_writer.add_token_list(tokens) - self.gguf_writer.add_token_types(toktypes) - - # Special tokens - # Note: Using <|endoftext|> (151329) for eot causes endless generation - special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331 - special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336 - special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329 - special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338 - - special_vocab.add_to_gguf(self.gguf_writer) + return self._set_vocab_glm() def set_gguf_parameters(self): super().set_gguf_parameters() @@ -8676,26 +8931,38 @@ class Glm4MoeModel(TextModel): class Glm4MoeLiteModel(DeepseekV2Model): model_arch = gguf.MODEL_ARCH.DEEPSEEK2 - # copied from Glm4MoeModel def set_vocab(self): - from transformers import AutoTokenizer + return self._set_vocab_glm() - tokenizer = AutoTokenizer.from_pretrained(self.dir_model) - special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) - tokens, toktypes, tokpre = self.get_vocab_base() - self.gguf_writer.add_tokenizer_model("gpt2") - self.gguf_writer.add_tokenizer_pre(tokpre) - self.gguf_writer.add_token_list(tokens) - self.gguf_writer.add_token_types(toktypes) - # Special tokens - # Note: Using <|endoftext|> (151329) for eot causes endless generation - special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331 - special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336 - special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329 - special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338 +@ModelBase.register("GlmMoeDsaForCausalLM") +class GlmMoeDsaModel(DeepseekV2Model): + model_arch = gguf.MODEL_ARCH.GLM_DSA + skip_mtp = False - special_vocab.add_to_gguf(self.gguf_writer) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("num_nextn_predict_layers", 0) + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + + def set_vocab(self): + return self._set_vocab_glm() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + rope_dim = self.hparams["qk_rope_head_dim"] + partial_rotary_factor = self.hparams.get("partial_rotary_factor", 1.0) + self.gguf_writer.add_rope_dimension_count(int(rope_dim * partial_rotary_factor)) + + # NextN/MTP prediction layers + if (num_nextn_predict_layers := self.hparams.get("num_nextn_predict_layers")) is not None: + self.gguf_writer.add_nextn_predict_layers(num_nextn_predict_layers) + + # DSA indexer parameters + self.gguf_writer.add_indexer_head_count(self.hparams["index_n_heads"]) + self.gguf_writer.add_indexer_key_length(self.hparams["index_head_dim"]) + self.gguf_writer.add_indexer_top_k(self.hparams["index_topk"]) @ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration") @@ -9012,7 +9279,6 @@ class ExaoneMoEModel(Exaone4Model): def set_gguf_parameters(self): super().set_gguf_parameters() - self.gguf_writer.add_expert_count(self.hparams["num_experts"]) moe_intermediate_size = self.hparams["moe_intermediate_size"] num_shared_experts = self.hparams["num_shared_experts"] self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) @@ -9053,7 +9319,7 @@ class ExaoneMoEModel(Exaone4Model): name = name.replace("e_score_correction_bias", "e_score_correction.bias") if name.find("mlp.experts") != -1: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -9204,7 +9470,7 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel): # case, the model architecture needs to be updated to a standard # "granite" or "granitemoe" model if not self._ssm_layers: - has_experts = self.find_hparam(["num_experts_per_tok"], optional=True) + has_experts = self.find_hparam(["num_experts_per_tok", "num_experts_per_token"], optional=True) new_arch = ( gguf.MODEL_ARCH.GRANITE_MOE if has_experts else @@ -9400,6 +9666,14 @@ class NemotronHModel(GraniteHybridModel): self.gguf_writer.add_add_bos_token(True) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # Skip vision model and projector tensors for VLM models (handled by mmproj) (e.g., Nemotron Nano 12B v2 VL) + if name.startswith(("vision_model.", "mlp1.")): + return + + # Strip language_model. prefix for VLM models (e.g., Nemotron Nano 12B v2 VL) + if name.startswith("language_model."): + name = name[len("language_model."):] + if self.is_moe and bid is not None: if name.endswith("mixer.gate.e_score_correction_bias"): new_name = name.replace("e_score_correction_bias", "e_score_correction.bias") @@ -9494,7 +9768,6 @@ class BailingMoeModel(TextModel): self.gguf_writer.add_vocab_size(hparams["vocab_size"]) self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) self.gguf_writer.add_expert_weights_scale(1.0) - self.gguf_writer.add_expert_count(hparams["num_experts"]) self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"]) self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) @@ -9528,7 +9801,7 @@ class BailingMoeModel(TextModel): yield from super().modify_tensors(v,self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), bid) return elif name.find("mlp.experts") != -1: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -9599,7 +9872,6 @@ class BailingMoeV2Model(TextModel): self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) self.gguf_writer.add_expert_shared_feed_forward_length(hparams.get("moe_shared_expert_intermediate_size", hparams["moe_intermediate_size"] * hparams["num_shared_experts"])) self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"]) - self.gguf_writer.add_expert_count(hparams["num_experts"]) self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"]) self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) @@ -9610,7 +9882,7 @@ class BailingMoeV2Model(TextModel): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: if "mlp.experts" in name: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -9656,8 +9928,6 @@ class GroveMoeModel(TextModel): def set_gguf_parameters(self): super().set_gguf_parameters() - if (n_experts := self.hparams.get("num_experts")) is not None: - self.gguf_writer.add_expert_count(n_experts) if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None: self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}") @@ -9678,7 +9948,7 @@ class GroveMoeModel(TextModel): # process the experts separately if name.find("chunk_experts") != -1: - n_experts = self.hparams["num_experts"] // 2 # see add_experts_per_group + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) // 2 # see add_experts_per_group assert bid is not None if self._chunk_experts is None: @@ -9705,7 +9975,7 @@ class GroveMoeModel(TextModel): else: return elif name.find("experts") != -1: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -10098,7 +10368,6 @@ class HunYuanMoEModel(TextModel): super().set_gguf_parameters() hparams = self.hparams - self.gguf_writer.add_expert_count(hparams["num_experts"]) self.gguf_writer.add_expert_shared_feed_forward_length(hparams["intermediate_size"]) moe_intermediate_size = hparams["moe_intermediate_size"] @@ -10141,7 +10410,7 @@ class HunYuanMoEModel(TextModel): return if name.find("mlp.experts") != -1: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -10183,16 +10452,9 @@ class LLaDAMoEModel(TextModel): def set_gguf_parameters(self): super().set_gguf_parameters() - if (n_experts := self.hparams.get("num_experts")) is not None: - self.gguf_writer.add_expert_count(n_experts) - if (expert_intermediate_size := self.hparams.get("expert_intermediate_size")) is not None: self.gguf_writer.add_expert_feed_forward_length(expert_intermediate_size) - # number of experts used per token (top-k) - if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None: - self.gguf_writer.add_expert_used_count(n_experts_used) - self.gguf_writer.add_mask_token_id(156895) self.gguf_writer.add_causal_attention(False) self.gguf_writer.add_diffusion_shift_logits(False) @@ -10203,7 +10465,7 @@ class LLaDAMoEModel(TextModel): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # process the experts separately if name.find("experts") != -1: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -10478,7 +10740,7 @@ class LFM2Model(TextModel): def set_gguf_parameters(self): # set num_key_value_heads only for attention layers self.hparams["num_key_value_heads"] = [ - self.hparams["num_key_value_heads"] if layer_type == "full_attention" else 0 + self.hparams["num_key_value_heads"] if layer_type != "conv" else 0 for layer_type in self.hparams["layer_types"] ] @@ -10540,7 +10802,6 @@ class LFM2MoeModel(TextModel): super().set_gguf_parameters() - self.gguf_writer.add_expert_count(self.hparams["num_experts"]) self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"]) self.gguf_writer.add_leading_dense_block_count(self.hparams["num_dense_layers"]) self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) @@ -10561,7 +10822,7 @@ class LFM2MoeModel(TextModel): # merge expert weights if 'experts' in name: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None expert_cache = self._experts_cache.setdefault(bid, {}) @@ -10665,15 +10926,37 @@ class LFM2AudioModel(ConformerAudioModel): yield from super().modify_tensors(data_torch, name, bid) +@ModelBase.register("Lfm25AudioTokenizer") +class LFM25AudioTokenizer(LFM2Model): + model_arch = gguf.MODEL_ARCH.LFM2 + + def set_vocab(self): + self._set_vocab_none() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + self.gguf_writer.add_sliding_window(self.hparams["sliding_window"]) + self.gguf_writer.add_embedding_length_out(self.hparams["output_size"]) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if name == "istft.window" or name.startswith("emb.emb"): + return + + if name.startswith("lin"): + name = name.replace("lin", "dense_2_out") + + yield from super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("SmallThinkerForCausalLM") class SmallThinkerModel(TextModel): model_arch = gguf.MODEL_ARCH.SMALLTHINKER def set_gguf_parameters(self): super().set_gguf_parameters() - if (n_experts := self.hparams.get("num_experts", self.hparams.get("moe_num_primary_experts"))) is not None: + if (n_experts := self.hparams.get("moe_num_primary_experts")) is not None: self.gguf_writer.add_expert_count(n_experts) - if (n_experts_used := self.hparams.get("num_experts_per_tok", self.hparams.get("moe_num_active_primary_experts"))) is not None: + if (n_experts_used := self.hparams.get("moe_num_active_primary_experts")) is not None: self.gguf_writer.add_expert_used_count(n_experts_used) if (moe_intermediate_size := self.hparams.get("moe_ffn_hidden_size")) is not None: self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) @@ -10698,7 +10981,7 @@ class SmallThinkerModel(TextModel): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # process the experts separately if name.find("experts") != -1: - n_experts = self.hparams.get("num_experts", self.hparams.get("moe_num_primary_experts")) + n_experts = self.hparams.get("moe_num_primary_experts") or self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -10756,13 +11039,17 @@ class ModernBertModel(BertModel): self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - # these layers act as MLM head, so we don't need them - if name.startswith("decoder."): - return - if name.startswith("model."): name = name[6:] + if self.cls_out_labels: + # For BertForSequenceClassification (direct projection layer) + if name == "classifier.weight": + name = "classifier.out_proj.weight" + + if name == "classifier.bias": + name = "classifier.out_proj.bias" + yield from super().modify_tensors(data_torch, name, bid) @@ -11060,6 +11347,103 @@ class KimiVLModel(MmprojModel): yield from super().modify_tensors(data_torch, name, bid) +@ModelBase.register("KimiK25ForConditionalGeneration") +class KimiK25Model(MmprojModel): + """Kimi-K2.5 with MoonViT3d vision encoder""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + assert self.hparams_vision is not None, "Kimi-K2.5 requires vision_config in model config" + + self.merge_kernel_size = tuple(self.hparams_vision.get("merge_kernel_size", [2, 2])) + self.patch_size = self.hparams_vision.get("patch_size", 14) + + # Set image_size for compatibility with base class + # Use position embedding dimensions as image_size reference + pos_emb_h = self.hparams_vision.get("init_pos_emb_height", 64) + self.hparams_vision["image_size"] = pos_emb_h * self.patch_size + + def set_gguf_parameters(self): + # Base class MmprojModel.set_gguf_parameters() already writes: + # - vision_block_count, vision_head_count, vision_embedding_length + # - vision_feed_forward_length, vision_patch_size, image_mean, image_std + # via find_vparam() which handles the vt_* prefixed keys in Kimi-K2.5's config + super().set_gguf_parameters() + assert self.hparams_vision is not None + + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.KIMIK25) + + # Position embedding parameters (for interpolation) + self.gguf_writer.add_uint32("vision.pos_emb_height", self.hparams_vision.get("init_pos_emb_height", 64)) + self.gguf_writer.add_uint32("vision.pos_emb_width", self.hparams_vision.get("init_pos_emb_width", 64)) + self.gguf_writer.add_uint32("vision.pos_emb_time", self.hparams_vision.get("init_pos_emb_time", 4)) + + # Projector parameters + self.gguf_writer.add_vision_use_gelu(self.hparams_vision.get("projector_hidden_act", "gelu") == "gelu") + self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("projector_ln_eps", 1e-5)) + self.gguf_writer.add_vision_projector_scale_factor(self.merge_kernel_size[0]) + + # Image size limits + # Note: in_patch_limit is for images, in_patch_limit_each_frame is for video (not supported yet) + in_patch_limit = self.preprocessor_config.get("in_patch_limit", 16384) + min_patches = 8 # reasonable minimum + pixels_per_patch = self.patch_size ** 2 + self.gguf_writer.add_vision_min_pixels(min_patches * pixels_per_patch) + self.gguf_writer.add_vision_max_pixels(in_patch_limit * pixels_per_patch) + + @staticmethod + def permute(weights: Tensor, n_head: int) -> Tensor: + out_dim, in_dim = weights.shape + head_dim = out_dim // n_head + w = weights.reshape(n_head, head_dim // 4, 2, 2, in_dim) + w = w.permute(0, 2, 1, 3, 4) + return w.reshape(out_dim, in_dim) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # Only process vision and projector tensors + is_vision = any(x in name for x in ["vision_tower", "mm_projector"]) + + if not is_vision: + return + + assert self.hparams_vision is not None + n_head = self.hparams_vision.get("num_attention_heads", 16) + + # Permute Q/K weights/biases from interleaved to split RoPE format + # This allows using build_rope_2d at runtime without post-permutation. + if "wqkv" in name: + out_dim = data_torch.shape[0] + qkv_dim = out_dim // 3 + head_dim = qkv_dim // n_head + + if "weight" in name: + wq, wk, wv = data_torch[:qkv_dim, :], data_torch[qkv_dim:2 * qkv_dim, :], data_torch[2 * qkv_dim:, :] + wq = self.permute(wq, n_head) + wk = self.permute(wk, n_head) + data_torch = torch.cat([wq, wk, wv], dim=0) + elif "bias" in name: + bq, bk, bv = data_torch[:qkv_dim], data_torch[qkv_dim:2 * qkv_dim], data_torch[2 * qkv_dim:] + bq = bq.reshape(n_head, head_dim // 4, 2, 2).permute(0, 2, 1, 3).reshape(-1) + bk = bk.reshape(n_head, head_dim // 4, 2, 2).permute(0, 2, 1, 3).reshape(-1) + data_torch = torch.cat([bq, bk, bv], dim=0) + + # Temporal embeddings: (T, 1, C) → (T, C) + if "pos_emb.time_weight" in name: + T, _, C = data_torch.shape + data_torch = data_torch.reshape(T, C) + + # PatchMergerMLP tensor name mapping + # proj.0.weight → proj.linear_1.weight + # proj.2.weight → proj.linear_2.weight + if "mm_projector.proj.0." in name: + name = name.replace(".proj.0.", ".proj.linear_1.") + elif "mm_projector.proj.2." in name: + name = name.replace(".proj.2.", ".proj.linear_2.") + + yield from super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("CogVLMForCausalLM") class CogVLMVisionModel(MmprojModel): diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 2811f7f884..8f7443d1b5 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -99,6 +99,7 @@ models = [ {"name": "stablelm2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/stabilityai/stablelm-2-zephyr-1_6b", }, {"name": "refact", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/smallcloudai/Refact-1_6-base", }, {"name": "command-r", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereForAI/c4ai-command-r-v01", }, + {"name": "tiny_aya", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/CohereLabs/tiny-aya-base", }, {"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen1.5-7B", }, {"name": "olmo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/allenai/OLMo-1.7-7B-hf", }, {"name": "dbrx", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/databricks/dbrx-base", }, @@ -113,6 +114,7 @@ models = [ {"name": "gemma", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2b", }, {"name": "gemma-2", "tokt": TOKENIZER_TYPE.SPM, "repo": "https://huggingface.co/google/gemma-2-9b", }, {"name": "jais", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/core42/jais-13b", }, + {"name": "jais-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/inceptionai/Jais-2-8B-Chat", }, {"name": "t5", "tokt": TOKENIZER_TYPE.UGM, "repo": "https://huggingface.co/google-t5/t5-small", }, {"name": "codeshell", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/WisdomShell/CodeShell-7B", }, {"name": "tekken", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/mistralai/Mistral-Nemo-Base-2407", }, @@ -148,6 +150,8 @@ models = [ {"name": "youtu", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Youtu-LLM-2B", }, {"name": "solar-open", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/upstage/Solar-Open-100B", }, {"name": "exaone-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/K-EXAONE-236B-A23B", }, + {"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-9B-Instruct", }, + {"name": "joyai-llm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jdopensource/JoyAI-LLM-Flash", }, ] # some models are known to be broken upstream, so we will skip them as exceptions @@ -157,6 +161,7 @@ pre_computed_hashes = [ {"name": "chatglm-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-chat", "chkhsh": "81d72c7348a9f0ebe86f23298d37debe0a5e71149e29bd283904c02262b27516"}, {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/THUDM/glm-4-9b-hf", "chkhsh": "a1336059768a55c99a734006ffb02203cd450fed003e9a71886c88acf24fdbc2"}, {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.5-Air", "chkhsh": "9ca2dd618e8afaf09731a7cf6e2105b373ba6a1821559f258b272fe83e6eb902"}, + {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.7-Flash", "chkhsh": "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267"}, {"name": "minerva-7b", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/sapienzanlp/Minerva-7B-base-v1.0", "chkhsh": "1431a23e583c97432bc230bff598d103ddb5a1f89960c8f1d1051aaa944d0b35"}, {"name": "hunyuan", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-A13B-Instruct", "chkhsh": "7e57df22b1fe23a7b1e1c7f3dc4e3f96d43a4eb0836d0c6bdc3436d7b2f1c664"}, {"name": "hunyuan-dense", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Hunyuan-4B-Instruct", "chkhsh": "bba3b3366b646dbdded5dbc42d59598b849371afc42f7beafa914afaa5b70aa6"}, @@ -170,7 +175,6 @@ pre_computed_hashes = [ {"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"}, # jina-v2-de variants {"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/aari1995/German_Semantic_V3", "chkhsh": "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df"}, - {"name": "glm4", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/zai-org/GLM-4.7-Flash", "chkhsh": "cdf5f35325780597efd76153d4d1c16778f766173908894c04afc20108536267"}, ] diff --git a/docs/backend/CANN.md b/docs/backend/CANN.md index b03c2a122c..23b6a62763 100755 --- a/docs/backend/CANN.md +++ b/docs/backend/CANN.md @@ -246,7 +246,7 @@ cmake --build build --config release 1. **Retrieve and prepare model** - You can refer to the general [*Prepare and Quantize*](../../README.md#prepare-and-quantize) guide for model prepration. + You can refer to the general [*Obtaining and quantizing models*](../../README.md#obtaining-and-quantizing-models) guide for model prepration. **Notes**: diff --git a/docs/backend/SYCL.md b/docs/backend/SYCL.md index b3cff96604..07c68be5cb 100644 --- a/docs/backend/SYCL.md +++ b/docs/backend/SYCL.md @@ -281,7 +281,7 @@ as `-cl-fp32-correctly-rounded-divide-sqrt` #### Retrieve and prepare model -You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/resolve/main/llama-2-7b.Q4_0.gguf?download=true) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf). +You can refer to the general [*Obtaining and quantizing models*](../../README.md#obtaining-and-quantizing-models) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/resolve/main/llama-2-7b.Q4_0.gguf?download=true) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf). ##### Check device @@ -569,7 +569,7 @@ Once it is completed, final results will be in **build/Release/bin** #### Retrieve and prepare model -You can refer to the general [*Prepare and Quantize*](README.md#prepare-and-quantize) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/blob/main/llama-2-7b.Q4_0.gguf) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf). +You can refer to the general [*Obtaining and quantizing models*](../../README.md#obtaining-and-quantizing-models) guide for model preparation, or download an already quantized model like [llama-2-7b.Q4_0.gguf](https://huggingface.co/TheBloke/Llama-2-7B-GGUF/blob/main/llama-2-7b.Q4_0.gguf) or [Meta-Llama-3-8B-Instruct-Q4_0.gguf](https://huggingface.co/aptha/Meta-Llama-3-8B-Instruct-Q4_0-GGUF/resolve/main/Meta-Llama-3-8B-Instruct-Q4_0.gguf). ##### Check device diff --git a/docs/backend/VirtGPU.md b/docs/backend/VirtGPU.md new file mode 100644 index 0000000000..c81468da13 --- /dev/null +++ b/docs/backend/VirtGPU.md @@ -0,0 +1,180 @@ +# GGML-VirtGPU Backend + +The GGML-VirtGPU backend enables GGML applications to run machine +learning computations on host hardware while the application itself +runs inside a virtual machine. It uses host-guest shared memory to +efficiently share data buffers between the two sides. + +This backend relies on the virtio-gpu, and VirglRenderer API Remoting +(APIR) component. The backend is split into two libraries: +- a GGML implementation (the "remoting frontend"), running in the + guest and interacting with the virtgpu device +- a VirglRenderer APIR compatible library (the "remoting backend"), + running in the host and interacting with Virglrenderer and an actual + GGML device backend. + +## OS support + +| OS | Status | Backend | CI testing | Notes +| -------- | ----------------- | ----------- | ----------- | ----- +| MacOS 14 | Supported | ggml-metal | X | Working when compiled on MacOS 14 +| MacOS 15 | Supported | ggml-metal | X | Working when compiled on MacOS 14 or MacOS 15 +| MacOS 26 | Not tested | | | +| Linux | Under development | ggml-vulkan | not working | Working locally, CI running into deadlocks + + +## Architecture Overview + +The GGML-VirtGPU backend consists of three main components: + +```mermaid +graph TD + %% Nodes + + subgraph GuestVM ["Guest VM - Frontend"] + App([GGML Application
llama.cpp, etc.]) + + direction TB + Interface[GGML Backend Interface] + Comm["GGML-VirtGPU
(hypercalls + shared mem)"] + + App --> Interface + Interface --> Comm + end + + API[virtio-gpu / virglrenderer API] + + subgraph HostSystem [Host System - Backend] + direction TB + Dispatcher[GGML-VirtGPU-Backend] + BackendLib[GGML Backend library
Metal / Vulkan / CPU / ...] + + Dispatcher --> BackendLib + end + + %% Connections + Comm --> API + API --> HostSystem +``` + +### Key Components + +1. **Guest-side Frontend** (`ggml-virtgpu/`): Implements the GGML backend interface and forwards operations to the host +2. **Host-side Backend** (`ggml-virtgpu/backend/`): Receives forwarded operations and executes them on actual hardware backends +3. **Communication Layer**: Uses virtio-gpu hypercalls and shared memory for efficient data transfer + +## Features + +- **Dynamic backend loading** on the host side (CPU, CUDA, Metal, etc.) +- **Zero-copy data transfer** via host-guest shared memory pages + +## Communication Protocol + +### Hypercalls and Shared Memory + +The backend uses two primary communication mechanisms: + +1. **Hypercalls (`DRM_IOCTL_VIRTGPU_EXECBUFFER`)**: Trigger remote execution from guest to host +2. **Shared Memory Pages**: Zero-copy data transfer for tensors and parameters + +#### Shared Memory Layout + +Each connection uses two shared memory buffers: + +- **Data Buffer** (24 MiB): For command/response data and tensor transfers +- **Reply Buffer** (16 KiB): For command replies and status information +- **Data Buffers**: Dynamically allocated host-guest shared buffers + served as GGML buffers. + +### APIR Protocol + +The Virglrender API Remoting protocol defines three command types: + +- `HANDSHAKE`: Protocol version negotiation and capability discovery +- `LOADLIBRARY`: Dynamic loading of backend libraries on the host +- `FORWARD`: API function call forwarding + +### Binary Serialization + +Commands and data are serialized using a custom binary protocol with: + +- Fixed-size encoding for basic types +- Variable-length arrays with size prefixes +- Buffer bounds checking +- Error recovery mechanisms + +## Supported Operations + +### Device Operations +- Device enumeration and capability queries +- Memory information (total/free) +- Backend type detection + +### Buffer Operations +- Buffer allocation and deallocation +- Tensor data transfer (host ↔ guest) +- Memory copying and clearing + +### Computation Operations +- Graph execution forwarding + +## Build Requirements + +### Guest-side Dependencies +- `libdrm` for DRM/virtio-gpu communication +- C++20 compatible compiler +- CMake 3.14+ + +### Host-side Dependencies +- virglrenderer with APIR support (pending upstream review) +- Target backend libraries (libggml-metal, libggml-vulkan, etc.) + +## Configuration + +### Environment Variables + +- `GGML_VIRTGPU_BACKEND_LIBRARY`: Path to the host-side backend library +- `GGML_VIRTGPU_DEBUG`: Enable debug logging + +### Build Options + +- `GGML_VIRTGPU`: Enable the VirtGPU backend (`ON` or `OFF`, default: `OFF`) +- `GGML_VIRTGPU_BACKEND`: Build the host-side backend component (`ON`, `OFF` or `ONLY`, default: `OFF`) + +### System Requirements + +- VM with virtio-gpu support +- VirglRenderer with APIR patches +- Compatible backend libraries on host + +## Limitations + +- **VM-specific**: Only works in virtual machines with virtio-gpu support +- **Host dependency**: Requires properly configured host-side backend +- **Latency**: Small overhead from VM escaping for each operation + + +* This work is pending upstream changes in the VirglRenderer + project. + * The backend can be tested with Virglrenderer compiled from source + using this PR: + https://gitlab.freedesktop.org/virgl/virglrenderer/-/merge_requests/1590 +* This work is pending changes in the VMM/hypervisor running the + virtual machine, which need to know how to route the newly + introduced APIR capset. + * The environment variable `VIRGL_ROUTE_VENUS_TO_APIR=1` allows + using the Venus capset, until the relevant hypervisors have been + patched. However, setting this flag breaks the Vulkan/Venus normal + behavior. + * The environment variable `GGML_REMOTING_USE_APIR_CAPSET` tells the + `ggml-virtgpu` backend to use the APIR capset. This will become + the default when the relevant hypervisors have been patched. + +* This work focused on improving the performance of llama.cpp running + on MacOS containers, and is mainly tested on this platform. The + linux support (via `krun`) is in progress. + +## See Also + +- [Development and Testing](VirtGPU/development.md) +- [Backend configuration](VirtGPU/configuration.md) diff --git a/docs/backend/VirtGPU/configuration.md b/docs/backend/VirtGPU/configuration.md new file mode 100644 index 0000000000..597862d5c8 --- /dev/null +++ b/docs/backend/VirtGPU/configuration.md @@ -0,0 +1,174 @@ +# GGML-VirtGPU Backend Configuration + +This document describes the environment variables used by the ggml-virtgpu backend system, covering both the frontend (guest-side) and backend (host-side) components. + +## Environment Variables Overview + +The ggml-virtgpu backend uses environment variables for configuration across three main components: +- **Frontend (Guest)**: GGML applications running in VMs +- **Hypervisor**: Virglrenderer/APIR system +- **Backend (Host)**: Host-side GGML backend integration + +## Frontend (Guest-side) Configuration + +### GGML_REMOTING_USE_APIR_CAPSET +- **Location**: `ggml/src/ggml-virtgpu/virtgpu.cpp` +- **Type**: Boolean flag (presence-based) +- **Purpose**: Controls which virtio-gpu capability set to use for communication +- **Values**: + - Set (any value): Use the APIR capset (long-term setup) + - Unset: Use the Venus capset (easier for testing with an unmodified hypervisor) +- **Default**: Unset (Venus capset) +- **Usage**: + ```bash + export GGML_REMOTING_USE_APIR_CAPSET=1 # Use APIR capset + # or leave unset for Venus capset + ``` + +## Hypervisor (Virglrenderer/APIR) Configuration + +These environment variables are used during the transition phase for +running with an unmodified hypervisor (not supporting the +VirglRenderer APIR component). They will be removed in the future, and +the hypervisor will instead configure VirglRenderer with the APIR +_Configuration Key_. + +### VIRGL_APIR_BACKEND_LIBRARY +- **Location**: `virglrenderer/src/apir/apir-context.c` +- **Configuration Key**: `apir.load_library.path` +- **Type**: File path string +- **Purpose**: Path to the APIR backend library that virglrenderer should dynamically load +- **Required**: Yes +- **Example**: + ```bash + export VIRGL_APIR_BACKEND_LIBRARY="/path/to/libggml-remotingbackend.so" + ``` + +### VIRGL_ROUTE_VENUS_TO_APIR +- **Location**: `virglrenderer/src/apir/apir-renderer.h` +- **Type**: Boolean flag (presence-based) +- **Purpose**: Temporary workaround to route Venus capset calls to APIR during hypervisor transition period +- **Status**: will be removed once hypervisors support APIR natively +- **Warning**: Breaks normal Vulkan/Venus functionality +- **Usage**: + ```bash + export VIRGL_ROUTE_VENUS_TO_APIR=1 # For testing with an unmodified hypervisor + ``` + +### VIRGL_APIR_LOG_TO_FILE +- **Location**: `virglrenderer/src/apir/apir-renderer.c` +- **Environment Variable**: `VIRGL_APIR_LOG_TO_FILE` +- **Type**: File path string +- **Purpose**: Enable debug logging from the VirglRenderer APIR component to specified file +- **Required**: No (optional debugging) +- **Default**: Logging to `stderr` +- **Usage**: + ```bash + export VIRGL_APIR_LOG_TO_FILE="/tmp/apir-debug.log" + ``` + +## Backend (Host-side) Configuration + +These environment variables are used during the transition phase for +running with an unmodified hypervisor (not supporting the +VirglRenderer APIR component). They will be removed in the future, and +the hypervisor will instead configure VirglRenderer with the APIR +_Configuration Key_. + +### APIR_LLAMA_CPP_GGML_LIBRARY_PATH +- **Location**: `ggml/src/ggml-virtgpu/backend/backend.cpp` +- **Environment Variable**: `APIR_LLAMA_CPP_GGML_LIBRARY_PATH` +- **Configuration Key**: `ggml.library.path` +- **Type**: File path string +- **Purpose**: Path to the actual GGML backend library (Metal, CUDA, Vulkan, etc.) +- **Required**: **Yes** - backend initialization fails without this +- **Examples**: + ```bash + # macOS with Metal backend + export APIR_LLAMA_CPP_GGML_LIBRARY_PATH="/opt/llama.cpp/lib/libggml-metal.dylib" + + # Linux with CUDA backend + export APIR_LLAMA_CPP_GGML_LIBRARY_PATH="/opt/llama.cpp/lib/libggml-cuda.so" + + # macOS or Linux with Vulkan backend + export APIR_LLAMA_CPP_GGML_LIBRARY_PATH="/opt/llama.cpp/lib/libggml-vulkan.so" + ``` + +### APIR_LLAMA_CPP_GGML_LIBRARY_REG +- **Location**: `ggml/src/ggml-virtgpu/backend/backend.cpp` +- **Environment Variable**: `APIR_LLAMA_CPP_GGML_LIBRARY_REG` +- **Configuration Key**: `ggml.library.reg` +- **Type**: Function symbol name string +- **Purpose**: Name of the backend registration function to call after loading the library +- **Required**: No (defaults to `ggml_backend_init`) +- **Default**: `ggml_backend_init` +- **Examples**: + ```bash + # Metal backend + export APIR_LLAMA_CPP_GGML_LIBRARY_REG="ggml_backend_metal_reg" + + # CUDA backend + export APIR_LLAMA_CPP_GGML_LIBRARY_REG="ggml_backend_cuda_reg" + + # Vulkan backend + export APIR_LLAMA_CPP_GGML_LIBRARY_REG="ggml_backend_vulkan_reg" + + # Generic fallback (default) + # export APIR_LLAMA_CPP_GGML_LIBRARY_REG="ggml_backend_init" + ``` + +### APIR_LLAMA_CPP_LOG_TO_FILE +- **Location**: `ggml/src/ggml-virtgpu/backend/backend.cpp:62` +- **Environment Variable**: `APIR_LLAMA_CPP_LOG_TO_FILE` +- **Type**: File path string +- **Purpose**: Enable debug logging from the GGML backend to specified file +- **Required**: No (optional debugging) +- **Usage**: + ```bash + export APIR_LLAMA_CPP_LOG_TO_FILE="/tmp/ggml-backend-debug.log" + ``` + +## Configuration Flow + +The configuration system works as follows: + +1. **Hypervisor Setup**: Virglrenderer loads the APIR backend library specified by `VIRGL_APIR_BACKEND_LIBRARY` + +2. **Context Creation**: When an APIR context is created, it populates a configuration table with environment variables: + - `apir.load_library.path` ← `VIRGL_APIR_BACKEND_LIBRARY` + - `ggml.library.path` ← `APIR_LLAMA_CPP_GGML_LIBRARY_PATH` + - `ggml.library.reg` ← `APIR_LLAMA_CPP_GGML_LIBRARY_REG` + - this step will eventually be performed by the hypervisor itself, with command-line arguments instead of environment variables. + +3. **Backend Initialization**: The backend queries the configuration via callbacks: + - `virgl_cbs->get_config(ctx_id, "ggml.library.path")` returns the library path + - `virgl_cbs->get_config(ctx_id, "ggml.library.reg")` returns the registration function + +4. **Library Loading**: The backend dynamically loads and initializes the specified GGML library + +## Error Messages + +Common error scenarios and their messages: + +- **Missing library path**: `"cannot open the GGML library: env var 'APIR_LLAMA_CPP_GGML_LIBRARY_PATH' not defined"` +- **Missing registration function**: `"cannot register the GGML library: env var 'APIR_LLAMA_CPP_GGML_LIBRARY_REG' not defined"` + +## Example Complete Configuration + +Here's an example configuration for a macOS host with Metal backend: + +```bash +# Hypervisor environment +export VIRGL_APIR_BACKEND_LIBRARY="/opt/llama.cpp/lib/libggml-virtgpu-backend.dylib" + +# Backend configuration +export APIR_LLAMA_CPP_GGML_LIBRARY_PATH="/opt/llama.cpp/lib/libggml-metal.dylib" +export APIR_LLAMA_CPP_GGML_LIBRARY_REG="ggml_backend_metal_reg" + +# Optional logging +export VIRGL_APIR_LOG_TO_FILE="/tmp/apir.log" +export APIR_LLAMA_CPP_LOG_TO_FILE="/tmp/ggml.log" + +# Guest configuration +export GGML_REMOTING_USE_APIR_CAPSET=1 +``` diff --git a/docs/backend/VirtGPU/development.md b/docs/backend/VirtGPU/development.md new file mode 100644 index 0000000000..ca2e47772a --- /dev/null +++ b/docs/backend/VirtGPU/development.md @@ -0,0 +1,220 @@ +# Development and Testing + +## Development + +### Code Generation + +The backend uses code generation from YAML configuration: + +```bash +# Regenerate protocol code +cd ggml-virtgpu/ +python regenerate_remoting.py +``` + +### Adding New Operations + +1. Add function definition to `ggmlremoting_functions.yaml` +2. Regenerate code with `regenerate_remoting.py` +3. Implement guest-side forwarding in `virtgpu-forward-*.cpp` +4. Implement host-side handling in `backend-dispatched-*.cpp` + +## Testing + +This document provides instructions for building and testing the GGML-VirtGPU backend on macOS with containers. + +### Prerequisites + +The testing setup requires: + +- macOS host system +- Container runtime with `libkrun` provider (podman machine) +- Access to development patchset for VirglRenderer + +### Required Patchsets + +The backend requires patches that are currently under review: + +- **Virglrenderer APIR upstream PR**: https://gitlab.freedesktop.org/virgl/virglrenderer/-/merge_requests/1590 (for reference) +- **MacOS Virglrenderer (for krunkit)**: https://gitlab.freedesktop.org/kpouget/virglrenderer/-/tree/main-macos +- **Linux Virglrenderer (for krun)**: https://gitlab.freedesktop.org/kpouget/virglrenderer/-/tree/main-linux + +### Build Instructions + +#### 1. Build ggml-virtgpu-backend (Host-side, macOS) + +```bash +# Build the backend that runs natively on macOS +mkdir llama.cpp +cd llama.cpp +git clone https://github.com/ggml-org/llama.cpp.git src +cd src + +LLAMA_MAC_BUILD=$PWD/build/ggml-virtgpu-backend + +cmake -S . -B $LLAMA_MAC_BUILD \ + -DGGML_NATIVE=OFF \ + -DLLAMA_CURL=ON \ + -DGGML_REMOTINGBACKEND=ONLY \ + -DGGML_METAL=ON + +TARGETS="ggml-metal" +cmake --build $LLAMA_MAC_BUILD --parallel 8 --target $TARGETS + +# Build additional tools for native benchmarking +EXTRA_TARGETS="llama-run llama-bench" +cmake --build $LLAMA_MAC_BUILD --parallel 8 --target $EXTRA_TARGETS +``` + +#### 2. Build virglrenderer (Host-side, macOS) + +```bash +# Build virglrenderer with APIR support +mkdir virglrenderer +git clone https://gitlab.freedesktop.org/kpouget/virglrenderer -b main-macos src +cd src + +VIRGL_BUILD_DIR=$PWD/build + +# -Dvenus=true and VIRGL_ROUTE_VENUS_TO_APIR=1 route the APIR requests via the Venus backend, for easier testing without a patched hypervisor + +meson setup $VIRGL_BUILD_DIR \ + -Dvenus=true \ + -Dapir=true + +ninja -C $VIRGL_BUILD_DIR +``` + +#### 3. Build ggml-virtgpu (Guest-side, Linux) + +Option A: Build from a script: + +```bash +# Inside a Linux container +mkdir llama.cpp +git clone https://github.com/ggml-org/llama.cpp.git src +cd src + +LLAMA_LINUX_BUILD=$PWD//build-virtgpu + +cmake -S . -B $LLAMA_LINUX_BUILD \ + -DGGML_VIRTGPU=ON + +ninja -C $LLAMA_LINUX_BUILD +``` + +Option B: Build container image with frontend: + +```bash +cat << EOF > remoting.containerfile +FROM quay.io/fedora/fedora:43 +USER 0 + +WORKDIR /app/remoting + +ARG LLAMA_CPP_REPO="https://github.com/ggml-org/llama.cpp.git" +ARG LLAMA_CPP_VERSION="master" +ARG LLAMA_CPP_CMAKE_FLAGS="-DGGML_VIRTGPU=ON" +ARG LLAMA_CPP_CMAKE_BUILD_FLAGS="--parallel 4" + +RUN dnf install -y git cmake gcc gcc-c++ libcurl-devel libdrm-devel + +RUN git clone "\${LLAMA_CPP_REPO}" src \\ + && git -C src fetch origin \${LLAMA_CPP_VERSION} \\ + && git -C src reset --hard FETCH_HEAD + +RUN mkdir -p build \\ + && cd src \\ + && set -o pipefail \\ + && cmake -S . -B ../build \${LLAMA_CPP_CMAKE_FLAGS} \\ + && cmake --build ../build/ \${LLAMA_CPP_CMAKE_BUILD_FLAGS} + +ENTRYPOINT ["/app/remoting/src/build/bin/llama-server"] +EOF + +mkdir -p empty_dir +podman build -f remoting.containerfile ./empty_dir -t localhost/llama-cpp.virtgpu +``` + +### Environment Setup + +#### Set krunkit Environment Variables + +```bash +# Define the base directories (adapt these paths to your system) +VIRGL_BUILD_DIR=$HOME/remoting/virglrenderer/build +LLAMA_MAC_BUILD=$HOME/remoting/llama.cpp/build-backend + +# For krunkit to load the custom virglrenderer library +export DYLD_LIBRARY_PATH=$VIRGL_BUILD_DIR/src + +# For Virglrenderer to load the ggml-remotingbackend library +export VIRGL_APIR_BACKEND_LIBRARY="$LLAMA_MAC_BUILD/bin/libggml-virtgpu-backend.dylib" + +# For llama.cpp remotingbackend to load the ggml-metal backend +export APIR_LLAMA_CPP_GGML_LIBRARY_PATH="$LLAMA_MAC_BUILD/bin/libggml-metal.dylib" +export APIR_LLAMA_CPP_GGML_LIBRARY_REG=ggml_backend_metal_reg +``` + +#### Launch Container Environment + +```bash +# Set container provider to libkrun +export CONTAINERS_MACHINE_PROVIDER=libkrun +podman machine start +``` + +#### Verify Environment + +Confirm that krunkit is using the correct virglrenderer library: + +```bash +lsof -c krunkit | grep virglrenderer +# Expected output: +# krunkit 50574 user txt REG 1,14 2273912 10849442 ($VIRGL_BUILD_DIR/src)/libvirglrenderer.1.dylib +``` + +### Running Tests + +#### Launch Test Container + +```bash +# Optional model caching +mkdir -p models +PODMAN_CACHE_ARGS="-v models:/models --user root:root --cgroupns host --security-opt label=disable -w /models" + +podman run $PODMAN_CACHE_ARGS -it --rm --device /dev/dri localhost/llama-cpp.virtgpu +``` + +#### Test llama.cpp in Container + +```bash + +# Run performance benchmark +/app/remoting/build/bin/llama-bench -m ./llama3.2 +``` + +Expected output (performance may vary): +``` +| model | size | params | backend | ngl | test | t/s | +| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------: | -------------------: | +| llama 3B Q4_K - Medium | 1.87 GiB | 3.21 B | ggml-virtgpu | 99 | pp512 | 991.30 ± 0.66 | +| llama 3B Q4_K - Medium | 1.87 GiB | 3.21 B | ggml-virtgpu | 99 | tg128 | 85.71 ± 0.11 | +``` + +### Troubleshooting + +#### SSH Environment Variable Issues + +⚠️ **Warning**: Setting `DYLD_LIBRARY_PATH` from SSH doesn't work on macOS. Here is a workaround: + +**Workaround 1: Replace system library** +```bash +VIRGL_BUILD_DIR=$HOME/remoting/virglrenderer/build # ⚠️ adapt to your system +BREW_VIRGL_DIR=/opt/homebrew/Cellar/virglrenderer/0.10.4d/lib +VIRGL_LIB=libvirglrenderer.1.dylib + +cd $BREW_VIRGL_DIR +mv $VIRGL_LIB ${VIRGL_LIB}.orig +ln -s $VIRGL_BUILD_DIR/src/$VIRGL_LIB +``` diff --git a/docs/backend/snapdragon/README.md b/docs/backend/snapdragon/README.md index 8e1f37b206..2c3f88e91a 100644 --- a/docs/backend/snapdragon/README.md +++ b/docs/backend/snapdragon/README.md @@ -35,7 +35,7 @@ Adapt below build commands accordingly. Let's build llama.cpp with CPU, OpenCL, and Hexagon backends via CMake presets: ``` -[d]/workspace> cp docs/backend/hexagon/CMakeUserPresets.json . +[d]/workspace> cp docs/backend/snapdragon/CMakeUserPresets.json . [d]/workspace> cmake --preset arm64-android-snapdragon-release -B build-snapdragon Preset CMake variables: diff --git a/docs/build-s390x.md b/docs/build-s390x.md index 67df4e2eac..4568d5010f 100644 --- a/docs/build-s390x.md +++ b/docs/build-s390x.md @@ -242,10 +242,10 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl |------------|-------------|------|-------| | FP32 | ✅ | ✅ | ❓ | | FP16 | ✅ | ✅ | ❓ | -| BF16 | 🚫 | ✅ | ❓ | +| BF16 | ✅ | ✅ | ❓ | | Q4_0 | ✅ | ❓ | ❓ | | Q4_1 | ✅ | ❓ | ❓ | -| MXFP4 | 🚫 | ❓ | ❓ | +| MXFP4 | ✅ | ❓ | ❓ | | Q5_0 | ✅ | ❓ | ❓ | | Q5_1 | ✅ | ❓ | ❓ | | Q8_0 | ✅ | ❓ | ❓ | @@ -272,4 +272,4 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl - 🚫 - acceleration unavailable, will still run using scalar implementation - ❓ - acceleration unknown, please contribute if you can test it yourself -Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on Sep 7, 2025. +Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on Feb 15, 2026. diff --git a/docs/speculative.md b/docs/speculative.md index 03afab5b41..29da332875 100644 --- a/docs/speculative.md +++ b/docs/speculative.md @@ -119,8 +119,6 @@ If a draft model is combined with a draftless decoding the draftless decoding ha of lookup n-gram (default: 12) --spec-ngram-size-m N ngram size M for ngram-simple/ngram-map speculative decoding, length of draft m-gram (default: 48) ---spec-ngram-check-rate N ngram check rate for ngram-simple/ngram-map speculative decoding - (default: 1) --spec-ngram-min-hits N minimum hits for ngram-map speculative decoding (default: 1) ``` @@ -153,10 +151,6 @@ Sets the size M of the draft m-gram for n-gram map based speculative decoding. The m-gram size determines how many tokens to draft when a match is found. Larger values can provide more speedup but may reduce acceptance rate. -### `--spec-ngram-check-rate R` - -This option aims at performance if the n-gram lookup in history is to costly. A lookup will be executed at every R tokens (default is 1, every token). - ### `--spec-ngram-min-hits H` This option defines how often a key has to appear in the token history to be used as a draft (default is 1). @@ -175,7 +169,12 @@ draft acceptance rate = 0.70312 ( 90 accepted / 128 generated) statistics ngram_mod: #calls = 810, #gen drafts = 15, #acc drafts = 15, #gen tokens = 960, #acc tokens = 730, dur(b,g,a) = 0.149, 0.347, 0.005 ms ``` -- `#calls`: number of calls of this implementations +``` +statistics ngram_map_k: #calls(b,g,a) = 6 1690 26, #gen drafts = 26, #acc drafts = 26, #gen tokens = 1248, #acc tokens = 968, dur(b,g,a) = 2.234, 1.427, 0.016 ms +``` + + +- `#calls(b,g,a)`: number of calls of begin (new prompt), generation and accumulation of this implementations - `#gen drafts`: number of drafts generated by this implementation - `#acc drafts`: number of drafts accepted (partially) by the main model - `#gen tokens`: number of tokens generated by this implementation (including rejected tokens) diff --git a/examples/model-conversion/scripts/causal/run-org-model.py b/examples/model-conversion/scripts/causal/run-org-model.py index 215f1a9ee0..6f85ee4485 100755 --- a/examples/model-conversion/scripts/causal/run-org-model.py +++ b/examples/model-conversion/scripts/causal/run-org-model.py @@ -42,11 +42,15 @@ def load_model_and_tokenizer(model_path, device="auto"): config = config.text_config multimodal = True - print("Vocab size: ", config.vocab_size) - print("Hidden size: ", config.hidden_size) - print("Number of layers: ", config.num_hidden_layers) - print("BOS token id: ", config.bos_token_id) - print("EOS token id: ", config.eos_token_id) + def print_if_exists(label, obj, attr, default="N/A"): + val = getattr(obj, attr) if hasattr(obj, attr) else default + print(f"{label}", val) + + print_if_exists("Vocab size: ", config, "vocab_size") + print_if_exists("Hidden size: ", config, "hidden_size") + print_if_exists("Number of layers: ", config, "num_hidden_layers") + print_if_exists("BOS token id: ", config, "bos_token_id") + print_if_exists("EOS token id: ", config, "eos_token_id") unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME") if unreleased_model_name: diff --git a/examples/model-conversion/scripts/utils/tensor-info.py b/examples/model-conversion/scripts/utils/tensor-info.py index 12a3430b49..1bb9e0564c 100755 --- a/examples/model-conversion/scripts/utils/tensor-info.py +++ b/examples/model-conversion/scripts/utils/tensor-info.py @@ -78,7 +78,7 @@ def list_all_tensors(model_path: Path, unique: bool = False): print(tensor_name) -def print_tensor_info(model_path: Path, tensor_name: str): +def print_tensor_info(model_path: Path, tensor_name: str, num_values: Optional[int] = None): tensor_file = find_tensor_file(model_path, tensor_name) if tensor_file is None: @@ -96,6 +96,12 @@ def print_tensor_info(model_path: Path, tensor_name: str): print(f"Tensor: {tensor_name}") print(f"File: {tensor_file}") print(f"Shape: {shape}") + if num_values is not None: + tensor = f.get_tensor(tensor_name) + print(f"Dtype: {tensor.dtype}") + flat = tensor.flatten() + n = min(num_values, flat.numel()) + print(f"Values: {flat[:n].tolist()}") else: print(f"Error: Tensor '{tensor_name}' not found in {tensor_file}") sys.exit(1) @@ -127,6 +133,15 @@ def main(): action="store_true", help="List unique tensor patterns in the model (layer numbers replaced with #)" ) + parser.add_argument( + "-n", "--num-values", + nargs="?", + const=10, + default=None, + type=int, + metavar="N", + help="Print the first N values of the tensor flattened (default: 10 if flag is given without a number)" + ) args = parser.parse_args() @@ -152,7 +167,7 @@ def main(): if args.tensor_name is None: print("Error: tensor_name is required when not using --list") sys.exit(1) - print_tensor_info(model_path, args.tensor_name) + print_tensor_info(model_path, args.tensor_name, args.num_values) if __name__ == "__main__": diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 71d1a7f0e3..4323afe57b 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -4,7 +4,7 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MINOR 9) -set(GGML_VERSION_PATCH 5) +set(GGML_VERSION_PATCH 7) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH) diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index f759e2d588..77af0e7fb6 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -752,6 +752,7 @@ extern "C" { GGML_API bool ggml_is_transposed(const struct ggml_tensor * tensor); GGML_API bool ggml_is_permuted (const struct ggml_tensor * tensor); GGML_API bool ggml_is_empty (const struct ggml_tensor * tensor); + GGML_API bool ggml_is_view (const struct ggml_tensor * tensor); GGML_API bool ggml_is_scalar (const struct ggml_tensor * tensor); GGML_API bool ggml_is_vector (const struct ggml_tensor * tensor); GGML_API bool ggml_is_matrix (const struct ggml_tensor * tensor); diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c index 41419b617b..7f414b2311 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c @@ -17,11 +17,6 @@ //#define AT_PRINTF(...) GGML_LOG_DEBUG(__VA_ARGS__) #define AT_PRINTF(...) - -static bool ggml_is_view(const struct ggml_tensor * t) { - return t->view_src != NULL; -} - // ops that return true for this function must not use restrict pointers for their backend implementations bool ggml_op_can_inplace(enum ggml_op op) { switch (op) { @@ -627,7 +622,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor GGML_ASSERT(buffer_id >= 0); struct hash_node * hn = ggml_gallocr_hash_get(galloc, node); - if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_is_view(node)) { + if (!ggml_gallocr_is_allocated(galloc, node) && !ggml_impl_is_view(node)) { hn->allocated = true; assert(hn->addr.offset == 0); @@ -658,7 +653,7 @@ static void ggml_gallocr_allocate_node(ggml_gallocr_t galloc, struct ggml_tensor struct hash_node * p_hn = ggml_gallocr_hash_get(galloc, parent); if (p_hn->n_children == 1 && p_hn->n_views == 0) { - if (ggml_is_view(parent)) { + if (ggml_impl_is_view(parent)) { struct ggml_tensor * view_src = parent->view_src; struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src); if (view_src_hn->n_views == 1 && view_src_hn->n_children == 0 && view_src->data == parent->data) { @@ -739,7 +734,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr // GGML_OP_NONE does not appear normally in the graph nodes, but is used by ggml-backend to add dependencies to // control when some tensors are allocated and freed. in this case, the dependencies are in `src`, but the node // itself is never used and should not be considered a dependency - if (ggml_is_view(node) && node->op != GGML_OP_NONE) { + if (ggml_impl_is_view(node) && node->op != GGML_OP_NONE) { struct ggml_tensor * view_src = node->view_src; ggml_gallocr_hash_get(galloc, view_src)->n_views += 1; } @@ -806,7 +801,7 @@ static void ggml_gallocr_alloc_graph_impl(ggml_gallocr_t galloc, struct ggml_cgr parent->name, p_hn->n_children, p_hn->n_views, p_hn->allocated); if (p_hn->n_children == 0 && p_hn->n_views == 0) { - if (ggml_is_view(parent)) { + if (ggml_impl_is_view(parent)) { struct ggml_tensor * view_src = parent->view_src; struct hash_node * view_src_hn = ggml_gallocr_hash_get(galloc, view_src); view_src_hn->n_views -= 1; diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index 8a693f84af..311fa5fe36 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -471,9 +471,10 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, int best_score = 0; fs::path best_path; + std::error_code ec; for (const auto & search_path : search_paths) { - if (std::error_code ec; !fs::exists(search_path, ec)) { + if (!fs::exists(search_path, ec)) { if (ec) { GGML_LOG_DEBUG("%s: posix_stat(%s) failure, error-message: %s\n", __func__, path_str(search_path).c_str(), ec.message().c_str()); } else { @@ -483,7 +484,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, } fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied); for (const auto & entry : dir_it) { - if (entry.is_regular_file()) { + if (entry.is_regular_file(ec)) { auto filename = entry.path().filename(); auto ext = entry.path().extension(); if (filename.native().find(file_prefix) == 0 && ext == file_extension) { diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 87ac05748e..fc7c3e3b72 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -3286,130 +3286,223 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context & ctx, ggml_tensor } /** - * @brief Performs expert-specific matrix multiplication (MoE) with - * quantized precision using the CANN backend. + * @brief Performs quantized matrix multiplication for Mixture of Experts (MoE) + * models using the CANN backend. * - * This function executes a matrix multiplication operation tailored for - * Mixture of Experts (MoE) models, where the input tensor is multiplied - * with expert-specific quantized weight matrices. It leverages the CANN - * backend to perform efficient low-precision computations and stores the - * quantized result in the destination tensor `dst`. + * This function implements MUL_MAT_ID operation for quantized weight matrices + * (Q4_0 and Q8_0 formats). It selects expert-specific weight matrices based on + * the provided expert indices, and computes matrix multiplication using CANN's + * WeightQuantBatchMatmulV2 operator. * - * Quantization techniques reduce memory footprint and improve performance - * by using lower-bit representations (e.g., int8) instead of floating-point. - * This function is designed to work with such formats and may incorporate - * optimizations like identity-based fast paths or routing masks for sparse - * expert selection. + * The function performs the following steps: + * 1. Converts input/output tensors to F16 format if necessary + * 2. Uses IndexSelect to extract expert-specific weights and scales based on indices + * 3. Performs quantized matrix multiplication for each expert using WeightQuantBatchMatmulV2 + * 4. Converts output back to the target type if needed * - * @param ctx The context for executing CANN backend operations. - * @param dst The destination tensor where the quantized MoE multiplication result - * will be stored. + * Tensor shapes: + * - dst: [M, K, N, 1] - output tensor + * - src0: [D, M, A, 1] - quantized weight matrices (Q4_0 or Q8_0) + * - src1: [D, B, N, 1] - input activations (B = K for per-expert input, or B = 1 for broadcast) + * - ids: [K, N] - expert indices for routing * - * @note This function assumes quantized data types and is designed for - * MoE architectures with potential sparse expert routing. + * @param ctx The CANN backend context for operation execution. + * @param dst The destination tensor where the multiplication result will be stored. + * + * @note Only Q4_0 and Q8_0 quantization formats are supported. + * @note The function handles automatic type conversion to/from F16 as needed by the hardware. */ static void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context & ctx, ggml_tensor * dst) { - // TODO: Use aclnnGroupedMatMul - //dst [M, K, N, 1] - ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1] - ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1 - ggml_tensor * ids = dst->src[2]; //ids [K, N] + // dst: [M, K, N, 1] + // src0: [D, M, A, 1] - quantized weights + // src1: [D, B, N, 1] - input activations, B = K or B = 1 + // ids: [K, N] - expert indices + ggml_tensor * src0 = dst->src[0]; + ggml_tensor * src1 = dst->src[1]; + ggml_tensor * ids = dst->src[2]; - GGML_TENSOR_BINARY_OP_LOCALS + GGML_ASSERT(src0->ne[3] == 1); + GGML_ASSERT(src1->ne[3] == 1); + GGML_ASSERT(dst->ne[3] == 1); + GGML_ASSERT(src1->ne[2] == ids->ne[1]); - // copy index from npu to cpu - int64_t n_as = ne02; // A - int64_t n_ids = ids->ne[0]; // K + const int64_t n_batches = ids->ne[1]; + const int64_t n_select_experts = ids->ne[0]; + const enum ggml_type type = src0->type; - std::vector ids_host(ggml_nbytes(ids)); - ACL_CHECK(aclrtMemcpyAsync(ids_host.data(), ggml_nbytes(ids), ids->data, ggml_nbytes(ids), - ACL_MEMCPY_DEVICE_TO_HOST, ctx.stream())); - ACL_CHECK(aclrtSynchronizeStream(ctx.stream())); + const int32_t group_size = QK8_0; // Both Q4_0 and Q8_0 use group size of 32 + GGML_ASSERT(group_size == QK4_0); - char * src0_original = (char *) src0->data; - char * src1_original = (char *) src1->data; - char * dst_original = (char *) dst->data; + // Calculate element size for quantized weights + const float weight_elem_size = + (type == GGML_TYPE_Q4_0) ? 0.5f : + (type == GGML_TYPE_Q8_0) ? 1.0f : + (GGML_ABORT("MUL_MAT_ID only supports Q4_0 and Q8_0"), 0.0f); - ggml_tensor src0_row = *src0; - ggml_tensor src1_row = *src1; - ggml_tensor dst_row = *dst; + // Calculate scale offset in memory + const size_t weight_size = src0->ne[0] * src0->ne[1] * src0->ne[2] * weight_elem_size; + const size_t scale_elem_size = sizeof(uint16_t); + char * scale_data = (char *) src0->data + weight_size; - const enum ggml_type type = dst->src[0]->type; - float weight_elem_size; - if (type == GGML_TYPE_Q4_0) { - weight_elem_size = float(sizeof(uint8_t)) / 2; - } else if (type == GGML_TYPE_Q8_0) { - weight_elem_size = float(sizeof(uint8_t)); - } else { - GGML_ABORT("MUL_MAT_ID only support quant type Q4_0 and Q8_0 "); - } + // Allocate buffers for selected expert weights and scales + const size_t selected_weight_size = src0->ne[0] * src0->ne[1] * n_select_experts * weight_elem_size; + ggml_cann_pool_alloc selected_weight_alloc(ctx.pool(), selected_weight_size); + void * selected_weight_buffer = selected_weight_alloc.get(); - // src0_row [D, M, 1, 1] weight without permute - src0_row.ne[2] = 1; - src0_row.ne[3] = 1; - src0_row.nb[0] = weight_elem_size; - src0_row.nb[1] = weight_elem_size * ne00; - src0_row.nb[2] = weight_elem_size * ne00; - src0_row.nb[3] = weight_elem_size * ne00; - size_t weight_stride = ne00 * ne01 * weight_elem_size; - size_t weight_size = weight_stride * ne02 * ne03; + const size_t selected_scale_size = (src0->ne[0] / group_size) * src0->ne[1] * n_select_experts * scale_elem_size; + ggml_cann_pool_alloc selected_scale_alloc(ctx.pool(), selected_scale_size); + void * selected_scale_buffer = selected_scale_alloc.get(); - // scale [D, M, 1, 1] -> scale && permute - size_t scale_elem_size = sizeof(uint16_t); - size_t scale_stride = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size; + // Helper lambda to allocate and cast tensor to F16 if needed + constexpr size_t f16_elem_size = sizeof(uint16_t); + auto prepare_f16_buffer = [&](ggml_tensor * tensor, ggml_cann_pool_alloc & allocator, + bool need_cast = false) -> void * { + if (tensor->type == GGML_TYPE_F16) { + return tensor->data; + } - // src1_row [D, 1, 1, 1] -> input - src1_row.ne[1] = 1; - src1_row.ne[2] = 1; - src1_row.ne[3] = 1; - src1_row.nb[2] = nb11; - src1_row.nb[3] = nb11; + size_t total_size = f16_elem_size; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + total_size *= tensor->ne[i]; + } + void * buffer = allocator.alloc(total_size); - // dst_row [M, 1, 1, 1] -> out - dst_row.ne[1] = 1; - dst_row.ne[2] = 1; - dst_row.ne[3] = 1; - dst_row.nb[2] = nb1; - dst_row.nb[3] = nb1; + if (need_cast == false) { + return buffer; + } - //create weight for one row - ggml_cann_pool_alloc weight_allocator(ctx.pool()); - void * weight_buffer = weight_allocator.alloc(nb02); - for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) { - for (int64_t id = 0; id < n_ids; id++) { - // expert index - int32_t i02 = *(int32_t *) (ids_host.data() + iid1 * ids->nb[1] + id * ids->nb[0]); - GGML_ASSERT(i02 >= 0 && i02 < n_as); + int64_t ne[GGML_MAX_DIMS]; + size_t nb[GGML_MAX_DIMS] = { f16_elem_size }; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + ne[i] = tensor->ne[i]; + if (i > 0) { + nb[i] = nb[i - 1] * ne[i - 1]; + } + } - // If B = 1 (broadcast), always use 0; otherwise, use id. - int64_t i11 = (ne11 == 1 ? 0 : id); - int64_t i12 = iid1; + acl_tensor_ptr src_tensor = ggml_cann_create_tensor(tensor); + acl_tensor_ptr f16_tensor = ggml_cann_create_tensor(buffer, ACL_FLOAT16, f16_elem_size, ne, nb, GGML_MAX_DIMS); + aclnn_cast(ctx, src_tensor.get(), f16_tensor.get(), ACL_FLOAT16); - int64_t i1 = id; - int64_t i2 = i12; + return buffer; + }; - void * src0_tmp_ptr = src0_original + i02 * weight_stride; - void * scale_tmp_ptr = src0_original + weight_size + i02 * scale_stride; - void * src1_tmp_ptr = src1_original + i11 * nb11 + i12 * nb12; - void * dst_tmp_ptr = dst_original + i1 * nb1 + i2 * nb2; + // Prepare input and output buffers + ggml_cann_pool_alloc input_alloc(ctx.pool()); + void * input_buffer = prepare_f16_buffer(src1, input_alloc, true); - // mem cpy - ACL_CHECK(aclrtMemcpyAsync(weight_buffer, weight_stride, src0_tmp_ptr, weight_stride, - ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream())); - void * scale_buffer = (char *) weight_buffer + weight_stride; - ACL_CHECK(aclrtMemcpyAsync(scale_buffer, scale_stride, scale_tmp_ptr, scale_stride, - ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream())); + ggml_cann_pool_alloc output_alloc(ctx.pool()); + void * output_buffer = prepare_f16_buffer(dst, output_alloc, false); - src0_row.data = weight_buffer; - src1_row.data = src1_tmp_ptr; - dst_row.data = dst_tmp_ptr; - dst_row.src[0] = &src0_row; - dst_row.src[1] = &src1_row; + // Process each batch + for (int64_t batch_idx = 0; batch_idx < n_batches; batch_idx++) { + // Create index tensor for current batch + const size_t index_offset = batch_idx * ids->nb[1]; + acl_tensor_ptr batch_indices = ggml_cann_create_tensor(ids, ids->ne, ids->nb, 1, ACL_FORMAT_ND, index_offset); - ggml_cann_mul_mat(ctx, &dst_row); + // Select quantized weights using expert indices + // Q4_0 stores 2 values per byte, Q8_0 stores 1 value per byte + const int64_t weight_d = (type == GGML_TYPE_Q4_0) ? src0->ne[0] / 2 : src0->ne[0]; + const int64_t weight_m = src0->ne[1]; + const int64_t weight_n_experts = src0->ne[2]; + + int64_t weight_ne[3] = { weight_d, weight_m, weight_n_experts }; + size_t weight_nb[3] = { sizeof(int8_t), weight_d * sizeof(int8_t), weight_d * weight_m * sizeof(int8_t) }; + + acl_tensor_ptr all_weights = + ggml_cann_create_tensor(src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb, 3); + + int64_t selected_weight_ne[3] = { weight_d, weight_m, n_select_experts }; + size_t selected_weight_nb[3] = { sizeof(int8_t), weight_d * sizeof(int8_t), + weight_d * weight_m * sizeof(int8_t) }; + + acl_tensor_ptr selected_weights = ggml_cann_create_tensor(selected_weight_buffer, ACL_INT8, sizeof(int8_t), + selected_weight_ne, selected_weight_nb, 3); + + GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, all_weights.get(), 0, batch_indices.get(), selected_weights.get()); + + // Select scales using the same expert indices + const int64_t scale_d = src0->ne[0] / group_size; + int64_t scale_ne[3] = { scale_d, weight_m, weight_n_experts }; + size_t scale_nb[3] = { scale_elem_size, scale_d * scale_elem_size, scale_d * weight_m * scale_elem_size }; + + acl_tensor_ptr all_scales = + ggml_cann_create_tensor(scale_data, ACL_FLOAT16, scale_elem_size, scale_ne, scale_nb, 3); + + int64_t selected_scale_ne[3] = { scale_d, weight_m, n_select_experts }; + size_t selected_scale_nb[3] = { scale_elem_size, scale_d * scale_elem_size, + scale_d * weight_m * scale_elem_size }; + + acl_tensor_ptr selected_scales = ggml_cann_create_tensor(selected_scale_buffer, ACL_FLOAT16, scale_elem_size, + selected_scale_ne, selected_scale_nb, 3); + + GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, all_scales.get(), 0, batch_indices.get(), selected_scales.get()); + + // Process each expert for current batch + // IndexSelect output layout: [D, M, K] in contiguous format + // WeightQuantBatchMatmulV2 expects: [M, D] with row-major stride + for (int64_t expert_idx = 0; expert_idx < n_select_experts; expert_idx++) { + // Determine input offset: broadcast if src1->ne[1]==1, otherwise use per-expert input + const size_t input_offset = + (batch_idx * src1->ne[1] + (src1->ne[1] == 1 ? 0 : expert_idx)) * src1->ne[0] * f16_elem_size; + const size_t output_offset = (batch_idx * dst->ne[1] + expert_idx) * dst->ne[0] * f16_elem_size; + + // Create weight view for current expert: [D, M, K] -> [M, D] + int64_t weight_view_ne[2] = { weight_m, src0->ne[0] }; + float weight_view_nb[2] = { src0->ne[0] * weight_elem_size, weight_elem_size }; + const size_t weight_view_offset = expert_idx * selected_weight_nb[2]; + + acl_tensor_ptr weight_view = + ggml_cann_create_tensor(selected_weight_buffer, ggml_cann_type_mapping(type), weight_elem_size, + weight_view_ne, weight_view_nb, 2, ACL_FORMAT_ND, weight_view_offset); + + // Create scale view for current expert: [D, M, K] -> [M, D] + int64_t scale_view_ne[2] = { weight_m, scale_d }; + size_t scale_view_nb[2] = { selected_scale_nb[1], selected_scale_nb[0] }; + const size_t scale_view_offset = expert_idx * selected_scale_nb[2]; + + acl_tensor_ptr scale_view = + ggml_cann_create_tensor(selected_scale_buffer, ACL_FLOAT16, scale_elem_size, scale_view_ne, + scale_view_nb, 2, ACL_FORMAT_ND, scale_view_offset); + + // Create input activation tensor [D, 1] + int64_t input_ne[2] = { src1->ne[0], 1 }; + size_t input_nb[2] = { f16_elem_size, src1->ne[0] * f16_elem_size }; + + acl_tensor_ptr input_tensor = ggml_cann_create_tensor(input_buffer, ACL_FLOAT16, f16_elem_size, input_ne, + input_nb, 2, ACL_FORMAT_ND, input_offset); + + // Create output tensor [M, 1] + int64_t output_ne[2] = { dst->ne[0], 1 }; + size_t output_nb[2] = { f16_elem_size, dst->ne[0] * f16_elem_size }; + + acl_tensor_ptr output_tensor = ggml_cann_create_tensor(output_buffer, ACL_FLOAT16, f16_elem_size, output_ne, + output_nb, 2, ACL_FORMAT_ND, output_offset); + + // Perform quantized matrix multiplication + GGML_CANN_CALL_ACLNN_OP(ctx, WeightQuantBatchMatmulV2, input_tensor.get(), weight_view.get(), + scale_view.get(), nullptr, nullptr, nullptr, nullptr, group_size, + output_tensor.get()); } } - return; + + // Cast output back to original type if we used a temporary F16 buffer + if (dst->type != GGML_TYPE_F16) { + int64_t ne[GGML_MAX_DIMS]; + size_t nb[GGML_MAX_DIMS] = { f16_elem_size }; + for (int i = 0; i < GGML_MAX_DIMS; i++) { + ne[i] = dst->ne[i]; + if (i > 0) { + nb[i] = nb[i - 1] * ne[i - 1]; + } + } + + acl_tensor_ptr f16_output = + ggml_cann_create_tensor(output_buffer, ACL_FLOAT16, f16_elem_size, ne, nb, GGML_MAX_DIMS); + acl_tensor_ptr dst_tensor = ggml_cann_create_tensor(dst); + + aclnn_cast(ctx, f16_output.get(), dst_tensor.get(), ggml_cann_type_mapping(dst->type)); + } } void ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst) { diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index 6b2dbdd359..3f3de9f0bc 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -794,19 +794,44 @@ struct ggml_backend_cann_buffer_context { ~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr)); } }; +// cann buffer type /** - * @brief Check if a buffer is a CANN buffer. - * - * This function checks if a given buffer is a CANN buffer by comparing its - * `get_name` function pointer to `ggml_backend_cann_buffer_get_name`. - * - * @param buffer The buffer to check. - * @return true if the buffer is a CANN buffer, false otherwise. + * @brief Structure representing context information for a specific backend + * buffer type. */ -static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft); +struct ggml_backend_cann_buffer_type_context { + int32_t device; /**< Device identifier associated with the buffer context. */ + std::string name; /**< Name associated with the buffer context. */ +}; -static bool ggml_backend_buffer_is_cann(ggml_backend_buffer_t buffer) { - return ggml_backend_buft_is_cann(buffer->buft); +/** + * @brief Retrieves the name associated with a CANN buffer type. + * + * This function returns the descriptive name associated with the specified + * CANN buffer type context. + * + * @param buft Pointer to the buffer type context. + * @return Const pointer to the C-style string containing the name. + */ +static const char * ggml_backend_cann_buffer_type_name(ggml_backend_buffer_type_t buft) { + ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context; + + return buft_ctx->name.c_str(); +} + +/** + * @brief Checks if the backend buffer type is associated with the CANN backend. + * + * This function checks whether the provided backend buffer type is associated + * with the CANN backend based on the comparison of its name retrieval function + * pointer. + * + * @param buft Pointer to the backend buffer type to check. + * @return bool Returns true if the buffer type is associated with the CANN + * backend, otherwise false. + */ +static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) { + return buft->iface.get_name == ggml_backend_cann_buffer_type_name; } /** @@ -1271,7 +1296,7 @@ static void ggml_backend_cann_buffer_get_tensor(ggml_backend_buffer_t buffer, static bool ggml_backend_cann_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { - if (ggml_backend_buffer_is_cann(src->buffer)) { + if (ggml_backend_buft_is_cann(src->buffer->buft)) { ggml_backend_cann_buffer_context * src_ctx = (ggml_backend_cann_buffer_context *) src->buffer->context; ggml_backend_cann_buffer_context * dst_ctx = (ggml_backend_cann_buffer_context *) buffer->context; @@ -1335,31 +1360,6 @@ static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = { /* .reset = */ NULL, }; -// cann buffer type -/** - * @brief Structure representing context information for a specific backend - * buffer type. - */ -struct ggml_backend_cann_buffer_type_context { - int32_t device; /**< Device identifier associated with the buffer context. */ - std::string name; /**< Name associated with the buffer context. */ -}; - -/** - * @brief Retrieves the name associated with a CANN buffer type. - * - * This function returns the descriptive name associated with the specified - * CANN buffer type context. - * - * @param buft Pointer to the buffer type context. - * @return Const pointer to the C-style string containing the name. - */ -static const char * ggml_backend_cann_buffer_type_name(ggml_backend_buffer_type_t buft) { - ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context; - - return buft_ctx->name.c_str(); -} - /** * @brief Allocates a new CANN buffer of the specified type and size. * @@ -1997,7 +1997,7 @@ static bool ggml_backend_cann_cpy_tensor_async(ggml_backend_t backend_src, GGML_ASSERT(!is_matmul_weight((const ggml_tensor *) src)); - if (!ggml_backend_buffer_is_cann(src->buffer) || !ggml_backend_buffer_is_cann(dst->buffer)) { + if (!ggml_backend_buft_is_cann(src->buffer->buft) || !ggml_backend_buft_is_cann(dst->buffer->buft)) { return false; } @@ -2523,21 +2523,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten GGML_UNUSED(dev); } -/** - * @brief Checks if the backend buffer type is associated with the CANN backend. - * - * This function checks whether the provided backend buffer type is associated - * with the CANN backend based on the comparison of its name retrieval function - * pointer. - * - * @param buft Pointer to the backend buffer type to check. - * @return bool Returns true if the buffer type is associated with the CANN - * backend, otherwise false. - */ -static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) { - return buft->iface.get_name == ggml_backend_cann_buffer_type_name; -} - /** * @brief Records an event on the CANN backend stream. * diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 7622d0bf49..3dc948e4d8 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -9,6 +9,11 @@ function(ggml_add_cpu_backend_features cpu_name arch) target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE ${ARGN}) target_compile_definitions(${GGML_CPU_FEATS_NAME} PRIVATE GGML_BACKEND_DL GGML_BACKEND_BUILD GGML_BACKEND_SHARED) set_target_properties(${GGML_CPU_FEATS_NAME} PROPERTIES POSITION_INDEPENDENT_CODE ON) + # Disable LTO for the feature detection code to prevent cross-module optimization + # from inlining architecture-specific instructions into the score function. + # Without this, LTO can cause SIGILL when loading backends on older CPUs + # (e.g., loading power10 backend on power9 crashes before feature check runs). + target_compile_options(${GGML_CPU_FEATS_NAME} PRIVATE -fno-lto) target_link_libraries(${cpu_name} PRIVATE ${GGML_CPU_FEATS_NAME}) endfunction() @@ -569,27 +574,24 @@ function(ggml_add_cpu_backend_variant_impl tag_name) cmake_policy(SET CMP0135 NEW) endif() + # TODO: Use FetchContent_MakeAvailable with EXCLUDE_FROM_ALL after bumping minimum CMake version to 3.28+ + # Using FetchContent_Populate instead to avoid EXCLUDE_FROM_ALL which requires CMake 3.28 FetchContent_Declare(KleidiAI_Download URL ${KLEIDIAI_DOWNLOAD_URL} DOWNLOAD_EXTRACT_TIMESTAMP NEW URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5}) - FetchContent_MakeAvailable(KleidiAI_Download) FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC POPULATED KLEIDIAI_POPULATED) if (NOT KLEIDIAI_POPULATED) - message(FATAL_ERROR "KleidiAI source downloaded failed.") + FetchContent_Populate(KleidiAI_Download) + FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC) endif() add_compile_definitions(GGML_USE_CPU_KLEIDIAI) - # Remove kleidiai target after fetching it - if (TARGET kleidiai) - set_target_properties(kleidiai PROPERTIES EXCLUDE_FROM_ALL TRUE) - endif() - list(APPEND GGML_CPU_SOURCES ggml-cpu/kleidiai/kleidiai.cpp ggml-cpu/kleidiai/kernels.cpp diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 427c1146e4..c6eb75b230 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -43,6 +43,7 @@ #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 @@ -55,7 +56,8 @@ #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K -# define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K +#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K +#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 @@ -76,6 +78,7 @@ #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 @@ -84,6 +87,7 @@ #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 @@ -107,6 +111,7 @@ #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 @@ -119,6 +124,7 @@ #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 @@ -143,6 +149,7 @@ #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 @@ -155,6 +162,7 @@ #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 @@ -186,6 +194,7 @@ #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 @@ -197,6 +206,7 @@ #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 @@ -227,6 +237,7 @@ #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 @@ -239,6 +250,7 @@ #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 @@ -271,6 +283,7 @@ #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K +#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 @@ -283,6 +296,7 @@ #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K +#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index 99bb70274c..3a3b32efb2 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -1072,6 +1072,195 @@ void ggml_gemv_q5_K_8x8_q8_K(int n, ggml_gemv_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemv_q6_K_8x4_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 4; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + constexpr int col_groups = ncols_interleaved / 4; + const uint8x16_t m4b = vdupq_n_u8(0x0f); + const uint8x16_t mask_lo = vdupq_n_u8(0x03); + const uint8x16_t mask_hi = vdupq_n_u8(0x30); + + // 1x8 tile = 2 x 4 + float32x4_t acc_f32[2]; + + const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + for (int i = 0; i < col_groups; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + float32x4_t q6_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d)); // d0 d1 d2 d3 + float32x4_t q6_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4)); // d4 d5 d6 d7 + float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d); + float32x4_t sb_scale_0 = vmulq_f32(q6_d_0, q8_d); + float32x4_t sb_scale_1 = vmulq_f32(q6_d_1, q8_d); + + int32x4_t acc[col_groups]; + for (int i = 0; i < col_groups; i++) { + acc[i] = vdupq_n_s32(0); + } + + // Load all 16 scales once and widen to int16 (Q6_K has 16 scales per block) + // Reused for bias and dequantization later + int16_t q6_scales[16 * 8]; + for (int i = 0; i < 16; i++) { + int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8)); + vst1q_s16(q6_scales + i * 8, scales); + } + + // Compute bias per column using q8 bsums and preloaded scales to skip the -32 shift + int32x4_t bias_lo = vdupq_n_s32(0); + int32x4_t bias_hi = vdupq_n_s32(0); + + // Load bsums in chunks of 4 to process with vectorized operations + for (int i = 0; i < 16; i += 4) { + int16x4_t bsums_vec = vld1_s16(q8_ptr[b].bsums + i); + int16x4_t scales_lo_0 = vld1_s16(q6_scales + (i + 0) * 8); + int16x4_t scales_hi_0 = vld1_s16(q6_scales + (i + 0) * 8 + 4); + int16x4_t scales_lo_1 = vld1_s16(q6_scales + (i + 1) * 8); + int16x4_t scales_hi_1 = vld1_s16(q6_scales + (i + 1) * 8 + 4); + int16x4_t scales_lo_2 = vld1_s16(q6_scales + (i + 2) * 8); + int16x4_t scales_hi_2 = vld1_s16(q6_scales + (i + 2) * 8 + 4); + int16x4_t scales_lo_3 = vld1_s16(q6_scales + (i + 3) * 8); + int16x4_t scales_hi_3 = vld1_s16(q6_scales + (i + 3) * 8 + 4); + + bias_lo = vmlal_lane_s16(bias_lo, scales_lo_0, bsums_vec, 0); + bias_hi = vmlal_lane_s16(bias_hi, scales_hi_0, bsums_vec, 0); + bias_lo = vmlal_lane_s16(bias_lo, scales_lo_1, bsums_vec, 1); + bias_hi = vmlal_lane_s16(bias_hi, scales_hi_1, bsums_vec, 1); + bias_lo = vmlal_lane_s16(bias_lo, scales_lo_2, bsums_vec, 2); + bias_hi = vmlal_lane_s16(bias_hi, scales_hi_2, bsums_vec, 2); + bias_lo = vmlal_lane_s16(bias_lo, scales_lo_3, bsums_vec, 3); + bias_hi = vmlal_lane_s16(bias_hi, scales_hi_3, bsums_vec, 3); + } + bias_lo = vshlq_n_s32(bias_lo, 5); + bias_hi = vshlq_n_s32(bias_hi, 5); + + // Process two 128-value halves per superblock + for (int half = 0; half < 2; half++) { + const uint8_t * ql_base = q6_ptr[b].ql + half * 512; + const uint8_t * qh_base = q6_ptr[b].qh + half * 256; + + // A subblock (sb) is a set of weights that share the scale + // Since q6_K scales are per 16 elements + // num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves) + for (int sb = 0; sb < QK_K / 64; sb++) { + const int8_t * q8_base_l = q8_ptr[b].qs + half * 128 + sb * 16; + const int8_t * q8_base_h = q8_base_l + 64; + + // Load and duplicate q8 values (each register covers four interleaved columns of q6) + int8x16_t q8_l[4]; + int8x16_t q8_h[4]; + for (int i = 0; i < 4; i++) { + q8_l[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_l + i * 4)); + q8_h[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_h + i * 4)); + } + + const int ql_off_base = sb * QK_K / 2; + const int qh_off_base = ql_off_base & 255; // wraps after 256 bytes + + // Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1) + uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base); + uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64); + uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base); + uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64); + + // Adjust qh for subblocks 2 and 3 (shift right by 2) + if (sb > 1) { + q6_qh_0.val[0] = vshrq_n_u8(q6_qh_0.val[0], 2); + q6_qh_0.val[1] = vshrq_n_u8(q6_qh_0.val[1], 2); + q6_qh_0.val[2] = vshrq_n_u8(q6_qh_0.val[2], 2); + q6_qh_0.val[3] = vshrq_n_u8(q6_qh_0.val[3], 2); + q6_qh_1.val[0] = vshrq_n_u8(q6_qh_1.val[0], 2); + q6_qh_1.val[1] = vshrq_n_u8(q6_qh_1.val[1], 2); + q6_qh_1.val[2] = vshrq_n_u8(q6_qh_1.val[2], 2); + q6_qh_1.val[3] = vshrq_n_u8(q6_qh_1.val[3], 2); + } + + const uint8x16_t q6_ql[8] = { q6_ql_0.val[0], q6_ql_0.val[1], q6_ql_0.val[2], q6_ql_0.val[3], + q6_ql_1.val[0], q6_ql_1.val[1], q6_ql_1.val[2], q6_ql_1.val[3] }; + const uint8x16_t q6_qh[8] = { q6_qh_0.val[0], q6_qh_0.val[1], q6_qh_0.val[2], q6_qh_0.val[3], + q6_qh_1.val[0], q6_qh_1.val[1], q6_qh_1.val[2], q6_qh_1.val[3] }; + + // Process column groups (0-3, 4-7) + for (int g = 0; g < col_groups; g++) { + int32x4_t sb_acc_l = vdupq_n_s32(0); + int32x4_t sb_acc_h = vdupq_n_s32(0); + + for (int chunk = 0; chunk < 4; chunk++) { + const int idx = chunk * 2 + g; + + const uint8x16_t q6_qs_l = q6_ql[idx]; + const uint8x16_t q6_qs_h = q6_qh[idx]; + + // Extract high 2 bits for upper nibble reconstruction + const uint8x16_t q6_qs_hh = vandq_u8(q6_qs_h, mask_hi); + + // q6 = (low4 | high2<<4), without -32 bias (handled via bsums) + const int8x16_t q6_l = + vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_l, m4b), vandq_u8(q6_qs_h, mask_lo), 4)); + const int8x16_t q6_h = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_l, 4), q6_qs_hh)); + + sb_acc_l = vdotq_s32(sb_acc_l, q6_l, q8_l[chunk]); + sb_acc_h = vdotq_s32(sb_acc_h, q6_h, q8_h[chunk]); + } + + const int scale_idx_l = half * 8 + sb; + const int scale_idx_h = half * 8 + sb + 4; + + const int32x4_t scale_vec_l = vmovl_s16(vld1_s16(q6_scales + scale_idx_l * 8 + g * 4)); + const int32x4_t scale_vec_h = vmovl_s16(vld1_s16(q6_scales + scale_idx_h * 8 + g * 4)); + + acc[g] = vmlaq_s32(acc[g], sb_acc_l, scale_vec_l); + acc[g] = vmlaq_s32(acc[g], sb_acc_h, scale_vec_h); + } + } + } // for half + + // Bias correction + acc[0] = vsubq_s32(acc[0], bias_lo); + acc[1] = vsubq_s32(acc[1], bias_hi); + + // Apply superblock scale (no mins for q6_K) + // acc[g] has [c0, c1, c2, c3] + float32x4_t w_0123 = vmulq_f32(vcvtq_f32_s32(acc[0]), sb_scale_0); + float32x4_t w_4567 = vmulq_f32(vcvtq_f32_s32(acc[1]), sb_scale_1); + + acc_f32[0] = vaddq_f32(acc_f32[0], w_0123); + acc_f32[1] = vaddq_f32(acc_f32[1], w_4567); + } // for b + + int base = x * ncols_interleaved; + vst1q_f32(s + base, acc_f32[0]); + vst1q_f32(s + base + 4, acc_f32[1]); + } // for x + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemv_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, @@ -1177,15 +1366,14 @@ void ggml_gemv_q6_K_8x8_q8_K(int n, q8_h[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_h + i * 8)); } - // TODO: Test other qh repack patterns to reduce loads const int ql_off_base = sb * QK_K / 2; const int qh_off_base = ql_off_base & 255; // wraps after 256 bytes // Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1) - ggml_uint8x16x4_t q6_ql_0 = ggml_vld1q_u8_x4(ql_base + ql_off_base); - ggml_uint8x16x4_t q6_ql_1 = ggml_vld1q_u8_x4(ql_base + ql_off_base + 64); - ggml_uint8x16x4_t q6_qh_0 = ggml_vld1q_u8_x4(qh_base + qh_off_base); - ggml_uint8x16x4_t q6_qh_1 = ggml_vld1q_u8_x4(qh_base + qh_off_base + 64); + uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base); + uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64); + uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base); + uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64); // Adjust qh for subblocks 2 and 3 (shift right by 2) if (sb > 1) { @@ -3038,6 +3226,316 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, UNUSED(ncols_interleaved); UNUSED(blocklen); +#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) + if (svcntb() * 8 == 256) { + constexpr int q8_k_blocklen = 4; + const svuint8_t m4b_1 = svdup_n_u8(0x0f); + // 8 accumulators: 2 row pairs × 4 col pairs + svfloat32_t acc_f32_01, acc_f32_23, acc_f32_45, acc_f32_67; + uint32_t idx_arr[8] = { 0, 2, 4, 6, 1, 3, 5, 7 }; + svbool_t pg = svptrue_pat_b32(SV_VL8); + svuint32_t idx = svld1(pg, idx_arr); + + static const uint32_t idx_data[8] = {0, 4, 2, 6, 1, 5, 3, 7}; + svuint32_t idx1 = svld1_u32(svptrue_b32(), idx_data); + + for (int y = 0; y < nr / q8_k_blocklen; y++) { + const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb); + + acc_f32_01 = svdup_n_f32(0); + acc_f32_23 = svdup_n_f32(0); + acc_f32_45 = svdup_n_f32(0); + acc_f32_67 = svdup_n_f32(0); + + for (int b = 0; b < nb; b++) { + // bsums pairs belongs to the same q8_k subblock + // 64 elemnts loaded and made sum of 0-7 and 8-15 sum || 16-23 and 24 - 31 sum + const int16x8_t bsums[4]{ + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)), + }; + + int32_t bsums_arr32[4][8]; + + for (int q8_row = 0; q8_row < 4; q8_row++) { + int16x8_t v16 = bsums[q8_row]; + + // low 4 + int32x4_t v32_lo = vmovl_s16(vget_low_s16(v16)); + vst1q_s32(&bsums_arr32[q8_row][0], v32_lo); + + // high 4 + int32x4_t v32_hi = vmovl_s16(vget_high_s16(v16)); + vst1q_s32(&bsums_arr32[q8_row][4], v32_hi); + } + + svint32_t sb_acc_0 = svdup_n_s32(0); + svint32_t sb_acc_2 = svdup_n_s32(0); + + svint32_t acc_00 = svdup_n_s32(0); + svint32_t acc_11 = svdup_n_s32(0); + svint32_t acc_22 = svdup_n_s32(0); + svint32_t acc_33 = svdup_n_s32(0); + svint32_t acc_44 = svdup_n_s32(0); + svint32_t acc_55 = svdup_n_s32(0); + svint32_t acc_66 = svdup_n_s32(0); + svint32_t acc_77 = svdup_n_s32(0); + + svint32_t bias_acc_00 = svdup_n_s32(0); + svint32_t bias_acc_22 = svdup_n_s32(0); + svint32_t bias_acc_44 = svdup_n_s32(0); + svint32_t bias_acc_66 = svdup_n_s32(0); + + for (int sb = 0; sb < QK_K / 64; sb++) { + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + svint32_t block_scale_0, block_scale_1, block_scale_2, block_scale_3; + svint32_t q4sb_mins_0, q4sb_mins_1; + { + // 2-superblock I am working on + const int offset = sb * 24 + 0 * 12; + const uint8_t * scales_in = &q4_ptr[b].scales[offset]; + + const int offset1 = sb * 24 + 12; + const uint8_t * scales_in1 = &q4_ptr[b].scales[offset1]; + + constexpr uint32_t kmask1 = 0x3f3f3f3f; + constexpr uint32_t kmask2 = 0x0f0f0f0f; + constexpr uint32_t kmask3 = 0x03030303; + constexpr uint8_t scales_size = 12; + + uint32_t sm[3]; + memcpy(sm, scales_in, scales_size); + + uint32_t sm1[3]; + memcpy(sm1, scales_in1, scales_size); + + const uint32_t mins_0_3 = sm[1] & kmask1; + const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4); + + const uint32_t mins_0_3_1 = sm1[1] & kmask1; + const uint32_t mins_4_7_1 = ((sm1[2] >> 4) & kmask2) | (((sm1[1] >> 6) & kmask3) << 4); + + svuint32_t mins_u32_temp = svzip1_u32(svdup_n_u32(mins_0_3), svdup_n_u32(mins_4_7)); + svuint32_t mins_u32_temp_1 = svzip1_u32(svdup_n_u32(mins_0_3_1), svdup_n_u32(mins_4_7_1)); + + /* reinterpret u32 → u8 */ + svuint8_t mins_u8 = svreinterpret_u8_u32(mins_u32_temp); + svuint8_t mins_u8_1 = svreinterpret_u8_u32(mins_u32_temp_1); + + /* widen u8 → u16->u32 (lower half only) */ + svuint32_t mins_u16 = svunpklo_u32(svunpklo_u16(mins_u8)); + svuint32_t mins_u16_1 = svunpklo_u32(svunpklo_u16(mins_u8_1)); + + q4sb_mins_0 = svreinterpret_s32_u32(mins_u16); + q4sb_mins_1 = svreinterpret_s32_u32(mins_u16_1); + + uint32_t scales_u32_0 = sm[0] & kmask1; + uint32_t scales_u32_1 = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4); + uint32_t scales_u32_2 = sm1[0] & kmask1; + uint32_t scales_u32_3 = (sm1[2] & kmask2) | (((sm1[0] >> 6) & kmask3) << 4); + + svuint32_t S01 = svdup_n_u32(scales_u32_0); + svuint32_t S23 = svdup_n_u32(scales_u32_1); + svuint32_t R01 = svdup_n_u32(scales_u32_2); + svuint32_t R23 = svdup_n_u32(scales_u32_3); + + svint8_t S01_b = svreinterpret_s8_u32(S01); + svint8_t S23_b = svreinterpret_s8_u32(S23); + svint8_t R01_b = svreinterpret_s8_u32(R01); + svint8_t R23_b = svreinterpret_s8_u32(R23); + + svint32_t S01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S01_b, S01_b))); + svint32_t R01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R01_b, R01_b))); + svint32_t S23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S23_b, S23_b))); + svint32_t R23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R23_b, R23_b))); + + block_scale_0 = svtbl_s32(svzip1_s32(S01_d, R01_d), idx); + block_scale_1 = svtbl_s32(svzip2_s32(S01_d, R01_d), idx); + block_scale_2 = svtbl_s32(svzip1_s32(S23_d, R23_d), idx); + block_scale_3 = svtbl_s32(svzip2_s32(S23_d, R23_d), idx); + } + + const int8_t * q8_base_1 = q8_ptr[b].qs + sb * 256; + + // Load 32-byte per row pair, 1 subblock each time + // predicate for activating higher lanes for 16 int8 elements + const svbool_t ph16 = svptrue_pat_b8(SV_VL16); + // predicate for activating lower lanes for 16 int8 elements + const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16); + + svint8_t q8_qs_0 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 0), svld1_s8(pl16, q8_base_1 + 112)); + svint8_t q8_qs_2 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 32), svld1_s8(pl16, q8_base_1 + 144)); + svint8_t q8_qs_4 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 64), svld1_s8(pl16, q8_base_1 + 176)); + svint8_t q8_qs_6 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 96), svld1_s8(pl16, q8_base_1 + 208)); + + svint8_t q8_qs_1 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 16), svld1_s8(pl16, q8_base_1 + 128)); + svint8_t q8_qs_3 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 48), svld1_s8(pl16, q8_base_1 + 160)); + svint8_t q8_qs_5 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 80), svld1_s8(pl16, q8_base_1 + 192)); + svint8_t q8_qs_7 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 112), svld1_s8(pl16, q8_base_1 + 224)); + + // Q4s columns iterated in pairs (01, 23, 45, 67) + for (int cp = 0; cp < ncols_interleaved / 2; cp++) { + + sb_acc_0 = svdup_n_s32(0); + sb_acc_2 = svdup_n_s32(0); + + svuint8_t q4_qs_cp_00 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 0); + svuint8_t q4_qs_cp_01 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 64); + svuint8_t q4_qs_cp_02 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 128); + svuint8_t q4_qs_cp_03 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 192); + + svint8_t q4_nibbles_00 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_00, m4b_1), 4)); + svint8_t q4_nibbles_01 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_01, m4b_1), 4)); + svint8_t q4_nibbles_02 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_02, m4b_1), 4)); + svint8_t q4_nibbles_03 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_03, m4b_1), 4)); + + sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_00, q8_qs_0); + sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_01, q8_qs_2); + + sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_02, q8_qs_4); + sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_03, q8_qs_6); + + sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_00, q8_qs_1); + sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_01, q8_qs_3); + + sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_02, q8_qs_5); + sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_03, q8_qs_7); + + if(cp == 0) { + acc_00 = svmla_s32_m(svptrue_b32(), acc_00, sb_acc_0, block_scale_0); + acc_44 = svmla_s32_m(svptrue_b32(), acc_44, sb_acc_2, block_scale_0); + } + if(cp == 1) { + acc_11 = svmla_s32_m(svptrue_b32(), acc_11, sb_acc_0, block_scale_1); + acc_55 = svmla_s32_m(svptrue_b32(), acc_55, sb_acc_2, block_scale_1); + } + if(cp == 2) { + acc_22 = svmla_s32_m(svptrue_b32(), acc_22, sb_acc_0, block_scale_2); + acc_66 = svmla_s32_m(svptrue_b32(), acc_66, sb_acc_2, block_scale_2); + } + if(cp == 3) { + acc_33 = svmla_s32_m(svptrue_b32(), acc_33, sb_acc_0, block_scale_3); + acc_77 = svmla_s32_m(svptrue_b32(), acc_77, sb_acc_2, block_scale_3); + } + } + + bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][0]), q4sb_mins_0); + bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][1]), q4sb_mins_1); + + bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][2]), q4sb_mins_0); + bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][3]), q4sb_mins_1); + + bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][4]), q4sb_mins_0); + bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][5]), q4sb_mins_1); + + bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][6]), q4sb_mins_0); + bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][7]), q4sb_mins_1); + } // for sb + + + acc_00 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_00, svext_s32(acc_00, acc_00, 4)); + acc_11 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_11, svext_s32(acc_11, acc_11, 4)); + acc_22 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_22, svext_s32(acc_22, acc_22, 4)); + acc_33 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_33, svext_s32(acc_33, acc_33, 4)); + acc_44 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_44, svext_s32(acc_44, acc_44, 4)); + acc_55 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_55, svext_s32(acc_55, acc_55, 4)); + acc_66 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_66, svext_s32(acc_66, acc_66, 4)); + acc_77 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_77, svext_s32(acc_77, acc_77, 4)); + + svint32_t reorder_acc_01 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_00, acc_11), svtrn1_s32(acc_22, acc_33)), idx1); + svint32_t reorder_acc_23 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_00, acc_11), svtrn2_s32(acc_22, acc_33)), idx1); + + svint32_t reorder_acc_45 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_44, acc_55), svtrn1_s32(acc_66, acc_77)), idx1); + svint32_t reorder_acc_67 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_44, acc_55), svtrn2_s32(acc_66, acc_77)), idx1); + + // Broadcast q8 scalar + svfloat32_t q8_d = svdup_f32(q8_ptr[b].d[0]); + + svfloat32_t q4_dmin_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].dmin), svdup_f16(0))); + + svfloat32_t q4_d_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].d), svdup_f16(0))); + + svfloat32_t scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d); + svfloat32_t dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d); + + acc_f32_01 = svmls_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_00), dmins1); + acc_f32_01 = svmla_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_01), scale1); + + q8_d = svdup_f32(q8_ptr[b].d[1]); + + scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d); + dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d); + + acc_f32_23 = svmls_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_22), dmins1); + acc_f32_23 = svmla_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_23), scale1); + + q8_d = svdup_f32(q8_ptr[b].d[2]); + + + scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d); + dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d); + + acc_f32_45 = svmls_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_44), dmins1); + acc_f32_45 = svmla_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_45), scale1); + + q8_d = svdup_f32(q8_ptr[b].d[3]); + + scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d); + dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d); + + acc_f32_67 = svmls_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_66), dmins1); + acc_f32_67 = svmla_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_67), scale1); + + } // for b + + // With the previous reorder, the tile is already in the correct memory layout. + // Predicate for exactly 4 lanes + svbool_t pg4 = svptrue_pat_b32(SV_VL4); + for (int i = 0; i < q8_k_blocklen; i++) { + int row = y * q8_k_blocklen + i; + for (int j = 0; j < 2; j++) { + int col = x * ncols_interleaved + j * 4; + int offset = row * bs + col; + + if (i == 0 && j == 0) { + // acc_f32_0 → lower half of acc_f32_01 + svst1_f32(pg4, s + offset, acc_f32_01); + } else if (i == 0 && j == 1) { + // acc_f32_1 → upper half of acc_f32_01 + svst1_f32(pg4, s + offset, svext_f32(acc_f32_01, acc_f32_01, 4)); + } else if (i == 1 && j == 0) { + // acc_f32_2 + svst1_f32(pg4, s + offset, acc_f32_23); + } else if (i == 1 && j == 1) { + // acc_f32_3 + svst1_f32(pg4, s + offset, svext_f32(acc_f32_23, acc_f32_23, 4)); + } else if (i == 2 && j == 0) { + // acc_f32_4 + svst1_f32(pg4, s + offset, acc_f32_45); + } else if (i == 2 && j == 1) { + // acc_f32_5 + svst1_f32(pg4, s + offset, svext_f32(acc_f32_45, acc_f32_45, 4)); + } else if (i == 3 && j == 0) { + // acc_f32_6 + svst1_f32(pg4, s + offset, acc_f32_67); + } else if (i == 3 && j == 1) { + // acc_f32_7 + svst1_f32(pg4, s + offset, svext_f32(acc_f32_67, acc_f32_67, 4)); + } + } + } + } // for x + } // for y + return; + } +#endif // SVE compile-time end + #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) constexpr int q8_k_blocklen = 4; const uint8x16_t m4b = vdupq_n_u8(0x0f); @@ -3474,6 +3972,208 @@ void ggml_gemm_q5_K_8x8_q8_K(int n, ggml_gemm_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemm_q6_K_8x4_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 4; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + constexpr int q8_k_blocklen = 4; + constexpr int col_groups = ncols_interleaved / 4; + constexpr int acc_size = q8_k_blocklen * col_groups; // 4 rows, 2 column groups + const uint8x16_t m4b = vdupq_n_u8(0x0f); + const uint8x16_t mask_lo = vdupq_n_u8(0x03); + const uint8x16_t mask_hi = vdupq_n_u8(0x30); + const int8x16_t m32s = vdupq_n_s8(32); + + float32x4_t acc_f32[acc_size]; + + for (int y = 0; y < nr / q8_k_blocklen; y++) { + const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + for (int i = 0; i < acc_size; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + float32x4_t q6_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d)); + float32x4_t q6_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4)); + float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d); + + float32x4_t sbd_scale_0123[q8_k_blocklen]; + float32x4_t sbd_scale_4567[q8_k_blocklen]; + + sbd_scale_0123[0] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 0); + sbd_scale_4567[0] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 0); + sbd_scale_0123[1] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 1); + sbd_scale_4567[1] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 1); + sbd_scale_0123[2] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 2); + sbd_scale_4567[2] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 2); + sbd_scale_0123[3] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 3); + sbd_scale_4567[3] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 3); + + int32x4_t acc_s32[acc_size]; + for (int i = 0; i < acc_size; i++) { + acc_s32[i] = vdupq_n_s32(0); + } + + int16_t q6_scales[8 * 16]; + for (int i = 0; i < 16; i++) { + int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8)); + vst1q_s16(q6_scales + i * 8, scales); + } + + for (int half = 0; half < 2; half++) { + const uint8_t * ql_base = q6_ptr[b].ql + half * 512; + const uint8_t * qh_base = q6_ptr[b].qh + half * 256; + + for (int sb = 0; sb < QK_K / 64; sb++) { + int32x4_t acc_lo[acc_size]; + int32x4_t acc_hi[acc_size]; + for (int i = 0; i < acc_size; i++) { + acc_lo[i] = vdupq_n_s32(0); + acc_hi[i] = vdupq_n_s32(0); + } + + const int8_t * q8_base_l = q8_ptr[b].qs + half * 512 + sb * 64; + const int8_t * q8_base_h = q8_ptr[b].qs + half * 512 + 256 + sb * 64; + + // 4 rows * 16 elements per scale + // 4 reads of 16 bytes each + constexpr int reads_per_sb = 4; + int8x16_t q8_l[reads_per_sb]; + int8x16_t q8_h[reads_per_sb]; + for (int k = 0; k < reads_per_sb; k++) { + q8_l[k] = vld1q_s8(q8_base_l + 16 * k); + q8_h[k] = vld1q_s8(q8_base_h + 16 * k); + } + + const int ql_off_base = sb * QK_K / 2; + const int qh_off_base = ql_off_base & 255; + + uint8x16_t q6_ql_0123[reads_per_sb]; + uint8x16_t q6_ql_4567[reads_per_sb]; + uint8x16_t q6_qh_0123[reads_per_sb]; + uint8x16_t q6_qh_4567[reads_per_sb]; + + for (int k = 0; k < reads_per_sb; k++) { + q6_ql_0123[k] = vld1q_u8(ql_base + ql_off_base + k * 32); + q6_ql_4567[k] = vld1q_u8(ql_base + ql_off_base + k * 32 + 16); + q6_qh_0123[k] = vld1q_u8(qh_base + qh_off_base + k * 32); + q6_qh_4567[k] = vld1q_u8(qh_base + qh_off_base + k * 32 + 16); + } + + if (sb > 1) { + for (int k = 0; k < reads_per_sb; k++) { + q6_qh_0123[k] = vshrq_n_u8(q6_qh_0123[k], 2); + q6_qh_4567[k] = vshrq_n_u8(q6_qh_4567[k], 2); + } + } + + for (int k = 0; k < reads_per_sb; k++) { + // q = (ql | qh) - 32 + const uint8x16_t hbit_lo_0123 = vandq_u8(q6_qh_0123[k], mask_lo); + const uint8x16_t hbit_hi_0123 = vandq_u8(q6_qh_0123[k], mask_hi); + const uint8x16_t hbit_lo_4567 = vandq_u8(q6_qh_4567[k], mask_lo); + const uint8x16_t hbit_hi_4567 = vandq_u8(q6_qh_4567[k], mask_hi); + + const int8x16_t q6_0123_lo = vsubq_s8( + vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_0123[k], m4b), hbit_lo_0123, 4)), m32s); + const int8x16_t q6_0123_hi = vsubq_s8( + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_0123[k], 4), hbit_hi_0123)), m32s); + + acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q6_0123_lo, q8_l[k], 0); // 0..3 r0 c0123 + acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q6_0123_lo, q8_l[k], 1); // 0..3 r1 c0123 + acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q6_0123_lo, q8_l[k], 2); // 0..3 r2 c0123 + acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q6_0123_lo, q8_l[k], 3); // 0..3 r3 c0123 + + acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q6_0123_hi, q8_h[k], 0); // 64..67 r0 c0123 + acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q6_0123_hi, q8_h[k], 1); // 64..67 r1 c0123 + acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q6_0123_hi, q8_h[k], 2); // 64..67 r2 c0123 + acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q6_0123_hi, q8_h[k], 3); // 64..67 r3 c0123 + + const int8x16_t q6_4567_lo = vsubq_s8( + vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_4567[k], m4b), hbit_lo_4567, 4)), m32s); + const int8x16_t q6_4567_hi = vsubq_s8( + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_4567[k], 4), hbit_hi_4567)), m32s); + + acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q6_4567_lo, q8_l[k], 0); // 0..3 r0 c4567 + acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q6_4567_lo, q8_l[k], 1); // 0..3 r1 c4567 + acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q6_4567_lo, q8_l[k], 2); // 0..3 r2 c4567 + acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q6_4567_lo, q8_l[k], 3); // 0..3 r3 c4567 + + acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q6_4567_hi, q8_h[k], 0); // 64..67 r0 c4567 + acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q6_4567_hi, q8_h[k], 1); // 64..67 r1 c4567 + acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q6_4567_hi, q8_h[k], 2); // 64..67 r2 c4567 + acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q6_4567_hi, q8_h[k], 3); // 64..67 r3 c4567 + } + + // Scale and bias + const int scale_idx_l = half * 8 + sb; + const int scale_idx_h = half * 8 + sb + 4; + + for (int g = 0; g < col_groups; g++) { + const int16x4_t scales_l16 = vld1_s16(q6_scales + scale_idx_l * 8 + g * 4); + const int16x4_t scales_h16 = vld1_s16(q6_scales + scale_idx_h * 8 + g * 4); + const int32x4_t scale_vec_l = vmovl_s16(scales_l16); + const int32x4_t scale_vec_h = vmovl_s16(scales_h16); + const int acc_offset = g * q8_k_blocklen; + + for (int row = 0; row < q8_k_blocklen; row++) { + const int idx = row * 2 + g; + acc_s32[idx] = vmlaq_s32(acc_s32[idx], acc_lo[acc_offset + row], scale_vec_l); + acc_s32[idx] = vmlaq_s32(acc_s32[idx], acc_hi[acc_offset + row], scale_vec_h); + } + } + } + } + + // Finally we apply the superblock scales + for (int row = 0; row < q8_k_blocklen; row++) { + const int idx0 = 2 * row; + const int idx1 = 2 * row + 1; + const int32x4_t acc_0123 = acc_s32[idx0]; + const int32x4_t acc_4567 = acc_s32[idx1]; + + acc_f32[idx0] = vmlaq_f32(acc_f32[idx0], vcvtq_f32_s32(acc_0123), sbd_scale_0123[row]); + acc_f32[idx1] = vmlaq_f32(acc_f32[idx1], vcvtq_f32_s32(acc_4567), sbd_scale_4567[row]); + } + } // for b + + for (int i = 0; i < q8_k_blocklen; i++) { + int row = y * q8_k_blocklen + i; + for (int j = 0; j < 2; j++) { + int col = x * ncols_interleaved + j * 4; + int offset = row * bs + col; + vst1q_f32(s + offset, acc_f32[2 * i + j]); + } + } + } // for x + } // for y + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemm_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, diff --git a/ggml/src/ggml-cpu/binary-ops.cpp b/ggml/src/ggml-cpu/binary-ops.cpp index 14f5b43ae0..75e3829001 100644 --- a/ggml/src/ggml-cpu/binary-ops.cpp +++ b/ggml/src/ggml-cpu/binary-ops.cpp @@ -59,11 +59,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds GGML_ASSERT(nb00 == sizeof(src0_t)); const auto [ir0, ir1] = get_thread_range(params, src0); - const bool is_src1_contiguous = (nb10 == sizeof(src1_t)); - - if (!is_src1_contiguous) { // broadcast not implemented yet for non-contiguous - GGML_ASSERT(ggml_are_same_shape(src0, src1)); - } + const bool is_src1_contiguous_rows = ggml_is_contiguous_rows(src1); #ifdef GGML_USE_ACCELERATE vDSP_fn_t vDSP_op = nullptr; @@ -94,7 +90,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); const src1_t * src1_ptr = (const src1_t *) ((const char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11); - if (is_src1_contiguous) { + if (is_src1_contiguous_rows) { // src1 is broadcastable across src0 and dst in i1, i2, i3 const int64_t nr0 = ne00 / ne10; diff --git a/ggml/src/ggml-cpu/common.h b/ggml/src/ggml-cpu/common.h index 1057b5bb15..abbadc359c 100644 --- a/ggml/src/ggml-cpu/common.h +++ b/ggml/src/ggml-cpu/common.h @@ -6,8 +6,8 @@ #include "ggml-impl.h" #include "simd-mappings.h" -#define GGML_FA_TILE_Q 32 -#define GGML_FA_TILE_KV 16 +#define GGML_FA_TILE_Q 64 +#define GGML_FA_TILE_KV 64 #ifdef __cplusplus diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index b003fe13fd..64eb01a4e1 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2874,8 +2874,8 @@ struct ggml_cplan ggml_graph_plan( const int64_t DV = node->src[2]->ne[0]; // Tiled flash attention scratch (tile sizes defined in common.h) - // Per-thread: Q_q + KQ + mask + VKQ32 + V32 + padding - size_t prefill = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks; + // Per-thread: Q_q + KQ + mask + VKQ32 + V32 + K_f32 + padding + size_t prefill = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV + GGML_FA_TILE_KV*DK)*n_tasks; // Decode path: n_kv_chunks = n_tasks (one chunk per thread) // Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ @@ -2947,7 +2947,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { /*.use_ref =*/ cplan->use_ref, }; - GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d \n", state->ith, cplan, state->last_graph); +#ifdef GGML_USE_OPENMP + GGML_PRINT_DEBUG("thread #%d compute-start cplan %p\n", state->ith, (const void *)cplan); +#else + GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d\n", state->ith, (const void *)cplan, state->last_graph); +#endif for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) { struct ggml_tensor * node = cgraph->nodes[node_n]; @@ -2974,7 +2978,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { } } - GGML_PRINT_DEBUG("thread #%d compute-done cplan %p last-graph %d \n", state->ith, cplan, state->last_graph); +#ifdef GGML_USE_OPENMP + GGML_PRINT_DEBUG("thread #%d compute-done cplan %p\n", state->ith, (const void *)cplan); +#else + GGML_PRINT_DEBUG("thread #%d compute-done cplan %p last-graph %d\n", state->ith, (const void *)cplan, state->last_graph); +#endif ggml_barrier(state->threadpool); diff --git a/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h b/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h deleted file mode 100644 index a707868728..0000000000 --- a/ggml/src/ggml-cpu/llamafile/sgemm-ppc.h +++ /dev/null @@ -1,333 +0,0 @@ -#pragma once - -typedef vector unsigned char vec_t; -typedef __vector_quad acc_t; - -template -class tinyBLAS_Q0_PPC { - public: - tinyBLAS_Q0_PPC(int64_t k, - const TA *A, int64_t lda, - const block_q8_0 *B, int64_t ldb, - float *C, int64_t ldc, - int ith, int nth); - - void matmul(int64_t m, int64_t n); - void matmul_tiled_q0(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) { - vec_t A_pack[mc*kc*2]; - vec_t B_pack[nc*kc*2]; - int comparray[mc*kc]; - constexpr bool is_Ablock_q4 = std::is_same_v; - int64_t ytiles = m / mc; - int64_t xtiles = n / nc; - int64_t tiles = xtiles * ytiles; - int64_t duty = (tiles + nth - 1) / nth; - int64_t start = duty * ith; - int64_t end = start + duty; - if (end > tiles) { - end = tiles; - } - for (int64_t job = start; job < end; ++job) { - int64_t ii = (job / xtiles) * mc; - int64_t jj = (job % xtiles) * nc; - for (int64_t kk = 0; kk < k; kk += kc) { - if constexpr(is_Ablock_q4) { - packNormalInt4_large(A + ii*lda + kk, lda, mc, 4, (int8_t*)A_pack, comparray); - } else { - packNormal_large(A + ii*lda + kk, lda, mc, 8, (int8_t*)A_pack, false, comparray); - } - packNormal_large(B + jj*ldb + kk, ldb, nc, 8, (uint8_t*)B_pack, true); - KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack, comparray); - } - } - } - - private: - inline void save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) { - for (int I = 0; I < RM; I++) { - for (int J = 0; J < RN; J++) { - *((float*)(C+ii+((jj+J)*ldc)+I)) = *((float*)&fin_res[idx+I]+J); - } - } - } - - inline void add_save_res(int ii, int jj, int idx, vector float* fin_res, int RM=4, int RN=4) { - for (int I = 0; I < RM; I++) { - for (int J = 0; J < RN; J++) { - float * c_ptr = (float *)(C+ii+((jj+J)*ldc)+I); - *c_ptr += *((float*)&fin_res[idx+I]+J); - } - } - } - - template - inline void compute(acc_t* ACC, int c_idx, int s_idx, ArrayType& comparray, vector float* vs, vector float* fin_res) { - vector signed int vec_C[4]; - vector float CA[4] = {0}; - vector float res[4] = {0}; - __builtin_mma_disassemble_acc(vec_C, ACC); - for (int i = 0; i < 4; i++) { - CA[i] = vec_splats((float)(((double)comparray[c_idx+i]) * -128.0)); - res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]); - fin_res[s_idx+i] = vec_madd(res[i], vs[s_idx+i], fin_res[s_idx+i]); - } - } - - inline void process_q4_elements(vector signed char (&c)[2], int* ca) { - const vector signed char lowMask = vec_splats((signed char)0xF); - const vector unsigned char v4 = vec_splats((unsigned char)0x4); - const vector signed char v8 = vec_splats((signed char)0x8); - vector signed int vsum = {0}; - vector signed int vsum2 = {0}; - c[0] = vec_and(c[1], lowMask); - c[1] = vec_sr(c[1], v4); - c[0] = vec_sub(c[0], v8); - c[1] = vec_sub(c[1], v8); - vsum = vec_sum4s(c[0], vsum); - vsum2 = vec_sum4s(c[1], vsum2); - vsum = vec_add(vsum, vsum2); - *(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - } - - template - inline void vector_permute_store(V2 &s1, V2 &s2, V2 &s3, V2 &s4, V1 *vecOffset, bool flip) { - vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; - vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; - vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; - vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; - V2 t1, t2, t3, t4, t5, t6, t7, t8; - vector unsigned char xor_vector; - uint8_t flip_vec = 0x80; - xor_vector = vec_splats(flip_vec); - t1 = vec_perm(s1, s2, swiz1); - t2 = vec_perm(s1, s2, swiz2); - t3 = vec_perm(s3, s4, swiz1); - t4 = vec_perm(s3, s4, swiz2); - t5 = vec_perm(t1, t3, swiz3); - t6 = vec_perm(t1, t3, swiz4); - t7 = vec_perm(t2, t4, swiz3); - t8 = vec_perm(t2, t4, swiz4); - if (flip == true) { - t5 = vec_xor(t5, xor_vector); - t6 = vec_xor(t6, xor_vector); - t7 = vec_xor(t7, xor_vector); - t8 = vec_xor(t8, xor_vector); - } - vec_xst(t5, 0, vecOffset); - vec_xst(t6, 0, vecOffset+16); - vec_xst(t7, 0, vecOffset+32); - vec_xst(t8, 0, vecOffset+48); - } - - template - inline void kernel(int64_t ii, int64_t jj) { - if constexpr(RM == 4 && RN == 8) { - KERNEL_4x8(ii,jj); - } else if constexpr(RM == 8 && RN == 4) { - KERNEL_8x4(ii,jj); - } else if constexpr(RM == 8 && RN == 8) { - KERNEL_8x8(ii,jj); - } else { - assert(false && "RN/RM values not supported"); - } - } - template - void packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array& comparray); - template - void packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip); - void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n); - void KERNEL_4x8(int64_t ii, int64_t jj); - void KERNEL_8x4(int64_t ii, int64_t jj); - void KERNEL_8x8(int64_t ii, int64_t jj); - void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN); - template - void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n); - - void compute_scale(int64_t ii, int64_t jj, int blk, vector float* vs){ - for (int I = 0; I<8; I++) { - float a_scale = unhalf((A+((ii+I)*lda)+blk)->d); - for (int J = 0; J<4; J++) { - *((float*)&vs[I]+J) = (a_scale * unhalf((B+((jj+J)*ldb)+blk)->d)); - *((float*)&vs[I+8]+J) = (a_scale * unhalf((B+((jj+J+4)*ldb)+blk)->d)); - } - } - } - - inline void process_q8_elements(const int8_t *qs, int *ca) { - vector signed char c1 = vec_xl(0, qs); - vector signed char c2 = vec_xl(16, qs); - vector signed int vsum1 = {0}; - vector signed int vsum2 = {0}; - vsum1 = vec_sum4s(c1, vsum1); - vsum2 = vec_sum4s(c2, vsum2); - vector signed int vsum = vec_add(vsum1, vsum2); - *ca = vsum[0] + vsum[1] + vsum[2] + vsum[3]; - } - - template - void packNormal_large(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip, int* comparray=nullptr) { - int64_t i, j; - block_q8_0 *aoffset = NULL; - VA *vecOffset = NULL; - block_q8_0* aoffsets[8]; - __vector_pair arr[8]; - VB c[8][2] = {0}; - VB c1[8] = {0}; VB c2[8] = {0}; - aoffset = const_cast(a); - vecOffset = vec; - j = (rows >> 3); - int index = 0; - if (j > 0) { - do { - for (int it = 0; it < 8; it++) - aoffsets[it] = aoffset + it*lda; - aoffset += 8 * lda; - for (int blk = 0; blk < kc; blk++) { - for (int it = 0; it < 8; it++) { - arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)(aoffsets[it]+blk)->qs); - __builtin_vsx_disassemble_pair(c[it], &arr[it]); - c1[it] = c[it][0]; - c2[it] = c[it][1]; - if (comparray){ - process_q8_elements((aoffsets[it]+ blk)->qs, &comparray[index + 8*blk + it]); - } - } - vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); - vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); - vector_permute_store(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip); - vector_permute_store(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip); - vecOffset += 256; - } - j--; - index += 8*kc; - } while(j > 0); - } - - } - - void packNormalInt4_large(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, int*comparray) { - int64_t i, j; - TA *aoffset = NULL; - int8_t *vecOffset = NULL; - TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; - TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; - vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0}; - vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0}; - aoffset = const_cast(a); - vecOffset = vec; - int index = 0; - j = (rows >> 3); - if (j > 0) { - do { - aoffset1 = aoffset; - aoffset2 = aoffset1 + lda; - aoffset3 = aoffset2 + lda; - aoffset4 = aoffset3 + lda; - aoffset5 = aoffset4 + lda; - aoffset6 = aoffset5 + lda; - aoffset7 = aoffset6 + lda; - aoffset8 = aoffset7 + lda; - aoffset += 8 * lda; - for (int blk = 0; blk < kc; blk++) { - c1[1] = reinterpret_cast(vec_xl(0, (aoffset1+blk)->qs)); - c2[1] = reinterpret_cast(vec_xl(0, (aoffset2+blk)->qs)); - c3[1] = reinterpret_cast(vec_xl(0, (aoffset3+blk)->qs)); - c4[1] = reinterpret_cast(vec_xl(0, (aoffset4+blk)->qs)); - c5[1] = reinterpret_cast(vec_xl(0, (aoffset5+blk)->qs)); - c6[1] = reinterpret_cast(vec_xl(0, (aoffset6+blk)->qs)); - c7[1] = reinterpret_cast(vec_xl(0, (aoffset7+blk)->qs)); - c8[1] = reinterpret_cast(vec_xl(0, (aoffset8+blk)->qs)); - - process_q4_elements(c1, &comparray[index + 8*blk+0]); - process_q4_elements(c2, &comparray[index + 8*blk+1]); - process_q4_elements(c3, &comparray[index + 8*blk+2]); - process_q4_elements(c4, &comparray[index + 8*blk+3]); - process_q4_elements(c5, &comparray[index + 8*blk+4]); - process_q4_elements(c6, &comparray[index + 8*blk+5]); - process_q4_elements(c7, &comparray[index + 8*blk+6]); - process_q4_elements(c8, &comparray[index + 8*blk+7]); - vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false); - vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); - vector_permute_store(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false); - vector_permute_store(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false); - vecOffset += 256; - } - j--; - index += 8*kc; - } while (j > 0); - } - } - - void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t *vec_A, vec_t *vec_B, int *comparray) { - acc_t acc[8]; - for (int i = 0; i < mc ; i += 8) { - for (int j = 0; j < nc; j += 8) { - vector float fin_res[16] = {0}; - vector float vs[16] = {0}; - for (int64_t kk = 0; kk < kc; kk+=2) { - for (int x = 0; x < 8; x++) { - __builtin_mma_xxsetaccz(&acc[x]); - } - int A_block_idx = (i/8)*(16*kc) + kk*16; - int B_block_idx = (j/8)*(16*kc)+ kk*16; - vec_t *A_block = &vec_A[A_block_idx]; - vec_t *B_block = &vec_B[B_block_idx]; - for (int x = 0; x < 8; x++) { - __builtin_mma_xvi8ger4pp(&acc[0], A_block[x], B_block[x]); - __builtin_mma_xvi8ger4pp(&acc[1], A_block[x + 8], B_block[x]); - __builtin_mma_xvi8ger4pp(&acc[2], A_block[x], B_block[x+8]); - __builtin_mma_xvi8ger4pp(&acc[3], A_block[x+8], B_block[x+8]); - } - compute_scale(ii+i, jj+j, l+kk, vs); - int c_index = (i/8)*(8*kc)+ kk*8; - int* c_block = &comparray[c_index]; - compute(&acc[0], 0, 0, c_block, vs, fin_res); - compute(&acc[1], 4, 4, c_block, vs, fin_res); - compute(&acc[2], 0, 8, c_block, vs, fin_res); - compute(&acc[3], 4, 12, c_block, vs, fin_res); - - A_block_idx = (i/8)*(16*kc) + (kk+1)*16; - B_block_idx = (j/8)*(16*kc)+ (kk+1)*16; - A_block = &vec_A[A_block_idx]; - B_block = &vec_B[B_block_idx]; - for (int x = 0; x < 8; x++) { - __builtin_mma_xvi8ger4pp(&acc[4], A_block[x], B_block[x]); - __builtin_mma_xvi8ger4pp(&acc[5], A_block[x + 8], B_block[x]); - __builtin_mma_xvi8ger4pp(&acc[6], A_block[x], B_block[x+8]); - __builtin_mma_xvi8ger4pp(&acc[7], A_block[x+8], B_block[x+8]); - } - compute_scale(ii+i, jj+j, l+kk+1, vs); - c_index = (i/8)*(8*kc)+ (kk+1)*8; - c_block = &comparray[c_index]; - compute(&acc[4], 0, 0, c_block, vs, fin_res); - compute(&acc[5], 4, 4, c_block, vs, fin_res); - compute(&acc[6], 0, 8, c_block, vs, fin_res); - compute(&acc[7], 4, 12, c_block, vs, fin_res); - - } - if (l == 0) { - save_res(ii+i, jj+j, 0, fin_res); - save_res(ii+i+4, jj+j, 4, fin_res); - save_res(ii+i, jj+j+4, 8, fin_res); - save_res(ii+i+4, jj+j+4, 12, fin_res); - } else { - add_save_res(ii+i, jj+j, 0, fin_res); - add_save_res(ii+i+4, jj+j, 4, fin_res); - add_save_res(ii+i, jj+j+4, 8, fin_res); - add_save_res(ii+i+4, jj+j+4, 12, fin_res); - } - } - } - } - - const TA *const A; - const block_q8_0 *const B; - float *C; - const int64_t k; - int64_t kc; - const int64_t lda; - const int64_t ldb; - const int64_t ldc; - const int ith; - const int nth; -}; diff --git a/ggml/src/ggml-cpu/llamafile/sgemm.cpp b/ggml/src/ggml-cpu/llamafile/sgemm.cpp index 8f980c16b9..da412fd009 100644 --- a/ggml/src/ggml-cpu/llamafile/sgemm.cpp +++ b/ggml/src/ggml-cpu/llamafile/sgemm.cpp @@ -121,7 +121,8 @@ inline float32x4_t mul(float32x4_t x, float32x4_t y) { return vec_mul(x, y); } #endif #if defined(__MMA__) -#include "sgemm-ppc.h" +typedef vector unsigned char vec_t; +typedef __vector_quad acc_t; #endif //////////////////////////////////////////////////////////////////////////////////////////////////// // VECTORIZED FUSED MULTIPLY ADD @@ -2153,7 +2154,7 @@ class tinyBLAS_HP16_PPC { packNormal((B+(jj*ldb)+l), ldb, 8, 4, (uint8_t*)vec_B); for (int x = 0; x < 4; x++) { mma_instr::outer_product(&acc_0, vec_A[x], vec_B[x]); - mma_instr::outer_product(&acc_1, vec_A[x], vec_B[x+4]); + mma_instr::outer_product(&acc_1, vec_A[x+4], vec_B[x]); } } SAVE_ACC(&acc_0, ii, jj); @@ -2301,43 +2302,299 @@ class tinyBLAS_HP16_PPC { const int nth; }; - template - tinyBLAS_Q0_PPC::tinyBLAS_Q0_PPC(int64_t k, - const TA *A, int64_t lda, - const block_q8_0 *B, int64_t ldb, - float *C, int64_t ldc, - int ith, int nth) +template +class tinyBLAS_Q0_PPC { + public: + tinyBLAS_Q0_PPC(int64_t k, + const TA * A, int64_t lda, + const block_q8_0 * B, int64_t ldb, + float * C, int64_t ldc, + int ith, int nth) : A(A), B(B), C(C), k(k), lda(lda), ldb(ldb), ldc(ldc), ith(ith), nth(nth) { - kc = 64; } - template - void tinyBLAS_Q0_PPC::matmul(int64_t m, int64_t n) { - int mc = 64; int nc = 64; - if (n % 8 == 0 && n < nc) { - nc = n; - mc = 32 ; - kc = 32; + void matmul(int64_t m, int64_t n) { + const int64_t mc = 64; + const int64_t kc = 64; + int64_t nc = 64; + int64_t n_aligned = 0; + if (n % 64 == 0) { + n_aligned = n; + } else if (n == 4) { + n_aligned = 4; + } else if (n < 64) { + n_aligned = (n / 8) * 8; + } else { + n_aligned = (n / 64) * 64; } - const bool is_aligned = ((m & (mc - 1)) == 0) & ((n & (nc - 1)) == 0) & ((k & (kc - 1)) == 0); - if (is_aligned) { - this->matmul_tiled_q0(m, n, mc, nc, kc); + + if (n_aligned > 0) { + if (n_aligned % 64 == 0) nc = 64; + else if (n_aligned == n) nc = n; + else if (n_aligned % 32 == 0) nc = 32; + else if (n_aligned % 24 == 0) nc = 24; + else if (n_aligned % 16 == 0) nc = 16; + else nc = 8; + } + bool can_use_tiled = n_aligned > 0 && (m % mc == 0) && (k % kc == 0); + if (can_use_tiled) { + matmul_tiled(m, n_aligned, mc, nc, kc); + if (n > n_aligned) { + mnpack(0, m, n_aligned, n); + } } else { mnpack(0, m, 0, n); } } - template - template - void tinyBLAS_Q0_PPC::packNormalInt4(const TA* a, int64_t lda, int rows, int cols, int8_t* vec, std::array& comparray) { + private: + inline void save_res(int ii, int jj, int idx, vector float * fin_res, int RM = 4, int RN = 4) { + for (int I = 0; I < RM; I++) { + for (int J = 0; J < RN; J++) { + *((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&fin_res[idx + I] + J); + } + } + } + + inline void save_acc(acc_t * ACC, int64_t ii, int64_t jj) { + vec_t vec_C[4]; + __builtin_mma_disassemble_acc(vec_C, ACC); + for (int I = 0; I < 4; I++) { + for (int J = 0; J < 4; J++) { + *((float *)(C + ii + ((jj + J) * ldc) + I)) = *((float *)&vec_C[I] + J); + } + } + } + + inline void add_save_acc(acc_t * ACC, int64_t ii, int64_t jj) { + vec_t vec_C[4]; + __builtin_mma_disassemble_acc(vec_C, ACC); + for (int I = 0; I < 4; I++) { + for (int J = 0; J < 4; J++) { + float * c_ptr = (float *)(C + ii+ ((jj + J) * ldc) + I); + *c_ptr += *((float *)&vec_C[I] + J); + } + } + } + + template + inline void compute(acc_t * ACC, int c_idx, int s_idx, ArrayType & comparray, vector float * vs, vector float * fin_res) { + vector signed int vec_C[4]; + vector float CA[4] = {0}; + vector float res[4] = {0}; + __builtin_mma_disassemble_acc(vec_C, ACC); + for (int i = 0; i < 4; i++) { + CA[i] = vec_splats((float)(((double)comparray[c_idx + i]) * -128.0)); + res[i] = vec_add(vec_ctf(vec_C[i], 0), CA[i]); + fin_res[s_idx + i] = vec_madd(res[i], vs[s_idx + i], fin_res[s_idx + i]); + } + } + + inline void process_q4_elements(vector signed char (&c)[2], int * ca) { + const vector signed char lowMask = vec_splats((signed char)0xF); + const vector unsigned char v4 = vec_splats((unsigned char)0x4); + const vector signed char v8 = vec_splats((signed char)0x8); + vector signed int vsum = {0}; + vector signed int vsum2 = {0}; + c[0] = vec_and(c[1], lowMask); + c[1] = vec_sr(c[1], v4); + c[0] = vec_sub(c[0], v8); + c[1] = vec_sub(c[1], v8); + vsum = vec_sum4s(c[0], vsum); + vsum2 = vec_sum4s(c[1], vsum2); + vsum = vec_add(vsum, vsum2); + *(ca) = vsum[0] + vsum[1] + vsum[2] + vsum[3]; + } + + template + inline void vector_permute_store(V2 & s1, V2 & s2, V2 & s3, V2 & s4, V1 * vecOffset, bool flip) { + vector unsigned char swiz1 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; + vector unsigned char swiz2 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; + vector unsigned char swiz3 = {0, 1, 2, 3, 8, 9, 10, 11, 16, 17, 18, 19, 24, 25, 26, 27}; + vector unsigned char swiz4 = {4, 5, 6, 7, 12, 13, 14, 15, 20, 21, 22, 23, 28, 29, 30, 31}; + V2 t1, t2, t3, t4, t5, t6, t7, t8; + vector unsigned char xor_vector; + uint8_t flip_vec = 0x80; + xor_vector = vec_splats(flip_vec); + t1 = vec_perm(s1, s2, swiz1); + t2 = vec_perm(s1, s2, swiz2); + t3 = vec_perm(s3, s4, swiz1); + t4 = vec_perm(s3, s4, swiz2); + t5 = vec_perm(t1, t3, swiz3); + t6 = vec_perm(t1, t3, swiz4); + t7 = vec_perm(t2, t4, swiz3); + t8 = vec_perm(t2, t4, swiz4); + if (flip == true) { + t5 = vec_xor(t5, xor_vector); + t6 = vec_xor(t6, xor_vector); + t7 = vec_xor(t7, xor_vector); + t8 = vec_xor(t8, xor_vector); + } + vec_xst(t5, 0, vecOffset); + vec_xst(t6, 0, vecOffset + 16); + vec_xst(t7, 0, vecOffset + 32); + vec_xst(t8, 0, vecOffset + 48); + } + + inline void unpack_q4_to_q8(vector signed char packed, vector signed char & lo, vector signed char & hi) { + const vector signed char lowMask = vec_splats((signed char)0x0F); + const vector signed char v8 = vec_splats((signed char)0x08); + const vector unsigned char v4 = vec_splats((unsigned char)4); + lo = vec_and(packed, lowMask); + hi = vec_sr(packed, v4); + lo = vec_sub(lo, v8); + hi = vec_sub(hi, v8); + } + + inline void vector_permute_store_fp16(vec_t * c, unsigned char * vecOffset) { + vec_t t[8], s[8]; + vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23}; + vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31}; + vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; + vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; + for (int i = 0; i < 4; i += 2) { + t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1); + t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2); + } + for (int i = 4; i < 8; i += 2) { + t[i + 0] = vec_perm(c[i + 0], c[i + 1], swiz1); + t[i + 1] = vec_perm(c[i + 0], c[i + 1], swiz2); + } + s[0] = vec_perm(t[0], t[2], swiz3); + s[1] = vec_perm(t[0], t[2], swiz4); + s[2] = vec_perm(t[1], t[3], swiz3); + s[3] = vec_perm(t[1], t[3], swiz4); + s[4] = vec_perm(t[4], t[6], swiz3); + s[5] = vec_perm(t[4], t[6], swiz4); + s[6] = vec_perm(t[5], t[7], swiz3); + s[7] = vec_perm(t[5], t[7], swiz4); + for (int i = 0; i < 8; ++i) { + vec_xst(s[i], 0, (vec_t *)(vecOffset + i * 16)); + } + } + + static inline void convert_and_scale_q8(vector signed char raw, vector float v_scale, vector unsigned short & out_hi, vector unsigned short & out_lo) { + vector signed short i16_hi = vec_unpackh(raw); + vector signed short i16_lo = vec_unpackl(raw); + + vector float f_hi_h = vec_ctf(vec_unpackh(i16_hi), 0); + vector float f_hi_l = vec_ctf(vec_unpackl(i16_hi), 0); + vector float f_lo_h = vec_ctf(vec_unpackh(i16_lo), 0); + vector float f_lo_l = vec_ctf(vec_unpackl(i16_lo), 0); + out_hi = vec_pack_to_short_fp32(vec_mul(f_hi_h, v_scale), vec_mul(f_hi_l, v_scale)); + out_lo = vec_pack_to_short_fp32(vec_mul(f_lo_h, v_scale), vec_mul(f_lo_l, v_scale)); + } + + void packNormal_q4_fp16(const block_q4_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) { + unsigned char * vecOffset = vec; + for (int i = 0; i < rows; i += 8) { + const block_q4_0 * rows_base[8]; + for (int r = 0; r < 8; r++) { + rows_base[r] = a + (i + r) * lda; + } + for (int blk = 0; blk < blocks; blk++) { + vector unsigned short hp_res[8][4]; + for (int r = 0; r < 8; r++) { + const block_q4_0 * current_blk = rows_base[r] + blk; + vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(current_blk->d)); + vector signed char v_qs = reinterpret_cast(vec_xl(0, current_blk->qs)); + vector signed char c1, c2; + unpack_q4_to_q8(v_qs, c1, c2); + convert_and_scale_q8(c1, v_scale, hp_res[r][0], hp_res[r][1]); + convert_and_scale_q8(c2, v_scale, hp_res[r][2], hp_res[r][3]); + } + for (int c = 0; c < 4; c++) { + vector unsigned char c_arr[8]; + for (int r = 0; r < 8; r++) { + c_arr[r] = (vector unsigned char)hp_res[r][c]; + } + vector_permute_store_fp16((vec_t *)c_arr, vecOffset); + vecOffset += 128; + } + } + } + } + + template + static inline void pack_q8_block(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) { + unsigned char * vecOffset = vec; + const vec_t swiz1 = {0, 1, 2, 3, 16, 17, 18, 19, 4, 5, 6, 7, 20, 21, 22, 23}; + const vec_t swiz2 = {8, 9, 10, 11, 24, 25, 26, 27, 12, 13, 14, 15, 28, 29, 30, 31}; + const vec_t swiz3 = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19, 20, 21, 22, 23}; + const vec_t swiz4 = {8, 9, 10, 11, 12, 13, 14, 15, 24, 25, 26, 27, 28, 29, 30, 31}; + + for (int i = 0; i < rows; i += chunk_size) { + const block_q8_0 * rows_base[chunk_size]; + for (int r = 0; r < chunk_size; r++) { + rows_base[r] = a + (i + r) * lda; + } + for (int blk = 0; blk < blocks; blk++) { + vector unsigned short hp_res[chunk_size][4]; + for (int r = 0; r < chunk_size; r++) { + const block_q8_0 * b = rows_base[r] + blk; + vector float v_scale = vec_extract_fp32_from_shorth(vec_splats(b->d)); + vector signed char c[2]; + __vector_pair pair = __builtin_vsx_lxvp(0, (__vector_pair *)b->qs); + __builtin_vsx_disassemble_pair(c, & pair); + convert_and_scale_q8(c[0], v_scale, hp_res[r][0], hp_res[r][1]); + convert_and_scale_q8(c[1], v_scale, hp_res[r][2], hp_res[r][3]); + } + for (int col = 0; col < 4; col++) { + if constexpr (chunk_size == 8) { + vec_t t[8]; + t[0] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1); + t[1] = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2); + t[2] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1); + t[3] = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2); + t[4] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz1); + t[5] = vec_perm((vec_t)hp_res[4][col], (vec_t)hp_res[5][col], swiz2); + t[6] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz1); + t[7] = vec_perm((vec_t)hp_res[6][col], (vec_t)hp_res[7][col], swiz2); + + vec_xst(vec_perm(t[0], t[2], swiz3), 0, (vec_t *)(vecOffset + 0)); + vec_xst(vec_perm(t[0], t[2], swiz4), 0, (vec_t *)(vecOffset + 16)); + vec_xst(vec_perm(t[1], t[3], swiz3), 0, (vec_t *)(vecOffset + 32)); + vec_xst(vec_perm(t[1], t[3], swiz4), 0, (vec_t *)(vecOffset + 48)); + vec_xst(vec_perm(t[4], t[6], swiz3), 0, (vec_t *)(vecOffset + 64)); + vec_xst(vec_perm(t[4], t[6], swiz4), 0, (vec_t *)(vecOffset + 80)); + vec_xst(vec_perm(t[5], t[7], swiz3), 0, (vec_t *)(vecOffset + 96)); + vec_xst(vec_perm(t[5], t[7], swiz4), 0, (vec_t *)(vecOffset + 112)); + vecOffset += 128; + } else { + vec_t t0 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz1); + vec_t t1 = vec_perm((vec_t)hp_res[0][col], (vec_t)hp_res[1][col], swiz2); + vec_t t2 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz1); + vec_t t3 = vec_perm((vec_t)hp_res[2][col], (vec_t)hp_res[3][col], swiz2); + + vec_xst(vec_perm(t0, t2, swiz3), 0, (vec_t *)(vecOffset + 0)); + vec_xst(vec_perm(t0, t2, swiz4), 0, (vec_t *)(vecOffset + 16)); + vec_xst(vec_perm(t1, t3, swiz3), 0, (vec_t *)(vecOffset + 32)); + vec_xst(vec_perm(t1, t3, swiz4), 0, (vec_t *)(vecOffset + 48)); + vecOffset += 64; + } + } + } + } + } + + void packNormal_q8_fp16(const block_q8_0 * a, int64_t lda, int rows, int blocks, unsigned char * vec) { + if (rows == 4) { + pack_q8_block<4>(a, lda, rows, blocks, vec); + } else { + pack_q8_block<8>(a, lda, rows, blocks, vec); + } + } + + template + void packNormalInt4(const TA * a, int64_t lda, int rows, int cols, int8_t * vec, std::array & comparray) { int64_t i, j; - TA *aoffset = NULL; - int8_t *vecOffset = NULL; - TA *aoffset1 = NULL, *aoffset2 = NULL, *aoffset3 = NULL, *aoffset4 = NULL; - TA *aoffset5 = NULL, *aoffset6 = NULL, *aoffset7 = NULL, *aoffset8 = NULL; + TA * aoffset = NULL; + int8_t * vecOffset = NULL; + TA * aoffset1 = NULL, * aoffset2 = NULL, * aoffset3 = NULL, * aoffset4 = NULL; + TA * aoffset5 = NULL, * aoffset6 = NULL, * aoffset7 = NULL, * aoffset8 = NULL; vector signed char c1[2] = {0}, c2[2] = {0}, c3[2] = {0}, c4[2] = {0}; vector signed char c5[2] = {0}, c6[2] = {0}, c7[2] = {0}, c8[2] = {0}; - aoffset = const_cast(a); + aoffset = const_cast(a); vecOffset = vec; j = (rows >> 3); if (j > 0) { @@ -2363,18 +2620,18 @@ class tinyBLAS_HP16_PPC { c7[1] = reinterpret_cast(vec_xl(0, aoffset7->qs)); c8[1] = reinterpret_cast(vec_xl(0, aoffset8->qs)); - process_q4_elements(c1, &comparray[0]); - process_q4_elements(c2, &comparray[1]); - process_q4_elements(c3, &comparray[2]); - process_q4_elements(c4, &comparray[3]); - process_q4_elements(c5, &comparray[4]); - process_q4_elements(c6, &comparray[5]); - process_q4_elements(c7, &comparray[6]); - process_q4_elements(c8, &comparray[7]); + process_q4_elements(c1, & comparray[0]); + process_q4_elements(c2, & comparray[1]); + process_q4_elements(c3, & comparray[2]); + process_q4_elements(c4, & comparray[3]); + process_q4_elements(c5, & comparray[4]); + process_q4_elements(c6, & comparray[5]); + process_q4_elements(c7, & comparray[6]); + process_q4_elements(c8, & comparray[7]); vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false); - vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); - vector_permute_store(c5[0], c6[0], c7[0], c8[0], vecOffset+128, false); - vector_permute_store(c5[1], c6[1], c7[1], c8[1], vecOffset+192, false); + vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false); + vector_permute_store(c5[0], c6[0], c7[0], c8[0], vecOffset + 128, false); + vector_permute_store(c5[1], c6[1], c7[1], c8[1], vecOffset + 192, false); aoffset1 += lda; aoffset2 += lda; aoffset3 += lda; @@ -2405,12 +2662,12 @@ class tinyBLAS_HP16_PPC { c3[1] = reinterpret_cast(vec_xl(0, aoffset3->qs)); c4[1] = reinterpret_cast(vec_xl(0, aoffset4->qs)); - process_q4_elements(c1, &comparray[0]); - process_q4_elements(c2, &comparray[1]); - process_q4_elements(c3, &comparray[2]); - process_q4_elements(c4, &comparray[3]); + process_q4_elements(c1, & comparray[0]); + process_q4_elements(c2, & comparray[1]); + process_q4_elements(c3, & comparray[2]); + process_q4_elements(c4, & comparray[3]); vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false); - vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); + vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false); aoffset1 += lda; aoffset2 += lda; aoffset3 += lda; @@ -2434,12 +2691,12 @@ class tinyBLAS_HP16_PPC { case 1: c1[1] = reinterpret_cast(vec_xl(0, aoffset1->qs)); break; } - process_q4_elements(c1, &comparray[0]); - process_q4_elements(c2, &comparray[1]); - process_q4_elements(c3, &comparray[2]); - process_q4_elements(c4, &comparray[3]); + process_q4_elements(c1, & comparray[0]); + process_q4_elements(c2, & comparray[1]); + process_q4_elements(c3, & comparray[2]); + process_q4_elements(c4, & comparray[3]); vector_permute_store(c1[0], c2[0], c3[0], c4[0], vecOffset, false); - vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset+64, false); + vector_permute_store(c1[1], c2[1], c3[1], c4[1], vecOffset + 64, false); aoffset1 += lda; aoffset2 += lda; aoffset3 += lda; @@ -2450,39 +2707,38 @@ class tinyBLAS_HP16_PPC { } } - template template - void tinyBLAS_Q0_PPC::packNormal(const block_q8_0* a, int64_t lda, int rows, int cols, VA* vec, bool flip) { + void packNormal(const block_q8_0 * a, int64_t lda, int rows, int cols, VA * vec, bool flip) { int64_t i, j; - block_q8_0 *aoffset = NULL; - VA *vecOffset = NULL; - block_q8_0* aoffsets[8]; + block_q8_0 * aoffset = NULL; + VA * vecOffset = NULL; + block_q8_0 * aoffsets[8]; __vector_pair arr[8]; VB c[8][2] = {0}; VB c1[8] = {0}; VB c2[8] = {0}; - aoffset = const_cast(a); + aoffset = const_cast(a); vecOffset = vec; j = (rows >> 3); if (j > 0) { do { aoffsets[0] = aoffset; for (int it = 1; it < 8; it++) - aoffsets[it] = aoffsets[it-1] + lda; + aoffsets[it] = aoffsets[it - 1] + lda; aoffset += 8 * lda; i = (cols >> 3); if (i > 0) { do { for (int it = 0; it < 8; it++) { - arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs); - __builtin_vsx_disassemble_pair(c[it], &arr[it]); + arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs); + __builtin_vsx_disassemble_pair(c[it], & arr[it]); c1[it] = c[it][0]; c2[it] = c[it][1]; } vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); - vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); - vector_permute_store(c1[4], c1[5], c1[6], c1[7], vecOffset+128, flip); - vector_permute_store(c2[4], c2[5], c2[6], c2[7], vecOffset+192, flip); + vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip); + vector_permute_store(c1[4], c1[5], c1[6], c1[7], vecOffset + 128, flip); + vector_permute_store(c2[4], c2[5], c2[6], c2[7], vecOffset + 192, flip); for (int it = 0; it < 8; it++) aoffsets[it] += lda; vecOffset += 256; @@ -2501,13 +2757,13 @@ class tinyBLAS_HP16_PPC { if (i > 0) { do { for (int it = 0; it < 4; it++) { - arr[it] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[it]->qs); - __builtin_vsx_disassemble_pair(c[it], &arr[it]); + arr[it] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[it]->qs); + __builtin_vsx_disassemble_pair(c[it], & arr[it]); c1[it] = c[it][0]; c2[it] = c[it][1]; } vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); - vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); + vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip); for (int it = 0; it < 4; it++) { aoffsets[it] += lda; } @@ -2520,24 +2776,24 @@ class tinyBLAS_HP16_PPC { if (rows & 3) { aoffsets[0] = aoffset; for (int it = 1; it < 3; it++ ) - aoffsets[it] = aoffsets[it-1] + lda; + aoffsets[it] = aoffsets[it - 1] + lda; i = (cols >> 3); if (i > 0) { do { switch(rows) { - case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[2]->qs); - __builtin_vsx_disassemble_pair(c[2], &arr[2]); + case 3: arr[2] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[2]->qs); + __builtin_vsx_disassemble_pair(c[2], & arr[2]); c1[2] = c[2][0]; c2[2] = c[2][1]; - case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[1]->qs); - __builtin_vsx_disassemble_pair(c[1], &arr[1]); + case 2: arr[1] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[1]->qs); + __builtin_vsx_disassemble_pair(c[1], & arr[1]); c1[1] = c[1][0]; c2[1] = c[1][1]; - case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair*)aoffsets[0]->qs); - __builtin_vsx_disassemble_pair(c[0], &arr[0]); + case 1: arr[0] = __builtin_vsx_lxvp(0, (__vector_pair *)aoffsets[0]->qs); + __builtin_vsx_disassemble_pair(c[0], & arr[0]); c1[0] = c[0][0]; c2[0] = c[0][1]; break; } vector_permute_store(c1[0], c1[1], c1[2], c1[3], vecOffset, flip); - vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset+64, flip); + vector_permute_store(c2[0], c2[1], c2[2], c2[3], vecOffset + 64, flip); for (int it = 0; it < 3; it++) aoffsets[it] += lda; vecOffset += 128; @@ -2547,8 +2803,7 @@ class tinyBLAS_HP16_PPC { } } - template - void tinyBLAS_Q0_PPC::mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { + void mnpack(int64_t m0, int64_t m, int64_t n0, int64_t n) { int m_rem = MIN(m - m0, 16); int n_rem = MIN(n - n0, 16); @@ -2585,8 +2840,7 @@ class tinyBLAS_HP16_PPC { } - template - void tinyBLAS_Q0_PPC::KERNEL_4x8(int64_t ii, int64_t jj) { + void KERNEL_4x8(int64_t ii, int64_t jj) { vec_t vec_A[8], vec_B[16] = {0}; acc_t acc_0, acc_1; std::array comparray {}; @@ -2594,26 +2848,26 @@ class tinyBLAS_HP16_PPC { vector float vs[8] = {0}; bool isAblock_q4 = std::is_same_v; for (int l = 0; l < k; l++) { - __builtin_mma_xxsetaccz(&acc_0); - __builtin_mma_xxsetaccz(&acc_1); + __builtin_mma_xxsetaccz(& acc_0); + __builtin_mma_xxsetaccz(& acc_1); if (std::is_same_v) { - packNormalInt4<4>((A+(ii*lda)+l), lda, 4, 4, (int8_t*)vec_A, comparray); + packNormalInt4<4>((A + (ii * lda) + l), lda, 4, 4, (int8_t *)vec_A, comparray); } else { - packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, 4, 8, (int8_t*)vec_A, false); + packNormal((const block_q8_0 *)(A + (ii * lda) + l), lda, 4, 8, (int8_t *)vec_A, false); } - packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true); + packNormal((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true); for(int x = 0; x < 8; x++) { - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x], vec_B[x+8]); + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvi8ger4pp(& acc_1, vec_A[x], vec_B[x+8]); } for (int I = 0; I<4; I++) { for (int J = 0; J<4; J++) { - *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d)); - *((float*)&vs[I+4]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d)); + *((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d)); + *((float *)& vs[I + 4] + J) = (unhalf((A +((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d)); } } if (!isAblock_q4) { - auto aoffset = A+(ii*lda)+l; + auto aoffset = A + (ii * lda) + l; for (int i = 0; i < 4; i++) { comparray[i] = 0; int ca = 0; @@ -2624,15 +2878,14 @@ class tinyBLAS_HP16_PPC { aoffset += lda; } } - compute(&acc_0, 0, 0, comparray, vs, fin_res); - compute(&acc_1, 0, 4, comparray, vs, fin_res); + compute(& acc_0, 0, 0, comparray, vs, fin_res); + compute(& acc_1, 0, 4, comparray, vs, fin_res); } save_res(ii, jj, 0, fin_res); - save_res(ii, jj+4, 4, fin_res); + save_res(ii, jj + 4, 4, fin_res); } - template - void tinyBLAS_Q0_PPC::KERNEL_8x4(int64_t ii, int64_t jj) { + void KERNEL_8x4(int64_t ii, int64_t jj) { vec_t vec_A[16], vec_B[8] = {0}; acc_t acc_0, acc_1; std::array comparray {}; @@ -2640,25 +2893,25 @@ class tinyBLAS_HP16_PPC { vector float vs[8] = {0}; bool isAblock_q4 = std::is_same_v; for (int l = 0; l < k; l++) { - __builtin_mma_xxsetaccz(&acc_0); - __builtin_mma_xxsetaccz(&acc_1); + __builtin_mma_xxsetaccz(& acc_0); + __builtin_mma_xxsetaccz(& acc_1); if (std::is_same_v) { - packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray); + packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray); } else { - packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); + packNormal((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false); } - packNormal((B+(jj*ldb)+l), ldb, 4, 8, (uint8_t*)vec_B, true); + packNormal((B + (jj * ldb) + l), ldb, 4, 8, (uint8_t *)vec_B, true); for(int x = 0; x < 8; x++) { - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]); + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]); } - for (int I = 0; I<8; I++) { - for (int J = 0; J<4; J++) { - *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d)); + for (int I = 0; I < 8; I++) { + for (int J = 0; J < 4; J++) { + *((float *)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d)); } } if (!isAblock_q4) { - auto aoffset = A+(ii*lda)+l; + auto aoffset = A + (ii * lda) + l; for (int i = 0; i < 8; i++) { comparray[i] = 0; int ca = 0; @@ -2669,15 +2922,14 @@ class tinyBLAS_HP16_PPC { aoffset += lda; } } - compute(&acc_0, 0, 0, comparray, vs, fin_res); - compute(&acc_1, 4, 4, comparray, vs, fin_res); + compute(& acc_0, 0, 0, comparray, vs, fin_res); + compute(& acc_1, 4, 4, comparray, vs, fin_res); } save_res(ii, jj, 0, fin_res); - save_res(ii+4, jj, 4, fin_res); + save_res(ii + 4, jj, 4, fin_res); } - template - void tinyBLAS_Q0_PPC::KERNEL_8x8(int64_t ii, int64_t jj) { + void KERNEL_8x8(int64_t ii, int64_t jj) { vec_t vec_A[16], vec_B[16] = {0}; acc_t acc_0, acc_1, acc_2, acc_3; acc_t acc_4, acc_5, acc_6, acc_7; @@ -2686,30 +2938,30 @@ class tinyBLAS_HP16_PPC { vector float vs[16] = {0}; bool isAblock_q4 = std::is_same_v; for (int l = 0; l < k; l++) { - __builtin_mma_xxsetaccz(&acc_0); - __builtin_mma_xxsetaccz(&acc_1); - __builtin_mma_xxsetaccz(&acc_2); - __builtin_mma_xxsetaccz(&acc_3); + __builtin_mma_xxsetaccz(& acc_0); + __builtin_mma_xxsetaccz(& acc_1); + __builtin_mma_xxsetaccz(& acc_2); + __builtin_mma_xxsetaccz(& acc_3); if (std::is_same_v) { - packNormalInt4<8>((A+(ii*lda)+l), lda, 8, 4, (int8_t*)vec_A, comparray); + packNormalInt4<8>((A + (ii * lda) + l), lda, 8, 4, (int8_t *)vec_A, comparray); } else { - packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, 8, 8, (int8_t*)vec_A, false); + packNormal((const block_q8_0 *)(A + (ii * lda) + l), lda, 8, 8, (int8_t *)vec_A, false); } - packNormal((B+(jj*ldb)+l), ldb, 8, 8, (uint8_t*)vec_B, true); + packNormal((B + (jj * ldb) + l), ldb, 8, 8, (uint8_t *)vec_B, true); for(int x = 0; x < 8; x++) { - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvi8ger4pp(&acc_1, vec_A[x+8], vec_B[x]); - __builtin_mma_xvi8ger4pp(&acc_2, vec_A[x], vec_B[x+8]); - __builtin_mma_xvi8ger4pp(&acc_3, vec_A[x+8], vec_B[x+8]); + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvi8ger4pp(& acc_1, vec_A[x + 8], vec_B[x]); + __builtin_mma_xvi8ger4pp(& acc_2, vec_A[x], vec_B[x + 8]); + __builtin_mma_xvi8ger4pp(& acc_3, vec_A[x + 8], vec_B[x + 8]); } - for (int I = 0; I<8; I++) { - for (int J = 0; J<4; J++) { - *((float*)&vs[I]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J)*ldb)+l)->d)); - *((float*)&vs[I+8]+J) = (unhalf((A+((ii+I)*lda)+l)->d) * unhalf((B+((jj+J+4)*ldb)+l)->d)); + for (int I = 0; I < 8 ; I++) { + for (int J = 0; J < 4; J++) { + *((float *)& vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d)); + *((float *)& vs[I + 8] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J + 4) * ldb) + l)->d)); } } if (!isAblock_q4) { - auto aoffset = A+(ii*lda)+l; + auto aoffset = A + (ii * lda) + l; for (int i = 0; i < 8; i++) { comparray[i] = 0; int ca = 0; @@ -2720,19 +2972,99 @@ class tinyBLAS_HP16_PPC { aoffset += lda; } } - compute(&acc_0, 0, 0, comparray, vs, fin_res); - compute(&acc_1, 4, 4, comparray, vs, fin_res); - compute(&acc_2, 0, 8, comparray, vs, fin_res); - compute(&acc_3, 4, 12, comparray, vs, fin_res); + compute(& acc_0, 0, 0, comparray, vs, fin_res); + compute(& acc_1, 4, 4, comparray, vs, fin_res); + compute(& acc_2, 0, 8, comparray, vs, fin_res); + compute(& acc_3, 4, 12, comparray, vs, fin_res); } save_res(ii, jj, 0, fin_res); - save_res(ii+4, jj, 4, fin_res); - save_res(ii, jj+4, 8, fin_res); - save_res(ii+4, jj+4, 12, fin_res); + save_res(ii + 4, jj, 4, fin_res); + save_res(ii, jj + 4, 8, fin_res); + save_res(ii + 4, jj + 4, 12, fin_res); } - template - void tinyBLAS_Q0_PPC::gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) { + void KERNEL_Q0(int64_t ii, int64_t jj, int64_t mc, int64_t nc, int64_t kc, int64_t l, vec_t * vec_A, vec_t * vec_B) { + acc_t acc[8]; + for (int i = 0; i < mc ; i += 16) { + for (int j = 0; j < nc; j += 8) { + int A0_base = (i / 16) * (2 * 32 * kc); + int B0_base = (j / 8) * (32 * kc); + for (int x = 0; x < 8; x++) { + __builtin_mma_xxsetaccz(&acc[x]); + } + for (int64_t kk = 0; kk < kc; kk++) { + int A0_block_idx = A0_base + kk * 32; + int B0_block_idx = B0_base + kk * 32; + int A1_block_idx = A0_block_idx + 32 * kc; + int B1_block_idx = B0_block_idx + 32 * kc; + vec_t * A0_block = & vec_A[A0_block_idx]; + vec_t * B0_block = & vec_B[B0_block_idx]; + vec_t * A1_block = & vec_A[A1_block_idx]; + for (int it = 0; it < 4; it++) { + for (int x = 0; x < 4; x++) { + __builtin_mma_xvf16ger2pp(& acc[0], A0_block[8 * it + x], B0_block[8 * it + x]); + __builtin_mma_xvf16ger2pp(& acc[1], A0_block[8 * it + x], B0_block[8 * it + x + 4]); + __builtin_mma_xvf16ger2pp(& acc[2], A0_block[8 * it + x + 4], B0_block[8 * it + x]); + __builtin_mma_xvf16ger2pp(& acc[3], A0_block[8 * it + x + 4], B0_block[8 * it + x + 4]); + __builtin_mma_xvf16ger2pp(& acc[4], A1_block[8 * it + x], B0_block[8 * it + x]); + __builtin_mma_xvf16ger2pp(& acc[5], A1_block[8 * it + x], B0_block[8 * it+ x + 4]); + __builtin_mma_xvf16ger2pp(& acc[6], A1_block[8 * it + x + 4], B0_block[8 * it + x]); + __builtin_mma_xvf16ger2pp(& acc[7], A1_block[8 * it + x + 4], B0_block[8 * it + x + 4]); + } + } + } + if (l == 0) { + save_acc(& acc[0], ii + i, jj + j); + save_acc(& acc[1], ii + i, jj + j + 4); + save_acc(& acc[2], ii + i + 4, jj + j); + save_acc(& acc[3], ii + i + 4, jj + j + 4); + save_acc(& acc[4], ii + i + 8, jj + j); + save_acc(& acc[5], ii + i + 8, jj + j + 4); + save_acc(& acc[6], ii + i + 12, jj + j); + save_acc(& acc[7], ii + i + 12, jj + j + 4); + } else { + add_save_acc(& acc[0], ii + i, jj + j); + add_save_acc(& acc[1], ii + i, jj + j + 4); + add_save_acc(& acc[2], ii + i + 4, jj + j); + add_save_acc(& acc[3], ii + i + 4, jj + j + 4); + add_save_acc(& acc[4], ii + i + 8, jj + j); + add_save_acc(& acc[5], ii + i + 8, jj + j + 4); + add_save_acc(& acc[6], ii + i + 12, jj + j); + add_save_acc(& acc[7], ii + i + 12, jj + j + 4); + } + } + } + } + + void matmul_tiled(int64_t m, int64_t n, int64_t mc, int64_t nc, int64_t kc) { + vec_t A_pack[mc * kc * 4]; + vec_t B_pack[nc * kc * 4]; + constexpr bool is_Ablock_q4 = std::is_same_v; + int64_t ytiles = m / mc; + int64_t xtiles = n / nc; + int64_t tiles = xtiles * ytiles; + int64_t duty = (tiles + nth - 1) / nth; + int64_t start = duty * ith; + int64_t end = start + duty; + if (end > tiles) { + end = tiles; + } + for (int64_t job = start; job < end; ++job) { + int64_t ii = (job / xtiles) * mc; + int64_t jj = (job % xtiles) * nc; + for (int64_t kk = 0; kk < k; kk += kc) { + if constexpr(is_Ablock_q4) { + packNormal_q4_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack); + } else { + packNormal_q8_fp16(A + ii * lda + kk, lda, mc, kc, (uint8_t *)A_pack); + } + packNormal_q8_fp16(B + jj * ldb + kk, ldb, nc, kc, (uint8_t *)B_pack); + KERNEL_Q0(ii, jj, mc, nc, kc, kk, A_pack, B_pack); + } + } + } + + void gemm_small(int64_t m0, int64_t m, int64_t n0, int64_t n, int RM, int RN) { int64_t ytiles = (m - m0) / RM; int64_t xtiles = (n - n0) / RN; int64_t tiles = xtiles * ytiles; @@ -2754,32 +3086,32 @@ class tinyBLAS_HP16_PPC { vector float fin_res[4] = {0}; vector float vs[4] = {0}; vector float CA[4] = {0}; - __builtin_prefetch((A+(ii*lda)+0)->qs, 0, 1); // prefetch first value - __builtin_prefetch((B+(jj*ldb)+0)->qs, 0, 1); // prefetch first value + __builtin_prefetch((A + (ii * lda) + 0)->qs, 0, 1); // prefetch first value + __builtin_prefetch((B + (jj * ldb) + 0)->qs, 0, 1); // prefetch first value for (int l = 0; l < k; l++) { - __builtin_prefetch((A+(ii*lda)+(l+1))->qs, 0, 1); // prefetch one loop ahead - __builtin_prefetch((B+(jj*ldb)+(l+1))->qs, 0, 1); // prefetch one loop ahead - __builtin_mma_xxsetaccz(&acc_0); + __builtin_prefetch((A + (ii * lda) + (l + 1))->qs, 0, 1); // prefetch one loop ahead + __builtin_prefetch((B + (jj * ldb) + (l + 1))->qs, 0, 1); // prefetch one loop ahead + __builtin_mma_xxsetaccz(& acc_0); if (isAblock_q4) { - packNormalInt4<4>((A+(ii*lda)+l), lda, RM, 4, (int8_t*)vec_A, comparray); + packNormalInt4<4>((A + (ii * lda) + l), lda, RM, 4, (int8_t *)vec_A, comparray); } else { - packNormal((const block_q8_0*)(A+(ii*lda)+l), lda, RM, 8, (int8_t*)vec_A, false); + packNormal((const block_q8_0 *)(A + (ii * lda) + l), lda, RM, 8, (int8_t *)vec_A, false); } - packNormal((B+(jj*ldb)+l), ldb, RN, 8, (uint8_t*)vec_B, true); - for(int x = 0; x < 8; x+=4) { - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x], vec_B[x]); - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+1], vec_B[x+1]); - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+2], vec_B[x+2]); - __builtin_mma_xvi8ger4pp(&acc_0, vec_A[x+3], vec_B[x+3]); + packNormal((B + (jj * ldb) + l), ldb, RN, 8, (uint8_t *)vec_B, true); + for (int x = 0; x < 8; x += 4) { + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x], vec_B[x]); + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 1], vec_B[x + 1]); + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 2], vec_B[x + 2]); + __builtin_mma_xvi8ger4pp(& acc_0, vec_A[x + 3], vec_B[x + 3]); } - for (int I = 0; Id) * unhalf((B+((jj+J)*ldb)+l)->d)); + for (int I = 0; I < RM; I++) { + for (int J = 0; J < RN; J++) { + *((float*)&vs[I] + J) = (unhalf((A + ((ii + I) * lda) + l)->d) * unhalf((B + ((jj + J) * ldb) + l)->d)); } } - __builtin_mma_disassemble_acc(vec_C, &acc_0); + __builtin_mma_disassemble_acc(vec_C, & acc_0); if (!isAblock_q4) { - auto aoffset = A+(ii*lda)+l; + auto aoffset = A + (ii * lda) + l; for (int i = 0; i < RM; i++) { comparray[i] = 0; int ca = 0; @@ -2800,9 +3132,21 @@ class tinyBLAS_HP16_PPC { } } - template + template + inline void kernel(int64_t ii, int64_t jj) { + if constexpr(RM == 4 && RN == 8) { + KERNEL_4x8(ii,jj); + } else if constexpr(RM == 8 && RN == 4) { + KERNEL_8x4(ii,jj); + } else if constexpr(RM == 8 && RN == 8) { + KERNEL_8x8(ii,jj); + } else { + assert(false && "RN/RM values not supported"); + } + } + template - NOINLINE void tinyBLAS_Q0_PPC::gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { + NOINLINE void gemm(int64_t m0, int64_t m, int64_t n0, int64_t n) { int64_t ytiles = (m - m0) / RM; int64_t xtiles = (n - n0) / RN; int64_t tiles = xtiles * ytiles; @@ -2814,12 +3158,20 @@ class tinyBLAS_HP16_PPC { for (int64_t job = start; job < end; ++job) { int64_t ii = m0 + job / xtiles * RM; int64_t jj = n0 + job % xtiles * RN; - this->kernel(ii, jj); + kernel(ii, jj); } } - -template class tinyBLAS_Q0_PPC; -template class tinyBLAS_Q0_PPC; + const TA * const A; + const block_q8_0 * const B; + float * C; + const int64_t k; + int64_t kc; + const int64_t lda; + const int64_t ldb; + const int64_t ldc; + const int ith; + const int nth; +}; class tinyBLAS_PPC { public: diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index ce15b18ce0..b7a70e06f1 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -3,6 +3,7 @@ #include "ggml-cpu.h" #include "ggml-impl.h" #include "binary-ops.h" +#include "simd-gemm.h" #include "ggml.h" #include "unary-ops.h" #include "vec.h" @@ -2096,10 +2097,14 @@ static void ggml_compute_forward_gelu_f32( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2113,10 +2118,14 @@ static void ggml_compute_forward_gelu_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2135,10 +2144,14 @@ static void ggml_compute_forward_gelu_f16( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2152,10 +2165,14 @@ static void ggml_compute_forward_gelu_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2276,10 +2293,14 @@ static void ggml_compute_forward_gelu_erf_f32( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2293,10 +2314,14 @@ static void ggml_compute_forward_gelu_erf_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_erf_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2315,10 +2340,14 @@ static void ggml_compute_forward_gelu_erf_f16( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2332,10 +2361,14 @@ static void ggml_compute_forward_gelu_erf_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_erf_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2379,10 +2412,14 @@ static void ggml_compute_forward_gelu_quick_f32( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2396,10 +2433,14 @@ static void ggml_compute_forward_gelu_quick_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_quick_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2418,10 +2459,14 @@ static void ggml_compute_forward_gelu_quick_f16( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2435,10 +2480,14 @@ static void ggml_compute_forward_gelu_quick_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_quick_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2482,10 +2531,14 @@ static void ggml_compute_forward_silu_f32( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2499,10 +2552,14 @@ static void ggml_compute_forward_silu_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_silu_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2521,10 +2578,14 @@ static void ggml_compute_forward_silu_f16( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2538,10 +2599,14 @@ static void ggml_compute_forward_silu_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_silu_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -7629,8 +7694,7 @@ static void ggml_compute_forward_pad_f32( const ggml_tensor * src0 = dst->src[0]; - GGML_ASSERT(src0->nb[0] == sizeof(float)); - GGML_ASSERT( dst->nb[0] == sizeof(float)); + assert(dst->nb[0] == sizeof(float)); const int ith = params->ith; const int nth = params->nth; @@ -8326,10 +8390,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled( GGML_ASSERT(k->type == v->type); const ggml_type kv_type = k->type; - const auto * kv_type_traits_cpu = ggml_get_type_traits_cpu(kv_type); - const ggml_from_float_t kv_from_float = kv_type_traits_cpu->from_float; - const ggml_vec_dot_t kv_vec_dot = kv_type_traits_cpu->vec_dot; - const size_t kv_type_size = ggml_type_size(kv_type); // broadcast factors const int64_t rk2 = neq2/nek2; @@ -8361,8 +8421,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled( static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q; static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV; - GGML_ASSERT(nek1 % KV_TILE_SZ == 0 && "KV sequence length must be divisible by KV_TILE_SZ"); - int ir = ir0; while (ir < ir1) { // q indices for the start of this tile @@ -8389,18 +8447,20 @@ static void ggml_compute_forward_flash_attn_ext_tiled( } // Per-thread scratch layout: - // Q_q: Q_TILE_SZ * DK (converted Q tile in KV type) + // Q_q: Q_TILE_SZ * DK (converted Q tile — F32 for GEMM, KV type for scalar) // KQ: Q_TILE_SZ * KV_TILE_SZ (attention scores in float) // mask: Q_TILE_SZ * KV_TILE_SZ (mask in float) // VKQ32: Q_TILE_SZ * DV (FP32 output accumulator) - // V32: KV_TILE_SZ * DV (F32 buffer for V tile - used for f166 conversion) - float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + CACHE_LINE_SIZE_F32); + // V32: KV_TILE_SZ * DV (F32 buffer for V tile) + // K_f32: KV_TILE_SZ * DK (F32 buffer for K tile — GEMM path) + float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + KV_TILE_SZ*DK + CACHE_LINE_SIZE_F32); void * Q_q = base; float * KQ = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float)); float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ; float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ; - float * V32 = VKQ32 + Q_TILE_SZ * DV; // F32 buffer for V tile + float * V32 = VKQ32 + Q_TILE_SZ * DV; + float * K_f32 = V32 + KV_TILE_SZ * DV; memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float)); memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float)); @@ -8413,28 +8473,38 @@ static void ggml_compute_forward_flash_attn_ext_tiled( const int iv3 = iq3 / rv3; const int iv2 = iq2 / rv2; - for (int tq = 0; tq < tile_rows; tq++) { - const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3)); - kv_from_float(pq, (char *)Q_q + tq * DK * kv_type_size, DK); - } - // Zero-pad remaining rows - for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) { - memset((char *)Q_q + tq * DK * kv_type_size, 0, DK * kv_type_size); + { + float * Q_f32 = (float *)Q_q; + for (int tq = 0; tq < tile_rows; tq++) { + const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3)); + memcpy(Q_f32 + tq * DK, pq, DK * sizeof(float)); + } + for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) { + memset(Q_f32 + tq * DK, 0, DK * sizeof(float)); + } } + memset(K_f32, 0, DK * KV_TILE_SZ * sizeof(float)); + memset(V32, 0, KV_TILE_SZ * DV * sizeof(float)); + for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) { + const int kv_tile = (int)std::min((int64_t)KV_TILE_SZ, nek1 - ic); // skip the tile entirely if all the masks are -inf if (mask) { bool can_skip = true; for (int tq = 0; tq < tile_rows; tq++) { const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]); - for (int tk = 0; tk < KV_TILE_SZ; tk++) { + for (int tk = 0; tk < kv_tile; tk++) { mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]); if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) { can_skip = false; } } + // Pad remaining mask entries with -inf + for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) { + mask32[tq * KV_TILE_SZ + tk] = -INFINITY; + } } if (can_skip) { @@ -8442,13 +8512,32 @@ static void ggml_compute_forward_flash_attn_ext_tiled( } } - for (int tq = 0; tq < Q_TILE_SZ; tq++) { - const void * q_row = (const char *)Q_q + tq * DK * kv_type_size; - for (int tk = 0; tk < KV_TILE_SZ; tk++) { - const void * k_row = (const char *) k->data + ((ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3); - float s; - kv_vec_dot(DK, &s, 0, k_row, 0, q_row, 0, 1); - KQ[tq * KV_TILE_SZ + tk] = s * scale; + // Pack K tile transposed: K_f32[dk][kv] so KV_TILE is contiguous (SIMD dim) + // Zero-pad the last tile so the GEMM always operates on KV_TILE_SZ columns + for (int tk = 0; tk < kv_tile; tk++) { + const char * k_data = (const char *)k->data + (ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3; + if (kv_type == GGML_TYPE_F16) { + const ggml_fp16_t * k_f16 = (const ggml_fp16_t *)k_data; + for (int64_t dk = 0; dk < DK; dk++) { + K_f32[dk * KV_TILE_SZ + tk] = GGML_CPU_FP16_TO_FP32(k_f16[dk]); + } + } else { + const float * k_f32_src = (const float *)k_data; + for (int64_t dk = 0; dk < DK; dk++) { + K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk]; + } + } + } + memset(KQ, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float)); + simd_gemm(KQ, (const float *)Q_q, K_f32, Q_TILE_SZ, DK, KV_TILE_SZ); + ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, scale); + + // Set padded KQ entries to -inf so softmax gives them zero weight + if (kv_tile < KV_TILE_SZ) { + for (int tq = 0; tq < Q_TILE_SZ; tq++) { + for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) { + KQ[tq * KV_TILE_SZ + tk] = -INFINITY; + } } } @@ -8488,33 +8577,22 @@ static void ggml_compute_forward_flash_attn_ext_tiled( S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew); } - // Convert V tile to F32 first (if F16), then do MAD - // On x86, ggml_vec_mad_f16 internall converts F16<->F32 on every load/store, so pre-converting is faster. - // TODO: on ARM, native f16 should be faster - if (kv_type == GGML_TYPE_F16) { - for (int tk = 0; tk < KV_TILE_SZ; tk++) { - const ggml_fp16_t * v_row = (const ggml_fp16_t *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3)); - ggml_fp16_to_fp32_row(v_row, V32 + tk * DV, DV); - } - for (int tq = 0; tq < Q_TILE_SZ; tq++) { - if (skip[tq]) continue; - float * vkq_row = VKQ32 + tq * DV; - for (int tk = 0; tk < KV_TILE_SZ; tk++) { - const float p = KQ[tq * KV_TILE_SZ + tk]; - ggml_vec_mad_f32(DV, vkq_row, V32 + tk * DV, p); - } - } - } else { - for (int tq = 0; tq < Q_TILE_SZ; tq++) { - if (skip[tq]) continue; - float * vkq_row = VKQ32 + tq * DV; - for (int tk = 0; tk < KV_TILE_SZ; tk++) { - const float p = KQ[tq * KV_TILE_SZ + tk]; - const float * v_row = (const float *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3)); - ggml_vec_mad_f32(DV, vkq_row, v_row, p); - } + // V accumulation: VKQ32 += softmax(KQ) * V + // Pack V tile to contiguous F32, zero-padded + for (int tk = 0; tk < kv_tile; tk++) { + const char * v_data = (const char *)v->data + (ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3; + if (kv_type == GGML_TYPE_F16) { + ggml_fp16_to_fp32_row((const ggml_fp16_t *)v_data, V32 + tk * DV, DV); + } else { + memcpy(V32 + tk * DV, v_data, DV * sizeof(float)); } } + for (int tq = 0; tq < Q_TILE_SZ; tq++) { + if (skip[tq]) { + memset(KQ + tq * KV_TILE_SZ, 0, KV_TILE_SZ * sizeof(float)); + } + } + simd_gemm(VKQ32, KQ, V32, Q_TILE_SZ, KV_TILE_SZ, DV); } // sinks (apply only to valid rows in the tile) @@ -8731,15 +8809,15 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int64_t dr = (nr + nchunk - 1) / nchunk; - static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV; static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q; - const bool use_tiled = !use_ref && + bool use_tiled = !use_ref && (q->type == GGML_TYPE_F32 && kv_is_f32_or_f16 && k->type == v->type && - nek1 % KV_TILE_SZ == 0 && neq1 >= Q_TILE_SZ); - +#ifdef GGML_SIMD + use_tiled &= (DV % GGML_F32_EPR == 0); +#endif int current_chunk = ith; while (current_chunk < nchunk) { diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 24e8ab4618..f94426ddd7 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -256,6 +256,200 @@ template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTR ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row); } +template +static void ggml_gemv_q6_K_NxM_q8_K_generic_impl(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int blocklen = M; + constexpr int ncols_interleaved = N; + const int qk = QK_K; + const int nb = n / qk; + const int blocks_per_half = 64 / blocklen; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[8]; + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0f; + } + + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen; + const int base_h = base_l + 64; + + const int scale_idx_l = base_l / 16; + const int scale_idx_h = base_h / 16; + + const int qh_shift_l = ((base_l % 128) / 32) * 2; + const int qh_shift_h = ((base_h % 128) / 32) * 2; + + const int qh_half_l = (base_l / 128) * 32; + const int qh_half_h = (base_h / 128) * 32; + + for (int j = 0; j < ncols_interleaved; j++) { + const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j]; + const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j]; + + int sumi_l = 0; + int sumi_h = 0; + + for (int i = 0; i < blocklen; i++) { + const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i; + const int l_4 = b_ptr[l].ql[ql_pos] & 0xF; + const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF; + + const int qh_idx_l = qh_half_l + ((base_l + i) % 32); + const int qh_chunk_l = qh_idx_l / blocklen; + const int qh_pos_l = qh_idx_l % blocklen; + const int qh_offset_l = qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l; + const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3; + + const int qh_idx_h = qh_half_h + ((base_h + i) % 32); + const int qh_chunk_h = qh_idx_h / blocklen; + const int qh_pos_h = qh_idx_h % blocklen; + const int qh_offset_h = qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h; + const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3; + + const int q_l = ((hi_2_l << 4) | l_4) - 32; + const int q_h = ((hi_2_h << 4) | hi_4) - 32; + + const int8_t a_l = a_ptr[l].qs[base_l + i]; + const int8_t a_h = a_ptr[l].qs[base_h + i]; + + sumi_l += q_l * a_l; + sumi_h += q_h * a_h; + } + + sumf[j] += + (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + } + + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j]; + } + } +} + +template +static void ggml_gemm_q6_K_NxM_q8_K_generic_impl(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int blocklen = M; + constexpr int ncols_interleaved = N; + const int qk = QK_K; + const int nb = n / qk; + const int blocks_per_half = 64 / blocklen; + const int q8_half_stride = 512; + const int q8_low_high_step = 256; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + + float sumf[4][8]; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); + + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0f; + } + } + + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen; + const int base_h = base_l + 64; + + const int scale_idx_l = base_l / 16; + const int scale_idx_h = base_h / 16; + + const int qh_shift_l = ((base_l % 128) / 32) * 2; + const int qh_shift_h = ((base_h % 128) / 32) * 2; + + const int qh_half_l = (base_l / 128) * 32; + const int qh_half_h = (base_h / 128) * 32; + + const int q8_base = (k / blocks_per_half) * q8_half_stride + (k % blocks_per_half) * (blocklen * 4); + + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j]; + const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j]; + + int sumi_l = 0; + int sumi_h = 0; + + for (int i = 0; i < blocklen; i++) { + const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i; + const int l_4 = b_ptr[l].ql[ql_pos] & 0xF; + const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF; + + const int qh_idx_l = qh_half_l + ((base_l + i) % 32); + const int qh_chunk_l = qh_idx_l / blocklen; + const int qh_pos_l = qh_idx_l % blocklen; + const int qh_offset_l = + qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l; + const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3; + + const int qh_idx_h = qh_half_h + ((base_h + i) % 32); + const int qh_chunk_h = qh_idx_h / blocklen; + const int qh_pos_h = qh_idx_h % blocklen; + const int qh_offset_h = + qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h; + const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3; + + const int q_l = ((hi_2_l << 4) | l_4) - 32; + const int q_h = ((hi_2_h << 4) | hi_4) - 32; + + const int8_t q8_l = a_ptr[l].qs[q8_base + m * blocklen + i]; + const int8_t q8_h = a_ptr[l].qs[q8_base + m * blocklen + i + q8_low_high_step]; + + sumi_l += q_l * q8_l; + sumi_h += q_h * q8_h; + } + + sumf[m][j] += (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * + a_ptr[l].d[m]; + } + } + } + } + + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } + } +} + extern "C" { void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -704,94 +898,12 @@ void ggml_gemv_q5_K_8x8_q8_K_generic(int n, } +void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q6_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { - constexpr int qk = QK_K; - const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; - - assert(n % qk == 0); - assert(nc % ncols_interleaved == 0); - - UNUSED(bs); - UNUSED(nr); - - float sumf[8]; - - const block_q8_K * a_ptr = (const block_q8_K *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); - - for (int j = 0; j < ncols_interleaved; j++) { - sumf[j] = 0.0f; - } - - for (int l = 0; l < nb; l++) { - - - for (int k = 0; k < 16; k++) { - // k = 0.. 7 weights 0-63 low, 64-127 high - // k = 8..15 weights 128-191 low, 192-255 high - const int base_l = (k / 8) * 128 + (k % 8) * 8; - const int base_h = base_l + 64; - - const int scale_idx_l = base_l / 16; - const int scale_idx_h = base_h / 16; - - // Bit shift cycles 0,2,4,6 for each 32-value group within a 128-value half - const int qh_shift_l = ((base_l % 128) / 32) * 2; - const int qh_shift_h = ((base_h % 128) / 32) * 2; - - // qh_half: offset to the correct 32-byte half (0 or 32) - const int qh_half_l = (base_l / 128) * 32; - const int qh_half_h = (base_h / 128) * 32; - - for (int j = 0; j < ncols_interleaved; j++) { - // Interleaved scales - const int8_t scale_l = b_ptr[l].scales[scale_idx_l * 8 + j]; - const int8_t scale_h = b_ptr[l].scales[scale_idx_h * 8 + j]; - - int sumi_l = 0; - int sumi_h = 0; - - for (int i = 0; i < blocklen; i++) { - const int ql_pos = k * 64 + j * 8 + i; - const int l_4 = b_ptr[l].ql[ql_pos] & 0xF; - const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF; - - // qh indexing with 8-byte interleaving (like q5_K) - const int qh_byte_l = qh_half_l + ((base_l + i) % 32); - const int qh_chunk_l = qh_byte_l / 8; - const int qh_pos_l = qh_byte_l % 8; - const int qh_offset_l = qh_chunk_l * 64 + j * 8 + qh_pos_l; - const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3; - - const int qh_byte_h = qh_half_h + ((base_h + i) % 32); - const int qh_chunk_h = qh_byte_h / 8; - const int qh_pos_h = qh_byte_h % 8; - const int qh_offset_h = qh_chunk_h * 64 + j * 8 + qh_pos_h; - const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3; - - const int q_l = ((hi_2_l << 4) | l_4) - 32; - const int q_h = ((hi_2_h << 4) | hi_4) - 32; - - const int8_t a_l = a_ptr[l].qs[base_l + i]; - const int8_t a_h = a_ptr[l].qs[base_h + i]; - - sumi_l += q_l * a_l; - sumi_h += q_h * a_h; - } - - sumf[j] += - (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; - } - } - } - - for (int j = 0; j < ncols_interleaved; j++) { - s[x * ncols_interleaved + j] = sumf[j]; - } - } + ggml_gemv_q6_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc); } void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -1485,109 +1597,12 @@ void ggml_gemm_q5_K_8x8_q8_K_generic(int n, } } -void ggml_gemm_q6_K_8x8_q8_K_generic(int n, - float * GGML_RESTRICT s, - size_t bs, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - int nr, - int nc) { - const int qk = QK_K; - const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; +void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q6_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc); +} - assert(n % qk == 0); - assert(nr % 4 == 0); - assert(nc % ncols_interleaved == 0); - - UNUSED(bs); - - float sumf[4][8]; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb); - - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumf[m][j] = 0.0f; - } - } - - for (int l = 0; l < nb; l++) { - for (int k = 0; k < 16; k++) { - // k = 0.. 7 weights 0-63 low, 64-127 high - // k = 8..15 weights 128-191 low, 192-255 high - const int base_l = (k / 8) * 128 + (k % 8) * 8; - const int base_h = base_l + 64; - - const int scale_idx_l = base_l / 16; - const int scale_idx_h = base_h / 16; - - // Bit shift cycles 0,2,4,6 for each 32-value group within a 128-value half - const int qh_shift_l = ((base_l % 128) / 32) * 2; - const int qh_shift_h = ((base_h % 128) / 32) * 2; - - // qh_half: offset to the correct 32-byte half (0 or 32) - const int qh_half_l = (base_l / 128) * 32; - const int qh_half_h = (base_h / 128) * 32; - - // Activation base indices for q8_Kx4 interleaved format - // Layout: 128-value halves (k/8), then 8-value sub-blocks (k%8) with stride 32 - const int q8_base = (k / 8) * 512 + (k % 8) * 32; - - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - // Interleaved scales - const int8_t scale_l = b_ptr[l].scales[scale_idx_l * 8 + j]; - const int8_t scale_h = b_ptr[l].scales[scale_idx_h * 8 + j]; - - int sumi_l = 0; - int sumi_h = 0; - - for (int i = 0; i < blocklen; i++) { - const int ql_pos = k * 64 + j * 8 + i; - const int l_4 = b_ptr[l].ql[ql_pos] & 0xF; - const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF; - - const int qh_idx_l = qh_half_l + ((base_l + i) % 32); - const int qh_chunk_l = qh_idx_l / 8; - const int qh_pos_l = qh_idx_l % 8; - const int qh_offset_l = qh_chunk_l * 64 + j * 8 + qh_pos_l; - const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3; - - const int qh_idx_h = qh_half_h + ((base_h + i) % 32); - const int qh_chunk_h = qh_idx_h / 8; - const int qh_pos_h = qh_idx_h % 8; - const int qh_offset_h = qh_chunk_h * 64 + j * 8 + qh_pos_h; - const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3; - - const int q_l = ((hi_2_l << 4) | l_4) - 32; - const int q_h = ((hi_2_h << 4) | hi_4) - 32; - - const int8_t q8_l = a_ptr[l].qs[q8_base + m * 8 + i]; - const int8_t q8_h = a_ptr[l].qs[q8_base + m * 8 + i + 256]; - - sumi_l += q_l * q8_l; - sumi_h += q_h * q8_h; - } - - sumf[m][j] += (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * - a_ptr[l].d[m]; - } - } - } - } - - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; - } - } - } - } +void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q6_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc); } void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -1901,9 +1916,10 @@ static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_in int src_offset = (i / 8) * blck_size_interleave; int dst_offset = i * blck_size_interleave; + // buffer large enough for the max interleave block size (8 bytes) uint64_t elems; - memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); - memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); + memcpy(&elems, &in[src_id].qs[src_offset], blck_size_interleave); + memcpy(&out.qs[dst_offset], &elems, blck_size_interleave); } // The below logic is designed so as to unpack and rearrange scales and mins values in Q4_K @@ -2097,18 +2113,18 @@ static block_q6_Kx8 make_block_q6_Kx8(block_q6_K * in, unsigned int blck_size_in } const int end_ls = QK_K * 4 / blck_size_interleave; - // Interleave Q6_K quants by taking 8 bytes at a time + // Interleave Q6_K quants by taking blck_size_interleave bytes at a time for (int i = 0; i < end_ls; ++i) { int src_id = i % n_blocks; int src_offset = (i / n_blocks) * blck_size_interleave; int dst_offset = i * blck_size_interleave; uint64_t elem_ls; - memcpy(&elem_ls, &in[src_id].ql[src_offset], sizeof(uint64_t)); - memcpy(&out.ql[dst_offset], &elem_ls, sizeof(uint64_t)); + memcpy(&elem_ls, &in[src_id].ql[src_offset], blck_size_interleave); + memcpy(&out.ql[dst_offset], &elem_ls, blck_size_interleave); } - // Interleave high bits using same 8-byte pattern as low bits + // Interleave high bits using same chunk size as low bits const int end_hs = end_ls / 2; for (int i = 0; i < end_hs; ++i) { int src_id = i % n_blocks; @@ -2116,8 +2132,8 @@ static block_q6_Kx8 make_block_q6_Kx8(block_q6_K * in, unsigned int blck_size_in int dst_offset = i * blck_size_interleave; uint64_t elem_hs; - memcpy(&elem_hs, &in[src_id].qh[src_offset], sizeof(uint64_t)); - memcpy(&out.qh[dst_offset], &elem_hs, sizeof(uint64_t)); + memcpy(&elem_hs, &in[src_id].qh[src_offset], blck_size_interleave); + memcpy(&out.qh[dst_offset], &elem_hs, blck_size_interleave); } // The below logic is designed so as to unpack and rearrange scales in Q6_K @@ -2262,7 +2278,7 @@ static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t, static int repack_q6_K_to_q6_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q6_K); - GGML_ASSERT(interleave_block == 8); + GGML_ASSERT(interleave_block == 4 || interleave_block == 8); constexpr int nrows_interleaved = 8; block_q6_Kx8 * dst = (block_q6_Kx8 *)t->data; @@ -2511,6 +2527,10 @@ template <> int repack(struct ggml_tensor * t, const void * da return repack_q5_K_to_q5_K_8_bl(t, 8, data, data_size); } +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q6_K_to_q6_K_8_bl(t, 4, data, data_size); +} + template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_q6_K_to_q6_K_8_bl(t, 8, data, data_size); } @@ -2575,6 +2595,10 @@ template <> void gemv(int n, float * s, size_t ggml_gemv_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q6_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); +} + template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } @@ -2634,6 +2658,10 @@ template <> void gemm(int n, float * s, size_t ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q6_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); +} + template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } @@ -3043,6 +3071,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits q5_K_8x8_q8_K; // instance for Q6_K + static const ggml::cpu::repack::tensor_traits q6_K_8x4_q8_K; static const ggml::cpu::repack::tensor_traits q6_K_8x8_q8_K; // instance for Q2 @@ -3107,6 +3136,11 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &q6_K_8x8_q8_K; } } + if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { + if (cur->ne[1] % 8 == 0) { + return &q6_K_8x4_q8_K; + } + } } else if (cur->type == GGML_TYPE_IQ4_NL) { if (ggml_cpu_has_avx2()) { if (cur->ne[1] % 8 == 0) { diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index 855320eeeb..39b6b48238 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -112,6 +112,7 @@ void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -122,6 +123,7 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -142,6 +144,7 @@ void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -152,6 +155,7 @@ void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); diff --git a/ggml/src/ggml-cpu/simd-gemm.h b/ggml/src/ggml-cpu/simd-gemm.h new file mode 100644 index 0000000000..78d663e593 --- /dev/null +++ b/ggml/src/ggml-cpu/simd-gemm.h @@ -0,0 +1,136 @@ +#pragma once + +// Computes C[M x N] += A[M x K] * B[K x N] + +#include "simd-mappings.h" + +// TODO: add support for sizeless vector types +#if defined(GGML_SIMD) && !defined(__ARM_FEATURE_SVE) && !defined(__riscv_v_intrinsic) + +// TODO: untested on avx512 +// These are in units of GGML_F32_EPR +#if defined(__AVX512F__) || defined (__ARM_NEON__) + static constexpr int GEMM_RM = 4; + static constexpr int GEMM_RN = 4; // 16+4+1 = 25/32 +#elif defined(__AVX2__) || defined(__AVX__) + static constexpr int GEMM_RM = 6; + static constexpr int GEMM_RN = 2; // 12+2+1 = 15/16 +#else + static constexpr int GEMM_RM = 2; + static constexpr int GEMM_RN = 2; +#endif + +template +static inline void simd_gemm_ukernel( + float * GGML_RESTRICT C, + const float * GGML_RESTRICT A, + const float * GGML_RESTRICT B, + int K, int N) +{ + static constexpr int KN = GGML_F32_EPR; + + GGML_F32_VEC acc[RM][RN]; + for (int64_t i = 0; i < RM; i++) { + for (int r = 0; r < RN; r++) { + acc[i][r] = GGML_F32_VEC_LOAD(C + i * N + r * KN); + } + } + + for (int64_t kk = 0; kk < K; kk++) { + GGML_F32_VEC Bv[RN]; + for (int r = 0; r < RN; r++) { + Bv[r] = GGML_F32_VEC_LOAD(B + kk * N + r * KN); + } + for (int64_t i = 0; i < RM; i++) { + GGML_F32_VEC p = GGML_F32_VEC_SET1(A[i * K + kk]); + for (int r = 0; r < RN; r++) { + acc[i][r] = GGML_F32_VEC_FMA(acc[i][r], Bv[r], p); + } + } + } + + for (int64_t i = 0; i < RM; i++) { + for (int r = 0; r < RN; r++) { + GGML_F32_VEC_STORE(C + i * N + r * KN, acc[i][r]); + } + } +} + +// C[M x N] += A[M x K] * B[K x N] +static void simd_gemm( + float * GGML_RESTRICT C, + const float * GGML_RESTRICT A, + const float * GGML_RESTRICT B, + int M, int K, int N) +{ + static constexpr int KN = GGML_F32_EPR; + + int64_t ii = 0; + for (; ii + GEMM_RM <= M; ii += GEMM_RM) { + int64_t jj = 0; + for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) { + simd_gemm_ukernel(C + jj, A, B + jj, K, N); + } + for (; jj + KN <= N; jj += KN) { + simd_gemm_ukernel(C + jj, A, B + jj, K, N); + } + for (; jj < N; jj++) { + for (int64_t i = 0; i < GEMM_RM; i++) { + float a = C[i * N + jj]; + for (int64_t kk = 0; kk < K; kk++) { + a += A[i + kk] * B[kk * N + jj]; + } + C[i * N + jj] = a; + } + } + + A += GEMM_RM * K; + C += GEMM_RM * N; + } + + // Tail rows: one at a time + for (; ii < M; ii++) { + int64_t jj = 0; + for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) { + simd_gemm_ukernel<1, GEMM_RN>(C + jj, A, B + jj, K, N); + } + for (; jj + KN <= N; jj += KN) { + simd_gemm_ukernel<1, 1>(C + jj, A, B + jj, K, N); + } + for (; jj < N; jj++) { + float a = C[jj]; + for (int64_t kk = 0; kk < K; kk++) { + a += A[kk] * B[kk * N + jj]; + } + C[jj] = a; + } + + A += K; + C += N; + } +} + +#if defined(__GNUC__) && !defined(__clang__) +#pragma GCC diagnostic pop +#endif + +#else // scalar path + +static void simd_gemm( + float * GGML_RESTRICT C, + const float * GGML_RESTRICT A, + const float * GGML_RESTRICT B, + int M, int K, int N) +{ + for (int64_t i = 0; i < M; i++) { + for (int64_t j = 0; j < N; j++) { + float sum = C[i * N + j]; + for (int64_t kk = 0; kk < K; kk++) { + sum += A[i * K + kk] * B[kk * N + j]; + } + C[i * N + j] = sum; + } + } +} + +#endif // GGML_SIMD diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h index 630e506542..22de55700d 100644 --- a/ggml/src/ggml-cpu/simd-mappings.h +++ b/ggml/src/ggml-cpu/simd-mappings.h @@ -1160,6 +1160,14 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) { float32x4_t tmp = x[0] + vec_reve(x[0]); \ res = tmp[0] + tmp[1]; \ } +#define GGML_F32x4_REDUCE_4(res, s0, s1, s2, s3) \ +{ \ + float32x4_t v = vec_add(vec_add(s0, s1), \ + vec_add(s2, s3)); \ + v = vec_add(v, vec_sld(v, v, 8)); \ + v = vec_add(v, vec_sld(v, v, 4)); \ + res += (ggml_float)vec_extract(v, 0); \ +} #define GGML_F32_VEC GGML_F32x4 #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO @@ -1209,6 +1217,24 @@ static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) { #define GGML_F16_VEC_MUL GGML_F32x4_MUL #define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE +// BF16 s390x +#define GGML_BF16_STEP 16 +#define GGML_BF16_EPR 8 + +#define GGML_BF16x8 __vector unsigned short +#define GGML_BF16x8_ZERO vec_splats((unsigned short)0) +#define GGML_BF16x8_LOAD(p) vec_xl(0, (const unsigned short *)(p)) + +#define GGML_BF16_VEC GGML_BF16x8 +#define GGML_BF16_VEC_ZERO GGML_BF16x8_ZERO +#define GGML_BF16_VEC_LOAD GGML_BF16x8_LOAD +#define GGML_BF16_TO_F32_LO(v) ((float32x4_t) vec_mergel((v), GGML_BF16_VEC_ZERO)) +#define GGML_BF16_TO_F32_HI(v) ((float32x4_t) vec_mergeh((v), GGML_BF16_VEC_ZERO)) +#define GGML_BF16_FMA_LO(acc, x, y) \ + (acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_LO(x), GGML_BF16_TO_F32_LO(y)) +#define GGML_BF16_FMA_HI(acc, x, y) \ + (acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_HI(x), GGML_BF16_TO_F32_HI(y)) + #elif defined(__riscv_v_intrinsic) // compatible with vlen >= 128 diff --git a/ggml/src/ggml-cpu/unary-ops.cpp b/ggml/src/ggml-cpu/unary-ops.cpp index 1d9873ad0f..1d8344436f 100644 --- a/ggml/src/ggml-cpu/unary-ops.cpp +++ b/ggml/src/ggml-cpu/unary-ops.cpp @@ -111,7 +111,7 @@ template static void apply_unary_op(const ggml_compute_params * params, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; - GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_contiguous_rows(src0) && ggml_is_contiguous_rows(dst) && ggml_are_same_shape(src0, dst)); GGML_TENSOR_UNARY_OP_LOCALS diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index 8708cd4e92..d0e4001338 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -236,8 +236,7 @@ void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * vfloat32m1_t redsum = __riscv_vfredusum_vs_f32m4_f32m1(vsum0, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl); sumf += __riscv_vfmv_f_s_f32m1_f32(redsum); -#endif -#if defined(__POWER9_VECTOR__) +#elif defined(__POWER9_VECTOR__) || defined(__VXE__) || defined(__VXE2__) const int np = (n & ~(GGML_BF16_STEP - 1)); if (np > 0) { GGML_F32_VEC sum[4] = {GGML_F32_VEC_ZERO}; diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index d313c1ac9a..262f88204e 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -64,7 +64,7 @@ if (CUDAToolkit_FOUND) FetchContent_Declare( CCCL GIT_REPOSITORY https://github.com/nvidia/cccl.git - GIT_TAG v3.2.0-rc2 + GIT_TAG v3.2.0 GIT_SHALLOW TRUE ) diff --git a/ggml/src/ggml-cuda/binbcast.cu b/ggml/src/ggml-cuda/binbcast.cu index 0e6d777b1e..7339fe0c07 100644 --- a/ggml/src/ggml-cuda/binbcast.cu +++ b/ggml/src/ggml-cuda/binbcast.cu @@ -39,13 +39,16 @@ static __global__ void k_bin_bcast(const src0_t * src0, const uint3 ne11, const uint3 ne12, const uint3 ne13, - /*int s0, */ const int s1, + /*const int s0,*/ + const int s1, const int s2, const int s3, - /*int s00,*/ const int s01, + const int s00, + const int s01, const int s02, const int s03, - /*int s10,*/ const int s11, + const int s10, + const int s11, const int s12, const int s13, src1_ptrs... src1s) { @@ -72,11 +75,11 @@ static __global__ void k_bin_bcast(const src0_t * src0, for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) { const uint32_t i10 = fastmodulo(i0, ne10); - float result = src0_row ? (float) src0_row[i0] : 0.0f; + float result = src0_row ? (float) src0_row[i0*s00] : 0.0f; if constexpr (sizeof...(src1_ptrs) > 0) { - result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10]))); + result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10]))); } else { - result = bin_op(result, (float)src1[i_src1 + i10]); + result = bin_op(result, (float)src1[i_src1 + i10*s10]); } dst_row[i0] = (dst_t) result; @@ -101,13 +104,16 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const uint3 ne11, const uint3 ne12, const uint3 ne13, - /*int s0, */ const int s1, + /*const int s0,*/ + const int s1, const int s2, const int s3, - /*int s00,*/ const int s01, + const int s00, + const int s01, const int s02, const int s03, - /*int s10,*/ const int s11, + const int s10, + const int s11, const int s12, const int s13, src1_ptrs... src1s) { @@ -135,11 +141,11 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const int i10 = fastmodulo(i0, ne10); - float result = src0_row ? (float) src0_row[i0] : 0.0f; + float result = src0_row ? (float) src0_row[i0*s00] : 0.0f; if constexpr (sizeof...(src1_ptrs) > 0) { - result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10]))); + result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10]))); } else { - result = bin_op(result, (float)src1[i_src1 + i10]); + result = bin_op(result, (float)src1[i_src1 + i10*s10]); } dst_row[i0] = (dst_t) result; @@ -179,7 +185,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * cnb[3] *= cne[3]; }; - if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { + if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) { for (int i = 0; i < 4; i++) { if (nr[i] != 1) { break; @@ -221,7 +227,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * size_t nb12 = cnb1[2]; size_t nb13 = cnb1[3]; - size_t s0 = nb0 / sizeof(dst_t); + //size_t s0 = nb0 / sizeof(dst_t); size_t s1 = nb1 / sizeof(dst_t); size_t s2 = nb2 / sizeof(dst_t); size_t s3 = nb3 / sizeof(dst_t); @@ -251,10 +257,6 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * GGML_ASSERT(nb12 % sizeof(src1_t) == 0); GGML_ASSERT(nb13 % sizeof(src1_t) == 0); - GGML_ASSERT(s0 == 1); - GGML_ASSERT(s00 == 1); - GGML_ASSERT(s10 == 1); - const int block_size = 128; int64_t hne0 = std::max(ne0 / 2LL, 1LL); @@ -284,31 +286,31 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor * k_bin_bcast_unravel<<>>( src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00,*/ s01, s02, s03, - /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); + /*s0,*/ s1, s2, s3, + s00, s01, s02, s03, + s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); } else { k_bin_bcast_unravel <<>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00,*/ s01, s02, s03, - /* s10,*/ s11, s12, s13); + /*s0,*/ s1, s2, s3, + s00, s01, s02, s03, + s10, s11, s12, s13); } } else { const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3); if constexpr (sizeof...(I) > 0) { k_bin_bcast<<>>( src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00,*/ s01, s02, s03, - /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); + /*s0,*/ s1, s2, s3, + s00 ,s01, s02, s03, + s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...); } else { k_bin_bcast<<>>( src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13, - /* s0, */ s1, s2, s3, - /* s00,*/ s01, s02, s03, - /* s10,*/ s11, s12, s13); + /*s0,*/ s1, s2, s3, + s00, s01, s02, s03, + s10, s11, s12, s13); } } } diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index ba3d4eeb88..09b6d5db6a 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -7,7 +7,8 @@ template static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, - const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t ne00, const int64_t ne01, + const int64_t ne0203, const uint3 ne02, const int64_t s01, const int64_t s02, const int64_t s03) { const int64_t i00 = 2 * (int64_t(blockDim.x)*blockIdx.x + threadIdx.x); @@ -16,23 +17,27 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __ } const int64_t i01 = blockIdx.y; - const int64_t i02 = blockIdx.z % ne02; - const int64_t i03 = blockIdx.z / ne02; - const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01; + for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) { + const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02); + const int64_t i02 = dm.y; + const int64_t i03 = dm.x; - const int64_t ib = ibx0 + i00/qk; // block index - const int64_t iqs = (i00%qk)/qr; // quant index - const int64_t iybs = i00 - i00%qk; // y block start index - const int64_t y_offset = qr == 1 ? 1 : qk/2; + const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01; - // dequantize - float2 v; - dequantize_kernel(vx, ib, iqs, v); + const int64_t ib = ibx0 + i00/qk; // block index + const int64_t iqs = (i00%qk)/qr; // quant index + const int64_t iybs = i00 - i00%qk; // y block start index + const int64_t y_offset = qr == 1 ? 1 : qk/2; - const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs; - y[iy0 + 0] = ggml_cuda_cast(v.x); - y[iy0 + y_offset] = ggml_cuda_cast(v.y); + // dequantize + float2 v; + dequantize_kernel(vx, ib, iqs, v); + + const int64_t iy0 = (i0203*ne01 + i01)*ne00 + iybs + iqs; + y[iy0 + 0] = ggml_cuda_cast(v.x); + y[iy0 + y_offset] = ggml_cuda_cast(v.y); + } } template @@ -485,9 +490,11 @@ template static void dequantize_block_cuda(const void * vx, dst_t * y, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) { - const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, ne02*ne03); + const int64_t ne0203 = ne02*ne03; + const uint3 ne02_fdv = init_fastdiv_values(ne02); + const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, (int)std::min(ne0203, (int64_t)65535)); dequantize_block<<>> - (vx, y, ne00, ne01, ne02, s01, s02, s03); + (vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03); } template @@ -612,7 +619,8 @@ static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t template static __global__ void convert_unary( - const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02, + const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, + const int64_t ne0203, const uint3 ne02, const int64_t s01, const int64_t s02, const int64_t s03) { const int64_t i00 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; @@ -621,23 +629,29 @@ static __global__ void convert_unary( } const int64_t i01 = blockIdx.y; - const int64_t i02 = blockIdx.z % ne02; - const int64_t i03 = blockIdx.z / ne02; const src_t * x = (const src_t *) vx; - const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00; - const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00; - y[iy] = ggml_cuda_cast(x[ix]); + for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) { + const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02); + const int64_t i02 = dm.y; + const int64_t i03 = dm.x; + + const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00; + const int64_t iy = (i0203*ne01 + i01)*ne00 + i00; + y[iy] = ggml_cuda_cast(x[ix]); + } } template static void convert_unary_cuda(const void * vx, dst_t * y, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) { - const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, ne02*ne03); + const int64_t ne0203 = ne02*ne03; + const uint3 ne02_fdv = init_fastdiv_values(ne02); + const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, (int)std::min(ne0203, (int64_t)65535)); convert_unary<<>> - (vx, y, ne00, ne01, ne02, s01, s02, s03); + (vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03); } template diff --git a/ggml/src/ggml-cuda/fattn-tile.cuh b/ggml/src/ggml-cuda/fattn-tile.cuh index b6db582281..f3fa80ab23 100644 --- a/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ggml/src/ggml-cuda/fattn-tile.cuh @@ -1186,8 +1186,10 @@ static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggm GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); const int gqa_ratio = Q->ne[2] / K->ne[2]; + // On NVIDIA (Pascal and older) the GQA optimizations seem to be detrimental in some cases. + // However, for DKQ == 576, DV == 512 only the kernel variant with GQA optimizations is implemented. const bool nvidia = GGML_CUDA_CC_IS_NVIDIA(ggml_cuda_info().devices[ggml_cuda_get_device()].cc); - const int gqa_limit = nvidia && gqa_ratio <= 4 ? 16 : INT_MAX; + const int gqa_limit = nvidia && gqa_ratio <= 4 && DV <= 256 ? 16 : INT_MAX; const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0; if constexpr (DV == 512) { diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 8694fd06c7..f19defbff9 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -63,11 +63,19 @@ static __global__ void flash_attn_ext_f16( constexpr int frag_m = ncols == 8 ? 32 : 16; constexpr int frag_n = ncols == 8 ? 8 : 16; static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); +#if defined(GGML_USE_HIP) && HIP_VERSION >= 60500000 + typedef wmma::fragment frag_a_K; + typedef wmma::fragment frag_a_V; + typedef wmma::fragment frag_b; + typedef wmma::fragment frag_c_KQ; + typedef wmma::fragment frag_c_VKQ; +#else typedef wmma::fragment frag_a_K; typedef wmma::fragment frag_a_V; typedef wmma::fragment frag_b; typedef wmma::fragment frag_c_KQ; typedef wmma::fragment frag_c_VKQ; +#endif constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. @@ -126,6 +134,19 @@ static __global__ void flash_attn_ext_f16( __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. half2 * VKQ2 = (half2 *) VKQ; + +#if defined(GGML_USE_HIP) && HIP_VERSION >= 60500000 + const _Float16 * K_h_f16 = reinterpret_cast(K_h); + const _Float16 * V_h_f16 = reinterpret_cast(V_h); + _Float16 * KQ_f16 = reinterpret_cast<_Float16 *>(KQ); + _Float16 * VKQ_f16 = reinterpret_cast<_Float16 *>(VKQ); +#else + const half * K_h_f16 = K_h; + const half * V_h_f16 = V_h; + half * KQ_f16 = KQ; + half * VKQ_f16 = VKQ; +#endif + #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; @@ -160,7 +181,7 @@ static __global__ void flash_attn_ext_f16( for (int i0 = 0; i0 < D; i0 += 16) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { - wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); + wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ_f16 + j0*D_padded + i0, D_padded); } } @@ -180,7 +201,7 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { frag_a_K K_a; - wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); + wmma::load_matrix_sync(K_a, K_h_f16 + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); @@ -310,7 +331,7 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + (threadIdx.y % VKQ_ratio)*16; wmma::load_matrix_sync( KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], - KQ + j0*(kqar*kqs_padded) + k, + KQ_f16 + j0*(kqar*kqs_padded) + k, kqar*kqs_padded); } } @@ -328,7 +349,7 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + (threadIdx.y % VKQ_ratio)*16; frag_a_V v_a; - wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); + wmma::load_matrix_sync(v_a, V_h_f16 + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); @@ -344,7 +365,7 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { wmma::store_matrix_sync( - KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), + KQ_f16 + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], D_padded, wmma::mem_col_major); } diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 9e77c231c8..ffa35eeb65 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2278,11 +2278,12 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; + // [TAG_MUL_MAT_ID_CUDA_GRAPHS] if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE); if (ne2 <= MMVQ_MAX_BATCH_SIZE) { if (ggml_is_quantized(src0->type)) { - if (ne2 <= 4) { + if (ne2 <= MMVQ_MMID_MAX_BATCH_SIZE) { ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst); return; } @@ -2305,6 +2306,8 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor * } } + // note: this path should not be reached when recording CUDA graphs, because it requires stream synchronization + // TODO: add asserts to verify this. should work with CUDA, HIP, etc. cudaStream_t stream = ctx.stream(); GGML_ASSERT(nb12 % nb11 == 0); @@ -2865,14 +2868,6 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { bool use_cuda_graph = true; // Loop over nodes in GGML graph to obtain info needed for CUDA graph - const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected"; - const std::string gemma3n_per_layer_proj_src1_name = "per_layer_proj"; - const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased"; - const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased"; - const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased"; - const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out"; - const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d"; - for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; @@ -2887,30 +2882,14 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { #endif } - if (node->op == GGML_OP_MUL_MAT_ID && node->ne[2] != 1) { - use_cuda_graph = false; // This node type is not supported by CUDA graph capture -#ifndef NDEBUG - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__); -#endif - } - - if (node->op == GGML_OP_ADD && - node->src[1] && node->src[1]->ne[1] > 1 && - (node->src[0] ? node->src[0]->name != gemma3n_per_layer_proj_src0_name : true) && - (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) && - strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 && - strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 && - strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 && - strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 && - strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) { - // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation - // by means of matching node names. See - // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and - // https://github.com/huggingface/transformers/blob/bda75b4011239d065de84aa3e744b67ebfa7b245/src/transformers/models/gemma3n/modeling_gemma3n.py#L1773, - // Generally, changes in batch size or context size can cause changes to the grid size of some kernels. + // [TAG_MUL_MAT_ID_CUDA_GRAPHS] + if (node->op == GGML_OP_MUL_MAT_ID && (!ggml_is_quantized(node->src[0]->type) || node->ne[2] > MMVQ_MMID_MAX_BATCH_SIZE)) { + // under these conditions, the mul_mat_id operation will need to synchronize the stream, so we cannot use CUDA graphs + // TODO: figure out a way to enable for larger batch sizes, without hurting performance + // ref: https://github.com/ggml-org/llama.cpp/pull/18958 use_cuda_graph = false; #ifndef NDEBUG - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to batch size > 1 [%s] [%ld %ld %ld %ld]\n", __func__, node->name, node->ne[0], node->ne[1], node->ne[2], node->ne[3]); + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to unsupported node type\n", __func__); #endif } @@ -3640,11 +3619,13 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud n_fuse++; if (n_fuse > 1) { + ggml_tensor fused_add_node; + memcpy(&fused_add_node, node, sizeof(ggml_tensor)); for (int j = 0; j < n_fuse - 1; ++j) { - node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1]; + fused_add_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1]; } - cgraph->nodes[i + n_fuse - 1]->data = node->data; - ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse); + fused_add_node.data = cgraph->nodes[i + n_fuse - 1]->data; + ggml_cuda_op_fused_add(*cuda_ctx, &fused_add_node, n_fuse); i += n_fuse - 1; continue; @@ -4542,6 +4523,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_CEIL: case GGML_UNARY_OP_ROUND: case GGML_UNARY_OP_TRUNC: + // TODO: should become: + //return ggml_is_contiguous_rows(op->src[0]); return ggml_is_contiguous(op->src[0]); default: return false; @@ -4820,8 +4803,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_TRANSPOSE_2D: case GGML_OP_POOL_2D: - case GGML_OP_ACC: return true; + case GGML_OP_ACC: + // TODO: extend support like so: + //return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]); + return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); case GGML_OP_SUM: return ggml_is_contiguous_rows(op->src[0]); case GGML_OP_TOP_K: @@ -4834,8 +4820,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: case GGML_OP_GROUP_NORM: - case GGML_OP_PAD: return ggml_is_contiguous(op->src[0]); + case GGML_OP_PAD: + return true; case GGML_OP_UPSCALE: case GGML_OP_PAD_REFLECT_1D: case GGML_OP_ARANGE: diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index f80f98cda2..255e59f6fc 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -2715,14 +2715,14 @@ template static __device__ __forceinline__ void loa #pragma unroll for (int l = 0; l < QR2_XXS; ++l) { - const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]); - const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F]; + const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[l]]; + const uint32_t signs = unpack_ksigns(aux32 >> (7 * l)); - const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000); - const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0); + const int signs0 = __vcmpne4(signs & 0x08040201, 0); + const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0); - const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); - const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); + const int signs1 = __vcmpne4(signs & 0x80402010, 0); + const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0; @@ -2733,12 +2733,12 @@ template static __device__ __forceinline__ void loa #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } - const int ls = aux32 >> 28; + const int ls = aux32 >> 27 | 1; // (scale * 2 + 1) const float d = bxi->d; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4; + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4 #else - x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4; + x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4 #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2776,11 +2776,14 @@ template static __device__ __forceinline__ void loa #pragma unroll for (int l = 0; l < QR2_XS; ++l) { - const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF)); - const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9)); + const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l] & 0x1FF]; + const uint32_t signs = unpack_ksigns(q2[l] >> 9); - const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); - const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); + const int signs0 = __vcmpne4(signs & 0x08040201, 0); + const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); + + const int signs1 = __vcmpne4(signs & 0x80402010, 0); + const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; @@ -2904,11 +2907,13 @@ template static __device__ __forceinline__ void loa #pragma unroll for (int l = 0; l < QR3_XXS; ++l) { const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]); + const uint32_t signs = unpack_ksigns(aux32 >> (7*l)); - const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F)); + const int signs0 = __vcmpne4(signs & 0x08040201, 0); + const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); - const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); - const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); + const int signs1 = __vcmpne4(signs & 0x80402010, 0); + const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l; diff --git a/ggml/src/ggml-cuda/mmvq.cuh b/ggml/src/ggml-cuda/mmvq.cuh index 4bb10cfaec..8a154631f6 100644 --- a/ggml/src/ggml-cuda/mmvq.cuh +++ b/ggml/src/ggml-cuda/mmvq.cuh @@ -1,6 +1,7 @@ #include "common.cuh" #define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels. +#define MMVQ_MMID_MAX_BATCH_SIZE 4 // Max. batch size for which to use MMVQ kernels for MUL_MAT_ID void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, const ggml_cuda_mm_fusion_args_host * fusion = nullptr); diff --git a/ggml/src/ggml-cuda/pad.cu b/ggml/src/ggml-cuda/pad.cu index 660c192e48..31cd00f778 100644 --- a/ggml/src/ggml-cuda/pad.cu +++ b/ggml/src/ggml-cuda/pad.cu @@ -7,7 +7,7 @@ __device__ __forceinline__ int64_t wrap_around(int64_t coord, int64_t size) { return (coord + size) % size; } -static __global__ void pad_f32(const float * src, float * dst, +static __global__ void pad_f32(const float * src, size_t s00, size_t s01, size_t s02, size_t s03, float * dst, const int lp0, const int rp0, const int lp1, const int rp1, const int lp2, const int rp2, const int lp3, const int rp3, const int ne0, const int ne1, const int ne2, const int ne3, @@ -34,11 +34,8 @@ static __global__ void pad_f32(const float * src, float * dst, const int64_t i01 = i1 - lp1; const int64_t i02 = i2 - lp2; const int64_t i03 = i3 - lp3; - const int64_t ne02 = ne2 - lp2 - rp2; - const int64_t ne01 = ne1 - lp1 - rp1; - const int64_t ne00 = ne0 - lp0 - rp0; - const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00; + const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00; dst[dst_idx] = src[src_idx]; } else { @@ -57,21 +54,21 @@ static __global__ void pad_f32(const float * src, float * dst, const int64_t i02 = wrap_around(i2 - lp2, ne02); const int64_t i03 = wrap_around(i3 - lp3, ne03); - const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00; + const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00; dst[dst_idx] = src[src_idx]; } } -static void pad_f32_cuda(const float * src, float * dst, +static void pad_f32_cuda(const float * src, size_t s00, size_t s01, size_t s02, size_t s03, float * dst, const int lp0, const int rp0, const int lp1, const int rp1, const int lp2, const int rp2, const int lp3, const int rp3, const int ne0, const int ne1, const int ne2, const int ne3, const bool circular, cudaStream_t stream) { int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE; dim3 gridDim(num_blocks, ne1, ne2 * ne3); - pad_f32<<>>(src, dst, + pad_f32<<>>(src, s00, s01, s02, s03, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1, ne2, ne3, circular); } @@ -82,9 +79,10 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { float * dst_d = (float *) dst->data; cudaStream_t stream = ctx.stream(); + GGML_TENSOR_UNARY_OP_LOCALS; + GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous(src0)); const int32_t lp0 = ((const int32_t *) (dst->op_params))[0]; const int32_t rp0 = ((const int32_t *) (dst->op_params))[1]; @@ -96,7 +94,12 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const int32_t rp3 = ((const int32_t *) (dst->op_params))[7]; const int32_t circular = ((const int32_t *) (dst->op_params))[8]; - pad_f32_cuda(src0_d, dst_d, + const size_t s00 = nb00 / ggml_type_size(src0->type); + const size_t s01 = nb01 / ggml_type_size(src0->type); + const size_t s02 = nb02 / ggml_type_size(src0->type); + const size_t s03 = nb03 / ggml_type_size(src0->type); + + pad_f32_cuda(src0_d, s00, s01, s02, s03, dst_d, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (bool) circular, stream); diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index 88ed79111a..45a49a5dc2 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -43,10 +43,15 @@ static __device__ void rope_yarn( template static __global__ void rope_norm(const T * x, D * dst, - const int ne0, - const int ne1, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, const int s1, const int s2, + const int s3, const int n_dims, const int32_t * pos, const float freq_scale, @@ -59,23 +64,23 @@ static __global__ void rope_norm(const T * x, const int set_rows_stride) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); - if (i0 >= ne0) { + if (i0 >= ne00) { return; } const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; - const int row_x = row_dst % ne1; - const int channel_x = row_dst / ne1; - - int idst = row_dst * ne0 + i0; - const int ix = channel_x*s2 + row_x*s1 + i0; + const uint32_t i3 = row_dst / (ne01 * ne02); + const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01; + const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01; + int idst = i0 + i1 * s1 + i2 * s2 + i3 * s3; + const int ix = i0 + i1 * s01 + i2 * s02 + i3 * s03; // Fusion optimization: ROPE + VIEW + SET_ROWS. // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices. if (set_rows_stride != 0) { - idst = row_x * ne0 + i0; - idst += row_indices[channel_x] * set_rows_stride; + idst = i1 * s1 + i0; + idst += row_indices[i2] * set_rows_stride; } const auto & store_coaelsced = [&](float x0, float x1) { @@ -92,7 +97,7 @@ static __global__ void rope_norm(const T * x, return; } - const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); + const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; @@ -110,10 +115,15 @@ static __global__ void rope_norm(const T * x, template static __global__ void rope_neox(const T * x, D * dst, - const int ne0, - const int ne1, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, const int s1, const int s2, + const int s3, const int n_dims, const int32_t * pos, const float freq_scale, @@ -126,23 +136,24 @@ static __global__ void rope_neox(const T * x, const int set_rows_stride) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); - if (i0 >= ne0) { + if (i0 >= ne00) { return; } const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; - const int row_x = row_dst % ne1; - const int channel_x = row_dst / ne1; + const uint32_t i3 = row_dst / (ne01 * ne02); + const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01; + const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01; - int idst = row_dst * ne0 + i0 / 2; - const int ix = channel_x*s2 + row_x*s1 + i0/2; + int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; + const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; // Fusion optimization: ROPE + VIEW + SET_ROWS. // The rope output is viewed as a 1D tensor and offset based on a row index in row_indices. if (set_rows_stride != 0) { - idst = row_x * ne0 + i0 / 2; - idst += row_indices[channel_x] * set_rows_stride; + idst = i1 * s1 + i0 / 2; + idst += row_indices[i2] * set_rows_stride; } if (i0 >= n_dims) { @@ -152,7 +163,7 @@ static __global__ void rope_neox(const T * x, return; } - const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); + const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f); const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; @@ -168,24 +179,42 @@ static __global__ void rope_neox(const T * x, dst[idst + n_dims / 2] = ggml_cuda_cast(x0 * sin_theta + x1 * cos_theta); } -template -static __global__ void rope_multi( - const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, - const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, const bool is_imrope) { - const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); +template +static __global__ void rope_multi(const T * x, + T * dst, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, + const int s1, + const int s2, + const int s3, + const int n_dims, + const int32_t * pos, + const float freq_scale, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float theta_scale, + const float * freq_factors, + const mrope_sections sections, + const bool is_imrope) { + const int i0 = 2 * (blockDim.y * blockIdx.y + threadIdx.y); - if (i0 >= ne0) { + if (i0 >= ne00) { return; } const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; - const int row_x = row_dst % ne1; - const int channel_x = row_dst / ne1; + const uint32_t i3 = row_dst / (ne01 * ne02); + const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01; + const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01; - const int idst = row_dst*ne0 + i0/2; - const int ix = channel_x*s2 + row_x*s1 + i0/2; + int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; + const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; if (i0 >= n_dims) { dst[idst + i0/2 + 0] = x[ix + i0/2 + 0]; @@ -200,27 +229,24 @@ static __global__ void rope_multi( float theta_base = 0.0; if (is_imrope) { - if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h - theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f); - } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w - theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f); - } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t - theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); + if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h + theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f); + } else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w + theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f); + } else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t + theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f); } else { - theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f); + theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f); } } else { if (sector < sections.v[0]) { - theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); - } - else if (sector >= sections.v[0] && sector < sec_w) { - theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f); - } - else if (sector >= sec_w && sector < sec_w + sections.v[2]) { - theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f); - } - else if (sector >= sec_w + sections.v[2]) { - theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f); + theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f); + } else if (sector >= sections.v[0] && sector < sec_w) { + theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f); + } else if (sector >= sec_w && sector < sec_w + sections.v[2]) { + theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f); + } else if (sector >= sec_w + sections.v[2]) { + theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f); } } @@ -238,37 +264,53 @@ static __global__ void rope_multi( dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta; } -template -static __global__ void rope_vision( - const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, - const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims, - const float theta_scale, const float * freq_factors, const mrope_sections sections) { +template +static __global__ void rope_vision(const T * x, + T * dst, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, + const int s1, + const int s2, + const int s3, + const int n_dims, + const int32_t * pos, + const float freq_scale, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float theta_scale, + const float * freq_factors, + const mrope_sections sections) { const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y); - if (i0 >= ne0) { + if (i0 >= ne00) { return; } const int row_dst = blockDim.x*blockIdx.x + threadIdx.x; - const int row_x = row_dst % ne1; - const int channel_x = row_dst / ne1; + const uint32_t i3 = row_dst / (ne01 * ne02); + const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01; + const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01; - const int idst = row_dst*ne0 + i0/2; - const int ix = channel_x*s2 + row_x*s1 + i0/2; + int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3; + const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03; const int sect_dims = sections.v[0] + sections.v[1]; - const int sec_w = sections.v[1] + sections.v[0]; - const int sector = (i0 / 2) % sect_dims; + const int sec_w = sections.v[1] + sections.v[0]; + const int sector = (i0 / 2) % sect_dims; float theta_base = 0.0; if (sector < sections.v[0]) { const int p = sector; - theta_base = pos[channel_x]*powf(theta_scale, p); - } - else if (sector >= sections.v[0] && sector < sec_w) { + theta_base = pos[i2] * powf(theta_scale, p); + } else if (sector >= sections.v[0] && sector < sec_w) { const int p = sector - sections.v[0]; - theta_base = pos[channel_x + ne2]*powf(theta_scale, p); + theta_base = pos[i2 + ne02] * powf(theta_scale, p); } const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; @@ -288,10 +330,15 @@ static __global__ void rope_vision( template static void rope_norm_cuda(const T * x, D * dst, - const int ne0, - const int ne1, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, const int s1, const int s2, + const int s3, const int n_dims, const int nr, const int32_t * pos, @@ -304,31 +351,36 @@ static void rope_norm_cuda(const T * x, const int64_t * row_indices, const int set_rows_stride, cudaStream_t stream) { - GGML_ASSERT(ne0 % 2 == 0); + GGML_ASSERT(ne00 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); - const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nr, n_blocks_x, 1); - const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float theta_scale = powf(freq_base, -2.0f / n_dims); if (freq_factors == nullptr) { rope_norm<<>>( - x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, - freq_factors, row_indices, set_rows_stride); + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } else { rope_norm<<>>( - x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, - freq_factors, row_indices, set_rows_stride); + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } } template static void rope_neox_cuda(const T * x, D * dst, - const int ne0, - const int ne1, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, const int s1, const int s2, + const int s3, const int n_dims, const int nr, const int32_t * pos, @@ -341,55 +393,92 @@ static void rope_neox_cuda(const T * x, const int64_t * row_indices, const int set_rows_stride, cudaStream_t stream) { - GGML_ASSERT(ne0 % 2 == 0); + GGML_ASSERT(ne00 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); - const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nr, n_blocks_x, 1); - const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float theta_scale = powf(freq_base, -2.0f / n_dims); if (freq_factors == nullptr) { rope_neox<<>>( - x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, - freq_factors, row_indices, set_rows_stride); + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } else { rope_neox<<>>( - x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, - freq_factors, row_indices, set_rows_stride); + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, + attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride); } } -template -static void rope_multi_cuda( - const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, - const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, const bool is_imrope, cudaStream_t stream) { - GGML_ASSERT(ne0 % 2 == 0); +template +static void rope_multi_cuda(const T * x, + T * dst, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, + const int s1, + const int s2, + const int s3, + const int n_dims, + const int nr, + const int32_t * pos, + const float freq_scale, + const float freq_base, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float * freq_factors, + const mrope_sections sections, + const bool is_imrope, + cudaStream_t stream) { + GGML_ASSERT(ne00 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); - const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nr, n_blocks_x, 1); - const float theta_scale = powf(freq_base, -2.0f/n_dims); + const float theta_scale = powf(freq_base, -2.0f / n_dims); if (freq_factors == nullptr) { rope_multi<<>>( - x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); } else { rope_multi<<>>( - x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope); } } -template -static void rope_vision_cuda( - const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr, - const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor, - const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) { - GGML_ASSERT(ne0 % 2 == 0); +template +static void rope_vision_cuda(const T * x, + T * dst, + const int ne00, + const int ne01, + const int ne02, + const int s01, + const int s02, + const int s03, + const int s1, + const int s2, + const int s3, + const int n_dims, + const int nr, + const int32_t * pos, + const float freq_scale, + const float freq_base, + const float ext_factor, + const float attn_factor, + const rope_corr_dims corr_dims, + const float * freq_factors, + const mrope_sections sections, + cudaStream_t stream) { + GGML_ASSERT(ne00 % 2 == 0); const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1); - const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE); + const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE); const dim3 block_nums(nr, n_blocks_x, 1); // break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq) // where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE); @@ -398,11 +487,11 @@ static void rope_vision_cuda( if (freq_factors == nullptr) { rope_vision<<>>( - x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections); } else { rope_vision<<>>( - x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor, + x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale, freq_factors, sections); } } @@ -445,6 +534,11 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, const size_t s01 = src0->nb[1] / ggml_type_size(src0->type); const size_t s02 = src0->nb[2] / ggml_type_size(src0->type); + const size_t s03 = src0->nb[3] / ggml_type_size(src0->type); + + const size_t s1 = dst->nb[1] / ggml_type_size(dst->type); + const size_t s2 = dst->nb[2] / ggml_type_size(dst->type); + const size_t s3 = dst->nb[3] / ggml_type_size(dst->type); //const int n_past = ((int32_t *) dst->op_params)[0]; const int n_dims = ((int32_t *) dst->op_params)[1]; @@ -495,57 +589,63 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx, // compute if (is_neox) { if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) { - rope_neox_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, - nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_neox_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) { - rope_neox_cuda((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, - nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_neox_cuda((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) { - rope_neox_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, - pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_neox_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else { GGML_ABORT("fatal error"); } } else if (is_mrope && !is_vision) { if (src0->type == GGML_TYPE_F32) { - rope_multi_cuda( - (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, - freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream); + rope_multi_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1, + s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, + corr_dims, freq_factors, sections, is_imrope, stream); } else if (src0->type == GGML_TYPE_F16) { - rope_multi_cuda( - (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, - freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream); + rope_multi_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1, + s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, + corr_dims, freq_factors, sections, is_imrope, stream); } else { GGML_ABORT("fatal error"); } } else if (is_vision) { if (src0->type == GGML_TYPE_F32) { - rope_vision_cuda( - (const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, - freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); + rope_vision_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1, + s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, + corr_dims, freq_factors, sections, stream); } else if (src0->type == GGML_TYPE_F16) { - rope_vision_cuda( - (const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale, - freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream); + rope_vision_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1, + s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor, + corr_dims, freq_factors, sections, stream); } else { GGML_ABORT("fatal error"); } } else { if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) { - rope_norm_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims, - nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_norm_cuda((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) { - rope_norm_cuda((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, - nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_norm_cuda((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) { - rope_norm_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr, - pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims, - freq_factors, row_indices, set_rows_stride, stream); + rope_norm_cuda((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, + s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base, + ext_factor, attn_factor, corr_dims, freq_factors, row_indices, + set_rows_stride, stream); } else { GGML_ABORT("fatal error"); } diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index 6baab1176f..ab803aca21 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -94,6 +94,15 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con #endif } +static __device__ __forceinline__ uint32_t unpack_ksigns(const uint8_t v) { + // v is a 7 bit int, with the 8th sign being encodable as popcnt + // with xor we can "correct" the bit instead of having to mask + const uint32_t p = __popc(v) & 1; + const uint32_t s = v ^ p << 7; + // broadcast over uint to allow for 0x08040201 / 0x80402010 as selectors + return s * 0x01010101; +} + // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q @@ -905,22 +914,22 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1( int sumi = 0; #pragma unroll for (int k0 = 0; k0 < 8; k0 += 2) { - const int * grid_pos = (const int *) (iq2xxs_grid + aux8[k0/2]); - const int signs_packed = ksigns_iq2xs[(aux32 >> (7*k0/2)) & 0x7F]; + const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[k0/2]]; + const uint32_t signs = unpack_ksigns(aux32 >> (7 * k0 / 2)); - const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000); - const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0); + const int signs0 = __vcmpne4(signs & 0x08040201, 0); + const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0); const int u0 = get_int_b4(bq8_1[iqs/2].qs, k0 + 0); sumi = ggml_cuda_dp4a(grid0, u0, sumi); - const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); - const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); + const int signs1 = __vcmpne4(signs & 0x80402010, 0); + const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1); const int u1 = get_int_b4(bq8_1[iqs/2].qs, k0 + 1); sumi = ggml_cuda_dp4a(grid1, u1, sumi); } - const int ls = aux32 >> 28; - sumi = (ls*sumi + sumi/2)/4; + const int ls = aux32 >> 27 | 1; // (scale * 2 + 1) + sumi = sumi * ls / 8; // (sumi * scale + sumi / 2) / 4 const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds); return d * sumi; } @@ -942,13 +951,15 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1( int sumi1 = 0; #pragma unroll for (int l0 = 0; l0 < 8; l0 += 2) { - const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l0/2] & 0x000001FF)); - const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l0/2] >> 9)); - - const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); - const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); + const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l0/2] & 0x1FF]; + const uint32_t signs = unpack_ksigns(q2[l0/2] >> 9); + const int signs0 = __vcmpne4(signs & 0x08040201, 0); + const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0); + + const int signs1 = __vcmpne4(signs & 0x80402010, 0); + const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1); if (l0 < 4) { @@ -1028,13 +1039,16 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1( #pragma unroll for (int l0 = 0; l0 < 8; l0 += 2) { const int2 grid_pos = make_int2(iq3xxs_grid[q3[l0 + 0]], iq3xxs_grid[q3[l0 + 1]]); + const uint32_t signs = unpack_ksigns(aux32 >> (7*l0/2)); - const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l0/2)) & 0x7F)); - - const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); - const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); + const int signs0 = __vcmpne4(signs & 0x08040201, 0); + const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0); + + const int signs1 = __vcmpne4(signs & 0x80402010, 0); + const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); + const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1); sumi = ggml_cuda_dp4a(grid_l, u0, sumi); diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 4f0a1620fb..54f9986498 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -1935,11 +1935,6 @@ static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * se return false; } - // TODO: add support for non-contigiuos tensors - if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { - return false; - } - return true; } @@ -1991,6 +1986,25 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses return true; } +static bool ggml_hexagon_supported_sum_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * dst = op; + + if (!hex_supported_src0_type(src0->type)) { + return false; + } + if (!hex_supported_dst_type(dst->type)) { + return false; + } + + // TODO: add support for non-contigiuos tensors + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) { + return false; + } + + return true; +} + static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { const struct ggml_tensor * src0 = op->src[0]; @@ -2111,6 +2125,26 @@ static bool ggml_hexagon_supported_get_rows(const struct ggml_hexagon_session * return true; } +static bool ggml_hexagon_supported_argsort(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; // values + const struct ggml_tensor * dst = op; // indices + + if (src0->type != GGML_TYPE_F32) { + return false; + } + + if (dst->type != GGML_TYPE_I32) { + return false; + } + + if (src0->ne[0] > (16*1024)) { + // reject tensors with huge rows for now + return false; + } + + return true; +} + static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { const int32_t * op_params = &op->op_params[0]; @@ -2278,6 +2312,9 @@ static inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bu case GGML_OP_SUB: req->op = HTP_OP_SUB; break; + case GGML_OP_DIV: + req->op = HTP_OP_DIV; + break; default: GGML_ABORT("ggml-hex: binary : unsupported op: %d\n", t->op); break; @@ -2316,6 +2353,17 @@ static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * return n_bufs; } +static inline size_t init_argsort_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + req->op = HTP_OP_ARGSORT; + memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + template static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { switch (t->op) { @@ -2370,6 +2418,16 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf supported = true; break; + case GGML_OP_SQR: + req->op = HTP_OP_SQR; + supported = true; + break; + + case GGML_OP_SQRT: + req->op = HTP_OP_SQRT; + supported = true; + break; + case GGML_OP_UNARY: if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) { req->op = HTP_OP_UNARY_SILU; @@ -2387,6 +2445,9 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf } else if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU_OAI) { req->op = HTP_OP_GLU_SWIGLU_OAI; supported = true; + } else if (ggml_get_glu_op(t) == GGML_GLU_OP_GEGLU) { + req->op = HTP_OP_GLU_GEGLU; + supported = true; } break; @@ -2411,6 +2472,17 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf return n_bufs; } +static inline size_t init_sum_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); + req->op = HTP_OP_SUM_ROWS; + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + static inline size_t init_rope_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); req->op = HTP_OP_ROPE; @@ -2519,6 +2591,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg case GGML_OP_MUL: case GGML_OP_ADD: case GGML_OP_SUB: + case GGML_OP_DIV: ggml_hexagon_dispatch_op>(sess, node, flags); break; case GGML_OP_ADD_ID: @@ -2528,6 +2601,13 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg case GGML_OP_SCALE: ggml_hexagon_dispatch_op(sess, node, flags); break; + case GGML_OP_SQR: + case GGML_OP_SQRT: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + case GGML_OP_SUM_ROWS: + ggml_hexagon_dispatch_op(sess, node, flags); + break; case GGML_OP_UNARY: if ((ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) || (ggml_get_unary_op(node) == GGML_UNARY_OP_GELU)) { @@ -2536,7 +2616,8 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg break; case GGML_OP_GLU: if ((ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU) || - (ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI)) { + (ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI) || + (ggml_get_glu_op(node) == GGML_GLU_OP_GEGLU)) { ggml_hexagon_dispatch_op(sess, node, flags); } break; @@ -2564,6 +2645,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg ggml_hexagon_dispatch_op(sess, node, flags); break; + case GGML_OP_ARGSORT: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + default: GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node)); } @@ -2916,6 +3001,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons case GGML_OP_MUL: case GGML_OP_ADD: case GGML_OP_SUB: + case GGML_OP_DIV: supp = ggml_hexagon_supported_binary(sess, op); break; @@ -2928,6 +3014,15 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_unary(sess, op); break; + case GGML_OP_SQR: + case GGML_OP_SQRT: + supp = ggml_hexagon_supported_unary(sess, op); + break; + + case GGML_OP_SUM_ROWS: + supp = ggml_hexagon_supported_sum_rows(sess, op); + break; + case GGML_OP_SOFT_MAX: supp = ggml_hexagon_supported_softmax(sess, op); break; @@ -2943,7 +3038,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons case GGML_OP_GLU: { const auto glu_op = ggml_get_glu_op(op); - if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI)) { + if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI) || (glu_op == GGML_GLU_OP_GEGLU)) { supp = ggml_hexagon_supported_activations(sess, op); } break; @@ -2968,6 +3063,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_cpy(sess, op); break; + case GGML_OP_ARGSORT: + supp = ggml_hexagon_supported_argsort(sess, op); + break; + default: break; } diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index e8ef203045..2c23b60da3 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -6,6 +6,7 @@ include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake) include_directories( ${HEXAGON_SDK_ROOT}/incs ${HEXAGON_SDK_ROOT}/incs/stddef + ${CMAKE_CURRENT_SOURCE_DIR}/../../../include ${CMAKE_CURRENT_SOURCE_DIR}/../.. ${CMAKE_CURRENT_SOURCE_DIR}/.. ${CMAKE_CURRENT_SOURCE_DIR} @@ -21,6 +22,7 @@ add_library(${HTP_LIB} SHARED matmul-ops.c binary-ops.c unary-ops.c + sum-rows-ops.c softmax-ops.c act-ops.c rope-ops.c @@ -28,6 +30,7 @@ add_library(${HTP_LIB} SHARED set-rows-ops.c get-rows-ops.c cpy-ops.c + argsort-ops.c ) target_compile_definitions(${HTP_LIB} PRIVATE diff --git a/ggml/src/ggml-hexagon/htp/act-ops.c b/ggml/src/ggml-hexagon/htp/act-ops.c index c3daf5adb2..950d836ad3 100644 --- a/ggml/src/ggml-hexagon/htp/act-ops.c +++ b/ggml/src/ggml-hexagon/htp/act-ops.c @@ -410,7 +410,7 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0, // gelu = x * sigmoid(1.702 * x) // current implementation hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0); hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0); - hvx_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0); + hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0); } dma_queue_push_vtcm_to_ddr(dma_queue, @@ -516,7 +516,7 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0, // silu = x * sigmoid(x) hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0); - hvx_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0); + hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0); } dma_queue_push_vtcm_to_ddr(dma_queue, @@ -541,6 +541,143 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0, ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } +static const float GELU_COEF_A = 0.044715f; +static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + +static void glu_geglu_f32_per_thread(const struct htp_tensor * src0, + const struct htp_tensor * src1, + struct htp_tensor * dst, + const int32_t * op_params, + struct htp_spad * src0_spad, + struct htp_spad * src1_spad, + struct htp_spad * dst_spad, + uint32_t nth, + uint32_t ith, + uint32_t src0_nrows_per_thread, + dma_queue * dma_queue) { + htp_act_preamble3; + + size_t src0_row_size = nb01; + size_t src1_row_size = nb11; + size_t dst_row_size = nb1; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return; + } + + const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; + const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; + uint8_t * restrict data_dst = (uint8_t *) dst->data; + + const bool src1_valid = src1->ne[0]; + const int nc = (src1_valid) ? ne00 : ne00 / 2; + if (!src1_valid) { + const int32_t swapped = op_params[1]; + data_src1 = data_src0; + src1_row_size = src0_row_size; + + const size_t nc_in_bytes = nc * SIZEOF_FP32; + data_src0 += swapped ? nc_in_bytes : 0; + data_src1 += swapped ? 0 : nc_in_bytes; + } + + const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); + const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + + uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); + uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread); + uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); + + // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 + size_t src0_spad_half_size = src0_spad->size_per_thread / 2; + size_t src1_spad_half_size = src1_spad->size_per_thread / 2; + size_t dst_spad_half_size = dst_spad->size_per_thread / 2; + + const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block + if (BLOCK == 0) { + FARF(ERROR, + "geglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", + src0_spad->size_per_thread, src0_row_size_aligned); + return; + } + + // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 + for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { + const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); + + // Dummy DMA transation for sequencing (interleaving dst,src,dst,...) + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)), + dst_row_size, dst_row_size_aligned, 0); + + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)), + src0_row_size_aligned, src0_row_size, block_size); + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)), + src1_row_size_aligned, src1_row_size, block_size); + } + + for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) { + const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); + + float * dst_spad = (float *) dma_queue_pop(dma_queue).src; + float * src0_spad = (float *) dma_queue_pop(dma_queue).dst; + float * src1_spad = (float *) dma_queue_pop(dma_queue).dst; + + for (uint32_t ib = 0; ib < block_size; ib++) { + const uint8_t * src0_spad_ptr = (const uint8_t *)(src0_spad + ib * (src0_row_size_aligned / sizeof(float))); + const uint8_t * src1_spad_ptr = (const uint8_t *)(src1_spad + ib * (src1_row_size_aligned / sizeof(float))); + uint8_t * dst_spad_ptr = (uint8_t *)(dst_spad + ib * (dst_row_size_aligned / sizeof(float))); + + // geglu tanh implementation + // geglu(x, g) = gelu(x) * g + // gelu(x) = 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))) + hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, src0_spad_ptr, nc); // res = x*x + hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, GELU_COEF_A, nc); // res = res * GELU_COEF_A + hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 1.0f, nc); // res = res + 1.0f + hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc); // res = res * x + hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, SQRT_2_OVER_PI, nc); // res = result * SQRT_2_OVER_PI + hvx_tanh_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, nc); // res = tanh(res) + hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, 1.0f, nc); // res = res + 1.0f + hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc); // res = res * x + hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 0.5f, nc); // res = res + 0.5f + hvx_mul_f32_aaa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, src1_spad_ptr, nc); // res = res * g + } + + dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size, + dst_row_size_aligned, block_size); + + // prefetch N+2 loop iteration if any + const uint32_t pref_block = (ir + BLOCK * 2); + if (pref_block < src0_end_row) { + const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)), + src0_row_size_aligned, src0_row_size, pref_block_size); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)), + src1_row_size_aligned, src1_row_size, pref_block_size); + } + } + + dma_queue_flush(dma_queue); + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "geglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + static void unary_silu_f32(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = (struct htp_ops_context *) data; unary_silu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i, @@ -559,6 +696,12 @@ static void glu_swiglu_oai_f32(unsigned int n, unsigned int i, void * data) { &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); } +static void glu_geglu_f32(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = (struct htp_ops_context *) data; + glu_geglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad, + &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); +} + static int execute_op_activations_f32(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; @@ -593,6 +736,11 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) { act_op_func = unary_gelu_f32; op_type = "gelu-f32"; break; + + case HTP_OP_GLU_GEGLU: + act_op_func = glu_geglu_f32; + op_type = "geglu-f32"; + break; default: FARF(ERROR, "Unsupported activations Op %u\n", octx->op); return HTP_STATUS_NO_SUPPORT; diff --git a/ggml/src/ggml-hexagon/htp/argsort-ops.c b/ggml/src/ggml-hexagon/htp/argsort-ops.c new file mode 100644 index 0000000000..a4cee980be --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/argsort-ops.c @@ -0,0 +1,281 @@ +#include +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "ggml.h" + +#include "hvx-utils.h" +#include "hex-dma.h" + +#include "htp-ctx.h" +#include "htp-msg.h" +#include "htp-ops.h" + +#ifndef MIN +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#endif + +struct htp_argsort_context { + struct htp_ops_context * octx; + uint32_t nrows_per_thread; +}; + +static inline bool all_greater_f32(HVX_Vector x, HVX_Vector y) +{ + const HVX_Vector one = Q6_V_vsplat_R(1); + const HVX_Vector zero = Q6_V_vzero(); + + HVX_VectorPred pred = Q6_Q_vcmp_gt_VsfVsf(x, y); + HVX_Vector matches = Q6_V_vmux_QVV(pred, one, zero); + HVX_Vector sum = hvx_vec_reduce_sum_i32(matches); + return hvx_vec_get_i32(sum) == 32; +} + +// Sorts values and mirrors swaps to indices. +static void quicksort_values_indices_asc(float * values, int32_t * indices, int left, int right) { + if (left >= right) return; + + int pivot_idx = (left + right) / 2; + float pivot = values[pivot_idx]; + int i = left; + int j = right; + + HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot); + while (i <= j) { + // Vectorized scan for i + while (i <= j) { + // Check if we have at least one full vector + if (i + 32 <= j) { + HVX_Vector vals_vec = *(HVX_UVector *)(values + i); + if (all_greater_f32(pivot_vec, vals_vec)) { + // If all elements are < pivot, we can skip this whole block + i += 32; + continue; + } + } + + // Scalar fallback / cleanup + if (values[i] < pivot) { + i++; + } else { + break; + } + } + + // Vectorized scan for j + while (i <= j) { + if (j - 32 >= i) { + // Load 32 elements ending at j. + // Since we want `values[j] > pivot`, let's load from j-31 to j. + HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31); + if (all_greater_f32(vals_vec, pivot_vec)) { + j -= 32; + continue; + } + } + + if (values[j] > pivot) { + j--; + } else { + break; + } + } + + if (i <= j) { + float tmp_val = values[i]; + values[i] = values[j]; + values[j] = tmp_val; + + int32_t tmp_idx = indices[i]; + indices[i] = indices[j]; + indices[j] = tmp_idx; + i++; + j--; + } + } + + if (left < j) quicksort_values_indices_asc(values, indices, left, j); + if (i < right) quicksort_values_indices_asc(values, indices, i, right); +} + +static void quicksort_values_indices_desc(float * values, int32_t * indices, int left, int right) { + if (left >= right) return; + + int pivot_idx = (left + right) / 2; + float pivot = values[pivot_idx]; + int i = left; + int j = right; + + HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot); + + while (i <= j) { + // Vectorized scan for i (values[i] > pivot) + while (i <= j) { + if (i + 32 <= j) { + HVX_Vector vals_vec = *(HVX_UVector *)(values + i); + if (all_greater_f32(vals_vec, pivot_vec)) { + i += 32; + continue; + } + } + + if (values[i] > pivot) { + i++; + } else { + break; + } + } + + // Vectorized scan for j (values[j] < pivot) + while (i <= j) { + if (j - 32 >= i) { + HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31); + if (all_greater_f32(pivot_vec, vals_vec)) { + j -= 32; + continue; + } + } + + if (values[j] < pivot) { + j--; + } else { + break; + } + } + + if (i <= j) { + float tmp_val = values[i]; + values[i] = values[j]; + values[j] = tmp_val; + + int32_t tmp_idx = indices[i]; + indices[i] = indices[j]; + indices[j] = tmp_idx; + i++; + j--; + } + } + + if (left < j) quicksort_values_indices_desc(values, indices, left, j); + if (i < right) quicksort_values_indices_desc(values, indices, i, right); +} + +static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) { + struct htp_argsort_context * actx = (struct htp_argsort_context *)data; + struct htp_ops_context * octx = actx->octx; + + // Unpack context + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * dst = &octx->dst; + + // Scratchpad memory + uint8_t * spad = octx->src0_spad.data + octx->src0_spad.size_per_thread * i; + + // Dimensions + uint32_t ne00 = src0->ne[0]; + uint32_t ne01 = src0->ne[1]; + uint32_t ne02 = src0->ne[2]; + uint32_t ne03 = src0->ne[3]; + + uint32_t nb01 = src0->nb[1]; + //uint32_t nb02 = src0->nb[2]; + //uint32_t nb03 = src0->nb[3]; + + uint32_t nb1 = dst->nb[1]; + //uint32_t nb2 = dst->nb[2]; + //uint32_t nb3 = dst->nb[3]; + + // Sort order + enum ggml_sort_order order = (enum ggml_sort_order) octx->op_params[0]; + + // Rows to process + uint32_t total_rows = ne01 * ne02 * ne03; + uint32_t rows_per_thread = actx->nrows_per_thread; + uint32_t start_row = rows_per_thread * i; + uint32_t end_row = MIN(start_row + rows_per_thread, total_rows); + + // Scratchpad layout: + // We need space for one row of float data (values) and one row of int32 indices. + // values: ne00 * sizeof(float) + // indices: ne00 * sizeof(int32_t) + // Padded to 128 bytes. + + size_t values_size = hex_round_up(ne00 * sizeof(float), 128); + float * values_buf = (float *) spad; + int32_t * indices_buf = (int32_t *) (spad + values_size); + + for (uint32_t r = start_row; r < end_row; r++) { + uint32_t src_offset = r * nb01; + uint32_t dst_offset = r * nb1; + + uint8_t * src_ptr = (uint8_t *) src0->data + src_offset; + uint8_t * dst_ptr = (uint8_t *) dst->data + dst_offset; + + hex_l2fetch(src_ptr, ne00 * sizeof(float), ne00 * sizeof(float), 1); + hvx_copy_f32_au((uint8_t*)values_buf, src_ptr, ne00); + + // Initialize indices + for (uint32_t j = 0; j < ne00; j++) { + indices_buf[j] = j; + } + + // Sort values and mirror swaps to indices + if (order == GGML_SORT_ORDER_ASC) { + quicksort_values_indices_asc(values_buf, indices_buf, 0, ne00 - 1); + } else { + quicksort_values_indices_desc(values_buf, indices_buf, 0, ne00 - 1); + } + + // Copy indices back to DDR + hvx_copy_f32_ua(dst_ptr, (const uint8_t *) indices_buf, ne00); + } +} + +int op_argsort(struct htp_ops_context * octx) { + // Check supported types + if (octx->src0.type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + // Allocate scratchpad + // We need 1 row of float + 1 row of int32 per thread. + uint32_t ne00 = octx->src0.ne[0]; + size_t values_size = hex_round_up(ne00 * sizeof(float), 128); + size_t indices_size = hex_round_up(ne00 * sizeof(int32_t), 128); + size_t spad_per_thread = values_size + indices_size; + + // Make sure we round up to 256 for alignment requirements + spad_per_thread = hex_round_up(spad_per_thread, 256); + + size_t total_spad_size = spad_per_thread * octx->n_threads; + + if (octx->ctx->vtcm_size < total_spad_size) { + FARF(ERROR, "argsort: VTCM size too small. Needed %zu, have %zu", total_spad_size, octx->ctx->vtcm_size); + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src0_spad.size = total_spad_size; + octx->src0_spad.size_per_thread = spad_per_thread; + + FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)", + octx->src0.ne[0], octx->src0.ne[1], octx->src0.ne[2], octx->src0.ne[3], + octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3], + octx->src0.data, octx->dst.data); + + uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3]; + uint32_t n_jobs = MIN(total_rows, octx->n_threads); + + struct htp_argsort_context actx; + actx.octx = octx; + actx.nrows_per_thread = (total_rows + n_jobs - 1) / n_jobs; + + // Run jobs + worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_jobs); + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/binary-ops.c b/ggml/src/ggml-hexagon/htp/binary-ops.c index de22afe460..00dbcf8798 100644 --- a/ggml/src/ggml-hexagon/htp/binary-ops.c +++ b/ggml/src/ggml-hexagon/htp/binary-ops.c @@ -17,15 +17,37 @@ #include "htp-msg.h" #include "htp-ops.h" -typedef void (*hvx_elemwise_f32_func)(uint8_t * data_dst, const uint8_t * src0, const uint8_t * src1, const uint32_t num_elems); +#ifndef MIN +#define MIN(a, b) ((a) < (b) ? (a) : (b)) +#endif -static hvx_elemwise_f32_func func_table_HVX[] = { hvx_mul_f32, hvx_add_f32, hvx_sub_f32 }; -static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_aa, hvx_add_f32_aa, hvx_sub_f32_aa }; +// Context for binary operations +struct htp_binary_context { + struct htp_ops_context * octx; + struct fastdiv_values dim1_div; + struct fastdiv_values dim2_div; + struct fastdiv_values dim12_div; + + struct fastdiv_values src1_dim1_div; // ne11 + struct fastdiv_values src1_dim2_div; // ne12 + struct fastdiv_values src1_dim3_div; // ne13 + + uint32_t nrows_per_thread; + bool split_at_ne01; + bool split_at_ne02; + + // Precomputed values + uint32_t block_max; + size_t src0_row_size_aligned; + size_t src1_row_size_aligned; + size_t dst_row_size_aligned; + uint32_t src1_fetch_rows; // 1 or block_max + uint32_t src1_dma_stride; // 0 or stride +}; #define htp_binary_preamble \ const struct htp_tensor * src0 = &octx->src0; \ const struct htp_tensor * src1 = &octx->src1; \ - const struct htp_tensor * src2 = &octx->src2; \ struct htp_tensor * dst = &octx->dst; \ \ const uint32_t ne00 = src0->ne[0]; \ @@ -38,266 +60,696 @@ static hvx_elemwise_f32_func func_table_HVX_opt[] = { hvx_mul_f32_aa, hvx_add_f3 const uint32_t ne12 = src1->ne[2]; \ const uint32_t ne13 = src1->ne[3]; \ \ - const uint32_t ne0 = dst->ne[0]; \ - const uint32_t ne1 = dst->ne[1]; \ - const uint32_t ne2 = dst->ne[2]; \ - const uint32_t ne3 = dst->ne[3]; \ - \ - const uint32_t nb00 = src0->nb[0]; \ const uint32_t nb01 = src0->nb[1]; \ const uint32_t nb02 = src0->nb[2]; \ const uint32_t nb03 = src0->nb[3]; \ \ - const uint32_t nb10 = src1->nb[0]; \ const uint32_t nb11 = src1->nb[1]; \ const uint32_t nb12 = src1->nb[2]; \ const uint32_t nb13 = src1->nb[3]; \ \ - const uint32_t nb0 = dst->nb[0]; \ const uint32_t nb1 = dst->nb[1]; \ const uint32_t nb2 = dst->nb[2]; \ - const uint32_t nb3 = dst->nb[3]; \ - \ - const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; + const uint32_t nb3 = dst->nb[3]; -static void binary_job_f32_per_thread(struct htp_ops_context * octx, - uint8_t * spad_data, - uint32_t nth, - uint32_t ith, - enum htp_op op) { - htp_binary_preamble; +static inline uint32_t calc_block_size(struct htp_binary_context * bctx, uint32_t ir, uint32_t end_row, + uint32_t ne01, uint32_t ne02) { + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir, &bctx->dim12_div); + rem = ir - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; - const size_t src0_row_size = nb01; - const size_t src1_row_size = nb11; - const size_t dst_row_size = nb1; + uint32_t rows_left = end_row - ir; + uint32_t block_limit = rows_left; - const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows - const uint32_t src1_nrows = ne11 * ne12 * ne13; // src1 rows - - const uint32_t src0_start_row = src0_nrows_per_thread * ith; - const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); - - // no work for this thread - if (src0_start_row >= src0_end_row) { - return; + if (bctx->split_at_ne01) { + block_limit = MIN(block_limit, ne01 - i01); + } + if (bctx->split_at_ne02) { + uint32_t rows_in_plane = (ne02 * ne01) - rem; + block_limit = MIN(block_limit, rows_in_plane); } - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); - - int is_aligned = 1; - int opt_path = 0; - if ((0 == hex_is_aligned((void *) src0->data, VLEN)) || (0 == hex_is_aligned((void *) src1->data, VLEN)) || - (0 == hex_is_aligned((void *) dst->data, VLEN))) { - is_aligned = 0; - } - if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { - opt_path = 1; - } - - hvx_elemwise_f32_func func_HVX = (1 == opt_path) ? func_table_HVX_opt[op] : func_table_HVX[op]; - - uint8_t * restrict spad_data_th = spad_data + (ith * src0_row_size); - - const uint8_t * restrict src0_ptr = (const uint8_t *) src0->data + (src0_start_row * src0_row_size); - uint8_t * restrict dst_ptr = (uint8_t *) dst->data + (src0_start_row * dst_row_size); - - const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; - - const uint32_t ne02_ne01 = ne02 * ne01; - - for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) { - const uint32_t i03 = fastdiv(ir, &octx->src0_div21); - const uint32_t i02 = fastdiv(ir - i03 * ne02_ne01, &octx->src0_div1); - const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01); - - const uint32_t i13 = fastmodulo(i03, ne13, &octx->src1_div3); - const uint32_t i12 = fastmodulo(i02, ne12, &octx->src1_div2); - const uint32_t i11 = fastmodulo(i01, ne11, &octx->src1_div1); - - const uint8_t * restrict src1_ptr = data_src1 + i13 * nb13 + i12 * nb12 + i11 * src1_row_size; - - if (ir + 1 < src0_end_row) { - hex_l2fetch(src0_ptr + ne00, src0_row_size, src0_row_size, 1); - if (src1_row_size == src0_row_size) { - hex_l2fetch(src1_ptr, src1_row_size, src1_row_size, 1); - } - } - - const uint32_t nr0 = ne00 / ne10; - if (nr0 > 1) { - if ((1 == is_aligned) && (nr0 == ne00)) { - hvx_splat_f32_a(spad_data_th, *(float *) src1_ptr, nr0); - } else { - for (uint32_t r = 0; r < nr0; r++) { - memcpy(spad_data_th + r * nb11, (const uint8_t *) src1_ptr, nb11); - } - } - func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) spad_data_th, ne00); - } else { - func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, ne00); - } - - src0_ptr += src0_row_size; - dst_ptr += dst_row_size; - } - - t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "binary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, - ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, - (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); + return MIN(bctx->block_max, block_limit); } -static void binary_add_id_job_f32_per_thread(struct htp_ops_context * octx, - uint8_t * spad_data, - uint32_t nth, - uint32_t ith, - hvx_elemwise_f32_func func_HVX) { +// Macro for scalar op switch +#define COMPUTE_SCALAR_OP(DST, SRC, VAL, N) \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_scalar_f32_aa(DST, SRC, VAL, N); break; \ + case HTP_OP_SUB: hvx_sub_scalar_f32_aa(DST, SRC, VAL, N); break; \ + case HTP_OP_MUL: hvx_mul_scalar_f32_aa(DST, SRC, VAL, N); break; \ + case HTP_OP_DIV: hvx_mul_scalar_f32_aa(DST, SRC, 1.0f / (VAL), N); break; \ + default: break; \ + } + +// Macro for vector op switch (All Aligned) +#define COMPUTE_VECTOR_OP_AAA(DST, SRC0, SRC1, N) \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f32_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f32_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f32_aaa(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f32_aaa(DST, SRC0, SRC1, N); break; \ + default: break; \ + } + +// Macro for vector op switch (Dst Aligned, Src0 Aligned, Src1 Unaligned) +#define COMPUTE_VECTOR_OP_AAU(DST, SRC0, SRC1, N) \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f32_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f32_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f32_aau(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f32_aau(DST, SRC0, SRC1, N); break; \ + default: break; \ + } + +// Macro for vector op switch (All Unaligned - generic loop used in element repeat) +#define COMPUTE_VECTOR_OP_UUU(DST, SRC0, SRC1, N) \ + switch (octx->op) { \ + case HTP_OP_ADD: hvx_add_f32_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_SUB: hvx_sub_f32_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_MUL: hvx_mul_f32_uuu(DST, SRC0, SRC1, N); break; \ + case HTP_OP_DIV: hvx_div_f32_uuu(DST, SRC0, SRC1, N); break; \ + default: break; \ + } + +// 1. Scalar src1 (ne10 == 1) +static void binary_job_scalar(unsigned int nth, unsigned int ith, void * data) { + struct htp_binary_context * bctx = (struct htp_binary_context *) data; + struct htp_ops_context * octx = bctx->octx; htp_binary_preamble; - const size_t src0_row_size = nb01; - const size_t src1_row_size = nb11; - const size_t dst_row_size = nb1; + const uint32_t total_rows = ne01 * ne02 * ne03; + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + if (start_row >= end_row) return; - const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; - const uint32_t src0_start_row = src0_nrows_per_thread * ith; - const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + dma_queue * q = octx->ctx->dma[ith]; + uint32_t ir_prefetch = start_row; + int spad_idx = 0; - // no work for this thread - if (src0_start_row >= src0_end_row) { - return; + // Preamble + for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { + uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir_prefetch, &bctx->dim12_div); + rem = ir_prefetch - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + + uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; + uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; + + dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + ir_prefetch += current_block_size; + spad_idx ^= 1; } - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); + // Main loop + for (uint32_t ir = start_row; ir < end_row; ) { + uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); - const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; - const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; - uint8_t * restrict data_dst = (uint8_t *) dst->data; + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; - const uint32_t ne02_ne01 = ne02 * ne01; - for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) { - // src0 indices - const uint32_t i03 = fastdiv(ir, &octx->src0_div21); - const uint32_t i02 = fastdiv(ir - i03 * ne02_ne01, &octx->src0_div1); - const uint32_t i01 = (ir - i03 * ne02_ne01 - i02 * ne01); + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir, &bctx->dim12_div); + rem = ir - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; - // src1 indices - const int i11 = *(int32_t *) ((char *) src2->data + i01 * src2->nb[0] + i02 * src2->nb[1]); - assert(i11 >= 0 && i11 < ne11); + // src1 indices (broadcast/repeat) + uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div); + uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div); + uint32_t i11 = fastmodulo(i01, ne11, &bctx->src1_dim1_div); - float * restrict dst_ptr = (float *) (data_dst + i03 * nb3 + i02 * nb2 + i01 * nb1); - const float * restrict src0_ptr = (const float *) (data_src0 + i03 * nb03 + i02 * nb02 + i01 * nb01); - const float * restrict src1_ptr = (const float *) (data_src1 + 0 + 0 + i11 * nb11); + uint8_t * src1_ptr = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11; + uint32_t s1_stride = (ne11 == 1) ? 0 : nb11; - if (ir + 1 < src0_end_row) { - hex_l2fetch(src0_ptr + ne00, src0_row_size, src0_row_size, 1); - if (src1_row_size == src0_row_size) { - hex_l2fetch(src1_ptr + ne10, src1_row_size, src1_row_size, 1); - } + for (uint32_t r = 0; r < current_block_size; r++) { + uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; + uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; + float val = *(float *)src1_ptr; + src1_ptr += s1_stride; + COMPUTE_SCALAR_OP(r_dst, r_src0, val, ne00); } - const uint32_t nr0 = ne00 / ne10; - if (nr0 > 1) { - for (uint32_t r = 0; r < nr0; r++) { - memcpy(spad_data + r * nb10, (const uint8_t *) src1_ptr, nb10); - } - func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) spad_data, ne00); - } else { - func_HVX((uint8_t *) dst_ptr, (const uint8_t *) src0_ptr, (const uint8_t *) src1_ptr, ne00); + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + + if (ir_prefetch < end_row) { + uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t p03, p02, p01, prem; + p03 = fastdiv(ir_prefetch, &bctx->dim12_div); + prem = ir_prefetch - p03 * (ne02 * ne01); + p02 = fastdiv(prem, &bctx->dim1_div); + p01 = prem - p02 * ne01; + uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; + + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + ir_prefetch += next_block_size; } + ir += current_block_size; } - - t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "add-id-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", ith, nth, - src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], - src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1], - dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); + dma_queue_flush(q); } -static void binary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = (struct htp_ops_context *) data; +// 2. Vector Same Shape (ne1x == ne0x) or Simple Broadcast +static void binary_job_vector_same_shape(unsigned int nth, unsigned int ith, void * data) { + struct htp_binary_context * bctx = (struct htp_binary_context *) data; + struct htp_ops_context * octx = bctx->octx; + htp_binary_preamble; - switch (octx->op) { - case HTP_OP_MUL: - case HTP_OP_ADD: - case HTP_OP_SUB: - binary_job_f32_per_thread(octx, octx->src1_spad.data, n, i, octx->op); - break; + const uint32_t total_rows = ne01 * ne02 * ne03; + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + if (start_row >= end_row) return; - case HTP_OP_ADD_ID: - binary_add_id_job_f32_per_thread(octx, octx->src0_spad.data, n, i, hvx_add_f32); - break; + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * src1_spad_base = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread); + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); - default: - FARF(ERROR, "Unknown Binary Op %u", octx->op); - break; + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t src1_spad_half = octx->src1_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + + dma_queue * q = octx->ctx->dma[ith]; + uint32_t ir_prefetch = start_row; + int spad_idx = 0; + + for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { + uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir_prefetch, &bctx->dim12_div); + rem = ir_prefetch - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + uint32_t i13 = (ne13 == 1) ? 0 : i03; + uint32_t i12 = (ne12 == 1) ? 0 : i02; + uint32_t i11 = (ne11 == 1) ? 0 : i01; + + uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; + uint8_t * src1_base = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + + uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; + uint8_t * s1_spad = src1_spad_base + spad_idx * src1_spad_half; + uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; + + dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + dma_queue_push(q, dma_make_ptr(s1_spad, src1_base), bctx->src1_row_size_aligned, bctx->src1_dma_stride, ne00 * sizeof(float), current_block_size); + ir_prefetch += current_block_size; + spad_idx ^= 1; } + + for (uint32_t ir = start_row; ir < end_row; ) { + uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; + uint8_t * s1_spad = (uint8_t *) dma_queue_pop(q).dst; + + for (uint32_t r = 0; r < current_block_size; r++) { + uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; + uint8_t * r_src1 = s1_spad + r * bctx->src1_row_size_aligned; + uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; + COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, ne00); + } + + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir, &bctx->dim12_div); + rem = ir - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + + if (ir_prefetch < end_row) { + uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t p03, p02, p01, prem; + p03 = fastdiv(ir_prefetch, &bctx->dim12_div); + prem = ir_prefetch - p03 * (ne02 * ne01); + p02 = fastdiv(prem, &bctx->dim1_div); + p01 = prem - p02 * ne01; + + uint32_t p13 = (ne13 == 1) ? 0 : p03; + uint32_t p12 = (ne12 == 1) ? 0 : p02; + uint32_t p11 = (ne11 == 1) ? 0 : p01; + + uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; + uint8_t * s1_next = (uint8_t *)src1->data + p13 * nb13 + p12 * nb12 + p11 * nb11; + + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + dma_queue_push(q, dma_make_ptr(s1_spad, s1_next), bctx->src1_row_size_aligned, bctx->src1_dma_stride, ne00 * sizeof(float), next_block_size); + + ir_prefetch += next_block_size; + } + ir += current_block_size; + } + dma_queue_flush(q); +} + +// 3. Row Broadcast (ne11 == 1, ne12 == 1, single row src1) +static void binary_job_vector_row_broadcast(unsigned int nth, unsigned int ith, void * data) { + struct htp_binary_context * bctx = (struct htp_binary_context *) data; + struct htp_ops_context * octx = bctx->octx; + htp_binary_preamble; + + const uint32_t total_rows = ne01 * ne02 * ne03; + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + if (start_row >= end_row) return; + + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * src1_spad = octx->src1_spad.data + (ith * octx->src1_spad.size_per_thread); + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); + + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + + dma_queue * q = octx->ctx->dma[ith]; + uint32_t ir_prefetch = start_row; + int spad_idx = 0; + + void * s1_ptr = (void *) src1_spad; + + for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { + uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir_prefetch, &bctx->dim12_div); + rem = ir_prefetch - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + + uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; + uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; + + dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + ir_prefetch += current_block_size; + spad_idx ^= 1; + } + + for (uint32_t ir = start_row; ir < end_row; ) { + uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; + + for (uint32_t r = 0; r < current_block_size; r++) { + uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; + uint8_t * r_src1 = (uint8_t *)s1_ptr; // Constant + uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; + COMPUTE_VECTOR_OP_AAA(r_dst, r_src0, r_src1, ne00); + } + + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir, &bctx->dim12_div); + rem = ir - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + + if (ir_prefetch < end_row) { + uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t p03, p02, p01, prem; + p03 = fastdiv(ir_prefetch, &bctx->dim12_div); + prem = ir_prefetch - p03 * (ne02 * ne01); + p02 = fastdiv(prem, &bctx->dim1_div); + p01 = prem - p02 * ne01; + uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + ir_prefetch += next_block_size; + } + ir += current_block_size; + } + dma_queue_flush(q); +} + +// 4. Vector Complex (ne10 == ne00, complex broadcast) +static void binary_job_vector_complex(unsigned int nth, unsigned int ith, void * data) { + struct htp_binary_context * bctx = (struct htp_binary_context *) data; + struct htp_ops_context * octx = bctx->octx; + htp_binary_preamble; + + const uint32_t total_rows = ne01 * ne02 * ne03; + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + if (start_row >= end_row) return; + + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + + dma_queue * q = octx->ctx->dma[ith]; + uint32_t ir_prefetch = start_row; + int spad_idx = 0; + + for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { + uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir_prefetch, &bctx->dim12_div); + rem = ir_prefetch - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + + uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; + uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; + + dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + ir_prefetch += current_block_size; + spad_idx ^= 1; + } + + for (uint32_t ir = start_row; ir < end_row; ) { + uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; + + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir, &bctx->dim12_div); + rem = ir - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + for (uint32_t r = 0; r < current_block_size; r++) { + uint32_t r_i01 = i01 + r; + uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div); + uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div); + uint32_t i11 = fastmodulo(r_i01, ne11, &bctx->src1_dim1_div); + + uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; + uint8_t * r_src1 = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11; + uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; + + // Read src1 from DDR (unaligned) + COMPUTE_VECTOR_OP_AAU(r_dst, r_src0, r_src1, ne00); + } + + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + + if (ir_prefetch < end_row) { + uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t p03, p02, p01, prem; + p03 = fastdiv(ir_prefetch, &bctx->dim12_div); + prem = ir_prefetch - p03 * (ne02 * ne01); + p02 = fastdiv(prem, &bctx->dim1_div); + p01 = prem - p02 * ne01; + uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + ir_prefetch += next_block_size; + } + ir += current_block_size; + } + dma_queue_flush(q); +} + +// 5. Element Repeat (ne10 != ne00) +static void binary_job_element_repeat(unsigned int nth, unsigned int ith, void * data) { + struct htp_binary_context * bctx = (struct htp_binary_context *) data; + struct htp_ops_context * octx = bctx->octx; + htp_binary_preamble; + + const uint32_t total_rows = ne01 * ne02 * ne03; + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + if (start_row >= end_row) return; + + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + + dma_queue * q = octx->ctx->dma[ith]; + uint32_t ir_prefetch = start_row; + int spad_idx = 0; + + for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { + uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir_prefetch, &bctx->dim12_div); + rem = ir_prefetch - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + + uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; + uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; + + dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + ir_prefetch += current_block_size; + spad_idx ^= 1; + } + + for (uint32_t ir = start_row; ir < end_row; ) { + uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; + + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir, &bctx->dim12_div); + rem = ir - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + for (uint32_t r = 0; r < current_block_size; r++) { + uint32_t r_i01 = i01 + r; + uint32_t i13 = fastmodulo(i03, ne13, &bctx->src1_dim3_div); + uint32_t i12 = fastmodulo(i02, ne12, &bctx->src1_dim2_div); + uint32_t i11 = fastmodulo(r_i01, ne11, &bctx->src1_dim1_div); + + uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; + uint8_t * r_src1_row = (uint8_t *)src1->data + i13 * nb13 + i12 * nb12 + i11 * nb11; + uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; + + // Repeat src1 row + for (uint32_t c = 0; c < ne00; c += ne10) { + uint32_t len = MIN(ne10, ne00 - c); + // Use UUU for speed and simplicity + COMPUTE_VECTOR_OP_UUU(r_dst + c * sizeof(float), r_src0 + c * sizeof(float), r_src1_row, len); + } + } + + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + + if (ir_prefetch < end_row) { + uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t p03, p02, p01, prem; + p03 = fastdiv(ir_prefetch, &bctx->dim12_div); + prem = ir_prefetch - p03 * (ne02 * ne01); + p02 = fastdiv(prem, &bctx->dim1_div); + p01 = prem - p02 * ne01; + uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + ir_prefetch += next_block_size; + } + ir += current_block_size; + } + dma_queue_flush(q); +} + +// 6. ADD_ID (src1 gathered via src2 indices) +static void binary_job_add_id(unsigned int nth, unsigned int ith, void * data) { + struct htp_binary_context * bctx = (struct htp_binary_context *) data; + struct htp_ops_context * octx = bctx->octx; + + const struct htp_tensor * src0 = &octx->src0; + const struct htp_tensor * src1 = &octx->src1; + const struct htp_tensor * src2 = &octx->src2; + struct htp_tensor * dst = &octx->dst; + + const uint32_t ne00 = src0->ne[0]; + const uint32_t ne01 = src0->ne[1]; + const uint32_t ne02 = src0->ne[2]; + const uint32_t ne03 = src0->ne[3]; + const uint32_t ne11 = src1->ne[1]; // for bounds check + + const uint32_t nb01 = src0->nb[1]; + const uint32_t nb02 = src0->nb[2]; + const uint32_t nb03 = src0->nb[3]; + const uint32_t nb11 = src1->nb[1]; // src1 row stride + const uint32_t nb1 = dst->nb[1]; + const uint32_t nb2 = dst->nb[2]; + const uint32_t nb3 = dst->nb[3]; + + const uint32_t total_rows = ne01 * ne02 * ne03; + const uint32_t start_row = bctx->nrows_per_thread * ith; + const uint32_t end_row = MIN(start_row + bctx->nrows_per_thread, total_rows); + if (start_row >= end_row) return; + + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); + size_t src0_spad_half = octx->src0_spad.size_per_thread / 2; + size_t dst_spad_half = octx->dst_spad.size_per_thread / 2; + + dma_queue * q = octx->ctx->dma[ith]; + uint32_t ir_prefetch = start_row; + int spad_idx = 0; + + for (int k = 0; k < 2 && ir_prefetch < end_row; k++) { + uint32_t current_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir_prefetch, &bctx->dim12_div); + rem = ir_prefetch - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + uint8_t * src0_curr = (uint8_t *)src0->data + i03 * nb03 + i02 * nb02 + i01 * nb01; + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + + uint8_t * s0_spad = src0_spad_base + spad_idx * src0_spad_half; + uint8_t * d_spad = dst_spad_base + spad_idx * dst_spad_half; + + dma_queue_push_vtcm_to_ddr(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, 0); + dma_queue_push(q, dma_make_ptr(s0_spad, src0_curr), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), current_block_size); + ir_prefetch += current_block_size; + spad_idx ^= 1; + } + + for (uint32_t ir = start_row; ir < end_row; ) { + uint32_t current_block_size = calc_block_size(bctx, ir, end_row, ne01, ne02); + uint8_t * d_spad = (uint8_t *) dma_queue_pop(q).src; + uint8_t * s0_spad = (uint8_t *) dma_queue_pop(q).dst; + + uint32_t i03, i02, i01, rem; + i03 = fastdiv(ir, &bctx->dim12_div); + rem = ir - i03 * (ne02 * ne01); + i02 = fastdiv(rem, &bctx->dim1_div); + i01 = rem - i02 * ne01; + + for (uint32_t r = 0; r < current_block_size; r++) { + uint32_t r_i01 = i01 + r; // linear within block since we split at ne01 + + const int32_t idx = *(int32_t *)((char *)src2->data + r_i01 * src2->nb[0] + i02 * src2->nb[1]); + + uint8_t * r_src1 = (uint8_t *)src1->data + idx * nb11; + uint8_t * r_src0 = s0_spad + r * bctx->src0_row_size_aligned; + uint8_t * r_dst = d_spad + r * bctx->dst_row_size_aligned; + + hvx_add_f32_aau(r_dst, r_src0, r_src1, ne00); + } + + uint8_t * dst_curr = (uint8_t *)dst->data + i03 * nb3 + i02 * nb2 + i01 * nb1; + dma_queue_push(q, dma_make_ptr(dst_curr, d_spad), nb1, bctx->dst_row_size_aligned, ne00 * sizeof(float), current_block_size); + + if (ir_prefetch < end_row) { + uint32_t next_block_size = calc_block_size(bctx, ir_prefetch, end_row, ne01, ne02); + uint32_t p03, p02, p01, prem; + p03 = fastdiv(ir_prefetch, &bctx->dim12_div); + prem = ir_prefetch - p03 * (ne02 * ne01); + p02 = fastdiv(prem, &bctx->dim1_div); + p01 = prem - p02 * ne01; + uint8_t * s0_next = (uint8_t *)src0->data + p03 * nb03 + p02 * nb02 + p01 * nb01; + dma_queue_push(q, dma_make_ptr(s0_spad, s0_next), bctx->src0_row_size_aligned, nb01, ne00 * sizeof(float), next_block_size); + ir_prefetch += next_block_size; + } + ir += current_block_size; + } + dma_queue_flush(q); } static int execute_op_binary_f32(struct htp_ops_context * octx) { - int err = HTP_STATUS_OK; - const struct htp_tensor * src0 = &octx->src0; const struct htp_tensor * src1 = &octx->src1; struct htp_tensor * dst = &octx->dst; - worker_callback_t binary_op_func; - const char * op_type = NULL; - - switch (octx->op) { - case HTP_OP_MUL: - binary_op_func = binary_job_dispatcher_f32; - op_type = "mul-f32"; - break; - - case HTP_OP_ADD: - binary_op_func = binary_job_dispatcher_f32; - op_type = "add-f32"; - break; - - case HTP_OP_SUB: - binary_op_func = binary_job_dispatcher_f32; - op_type = "sub-f32"; - break; - - case HTP_OP_ADD_ID: - binary_op_func = binary_job_dispatcher_f32; - op_type = "add-id-f32"; - break; - - default: - FARF(ERROR, "Unsupported binary-Op %u\n", octx->op); - return HTP_STATUS_NO_SUPPORT; - } - - const int n_threads = octx->n_threads; + const uint32_t n_threads = octx->n_threads; const uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; - const size_t src0_row_size = src0->nb[1]; - const size_t src1_row_size = src1->nb[1]; - const size_t dst_row_size = dst->nb[1]; + // Use packed row sizes for VTCM allocation + const size_t src0_row_size = src0->ne[0] * sizeof(float); + const size_t src1_row_size = src1->ne[0] * sizeof(float); + const size_t dst_row_size = dst->ne[0] * sizeof(float); - // VTCM scratchpads for all tensors - octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads; - octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads; - octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads; + // Align to VLEN + const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); - size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; + bool is_add_id = (octx->op == HTP_OP_ADD_ID); + bool is_scalar = !is_add_id && (src1->ne[0] == 1); - FARF(HIGH, - "%s: (%ux%ux%ux%u) * (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", - op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], - src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, - octx->dst_spad.size); + // Determine which kernel we will use to alloc memory and dispatch + bool use_vector_same = !is_add_id && !is_scalar && src1->ne[0] == src0->ne[0] && + (src1->ne[1] == src0->ne[1] || src1->ne[1] == 1) && + (src1->ne[2] == src0->ne[2] || src1->ne[2] == 1) && + (src1->ne[3] == src0->ne[3] || src1->ne[3] == 1); - // Make sure the reserved vtcm size is sufficient - if (octx->ctx->vtcm_size < spad_size) { - FARF(ERROR, "binary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, - octx->ctx->vtcm_size, spad_size); + bool is_row_bcast = use_vector_same && (src1->ne[1] == 1 && src1->ne[2] == 1 && src1->ne[3] == 1); + bool use_complex = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] == src0->ne[0]); + bool use_repeat = !is_add_id && !is_scalar && !use_vector_same && (src1->ne[0] != src0->ne[0]); + + size_t spad_row_total; + if (is_scalar) { + spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); + } else if (is_row_bcast) { + spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); + } else if (use_vector_same) { + spad_row_total = 2 * (src0_row_size_aligned + src1_row_size_aligned + dst_row_size_aligned); + } else if (is_add_id) { + spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); // src1 read directly + } else { + spad_row_total = 2 * (src0_row_size_aligned + dst_row_size_aligned); + } + + size_t rows_per_buffer = octx->ctx->vtcm_size / (n_threads * spad_row_total); + // Adjust for static src1 in row_bcast case + if (is_row_bcast) { + size_t needed_static = src1_row_size_aligned; + if (octx->ctx->vtcm_size < needed_static) return HTP_STATUS_VTCM_TOO_SMALL; + size_t avail = octx->ctx->vtcm_size - needed_static; + rows_per_buffer = avail / (n_threads * spad_row_total); + } + + if (rows_per_buffer < 1) { + FARF(ERROR, "binary-f32: VTCM too small\n"); + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.size_per_thread = rows_per_buffer * 2 * src0_row_size_aligned; + octx->dst_spad.size_per_thread = rows_per_buffer * 2 * dst_row_size_aligned; + + if (is_scalar || use_complex || use_repeat || is_add_id) { + octx->src1_spad.size_per_thread = 0; + } else if (is_row_bcast) { + octx->src1_spad.size_per_thread = 0; + } else { + octx->src1_spad.size_per_thread = rows_per_buffer * 2 * src1_row_size_aligned; + } + + octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; + if (is_row_bcast) { + octx->src1_spad.size = src1_row_size_aligned; + } else { + octx->src1_spad.size = n_threads * octx->src1_spad.size_per_thread; + } + octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; + + if (octx->ctx->vtcm_size < (octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size)) { return HTP_STATUS_VTCM_TOO_SMALL; } @@ -305,39 +757,71 @@ static int execute_op_binary_f32(struct htp_ops_context * octx) { octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; - if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - uint32_t n_jobs = MIN(n_threads, src0_nrows); - - octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; - - octx->src0_div21 = init_fastdiv_values(src0->ne[2] * src0->ne[1]); - octx->src0_div3 = init_fastdiv_values(src0->ne[3]); - octx->src0_div2 = init_fastdiv_values(src0->ne[2]); - octx->src0_div1 = init_fastdiv_values(src0->ne[1]); - - octx->src1_div21 = init_fastdiv_values(src1->ne[2] * src1->ne[1]); - octx->src1_div3 = init_fastdiv_values(src1->ne[3]); - octx->src1_div2 = init_fastdiv_values(src1->ne[2]); - octx->src1_div1 = init_fastdiv_values(src1->ne[1]); - - worker_pool_run_func(octx->ctx->worker_pool, binary_op_func, octx, n_jobs); + if ((octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { + return HTP_STATUS_OK; } - return err; + uint32_t n_jobs = MIN(n_threads, src0_nrows); + + dma_queue * q = octx->ctx->dma[0]; + if (is_row_bcast) { + dma_queue_push(q, dma_make_ptr(octx->src1_spad.data, (const void *) src1->data), src1_row_size_aligned, 0, src1->ne[0] * sizeof(float), 1); + } + + struct htp_binary_context bctx; + bctx.octx = octx; + bctx.nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + bctx.block_max = rows_per_buffer; + bctx.src0_row_size_aligned = src0_row_size_aligned; + bctx.src1_row_size_aligned = src1_row_size_aligned; + bctx.dst_row_size_aligned = dst_row_size_aligned; + + bctx.dim1_div = init_fastdiv_values(src0->ne[1]); + bctx.dim2_div = init_fastdiv_values(src0->ne[2]); + bctx.dim12_div = init_fastdiv_values(src0->ne[1] * src0->ne[2]); + + bctx.src1_dim1_div = init_fastdiv_values(src1->ne[1]); + bctx.src1_dim2_div = init_fastdiv_values(src1->ne[2]); + bctx.src1_dim3_div = init_fastdiv_values(src1->ne[3]); + + bool src0_contig_dim1 = (src0->nb[2] == src0->ne[1] * src0->nb[1]); + bool dst_contig_dim1 = (dst->nb[2] == src0->ne[1] * dst->nb[1]); + + bool src0_contig_dim2 = (src0->nb[3] == src0->ne[2] * src0->nb[2]); + bool dst_contig_dim2 = (dst->nb[3] == src0->ne[2] * dst->nb[2]); + + bctx.split_at_ne01 = (src0->ne[2] > 1) && + ((src1->ne[1] > 1) || (src1->ne[2] > 1) || !src0_contig_dim1 || !dst_contig_dim1); + + bctx.split_at_ne02 = (src0->ne[3] > 1) && + ((src1->ne[2] > 1) || (src1->ne[3] > 1) || !src0_contig_dim2 || !dst_contig_dim2); + + // Precompute specific kernel parameters + if (use_vector_same) { + bctx.src1_dma_stride = (src1->ne[1] == 1) ? 0 : src1->nb[1]; + bctx.src1_fetch_rows = (src1->ne[1] == 1) ? 1 : rows_per_buffer; + } + + worker_callback_t worker_func; + if (is_add_id) worker_func = binary_job_add_id; + else if (is_scalar) worker_func = binary_job_scalar; + else if (is_row_bcast) worker_func = binary_job_vector_row_broadcast; + else if (use_vector_same) worker_func = binary_job_vector_same_shape; + else if (use_complex) worker_func = binary_job_vector_complex; + else worker_func = binary_job_element_repeat; + + if (is_row_bcast) { + dma_queue_pop(q); + } + + worker_pool_run_func(octx->ctx->worker_pool, worker_func, &bctx, n_jobs); + + return HTP_STATUS_OK; } int op_binary(struct htp_ops_context * octx) { - int err = HTP_STATUS_OK; - - switch (octx->src0.type) { - case HTP_TYPE_F32: - err = execute_op_binary_f32(octx); - break; - - default: - err = HTP_STATUS_NO_SUPPORT; - break; + if (octx->src0.type == HTP_TYPE_F32) { + return execute_op_binary_f32(octx); } - - return err; + return HTP_STATUS_NO_SUPPORT; } diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index c184637443..74c777d4c3 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -17,121 +17,6 @@ #include "htp-msg.h" #include "htp-ops.h" -static inline HVX_Vector hvx_load_f32_to_f16(const HVX_Vector * restrict src, const HVX_Vector zero) { - HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(src[0], zero); // 32 elements - HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(src[1], zero); // 32 elements - return Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); -} - -// Dot product of FP32 and FP16 vectors, accumulating to float -static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) { - const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32 - const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16 - - uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors - uint32_t nloe = n % VLEN_FP16; // leftover elements - - const HVX_Vector zero = Q6_V_vsplat_R(0); - HVX_Vector rsum = Q6_V_vsplat_R(0); - - uint32_t i = 0; - - #pragma unroll(4) - for (i = 0; i < nvec; i++) { - // Load y (fp32) and convert into fp16 - HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); - - // Load x (fp16) - HVX_Vector x_hf = vx[i]; - - HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); - - rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum)); - } - - if (nloe) { - // Load y (fp32) and convert into fp16 - HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); - - // Load x (fp16) - HVX_Vector x_hf = vx[i]; - - // Zero-out unused elements - // Note that we need to clear both x and y because they may contain NANs - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); - x_hf = Q6_V_vand_QV(bmask, x_hf); - y_hf = Q6_V_vand_QV(bmask, y_hf); - - HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); - - rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum)); - } - - rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum)); - hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum)); -} - -// Dot product of FP32 and FP16 vectors, accumulating to float -static inline void hvx_dot_f32_f16_aa_rx2(float * restrict r, - const void * restrict y, - const void * restrict x0, - const void * restrict x1, - unsigned int n, - float s) { - const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32 - const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16 - const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16 - - uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors - uint32_t nloe = n % VLEN_FP16; // leftover elements - - const HVX_Vector zero = Q6_V_vsplat_R(0); - HVX_Vector rsum0 = Q6_V_vsplat_R(0); - HVX_Vector rsum1 = Q6_V_vsplat_R(0); - - uint32_t i = 0; - - #pragma unroll(2) - for (i = 0; i < nvec; i++) { - // Load y (fp32) and convert into fp16 - HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); - // Load x (fp16) - HVX_Vector x0_hf = vx0[i]; - HVX_Vector x1_hf = vx1[i]; - - HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); - HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); - - rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0)); - rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1)); - } - - if (nloe) { - // Load y (fp32) and convert into fp16 - HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); - - // Load x (fp16) - HVX_Vector x0_hf = vx0[i]; - HVX_Vector x1_hf = vx1[i]; - - // Zero-out unused elements - // Note that we need to clear both x and y because they may contain NANs - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); - x0_hf = Q6_V_vand_QV(bmask, x0_hf); - x1_hf = Q6_V_vand_QV(bmask, x1_hf); - y_hf = Q6_V_vand_QV(bmask, y_hf); - - HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); - HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); - - rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0)); - rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1)); - } - - HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1)); - hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum)); -} - // Dot product of two F16 vectors, accumulating to float static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) { const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16 @@ -140,8 +25,7 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors uint32_t nloe = n % VLEN_FP16; // leftover elements - const HVX_Vector zero = Q6_V_vsplat_R(0); - HVX_Vector rsum = Q6_V_vsplat_R(0); + HVX_Vector rsum = Q6_V_vsplat_R(0); uint32_t i = 0; @@ -156,11 +40,10 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict } if (nloe) { - HVX_Vector y_hf = vy[i]; - // Load x (fp16) and zero-out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); - HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]); + HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]); + HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]); HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); @@ -181,12 +64,11 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r, const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16 const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16 - uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors - uint32_t nloe = n % VLEN_FP16; // leftover elements + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements - const HVX_Vector zero = Q6_V_vsplat_R(0); - HVX_Vector rsum0 = Q6_V_vsplat_R(0); - HVX_Vector rsum1 = Q6_V_vsplat_R(0); + HVX_Vector rsum0 = Q6_V_vsplat_R(0); + HVX_Vector rsum1 = Q6_V_vsplat_R(0); uint32_t i = 0; @@ -204,12 +86,11 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r, } if (nloe) { - HVX_Vector y_hf = vy[i]; - // Load x (fp16) and zero-out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); - HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]); - HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]); + HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]); + HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]); + HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]); HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); @@ -222,7 +103,7 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r, hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum)); } -// MAD: y (F32) += x (F16) * s (float) +// MAD: y (F32) += x (F16) * s (F32) static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) { const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x; HVX_Vector * restrict ptr_y = (HVX_Vector *) y; @@ -259,15 +140,125 @@ static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict } } +// MAD: y (F32) += x0 (F16) * s0 (F32) + x1 (F16) * s1 (F32) +static inline void hvx_mad_f32_f16_aa_rx2(float * restrict y, + const void * restrict x0, + const void * restrict x1, + float s0, + float s1, + int n) { + const HVX_Vector * restrict ptr_x0 = (const HVX_Vector *) x0; + const HVX_Vector * restrict ptr_x1 = (const HVX_Vector *) x1; + HVX_Vector * restrict ptr_y = (HVX_Vector *) y; + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + HVX_Vector S0 = hvx_vec_splat_f16(s0); + HVX_Vector S1 = hvx_vec_splat_f16(s1); + + uint32_t i = 0; + #pragma unroll(2) + for (i = 0; i < nvec; ++i) { + // Multiply x * s -> pair of F32 vectors + HVX_VectorPair xs0_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x0[i]), S0); + HVX_VectorPair xs1_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x1[i]), S1); + + HVX_Vector xs_p_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xs0_p), Q6_V_lo_W(xs1_p)); + HVX_Vector xs_p_hi = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p)); + + ptr_y[i * 2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs_p_lo, ptr_y[i * 2])); + ptr_y[i * 2 + 1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs_p_hi, ptr_y[i * 2 + 1])); + } + + if (nloe) { + HVX_VectorPair xs0_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x0[i]), S0); + HVX_VectorPair xs1_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x1[i]), S1); + + HVX_Vector xs_p_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xs0_p), Q6_V_lo_W(xs1_p)); + HVX_Vector xs = xs_p_lo; + i = 2 * i; // index for ptr_y + + if (nloe >= 32) { + ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i])); + nloe -= 32; ++i; + xs = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p)); + } + + if (nloe) { + HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i])); + hvx_vec_store_a(&ptr_y[i], nloe * 4, xy); + } + } +} + #define FLASH_ATTN_BLOCK_SIZE 128 -static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, int nth) { +struct htp_fa_context { + const struct htp_ops_context * octx; + + struct fastdiv_values src0_div21; + struct fastdiv_values src0_div1; + + struct fastdiv_values broadcast_rk2; + struct fastdiv_values broadcast_rk3; + struct fastdiv_values broadcast_rv2; + struct fastdiv_values broadcast_rv3; + + struct fastdiv_values src3_div2; + struct fastdiv_values src3_div3; + + float scale; + float max_bias; + float logit_softcap; + + uint32_t n_head_log2; + float m0; + float m1; + + uint32_t n_blocks; + + size_t size_q_row_padded; + size_t size_k_row_padded; + size_t size_v_row_padded; + + size_t size_k_block; + size_t size_v_block; + size_t size_m_block; + + bool is_q_fp32; +}; + +static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, HVX_Vector vs) { + assert((size_t) dst % 128 == 0); + assert((size_t) src % 128 == 0); + + const HVX_Vector * restrict vsrc = (const HVX_Vector * restrict) src; + HVX_Vector * restrict vdst = (HVX_Vector * restrict) dst; + + const uint32_t nvec = n / VLEN_FP32; + const uint32_t nloe = n % VLEN_FP32; + + uint32_t i = 0; + #pragma unroll(4) + for (; i < nvec; ++i) { + vdst[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs)); + } + if (nloe) { + HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); + hvx_vec_store_a(&vdst[i], nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v)); + } +} + +static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * data) { + struct htp_fa_context * factx = (struct htp_fa_context *) data; + const struct htp_ops_context * octx = factx->octx; const struct htp_tensor * q = &octx->src0; const struct htp_tensor * k = &octx->src1; const struct htp_tensor * v = &octx->src2; const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL; const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * dst = &octx->dst; const uint32_t neq0 = q->ne[0]; const uint32_t neq1 = q->ne[1]; @@ -304,18 +295,6 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in const uint32_t nb2 = dst->nb[2]; const uint32_t nb3 = dst->nb[3]; - float scale = 1.0f; - float max_bias = 0.0f; - float logit_softcap = 0.0f; - - memcpy(&scale, (float *) octx->op_params + 0, sizeof(float)); - memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float)); - memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float)); - - if (logit_softcap != 0) { - scale /= logit_softcap; - } - // total rows in q const uint32_t nr = neq1*neq2*neq3; @@ -331,18 +310,8 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in const uint32_t DV = nev0; const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2); - const size_t size_q_row_padded = hex_round_up(size_q_row, 128); - const size_t size_k_row = DK * sizeof(__fp16); const size_t size_v_row = DV * sizeof(__fp16); - const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask - - const size_t size_k_row_padded = hex_round_up(size_k_row, 128); - const size_t size_v_row_padded = hex_round_up(size_v_row, 128); - - const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE; - const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE; - const size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128); // Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith; @@ -351,31 +320,28 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith; uint8_t * spad_a = octx->dst_spad.data + octx->dst_spad.size_per_thread * ith; - const uint32_t n_head = neq2; - const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + const HVX_Vector logit_cap = hvx_vec_splat_f32(factx->logit_softcap); for (uint32_t ir = ir0; ir < ir1; ++ir) { - const uint32_t iq3 = fastdiv(ir, &octx->src0_div21); - const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1); + const uint32_t iq3 = fastdiv(ir, &factx->src0_div21); + const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &factx->src0_div1); const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1); - const uint32_t ik3 = fastdiv(iq3, &octx->broadcast_rk3); - const uint32_t ik2 = fastdiv(iq2, &octx->broadcast_rk2); + const uint32_t ik3 = fastdiv(iq3, &factx->broadcast_rk3); + const uint32_t ik2 = fastdiv(iq2, &factx->broadcast_rk2); - const uint32_t iv3 = fastdiv(iq3, &octx->broadcast_rv3); - const uint32_t iv2 = fastdiv(iq2, &octx->broadcast_rv2); + const uint32_t iv3 = fastdiv(iq3, &factx->broadcast_rv3); + const uint32_t iv2 = fastdiv(iq2, &factx->broadcast_rv2); // Fetch Q row const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3); - dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), size_q_row_padded, nbq1, size_q_row, 1); + dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), factx->size_q_row_padded, nbq1, size_q_row, 1); const uint32_t h = iq2; // head index - const float slope = (max_bias > 0.0f) ? (h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)) : 1.0f; + const float slope = (factx->max_bias > 0.0f) ? (h < factx->n_head_log2 ? powf(factx->m0, h + 1) : powf(factx->m1, 2*(h - factx->n_head_log2) + 1)) : 1.0f; - float S = 0.0f; // sum - float M = -INFINITY; // maximum KQ value + HVX_Vector S_vec = hvx_vec_splat_f32(0.0f); + HVX_Vector M_vec = hvx_vec_splat_f32(-INFINITY); // Clear accumulator hvx_splat_f32_a(spad_a, 0, DV); @@ -383,40 +349,42 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in const __fp16 * mp_base = NULL; if (mask) { - const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &octx->src3_div2); - const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &octx->src3_div3); + const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &factx->src3_div2); + const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &factx->src3_div3); mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]); } - const uint32_t n_blocks = (nek1 + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE; - // Prefetch first two blocks - for (uint32_t ib = 0; ib < MIN(n_blocks, 2); ++ib) { + for (uint32_t ib = 0; ib < MIN(factx->n_blocks, 2); ++ib) { const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE; const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start); // K const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3); - uint8_t * k_dst = spad_k + (ib % 2) * size_k_block; - dma_queue_push(dma, dma_make_ptr(k_dst, k_src), size_k_row_padded, nbk1, size_k_row, current_block_size); + uint8_t * k_dst = spad_k + (ib % 2) * factx->size_k_block; + dma_queue_push(dma, dma_make_ptr(k_dst, k_src), factx->size_k_row_padded, nbk1, size_k_row, current_block_size); // V const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3); - uint8_t * v_dst = spad_v + (ib % 2) * size_v_block; - dma_queue_push(dma, dma_make_ptr(v_dst, v_src), size_v_row_padded, nbv1, size_v_row, current_block_size); + uint8_t * v_dst = spad_v + (ib % 2) * factx->size_v_block; + dma_queue_push(dma, dma_make_ptr(v_dst, v_src), factx->size_v_row_padded, nbv1, size_v_row, current_block_size); // Mask if (mask) { const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start); - uint8_t * m_dst = spad_m + (ib % 2) * size_m_block; + uint8_t * m_dst = spad_m + (ib % 2) * factx->size_m_block; // Mask is 1D contiguous for this row dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1); } } - const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst; + uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst; + if (factx->is_q_fp32) { + hvx_copy_f16_f32_aa(q_ptr_vtcm, q_ptr_vtcm, DK); // inplace convert f32 to f16 + } - for (uint32_t ib = 0; ib < n_blocks; ++ib) { + const HVX_Vector slope_vec = hvx_vec_splat_f16(slope); + for (uint32_t ib = 0; ib < factx->n_blocks; ++ib) { const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE; const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start); @@ -428,8 +396,6 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in // Inner loop processing the block from VTCM uint32_t ic = 0; - const bool is_q_fp32 = (q->type == HTP_TYPE_F32); - // Process in blocks of 32 (VLEN_FP32) static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 <= 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage"); HVX_Vector_x4 scores_x4; @@ -437,22 +403,18 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) { // 1. Compute scores float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32]; - for (int j = 0; j < VLEN_FP32; j += 2) { + for (uint32_t j = 0; j < VLEN_FP32; j += 2) { const uint32_t cur_ic = ic + j; - const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded; - if (is_q_fp32) { - hvx_dot_f32_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale); - } else { - hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale); - } + const uint8_t * k_ptr = k_base + cur_ic * factx->size_k_row_padded; + hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + factx->size_k_row_padded, DK, factx->scale); } HVX_Vector scores = *(HVX_Vector *) scores_arr; // 2. Softcap - if (logit_softcap != 0.0f) { + if (factx->logit_softcap != 0.0f) { scores = hvx_vec_tanh_f32(scores); - scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_f32(logit_softcap)); + scores = Q6_Vqf32_vmpy_VsfVsf(scores, logit_cap); scores = Q6_Vsf_equals_Vqf32(scores); } @@ -460,70 +422,59 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in if (mask) { const __fp16 * mp = m_base + ic; HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp; - - HVX_Vector one_f16 = Q6_Vh_vsplat_R(0x3c00); - HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), one_f16); - - HVX_Vector m_vals_f32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_f32_pair)); - - HVX_Vector slope_vec = hvx_vec_splat_f32(slope); - HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_f32, slope_vec); - scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val)); + HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec); + HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair); + scores = Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores); scores = Q6_Vsf_equals_Vqf32(scores); } scores_x4.v[iv] = scores; - v_max = Q6_Vsf_vmax_VsfVsf(scores, v_max); + v_max = hvx_vec_reduce_max2_f32(scores, v_max); // All lanes have block max } { // 4. Online Softmax Update - v_max = hvx_vec_reduce_max_f32(v_max); - float m_block = hvx_vec_get_f32(v_max); - float M_old = M; - float M_new = (m_block > M) ? m_block : M; - M = M_new; + HVX_Vector M_new_vec = Q6_Vsf_vmax_VsfVsf(v_max, M_vec); + HVX_Vector diff_vec = Q6_Vqf32_vsub_VsfVsf(M_vec, M_new_vec); + HVX_Vector ms_vec = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(diff_vec)); + M_vec = M_new_vec; - const float ms = expf(M_old - M_new); - hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms); + hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec); - HVX_Vector M_new_vec = hvx_vec_splat_f32(M_new); HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f); for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) { HVX_Vector scores = scores_x4.v[iv]; - HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec); + HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_vec); HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted)); p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P)); // 5. Accumulate V float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32]; - *(HVX_Vector*)p_arr = P; + *(HVX_Vector *) p_arr = P; - for (int j = 0; j < VLEN_FP32; ++j) { - const uint32_t cur_ic = ic2 + j; - const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded; - hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]); + for (uint32_t j = 0; j < VLEN_FP32; j += 2) { + const uint32_t cur_ic = ic2 + j; + const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded; + hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + factx->size_v_row_padded, p_arr[j], p_arr[j + 1], DV); } } p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec); - S = S * ms + hvx_vec_get_f32(p_sum_vec); + S_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(S_vec, ms_vec)), p_sum_vec)); } + // Sync scalars for leftover/next block if needed + float M = hvx_vec_get_f32(M_vec); + float S = hvx_vec_get_f32(S_vec); + // Leftover for (; ic < current_block_size; ++ic) { float s_val; - const uint8_t * k_ptr = k_base + ic * size_k_row_padded; - - if (is_q_fp32) { - hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale); - } else { - hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale); - } - - if (logit_softcap != 0.0f) { - s_val = logit_softcap * tanhf(s_val); + const uint8_t * k_ptr = k_base + ic * factx->size_k_row_padded; + hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, factx->scale); + if (factx->logit_softcap != 0.0f) { + s_val = factx->logit_softcap * tanhf(s_val); } if (mask) { @@ -532,37 +483,42 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in } const float Mold = M; - float ms = 1.0f; float vs = 1.0f; if (s_val > M) { M = s_val; - ms = expf(Mold - M); - hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms); + HVX_Vector diff_vec = hvx_vec_splat_f32(Mold - M); + HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec); + hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec); + + float ms = hvx_vec_get_f32(ms_vec); + S = S * ms + vs; } else { - vs = expf(s_val - M); + HVX_Vector diff_vec = hvx_vec_splat_f32(s_val - M); + vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec)); + S += vs; } - const uint8_t * v_ptr = v_base + ic * size_v_row_padded; + const uint8_t * v_ptr = v_base + ic * factx->size_v_row_padded; hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs); - - S = S * ms + vs; } + M_vec = hvx_vec_splat_f32(M); + S_vec = hvx_vec_splat_f32(S); // Issue DMA for next+1 block (if exists) - if (ib + 2 < n_blocks) { + if (ib + 2 < factx->n_blocks) { const uint32_t next_ib = ib + 2; const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE; const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start); // K const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3); - dma_queue_push(dma, dma_make_ptr(k_base, k_src), size_k_row_padded, nbk1, size_k_row, next_block_size); + dma_queue_push(dma, dma_make_ptr(k_base, k_src), factx->size_k_row_padded, nbk1, size_k_row, next_block_size); // V const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3); - dma_queue_push(dma, dma_make_ptr(v_base, v_src), size_v_row_padded, nbv1, size_v_row, next_block_size); + dma_queue_push(dma, dma_make_ptr(v_base, v_src), factx->size_v_row_padded, nbv1, size_v_row, next_block_size); // Mask if (mask) { @@ -573,20 +529,26 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in } // sinks + float M = hvx_vec_get_f32(M_vec); + float S = hvx_vec_get_f32(S_vec); + if (sinks) { const float s = ((float *)((char *) sinks->data))[h]; - float ms = 1.0f; float vs = 1.0f; if (s > M) { - ms = expf(M - s); - hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms); - } else { - vs = expf(s - M); - } + HVX_Vector diff_vec = hvx_vec_splat_f32(M - s); + HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec); + hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec); - S = S * ms + vs; + float ms = hvx_vec_get_f32(ms_vec); + S = S * ms + vs; + } else { + HVX_Vector diff_vec = hvx_vec_splat_f32(s - M); + vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec)); + S += vs; + } } const float S_inv = S == 0.0f ? 0.0f : 1.0f/S; @@ -609,53 +571,73 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in } } -static void htp_flash_attn_ext_job(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - flash_attn_ext_f16_thread(octx, i, n); -} - int op_flash_attn_ext(struct htp_ops_context * octx) { const struct htp_tensor * q = &octx->src0; const struct htp_tensor * k = &octx->src1; const struct htp_tensor * v = &octx->src2; - const struct htp_tensor * mask = (octx->src3.type != HTP_TYPE_COUNT) ? &octx->src3 : NULL; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL; + const struct htp_tensor * dst = &octx->dst; // Check support - if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) || - k->type != HTP_TYPE_F16 || - v->type != HTP_TYPE_F16) { + if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) || k->type != HTP_TYPE_F16 || v->type != HTP_TYPE_F16) { return HTP_STATUS_NO_SUPPORT; } - octx->src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]); - octx->src0_div1 = init_fastdiv_values(q->ne[1]); + struct htp_fa_context factx; + factx.octx = octx; - octx->broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]); - octx->broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]); - octx->broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]); - octx->broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]); + factx.src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]); + factx.src0_div1 = init_fastdiv_values(q->ne[1]); + + factx.broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]); + factx.broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]); + factx.broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]); + factx.broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]); if (mask) { - octx->src3_div2 = init_fastdiv_values(mask->ne[2]); - octx->src3_div3 = init_fastdiv_values(mask->ne[3]); + factx.src3_div2 = init_fastdiv_values(mask->ne[2]); + factx.src3_div3 = init_fastdiv_values(mask->ne[3]); } - size_t size_q_row_padded = hex_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128); - size_t size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128); - size_t size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128); + factx.is_q_fp32 = (q->type == HTP_TYPE_F32); + factx.size_q_row_padded = hex_round_up(q->ne[0] * (factx.is_q_fp32 ? 4 : 2), 128); + factx.size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128); + factx.size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128); - size_t size_q_block = size_q_row_padded * 1; // single row for now - size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE; - size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE; - size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128); + size_t size_q_block = factx.size_q_row_padded * 1; // single row for now + factx.size_k_block = factx.size_k_row_padded * FLASH_ATTN_BLOCK_SIZE; + factx.size_v_block = factx.size_v_row_padded * FLASH_ATTN_BLOCK_SIZE; + factx.size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128); + + factx.n_blocks = (k->ne[1] + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE; + + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (float *) octx->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float)); + + if (logit_softcap != 0.0f) { + scale /= logit_softcap; + } + + factx.scale = scale; + factx.max_bias = max_bias; + factx.logit_softcap = logit_softcap; + + uint32_t n_head = q->ne[2]; + factx.n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + factx.m0 = powf(2.0f, -(max_bias ) / factx.n_head_log2); + factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2); size_t size_vkq_acc = hex_round_up(v->ne[0] * sizeof(float), 128); // VKQ32 octx->src0_spad.size_per_thread = size_q_block * 1; - octx->src1_spad.size_per_thread = size_k_block * 2; - octx->src2_spad.size_per_thread = size_v_block * 2; - octx->src3_spad.size_per_thread = mask ? size_m_block * 2 : 0; + octx->src1_spad.size_per_thread = factx.size_k_block * 2; + octx->src2_spad.size_per_thread = factx.size_v_block * 2; + octx->src3_spad.size_per_thread = mask ? factx.size_m_block * 2 : 0; octx->dst_spad.size_per_thread = size_vkq_acc; octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; @@ -677,7 +659,7 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size; if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - worker_pool_run_func(octx->ctx->worker_pool, htp_flash_attn_ext_job, octx, octx->n_threads); + worker_pool_run_func(octx->ctx->worker_pool, flash_attn_ext_f16_thread, &factx, octx->n_threads); } return HTP_STATUS_OK; diff --git a/ggml/src/ggml-hexagon/htp/htp-msg.h b/ggml/src/ggml-hexagon/htp/htp-msg.h index f49e8ee447..25403bb112 100644 --- a/ggml/src/ggml-hexagon/htp/htp-msg.h +++ b/ggml/src/ggml-hexagon/htp/htp-msg.h @@ -42,32 +42,36 @@ enum htp_data_type { HTP_TYPE_COUNT }; -// These values are manually translated over to HTP -// !!!! DO NOT ALTER THE ORDER OF THE FIRST FOUR ENUMS !!!! +// Do not reorder first 4 (used as an index) enum htp_op { - HTP_OP_MUL = 0, - HTP_OP_ADD = 1, - HTP_OP_SUB = 2, - HTP_OP_DIV = 3, - HTP_OP_MUL_MAT = 4, - HTP_OP_MUL_MAT_ID = 5, - HTP_OP_RMS_NORM = 6, - HTP_OP_UNARY_SILU = 7, - HTP_OP_UNARY_GELU = 8, - HTP_OP_GLU_SWIGLU = 9, - HTP_OP_GLU_SWIGLU_OAI = 10, - HTP_OP_SOFTMAX = 11, - HTP_OP_ADD_ID = 12, - HTP_OP_ROPE = 13, - HTP_OP_FLASH_ATTN_EXT = 14, - HTP_OP_SET_ROWS = 15, - HTP_OP_SCALE = 16, - HTP_OP_GET_ROWS = 17, - HTP_OP_CPY = 18, + HTP_OP_MUL = 0, + HTP_OP_ADD = 1, + HTP_OP_SUB = 2, + HTP_OP_DIV = 3, + HTP_OP_MUL_MAT, + HTP_OP_MUL_MAT_ID, + HTP_OP_RMS_NORM, + HTP_OP_UNARY_SILU, + HTP_OP_UNARY_GELU, + HTP_OP_GLU_SWIGLU, + HTP_OP_GLU_SWIGLU_OAI, + HTP_OP_GLU_GEGLU, + HTP_OP_SOFTMAX, + HTP_OP_ADD_ID, + HTP_OP_ROPE, + HTP_OP_FLASH_ATTN_EXT, + HTP_OP_SET_ROWS, + HTP_OP_GET_ROWS, + HTP_OP_SCALE, + HTP_OP_CPY, + HTP_OP_ARGSORT, + HTP_OP_SQR, + HTP_OP_SQRT, + HTP_OP_SUM_ROWS, INVALID }; -static inline size_t htp_type_block_size(uint32_t t) { +static inline size_t htp_t_block_size(uint32_t t) { switch (t) { case HTP_TYPE_F32: return 1; @@ -103,22 +107,6 @@ static inline size_t htp_type_nbytes(uint32_t t) { return 0; } -static const char * htp_type_name(uint32_t t) { - switch (t) { - case HTP_TYPE_F32: - return "fp32"; - case HTP_TYPE_F16: - return "fp16"; - case HTP_TYPE_Q4_0: - return "q4_0"; - case HTP_TYPE_Q8_0: - return "q8_0"; - case HTP_TYPE_MXFP4: - return "mxfp4"; - } - return 0; -} - // Internal types #define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128) #define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index 602a2775a4..f1ad24dbfa 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -64,25 +64,12 @@ struct htp_ops_context { struct fastdiv_values broadcast_rv2; struct fastdiv_values broadcast_rv3; - struct fastdiv_values mm_div_ne12_ne1; // fastdiv values for ne12 * ne1 - struct fastdiv_values mm_div_ne1; // fastdiv values for ne1 - struct fastdiv_values mm_div_r2; // fastdiv values for ne12 / ne02 - struct fastdiv_values mm_div_r3; // fastdiv values for ne13 / ne03 - struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12 struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11 struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10 struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11 - struct fastdiv_values cpy_div_ne01; // fastdiv values for ne01 - struct fastdiv_values cpy_div_ne02; // fastdiv values for ne02 - struct fastdiv_values cpy_div_ne03; // fastdiv values for ne03 - - struct fastdiv_values cpy_rshp_div_n0; // fastdiv values for ne00 - struct fastdiv_values cpy_rshp_div_n1n0; // fastdiv values for ne00*ne01 - struct fastdiv_values cpy_rshp_div_n2n1n0; // fastdiv values for ne00*ne01*ne02 - uint32_t flags; }; @@ -90,6 +77,7 @@ int op_matmul(struct htp_ops_context * octx); int op_matmul_id(struct htp_ops_context * octx); int op_binary(struct htp_ops_context * octx); int op_unary(struct htp_ops_context * octx); +int op_sum_rows(struct htp_ops_context * octx); int op_activations(struct htp_ops_context * octx); int op_softmax(struct htp_ops_context * octx); int op_add_id(struct htp_ops_context * octx); @@ -98,5 +86,6 @@ int op_flash_attn_ext(struct htp_ops_context * octx); int op_set_rows(struct htp_ops_context * octx); int op_get_rows(struct htp_ops_context * octx); int op_cpy(struct htp_ops_context * octx); +int op_argsort(struct htp_ops_context * octx); #endif /* HTP_OPS_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-arith.h b/ggml/src/ggml-hexagon/htp/hvx-arith.h index 3449739a4f..2577cdd041 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-arith.h +++ b/ggml/src/ggml-hexagon/htp/hvx-arith.h @@ -46,127 +46,76 @@ #define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) #endif -// ADD variants +// Generic macro to define alignment permutations for an op +#define DEFINE_HVX_BINARY_OP_VARIANTS(OP_NAME, OP_MACRO) \ +static inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src0 % 128 == 0); \ + assert((uintptr_t) src1 % 128 == 0); \ + hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \ +} \ +static inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src0 % 128 == 0); \ + hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \ +} \ +static inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + assert((uintptr_t) src1 % 128 == 0); \ + hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \ +} \ +static inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) dst % 128 == 0); \ + hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \ +} \ +static inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) src0 % 128 == 0); \ + assert((uintptr_t) src1 % 128 == 0); \ + hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \ +} \ +static inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) src0 % 128 == 0); \ + hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \ +} \ +static inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + assert((uintptr_t) src1 % 128 == 0); \ + hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \ +} \ +static inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \ + hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \ +} \ -static inline void hvx_add_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src0 % 128 == 0); - assert((unsigned long) src1 % 128 == 0); - hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_ADD); +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f32, HVX_OP_ADD) +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f32, HVX_OP_SUB) +DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f32, HVX_OP_MUL) + +// Dispatcher logic +#define HVX_BINARY_DISPATCHER(OP_NAME) \ +static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { \ + if (hex_is_aligned((void *) dst, 128)) { \ + if (hex_is_aligned((void *) src0, 128)) { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aaa(dst, src0, src1, num_elems); \ + else OP_NAME##_aau(dst, src0, src1, num_elems); \ + } else { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aua(dst, src0, src1, num_elems); \ + else OP_NAME##_auu(dst, src0, src1, num_elems); \ + } \ + } else { \ + if (hex_is_aligned((void *) src0, 128)) { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uaa(dst, src0, src1, num_elems); \ + else OP_NAME##_uau(dst, src0, src1, num_elems); \ + } else { \ + if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uua(dst, src0, src1, num_elems); \ + else OP_NAME##_uuu(dst, src0, src1, num_elems); \ + } \ + } \ } -static inline void hvx_add_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src0 % 128 == 0); - hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_ADD); -} - -static inline void hvx_add_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((unsigned long) src0 % 128 == 0); - assert((unsigned long) src1 % 128 == 0); - hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_ADD); -} - -static inline void hvx_add_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_ADD); -} - -// SUB variants - -static inline void hvx_sub_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src0 % 128 == 0); - assert((unsigned long) src1 % 128 == 0); - hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_SUB); -} - -static inline void hvx_sub_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src0 % 128 == 0); - hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_SUB); -} - -static inline void hvx_sub_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((unsigned long) src0 % 128 == 0); - assert((unsigned long) src1 % 128 == 0); - hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_SUB); -} - -static inline void hvx_sub_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_SUB); -} - -// MUL variants - -static inline void hvx_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src0 % 128 == 0); - assert((unsigned long) src1 % 128 == 0); - hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MUL); -} - -static inline void hvx_mul_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((unsigned long) dst % 128 == 0); - assert((unsigned long) src0 % 128 == 0); - hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MUL); -} - -static inline void hvx_mul_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - assert((unsigned long) src0 % 128 == 0); - assert((unsigned long) src1 % 128 == 0); - hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_MUL); -} - -static inline void hvx_mul_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { - hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MUL); -} - -// Dispatchers - -static inline void hvx_add_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { - if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) { - if (hex_is_aligned((void *) src1, 128)) { - hvx_add_f32_aa(dst, src0, src1, num_elems); - } else { - hvx_add_f32_au(dst, src0, src1, num_elems); - } - } else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) { - hvx_add_f32_ua(dst, src0, src1, num_elems); - } else { - hvx_add_f32_uu(dst, src0, src1, num_elems); - } -} - -static inline void hvx_sub_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { - if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) { - if (hex_is_aligned((void *) src1, 128)) { - hvx_sub_f32_aa(dst, src0, src1, num_elems); - } else { - hvx_sub_f32_au(dst, src0, src1, num_elems); - } - } else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) { - hvx_sub_f32_ua(dst, src0, src1, num_elems); - } else { - hvx_sub_f32_uu(dst, src0, src1, num_elems); - } -} - -static inline void hvx_mul_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { - if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) { - if (hex_is_aligned((void *) src1, 128)) { - hvx_mul_f32_aa(dst, src0, src1, num_elems); - } else { - hvx_mul_f32_au(dst, src0, src1, num_elems); - } - } else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) { - hvx_mul_f32_ua(dst, src0, src1, num_elems); - } else { - hvx_mul_f32_uu(dst, src0, src1, num_elems); - } -} +HVX_BINARY_DISPATCHER(hvx_add_f32) +HVX_BINARY_DISPATCHER(hvx_sub_f32) +HVX_BINARY_DISPATCHER(hvx_mul_f32) // Mul-Mul Optimized - static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint8_t * restrict src2, const uint32_t num_elems) { assert((unsigned long) dst % 128 == 0); assert((unsigned long) src0 % 128 == 0); @@ -443,6 +392,68 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t * } } +// +// Square +// + +#define hvx_sqr_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const uint32_t elem_size = sizeof(float); \ + const uint32_t epv = 128 / elem_size; \ + const uint32_t nvec = n / epv; \ + const uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + vdst[i] = HVX_OP_MUL(vsrc[i], vsrc[i]); \ + } \ + if (nloe) { \ + HVX_Vector v = HVX_OP_MUL(vsrc[i], vsrc[i]); \ + vec_store((void *) &vdst[i], nloe * elem_size, v); \ + } \ + } while(0) + +static inline void hvx_sqr_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_sqr_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_sqr_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) src % 128 == 0); + hvx_sqr_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +static inline void hvx_sqr_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_sqr_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +static inline void hvx_sqr_f32(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) { + if (hex_is_aligned((void *) dst, 128)) { + if (hex_is_aligned((void *) src, 128)) { + hvx_sqr_f32_aa(dst, src, num_elems); + } else { + hvx_sqr_f32_au(dst, src, num_elems); + } + } else { + if (hex_is_aligned((void *) src, 128)) { + hvx_sqr_f32_ua(dst, src, num_elems); + } else { + hvx_sqr_f32_uu(dst, src, num_elems); + } + } +} + #undef HVX_OP_ADD #undef HVX_OP_SUB #undef HVX_OP_MUL @@ -453,5 +464,7 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t * #undef hvx_scalar_loop_body #undef HVX_OP_MIN_SCALAR #undef HVX_OP_CLAMP_SCALAR +#undef DEFINE_HVX_BINARY_OP_VARIANTS +#undef HVX_BINARY_DISPATCHER #endif // HVX_ARITH_H diff --git a/ggml/src/ggml-hexagon/htp/hvx-base.h b/ggml/src/ggml-hexagon/htp/hvx-base.h index ffa6e18e64..12a1b7f128 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-base.h +++ b/ggml/src/ggml-hexagon/htp/hvx-base.h @@ -66,6 +66,12 @@ static inline float hvx_vec_get_f32(HVX_Vector v) { return x; } +static inline int32_t hvx_vec_get_i32(HVX_Vector v) { + int32_t __attribute__((aligned(128))) x; + hvx_vec_store_a(&x, 4, v); + return x; +} + static inline HVX_Vector hvx_vec_abs_f16(HVX_Vector v) { // abs by clearing the fp16 sign bit HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff); diff --git a/ggml/src/ggml-hexagon/htp/hvx-copy.h b/ggml/src/ggml-hexagon/htp/hvx-copy.h index 6b617b7617..ae0dbed030 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-copy.h +++ b/ggml/src/ggml-hexagon/htp/hvx-copy.h @@ -136,8 +136,6 @@ static inline void hvx_copy_f32_uu(uint8_t * restrict dst, const uint8_t * restr dst_type * restrict vdst = (dst_type *) dst; \ src_type * restrict vsrc = (src_type *) src; \ \ - const HVX_Vector zero = Q6_V_vsplat_R(0); \ - \ const uint32_t elem_size = sizeof(__fp16); \ const uint32_t epv = 128 / elem_size; \ const uint32_t nvec = n / epv; \ diff --git a/ggml/src/ggml-hexagon/htp/hvx-div.h b/ggml/src/ggml-hexagon/htp/hvx-div.h new file mode 100644 index 0000000000..7dae012e0e --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/hvx-div.h @@ -0,0 +1,116 @@ +#ifndef HVX_DIV_H +#define HVX_DIV_H + +#include + +#include +#include +#include +#include +#include + +#include "hvx-base.h" +#include "hex-utils.h" +#include "hvx-inverse.h" +#include "hvx-arith.h" + +#if __HVX_ARCH__ < 79 +#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) +#else +#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) +#endif + +#define hvx_div_f32_loop_body(dst_type, src0_type, src1_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src0_type * restrict vsrc0 = (src0_type *) src0; \ + src1_type * restrict vsrc1 = (src1_type *) src1; \ + \ + const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \ + \ + const uint32_t nvec = n / VLEN_FP32; \ + const uint32_t nloe = n % VLEN_FP32; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \ + HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \ + vdst[i] = res; \ + } \ + if (nloe) { \ + HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \ + HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \ + vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, res); \ + } \ + } while(0) + +// 3-letter suffix variants +static inline void hvx_div_f32_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((uintptr_t) dst % 128 == 0); + assert((uintptr_t) src0 % 128 == 0); + assert((uintptr_t) src1 % 128 == 0); + hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_div_f32_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((uintptr_t) dst % 128 == 0); + assert((uintptr_t) src0 % 128 == 0); + hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a); +} + +static inline void hvx_div_f32_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((uintptr_t) dst % 128 == 0); + assert((uintptr_t) src1 % 128 == 0); + hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_div_f32_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((uintptr_t) dst % 128 == 0); + hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a); +} + +static inline void hvx_div_f32_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((uintptr_t) src0 % 128 == 0); + assert((uintptr_t) src1 % 128 == 0); + hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u); +} + +static inline void hvx_div_f32_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((uintptr_t) src0 % 128 == 0); + hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u); +} + +static inline void hvx_div_f32_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + assert((uintptr_t) src1 % 128 == 0); + hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +static inline void hvx_div_f32_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { + hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +static inline void hvx_div_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { + if (hex_is_aligned((void *) dst, 128)) { + if (hex_is_aligned((void *) src0, 128)) { + if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aaa(dst, src0, src1, num_elems); + else hvx_div_f32_aau(dst, src0, src1, num_elems); + } else { + if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aua(dst, src0, src1, num_elems); + else hvx_div_f32_auu(dst, src0, src1, num_elems); + } + } else { + if (hex_is_aligned((void *) src0, 128)) { + if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uaa(dst, src0, src1, num_elems); + else hvx_div_f32_uau(dst, src0, src1, num_elems); + } else { + if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uua(dst, src0, src1, num_elems); + else hvx_div_f32_uuu(dst, src0, src1, num_elems); + } + } +} + +#undef HVX_OP_MUL + +#endif // HVX_DIV_H diff --git a/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h b/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h index 1b4aaff0c9..095193277e 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h +++ b/ggml/src/ggml-hexagon/htp/hvx-sigmoid.h @@ -91,6 +91,27 @@ static inline HVX_Vector hvx_vec_tanh_f32(HVX_Vector x) { } \ } while(0) +#define hvx_tanh_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const uint32_t epv = 128 / sizeof(float); \ + const uint32_t nvec = n / epv; \ + const uint32_t nloe = n % epv; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + vdst[i] = hvx_vec_tanh_f32(vsrc[i]); \ + } \ + if (nloe) { \ + HVX_Vector tmp = hvx_vec_tanh_f32(vsrc[i]); \ + vec_store((void *) &vdst[i], nloe * sizeof(float), tmp); \ + } \ + } while(0) + static inline void hvx_sigmoid_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { assert((unsigned long) dst % 128 == 0); assert((unsigned long) src % 128 == 0); @@ -111,4 +132,10 @@ static inline void hvx_sigmoid_f32_uu(uint8_t * restrict dst, const uint8_t * re hvx_sigmoid_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); } +static inline void hvx_tanh_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_tanh_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + #endif /* HVX_SIGMOID_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-sqrt.h b/ggml/src/ggml-hexagon/htp/hvx-sqrt.h index 28ee9f68d3..e31a1006d2 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-sqrt.h +++ b/ggml/src/ggml-hexagon/htp/hvx-sqrt.h @@ -12,11 +12,17 @@ #define RSQRT_ONE_HALF 0x3f000000 // 0.5 #define RSQRT_THREE_HALVES 0x3fc00000 // 1.5 +#if __HVX_ARCH__ < 79 +#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b)) +#else +#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b) +#endif + static inline HVX_Vector hvx_vec_rsqrt_f32(HVX_Vector in_vec) { //Algorithm : // x2 = input*0.5 // y = * (long *) &input - // y = 0x5f3759df - (y>>2) + // y = 0x5f3759df - (y>>1) // y = y*(threehalfs - x2*y*y) HVX_Vector rsqrtconst = Q6_V_vsplat_R(RSQRT_CONST); @@ -57,4 +63,64 @@ static inline HVX_Vector hvx_vec_rsqrt_f32(HVX_Vector in_vec) { return Q6_Vsf_equals_Vqf32(temp); } +// Compute sqrt(x) as x*inv_sqrt(x) +#define hvx_sqrt_f32_loop_body(dst_type, src_type, vec_store) \ + do { \ + dst_type * restrict vdst = (dst_type *) dst; \ + src_type * restrict vsrc = (src_type *) src; \ + \ + const uint32_t nvec = n / VLEN_FP32; \ + const uint32_t nloe = n % VLEN_FP32; \ + \ + uint32_t i = 0; \ + \ + _Pragma("unroll(4)") \ + for (; i < nvec; i++) { \ + HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]); \ + HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]); \ + vdst[i] = sqrt_res; \ + } \ + if (nloe) { \ + HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]); \ + HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]); \ + vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, sqrt_res); \ + } \ + } while(0) + +static inline void hvx_sqrt_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + assert((unsigned long) src % 128 == 0); + hvx_sqrt_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a); +} + +static inline void hvx_sqrt_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) dst % 128 == 0); + hvx_sqrt_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a); +} + +static inline void hvx_sqrt_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + assert((unsigned long) src % 128 == 0); + hvx_sqrt_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u); +} + +static inline void hvx_sqrt_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + hvx_sqrt_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u); +} + +static inline void hvx_sqrt_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int num_elems) { + if ((unsigned long) dst % 128 == 0) { + if ((unsigned long) src % 128 == 0) { + hvx_sqrt_f32_aa(dst, src, num_elems); + } else { + hvx_sqrt_f32_au(dst, src, num_elems); + } + } else { + if ((unsigned long) src % 128 == 0) { + hvx_sqrt_f32_ua(dst, src, num_elems); + } else { + hvx_sqrt_f32_uu(dst, src, num_elems); + } + } +} + #endif /* HVX_SQRT_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h index 7b79a5ea32..a518ad3733 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -12,6 +12,7 @@ #include "hvx-sigmoid.h" #include "hvx-sqrt.h" #include "hvx-arith.h" +#include "hvx-div.h" #include "hvx-base.h" #endif /* HVX_UTILS_H */ diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index e28a67a95d..92a1422896 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -189,7 +189,7 @@ static int vtcm_release_callback(unsigned int rctx, void * state) { // otherwise we'll release it once we're done with the current Op. if (ctx->vtcm_inuse) { - ctx->vtcm_needs_release = false; + ctx->vtcm_needs_release = true; return 0; } @@ -440,6 +440,45 @@ static void proc_matmul_req(struct htp_context * ctx, send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } +static void proc_argsort_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[1]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[1].fd; + rsp_bufs[0].ptr = bufs[1].ptr; + rsp_bufs[0].offset = bufs[1].offset; + rsp_bufs[0].size = bufs[1].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.dst.data = (uint32_t) bufs[1].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_argsort(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + static void proc_cpy_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { struct dspqueue_buffer rsp_bufs[1]; @@ -679,6 +718,45 @@ static void proc_unary_req(struct htp_context * ctx, struct htp_general_req * re send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } +static void proc_sum_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[1].fd; + rsp_bufs[0].ptr = bufs[1].ptr; + rsp_bufs[0].offset = bufs[1].offset; + rsp_bufs[0].size = bufs[1].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.dst.data = (uint32_t) bufs[1].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_sum_rows(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + static void proc_activations_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs, @@ -951,6 +1029,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { case HTP_OP_MUL: case HTP_OP_ADD: case HTP_OP_SUB: + case HTP_OP_DIV: if (n_bufs != 3) { FARF(ERROR, "Bad binary-req buffer list"); continue; @@ -968,6 +1047,25 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { proc_unary_req(ctx, &req, bufs); break; + case HTP_OP_SQR: + case HTP_OP_SQRT: + if (n_bufs != 2) { + FARF(ERROR, "Bad unary-req buffer list"); + continue; + } + + proc_unary_req(ctx, &req, bufs); + break; + + case HTP_OP_SUM_ROWS: + if (n_bufs != 2) { + FARF(ERROR, "Bad unary-req buffer list"); + continue; + } + + proc_sum_rows_req(ctx, &req, bufs); + break; + case HTP_OP_UNARY_SILU: case HTP_OP_UNARY_GELU: if (n_bufs != 2) { @@ -980,6 +1078,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { case HTP_OP_GLU_SWIGLU: case HTP_OP_GLU_SWIGLU_OAI: case HTP_OP_SOFTMAX: + case HTP_OP_GLU_GEGLU: if ((n_bufs != 2) && (n_bufs != 3)) { FARF(ERROR, "Bad act-req buffer list"); continue; @@ -1035,6 +1134,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { proc_cpy_req(ctx, &req, bufs); break; + case HTP_OP_ARGSORT: + if (n_bufs != 2) { + FARF(ERROR, "Bad argsort-req buffer list"); + continue; + } + proc_argsort_req(ctx, &req, bufs); + break; + default: FARF(ERROR, "Unknown Op %u", req.op); break; diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index d251eeed33..c360abe8da 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -23,10 +23,30 @@ #define MM_SPAD_SRC1_NROWS 16 #define MM_SPAD_DST_NROWS 2 -struct htp_matmul_type { +struct htp_matmul_context { const char * type; - void (*vec_dot)(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); - void (*vec_dot_rx2)(const int n, float * restrict s, const void * restrict vx, uint32_t vx_row_size, const void * restrict vy); + struct htp_ops_context * octx; + + void (*vec_dot_1x1)(const int n, float * restrict s0, + const void * restrict vx0, + const void * restrict vy0); + + void (*vec_dot_2x1)(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0); + + void (*vec_dot_2x2)(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1); + + // Precomputed values + uint32_t src0_nrows_per_thread; + uint32_t src1_nrows_per_thread; + + struct fastdiv_values mm_div_ne12_ne1; + struct fastdiv_values mm_div_ne1; + struct fastdiv_values mm_div_r2; + struct fastdiv_values mm_div_r3; }; // vdelta control to replicate first 4x fp32 values across lanes @@ -122,6 +142,7 @@ static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) { HVX_Vector v6_7 = vptr[3]; // ... const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector i8 = Q6_Vb_vsplat_R(8); HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 @@ -133,15 +154,14 @@ static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) { HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 // Convert uint4 to int4 (i.e. x - 8) - const HVX_Vector i8 = Q6_Vb_vsplat_R(8); - v0 = Q6_Vb_vsub_VbVb(v0, i8); - v1 = Q6_Vb_vsub_VbVb(v1, i8); - v2 = Q6_Vb_vsub_VbVb(v2, i8); - v3 = Q6_Vb_vsub_VbVb(v3, i8); - v4 = Q6_Vb_vsub_VbVb(v4, i8); - v5 = Q6_Vb_vsub_VbVb(v5, i8); - v6 = Q6_Vb_vsub_VbVb(v6, i8); - v7 = Q6_Vb_vsub_VbVb(v7, i8); + v0 = Q6_Vb_vsub_VbVb(v0, i8); + v1 = Q6_Vb_vsub_VbVb(v1, i8); + v2 = Q6_Vb_vsub_VbVb(v2, i8); + v3 = Q6_Vb_vsub_VbVb(v3, i8); + v4 = Q6_Vb_vsub_VbVb(v4, i8); + v5 = Q6_Vb_vsub_VbVb(v5, i8); + v6 = Q6_Vb_vsub_VbVb(v6, i8); + v7 = Q6_Vb_vsub_VbVb(v7, i8); HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; return r; @@ -156,6 +176,7 @@ static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr) HVX_Vector v6_7 = vptr[3]; // ... const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut; HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 @@ -166,15 +187,14 @@ static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr) HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 - HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut; - v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0); - v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0); - v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0); - v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0); - v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0); - v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0); - v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0); - v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0); + v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0); + v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0); + v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0); + v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0); + v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0); + v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0); + v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0); + v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0); HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; return r; @@ -196,46 +216,6 @@ static inline HVX_Vector_x8 hvx_vec_load_q8x4x8(const uint8_t * restrict ptr) { return r; } -static inline HVX_Vector_x4 hvx_vec_load_x4_f16(const uint8_t * restrict ptr) { - const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; - - HVX_Vector v0 = vptr[0]; // first 64 vals - HVX_Vector v1 = vptr[1]; // second 64 vals - HVX_Vector v2 = vptr[2]; // third 64 vals - HVX_Vector v3 = vptr[3]; // forth 64 vals - - HVX_Vector_x4 r = { v0, v1, v2, v3 }; - return r; -} - -static inline HVX_Vector_x4 hvx_vec_load_x4_f32_as_f16(const uint8_t * restrict ptr) { - const HVX_VectorPair * restrict vptr = (const HVX_VectorPair *) ptr; - - HVX_VectorPair v0 = vptr[0]; // first 64 vals - HVX_VectorPair v1 = vptr[1]; // second 64 vals - HVX_VectorPair v2 = vptr[2]; // third 64 vals - HVX_VectorPair v3 = vptr[3]; // forth 64 vals - - HVX_Vector vq0_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v0), Q6_V_vzero()); - HVX_Vector vq0_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v0), Q6_V_vzero()); - HVX_Vector vq1_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v1), Q6_V_vzero()); - HVX_Vector vq1_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v1), Q6_V_vzero()); - HVX_Vector vq2_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v2), Q6_V_vzero()); - HVX_Vector vq2_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v2), Q6_V_vzero()); - HVX_Vector vq3_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v3), Q6_V_vzero()); - HVX_Vector vq3_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v3), Q6_V_vzero()); - - HVX_Vector vh0 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq0_hi, vq0_lo)); - HVX_Vector vh1 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq1_hi, vq1_lo)); - HVX_Vector vh2 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq2_hi, vq2_lo)); - HVX_Vector vh3 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq3_hi, vq3_lo)); - - // vcombine does a shuffle, use vdeal to undo - - HVX_Vector_x4 r = { Q6_Vh_vdeal_Vh(vh0), Q6_Vh_vdeal_Vh(vh1), Q6_Vh_vdeal_Vh(vh2), Q6_Vh_vdeal_Vh(vh3) }; - return r; -} - // Reduce multiply 1024 x 1024 int8 elements (32x q4/8 blocks in 8x HVX vectors). // Accumulate each block into a single int32 value. // Return a single HVX vector with 32x int32 accumulators. @@ -300,26 +280,26 @@ static inline HVX_Vector hvx_vec_rmpy_x8_nloe(HVX_Vector_x8 x, HVX_Vector_x8 y, return hvx_vec_rmpy_x8_n(x, y, 1024); } -static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx % 128 == 0); - assert((unsigned long) vy % 128 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales - const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); @@ -372,36 +352,34 @@ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void * r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - hvx_vec_store_u(&s[0], 4, r0_sum); + hvx_vec_store_u(s0, 4, r0_sum); } -static void vec_dot_q4x4x2_q8x4x2_rx2(const int n, - float * restrict s, - const void * restrict vx, - uint32_t vx_row_size, - const void * restrict vy) { +static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx % 128 == 0); - assert((unsigned long) vy % 128 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); @@ -468,13 +446,143 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n, } HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); - hvx_vec_store_u(&s[0], 8, rsum); + hvx_vec_store_u(s0, 8, rsum); } -static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + assert((unsigned long) vy1 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales + const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first + const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales + + // Row sums (sf) - 4 accumulators for 2×2 tile + HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0); + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elements + + uint32_t i = 0; + for (; i < nb; i++) { + // Load src1 columns (reused across both src0 rows) + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + + // Load src0 rows (reused across both src1 columns) + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size); + + // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); + + // Load scales + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + // Compute combined scales + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + // Apply scales and accumulate + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe)); + + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + // Zero out unused scales + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); + r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); + r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); + r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); + r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); + r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); + r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); + r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + hvx_vec_store_u(s0, 8, r0_r1_c0_sum); // row0,col0 row1,col0 + hvx_vec_store_u(s1, 8, r0_r1_c1_sum); // row0,col1 row1,col1 +} + +static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx % 128 == 0); - assert((unsigned long) vy % 128 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_Q4_0x4x2 * 4; @@ -486,11 +594,11 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * const uint32_t y_qblk_size = qk; // int8 const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales - const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); @@ -543,36 +651,34 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - hvx_vec_store_u(&s[0], 4, r0_sum); + hvx_vec_store_u(s0, 4, r0_sum); } -static void vec_dot_q8x4x2_q8x4x2_rx2(const int n, - float * restrict s, - const void * restrict vx, - uint32_t vx_row_size, - const void * restrict vy) { +static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx % 128 == 0); - assert((unsigned long) vy % 128 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk; // int8 - const uint32_t x_qrow_size = n; // int8 (not padded) + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk; // int8 + const uint32_t x_qrow_size = n; // int8 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (qf32) HVX_Vector r0_sum = Q6_V_vsplat_R(0); @@ -639,16 +745,143 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n, } HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); - hvx_vec_store_u(&s[0], 8, rsum); + hvx_vec_store_u(s0, 8, rsum); } -static void vec_dot_mxfp4x4x2_q8x4x2(const int n, - float * restrict s, - const void * restrict vx, - const void * restrict vy) { +static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + assert((unsigned long) vy1 % 128 == 0); + + const uint32_t qk = QK_Q8_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk; // int8 + const uint32_t x_qrow_size = n; // int8 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales + const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first + const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales + + // Row sums (sf) - 4 accumulators for 2×2 tile + HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0); + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elements + + uint32_t i = 0; + for (; i < nb; i++) { + // Load src1 columns (reused across both src0 rows) + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + + // Load src0 rows (reused across both src1 columns) + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size); + + // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); + + // Load scales + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + // Compute combined scales + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + // Apply scales and accumulate + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe)); + + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + // Zero out unused scales + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); + r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); + r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); + r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); + r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); + r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); + r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); + r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0 + hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 +} + +static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx % 128 == 0); - assert((unsigned long) vy % 128 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_MXFP4x4x2 * 4; @@ -660,11 +893,11 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n, const uint32_t y_qblk_size = qk; // int8 const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales - const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); @@ -747,36 +980,34 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n, r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - hvx_vec_store_u(&s[0], 4, r0_sum); + hvx_vec_store_u(s0, 4, r0_sum); } -static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n, - float * restrict s, - const void * restrict vx, - uint32_t vx_row_size, - const void * restrict vy) { +static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx % 128 == 0); - assert((unsigned long) vy % 128 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_MXFP4x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 - const uint32_t x_qblk_size = qk / 2; // fp4 - const uint32_t x_qrow_size = n / 2; // fp4 (not padded) + const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 + const uint32_t x_qblk_size = qk / 2; // fp4 + const uint32_t x_qrow_size = n / 2; // fp4 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + const uint8_t * restrict y_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); @@ -879,10 +1110,180 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n, } HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); - hvx_vec_store_u(&s[0], 8, rsum); + hvx_vec_store_u(s0, 8, rsum); } -static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + assert((unsigned long) vy1 % 128 == 0); + + const uint32_t qk = QK_MXFP4x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 + const uint32_t x_qblk_size = qk / 2; // fp4 + const uint32_t x_qrow_size = n / 2; // fp4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales + const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first + const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales + + // Row sums (sf) - 4 accumulators for 2×2 tile + HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0); + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elements + + uint32_t i = 0; + for (; i < nb; i++) { + // Load src1 columns (reused across both src0 rows) + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + + // Load src0 rows (reused across both src1 columns) + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size); + + // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); + + // Load scales + HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size); + HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size); + HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + + // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving + HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 + vy0_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half)); + vy0_d = Q6_Vsf_equals_Vqf32(vy0_d); + vy1_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half)); + vy1_d = Q6_Vsf_equals_Vqf32(vy1_d); + + // Convert rX_d scales from e8m0 to fp32 + // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... + // Left shift with zero fill to create FP32 + // FIXME: might need to handle zero as a special case (see ggml-cpu code) + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + r0_d = Q6_V_vdelta_VV(r0_d, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + r1_d = Q6_V_vdelta_VV(r1_d, expand); + r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); + r1_d = Q6_Vw_vasl_VwR(r1_d, 23); + + // Compute combined scales + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d)); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d)); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d)); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d)); + + // Apply scales and accumulate + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe)); + + HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size); + HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size); + HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + + // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving + HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 + vy0_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half)); + vy0_d = Q6_Vsf_equals_Vqf32(vy0_d); + vy1_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half)); + vy1_d = Q6_Vsf_equals_Vqf32(vy1_d); + + // Convert rX_d scales from e8m0 to fp32 + // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... + // Left shift with zero fill to create FP32 + // FIXME: might need to handle zero as a special case (see ggml-cpu code) + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + r0_d = Q6_V_vdelta_VV(r0_d, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + r1_d = Q6_V_vdelta_VV(r1_d, expand); + r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); + r1_d = Q6_Vw_vasl_VwR(r1_d, 23); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d)); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d)); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d)); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d)); + + // Zero out unused scales + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); + r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); + r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); + r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); + r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); + r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); + r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); + r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0 + hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 +} + +static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { const HVX_Vector * restrict x = (const HVX_Vector *) vx; const HVX_Vector * restrict y = (const HVX_Vector *) vy; @@ -912,14 +1313,12 @@ static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * res hvx_vec_store_u(&s[0], 4, rsum); } -static void vec_dot_f16_f16_aa_rx2(const int n, - float * restrict s, - const void * restrict vx, - uint32_t vx_row_size, - const void * restrict vy) { - const HVX_Vector * restrict x0 = (const HVX_Vector *) vx; - const HVX_Vector * restrict x1 = (const HVX_Vector *) ((const uint8_t *) vx + vx_row_size); - const HVX_Vector * restrict y = (const HVX_Vector *) vy; +static void vec_dot_f16_f16_aa_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { + const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; + const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1; + const HVX_Vector * restrict y = (const HVX_Vector *) vy0; uint32_t nvec = n / VLEN_FP16; uint32_t nloe = n % VLEN_FP16; @@ -953,10 +1352,86 @@ static void vec_dot_f16_f16_aa_rx2(const int n, } HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(Q6_Vsf_equals_Vqf32(rsum0), Q6_Vsf_equals_Vqf32(rsum1)); - hvx_vec_store_u(&s[0], 8, rsum); + hvx_vec_store_u(s0, 8, rsum); } -static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +static void vec_dot_f16_f16_aa_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; + const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1; + const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0; + const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1; + + uint32_t nvec = n / VLEN_FP16; + uint32_t nloe = n % VLEN_FP16; + + // Row sums (sf) - 4 accumulators for 2×2 tile + HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + HVX_Vector r0_hf = x0[i]; + HVX_Vector r1_hf = x1[i]; + HVX_Vector c0_hf = y0[i]; + HVX_Vector c1_hf = y1[i]; + + // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 + HVX_VectorPair r0_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c0_hf); + HVX_VectorPair r0_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c1_hf); + HVX_VectorPair r1_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c0_hf); + HVX_VectorPair r1_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c1_hf); + + HVX_Vector r0_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c0_qf_p), Q6_V_hi_W(r0_c0_qf_p)); + HVX_Vector r0_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c1_qf_p), Q6_V_hi_W(r0_c1_qf_p)); + HVX_Vector r1_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c0_qf_p), Q6_V_hi_W(r1_c0_qf_p)); + HVX_Vector r1_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c1_qf_p), Q6_V_hi_W(r1_c1_qf_p)); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_qf, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_qf, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_qf, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_qf, r1_c1_sum)); + } + + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + + HVX_Vector r0_hf = Q6_V_vand_QV(bmask, x0[i]); + HVX_Vector r1_hf = Q6_V_vand_QV(bmask, x1[i]); + HVX_Vector c0_hf = Q6_V_vand_QV(bmask, y0[i]); + HVX_Vector c1_hf = Q6_V_vand_QV(bmask, y1[i]); + + HVX_VectorPair r0_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c0_hf); + HVX_VectorPair r0_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c1_hf); + HVX_VectorPair r1_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c0_hf); + HVX_VectorPair r1_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c1_hf); + + HVX_Vector r0_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c0_qf_p), Q6_V_hi_W(r0_c0_qf_p)); + HVX_Vector r0_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c1_qf_p), Q6_V_hi_W(r0_c1_qf_p)); + HVX_Vector r1_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c0_qf_p), Q6_V_hi_W(r1_c0_qf_p)); + HVX_Vector r1_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c1_qf_p), Q6_V_hi_W(r1_c1_qf_p)); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_qf, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_qf, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_qf, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_qf, r1_c1_sum)); + + } + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0 + hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 +} + +static void vec_dot_f16_f16_uu_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { const HVX_UVector * restrict x = (const HVX_UVector *) vx; const HVX_UVector * restrict y = (const HVX_UVector *) vy; @@ -986,7 +1461,7 @@ static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * res hvx_vec_store_u(&s[0], 4, rsum); } -static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * restrict x, const void * restrict y) { +static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void * restrict x, const void * restrict y) { const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x; const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y; @@ -1083,14 +1558,16 @@ static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * res const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; -#define htp_matmul_preamble \ - htp_matmul_tensors_preamble; \ - dma_queue *dma_queue = octx->ctx->dma[ith]; \ - uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; +#define htp_matmul_preamble \ + struct htp_matmul_context * mmctx = data; \ + struct htp_ops_context * octx = mmctx->octx; \ + htp_matmul_tensors_preamble; \ + dma_queue *dma_queue = octx->ctx->dma[ith]; \ + uint32_t src0_nrows_per_thread = mmctx->src0_nrows_per_thread; // *** matmul with support for 4d tensors and full broadcasting -static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { +static void matmul_4d(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; uint64_t t1, t2; @@ -1136,13 +1613,13 @@ static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) { - const uint32_t i13 = fastdiv(ir1, &octx->mm_div_ne12_ne1); - const uint32_t i12 = fastdiv(ir1 - i13 * ne12 * ne1, &octx->mm_div_ne1); + const uint32_t i13 = fastdiv(ir1, &mmctx->mm_div_ne12_ne1); + const uint32_t i12 = fastdiv(ir1 - i13 * ne12 * ne1, &mmctx->mm_div_ne1); const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1); // broadcast src0 into src1 - const uint32_t i03 = fastdiv(i13, &octx->mm_div_r3); - const uint32_t i02 = fastdiv(i12, &octx->mm_div_r2); + const uint32_t i03 = fastdiv(i13, &mmctx->mm_div_r3); + const uint32_t i02 = fastdiv(i12, &mmctx->mm_div_r2); const uint32_t i1 = i11; const uint32_t i2 = i12; @@ -1155,7 +1632,7 @@ static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end); for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) { const uint8_t * restrict src0_row = src0_base + ir0 * nb01; - mt->vec_dot(ne00, &dst_col[ir0], src0_row, src1_col); + mmctx->vec_dot_1x1(ne00, &dst_col[ir0], src0_row, src1_col); } } } @@ -1170,7 +1647,7 @@ static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx } // src1 tensor is already in VTCM spad -static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { +static void matmul_2d(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows @@ -1195,7 +1672,7 @@ static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx // Per-thread VTCM scratchpads for all tensors // Note that the entire src1 tensor is already in VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith; + uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith; uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith; uint8_t * restrict src1_data = src1_spad->data; @@ -1219,11 +1696,21 @@ static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - #pragma unroll(2) - for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) { + // Process src1 columns in pairs (2×2 tiling) + uint32_t ir1 = 0; + for (; ir1 + 1 < src1_nrows; ir1 += 2) { + const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride); + const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride); + float * restrict dst_row0 = (float *) (dst->data + ((ir1+0) * dst_row_size)); + float * restrict dst_row1 = (float *) (dst->data + ((ir1+1) * dst_row_size)); + mmctx->vec_dot_2x2(ne00, &dst_row0[ir0], &dst_row1[ir0], ss0, ss0 + src0_stride, src1_col0, src1_col1); + } + + // Handle remaining src1 rows (fallback to 2×1) + for (; ir1 < src1_nrows; ++ir1) { const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size)); - mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_stride, src1_col); + mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_stride, src1_col); } // Prefetch next (n + spad_nrows) row @@ -1247,20 +1734,20 @@ static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) { const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size)); - mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col); + mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col); } } t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "matmul-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mt->type, ith, nth, + FARF(HIGH, "matmul-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } // q8x4x2 src1 tensor is already in VTCM spad -static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { +static void matvec_2d(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; const uint32_t src0_nrows = ne01; @@ -1311,7 +1798,7 @@ static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx // Process src0 rows for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - mt->vec_dot_rx2(ne00, &tmp[ir0 - src0_start_row], ss0, src0_stride, src1_col); + mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col); // Prefetch next (n + spad_nrows) row const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS); @@ -1329,14 +1816,14 @@ static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), src0_stride, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - mt->vec_dot(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); + mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); } hvx_copy_f32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row); t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "matvec-%s %u/%u: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mt->type, ith, nth, + FARF(HIGH, "matvec-%s %u/%u: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); @@ -1350,7 +1837,7 @@ struct mmid_row_mapping { }; // src1 tensor is already in VTCM spad -static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { +static void matmul_id(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; struct htp_tensor * restrict ids = &octx->src2; @@ -1423,11 +1910,10 @@ static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx const int rm2 = row_mapping.i2; // token idx const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx - const uint8_t * restrict src1_col = - (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size); + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size); float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0)); - mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col); + mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col); } // Prefetch next (n + spad_nrows) row @@ -1453,25 +1939,24 @@ static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx const int rm2 = row_mapping.i2; // token idx const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx - const uint8_t * restrict src1_col = - (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size); + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size); float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0)); - mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col); + mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col); } } } t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "matmul-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mt->type, + FARF(HIGH, "matmul-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } // src1 tensor is already in VTCM spad -static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { +static void matvec_id(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; struct htp_tensor * restrict ids = &octx->src2; @@ -1531,7 +2016,7 @@ static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx // Process src0 rows for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col); + mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col); // Prefetch next (n + spad_nrows) row const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS); @@ -1549,13 +2034,13 @@ static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), src0_row_size_padded, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col); + mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col); } } t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "matvec-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mt->type, + FARF(HIGH, "matvec-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); @@ -1754,12 +2239,14 @@ static void quantize_row_f32_q8x4x2(float * restrict x, uint8_t * restrict y, ui hvx_copy_f16_ua(y_d, t_d, nb * 8); } -static void quantize_f32_q8x4x2(const struct htp_tensor * src, - uint8_t * restrict dst, - struct htp_spad * spad, - uint32_t nth, - uint32_t ith, - uint32_t nrows_per_thread) { +static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data) { + struct htp_matmul_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * src = &octx->src1; + uint8_t * restrict dst = octx->src1_spad.data; + struct htp_spad * spad = &octx->src0_spad; + uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; uint64_t t1 = HAP_perf_get_qtimer_count(); @@ -1799,8 +2286,14 @@ static void quantize_f32_q8x4x2(const struct htp_tensor * src, ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void quantize_f32_f16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith, - uint32_t nrows_per_thread, uint32_t dst_stride) { +static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) { + struct htp_matmul_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * src = &octx->src1; + uint8_t * restrict dst = octx->src1_spad.data; + uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; + uint32_t dst_stride = octx->src1_spad.stride; uint64_t t1 = HAP_perf_get_qtimer_count(); @@ -1835,8 +2328,14 @@ static void quantize_f32_f16(const struct htp_tensor * src, uint8_t * restrict d } // TODO just a plain copy that should be done via the DMA during the Op setup -static void quantize_f16_f16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith, - uint32_t nrows_per_thread, uint32_t dst_stride) { +static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) { + struct htp_matmul_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * src = &octx->src1; + uint8_t * restrict dst = octx->src1_spad.data; + uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; + uint32_t dst_stride = octx->src1_spad.stride; uint64_t t1 = HAP_perf_get_qtimer_count(); @@ -1870,213 +2369,76 @@ static void quantize_f16_f16(const struct htp_tensor * src, uint8_t * restrict d ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void htp_quantize_f32_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - quantize_f32_q8x4x2(&octx->src1, octx->src1_spad.data, &octx->src0_spad, n, i, octx->src1_nrows_per_thread); -} - -static void htp_quantize_f32_f16(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - quantize_f32_f16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride); -} - -static void htp_quantize_f16_f16(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - quantize_f16_f16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride); -} - -// ** matmul/matvec callbacks for worker_pool - -static void htp_matvec_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - - matvec_2d(&mt, octx, n, i); -} - -static void htp_matmul_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - - matmul_2d(&mt, octx, n, i); -} - -static void htp_matvec_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q8x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q8x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - - matvec_2d(&mt, octx, n, i); -} - -static void htp_matmul_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q8x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q8x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - - matmul_2d(&mt, octx, n, i); -} - -static void htp_matvec_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "mxfp4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - - matvec_2d(&mt, octx, n, i); -} - -static void htp_matmul_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "mxfp4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - - matmul_2d(&mt, octx, n, i); -} - -static void htp_matvec_2d_f16_f16(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "f16-f16"; - mt.vec_dot = vec_dot_f16_f16_aa; - mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2; - - matvec_2d(&mt, octx, n, i); -} - -static void htp_matmul_2d_f16_f16(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "f16-f16"; - mt.vec_dot = vec_dot_f16_f16_aa; - mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2; - - matmul_2d(&mt, octx, n, i); -} - -static void htp_matmul_4d_f16_f32(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "f16-f32"; - mt.vec_dot = vec_dot_f16_f32_uu; - - matmul_4d(&mt, octx, n, i); -} - -static void htp_matmul_4d_f16_f16(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "f16-f16"; - mt.vec_dot = vec_dot_f16_f16_uu; - - matmul_4d(&mt, octx, n, i); -} - -// ** matmul-id callbacks for worker_pool - -static void htp_matvec_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - - matvec_id(&mt, octx, n, i); -} - -static void htp_matmul_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - - matmul_id(&mt, octx, n, i); -} - -static void htp_matvec_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q8x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q8x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - - matvec_id(&mt, octx, n, i); -} - -static void htp_matmul_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q8x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q8x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - - matmul_id(&mt, octx, n, i); -} - -static void htp_matvec_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "mxfp4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - - matvec_id(&mt, octx, n, i); -} - -static void htp_matmul_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "mxfp4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - - matmul_id(&mt, octx, n, i); -} - -// ** main matmul entry point static inline bool htp_is_permuted(const struct htp_tensor * t) { return t->nb[0] > t->nb[1] || t->nb[1] > t->nb[2] || t->nb[2] > t->nb[3]; } +static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_type type) { + switch (type) { + case HTP_TYPE_Q4_0: + mmctx->type = "q4x4x2-f32"; + mmctx->vec_dot_1x1 = vec_dot_q4x4x2_q8x4x2_1x1; + mmctx->vec_dot_2x1 = vec_dot_q4x4x2_q8x4x2_2x1; + mmctx->vec_dot_2x2 = vec_dot_q4x4x2_q8x4x2_2x2; + return 0; + case HTP_TYPE_Q8_0: + mmctx->type = "q8x4x2-f32"; + mmctx->vec_dot_1x1 = vec_dot_q8x4x2_q8x4x2_1x1; + mmctx->vec_dot_2x1 = vec_dot_q8x4x2_q8x4x2_2x1; + mmctx->vec_dot_2x2 = vec_dot_q8x4x2_q8x4x2_2x2; + return 0; + case HTP_TYPE_MXFP4: + mmctx->type = "mxfp4x4x2-f32"; + mmctx->vec_dot_1x1 = vec_dot_mxfp4x4x2_q8x4x2_1x1; + mmctx->vec_dot_2x1 = vec_dot_mxfp4x4x2_q8x4x2_2x1; + mmctx->vec_dot_2x2 = vec_dot_mxfp4x4x2_q8x4x2_2x2; + return 0; + default: + return -1; + } +} + +static void htp_mminit_spad(struct htp_ops_context * octx, + size_t dst_row_size, + size_t src0_row_size_padded, + size_t src1_row_size, + uint32_t src1_nrows, + size_t src2_spad_size_per_thread) { + octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); + + if (src2_spad_size_per_thread > 0) { + octx->src2_spad.size_per_thread = src2_spad_size_per_thread; + octx->src2_spad.size = octx->src2_spad.size_per_thread; + } + + // src0 spad is also used in dynamic quantizer to store padded src1 rows + size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); + if (octx->src0_spad.size_per_thread < src1_row_size_padded) { + octx->src0_spad.size_per_thread = src1_row_size_padded; + } + + octx->src1_spad.size = octx->src1_spad.size_per_thread; + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; +} + int op_matmul(struct htp_ops_context * octx) { htp_matmul_tensors_preamble; - const char * op_type; + struct htp_matmul_context mmctx_struct = {0}; + struct htp_matmul_context * mmctx = &mmctx_struct; + mmctx->octx = octx; const uint32_t src0_nrows = ne01 * ne02 * ne03; const uint32_t src1_nrows = ne11 * ne12 * ne13; + // Compute src0_nrows_per_thread + mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; + mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even + const size_t src0_row_size = nb01; const size_t dst_row_size = nb1; size_t src1_row_size = nb11; @@ -2085,181 +2447,95 @@ int op_matmul(struct htp_ops_context * octx) { size_t src1_row_size_padded; worker_callback_t quant_job_func; - worker_callback_t matmul_job_func; + worker_callback_t matmul_job_func = src1_nrows > 1 ? matmul_2d : matvec_2d; bool need_quant = !(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE); - switch (src0->type) { - case HTP_TYPE_Q4_0: - op_type = "q4x4x2-f32"; - quant_job_func = htp_quantize_f32_q8x4x2; - if (src1_nrows > 1) { - matmul_job_func = htp_matmul_2d_q4x4x2_q8x4x2; - } else { - matmul_job_func = htp_matvec_2d_q4x4x2_q8x4x2; - } + if (src0->type == HTP_TYPE_F16) { + // Try optimized f16-f16 path first (src1 in VTCM) + const size_t f16_src1_row_size = hex_round_up(ne10 * 2, 128); + const size_t f16_src1_spad_size = hex_round_up(f16_src1_row_size * src1_nrows, 256); + const size_t f16_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads; + const size_t f16_dst_spad_size = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads; - src1_row_size = q8x4x2_row_size(ne10); // row size post quantization + const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size; - // Entire src1 tensor is placed into the VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size + // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting). + // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul. + const bool is_batched = (ne02 > 1) || (ne03 > 1); + const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1); + + if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) { + // Optimized path + quant_job_func = (src1->type == HTP_TYPE_F32) ? quantize_f32_f16 : quantize_f16_f16; + mmctx->type = "f16-f16"; + mmctx->vec_dot_1x1 = vec_dot_f16_f16_aa_1x1; + mmctx->vec_dot_2x1 = vec_dot_f16_f16_aa_2x1; + mmctx->vec_dot_2x2 = vec_dot_f16_f16_aa_2x2; + + src1_row_size = f16_src1_row_size; // row size post quantization octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); - // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); - if (octx->src0_spad.size_per_thread < src1_row_size_padded) { - octx->src0_spad.size_per_thread = src1_row_size_padded; - } - octx->src1_spad.size = octx->src1_spad.size_per_thread; octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - break; - - case HTP_TYPE_Q8_0: - op_type = "q8x4x2-f32"; - quant_job_func = htp_quantize_f32_q8x4x2; - if (src1_nrows > 1) { - matmul_job_func = htp_matmul_2d_q8x4x2_q8x4x2; + } else { + // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required + quant_job_func = NULL; + if (src1->type == HTP_TYPE_F32) { + mmctx->type = "f16-f32"; + mmctx->vec_dot_1x1 = vec_dot_f16_f32_uu_1x1; + matmul_job_func = matmul_4d; } else { - matmul_job_func = htp_matvec_2d_q8x4x2_q8x4x2; + mmctx->type = "f16-f16"; + mmctx->vec_dot_1x1 = vec_dot_f16_f16_uu_1x1; + matmul_job_func = matmul_4d; } - src1_row_size = q8x4x2_row_size(ne10); // row size post quantization - - // Entire src1 tensor is placed into the VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size + src1_row_size = nb11; // original row size in DDR octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256); + octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256); - // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); - if (octx->src0_spad.size_per_thread < src1_row_size_padded) { - octx->src0_spad.size_per_thread = src1_row_size_padded; - } - - octx->src1_spad.size = octx->src1_spad.size_per_thread; octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - break; - case HTP_TYPE_MXFP4: - op_type = "mxfp4x4x2-f32"; - quant_job_func = htp_quantize_f32_q8x4x2; - if (src1_nrows > 1) { - matmul_job_func = htp_matmul_2d_mxfp4x4x2_q8x4x2; - } else { - matmul_job_func = htp_matvec_2d_mxfp4x4x2_q8x4x2; - } + // Init fastdiv for matmul_4d (supports broadcasting) + mmctx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]); + mmctx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]); + mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]); + mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]); - src1_row_size = q8x4x2_row_size(ne10); // row size post quantization - - // Entire src1 tensor is placed into the VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size - - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); - - // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); - if (octx->src0_spad.size_per_thread < src1_row_size_padded) { - octx->src0_spad.size_per_thread = src1_row_size_padded; - } - - octx->src1_spad.size = octx->src1_spad.size_per_thread; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - break; - - case HTP_TYPE_F16: - { - // Try optimized f16-f16 path first (src1 in VTCM) - const size_t f16_src1_row_size = hex_round_up(ne10 * 2, 128); - const size_t f16_src1_spad_size = hex_round_up(f16_src1_row_size * src1_nrows, 256); - const size_t f16_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads; - const size_t f16_dst_spad_size = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads; - - const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size; - - // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting). - // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul. - const bool is_batched = (ne02 > 1) || (ne03 > 1); - const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1); - - if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) { - // Optimized path - op_type = "f16-f16"; - quant_job_func = (src1->type == HTP_TYPE_F32) ? htp_quantize_f32_f16 : htp_quantize_f16_f16; - if (src1_nrows > 1) { - matmul_job_func = htp_matmul_2d_f16_f16; - } else { - matmul_job_func = htp_matvec_2d_f16_f16; - } - - src1_row_size = f16_src1_row_size; // row size post quantization - - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); - - octx->src1_spad.size = octx->src1_spad.size_per_thread; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - } else { - // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required - quant_job_func = NULL; - if (src1->type == HTP_TYPE_F32) { - op_type = "f16-f32"; - matmul_job_func = htp_matmul_4d_f16_f32; - } else { - op_type = "f16-f16"; - matmul_job_func = htp_matmul_4d_f16_f16; - } - - src1_row_size = nb11; // original row size in DDR - - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256); - octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256); - - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - - // Init fastdiv for matmul_4d (supports broadcasting) - octx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]); - octx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]); - octx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]); - octx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]); - - need_quant = false; - } - } - break; - - default: + need_quant = false; + } + } else { + if (htp_mminit_vec_dot(mmctx, src0->type) != 0) { return HTP_STATUS_NO_SUPPORT; + } + + quant_job_func = quantize_f32_q8x4x2; + src1_row_size = q8x4x2_row_size(ne10); + htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, 0); } // VTCM scratchpads for all tensors size_t spad_size = octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size; - FARF(HIGH, "matmul-%s : src0-spad-size %u src1-spad-size %u dst-spad-size %u (%zu)\n", op_type, + FARF(HIGH, "matmul-%s : src0-spad-size %u src1-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type, octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, spad_size); - FARF(HIGH, "matmul-%s : %ux%ux%ux%u * %ux%ux%ux%u-> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", op_type, src0->ne[0], + FARF(HIGH, "matmul-%s : %ux%ux%ux%u * %ux%ux%ux%u-> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->data, src1->data, dst->data); // Make sure the reserved vtcm size is sufficient if (octx->ctx->vtcm_size < spad_size) { - FARF(ERROR, "matmul-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, + FARF(ERROR, "matmul-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, octx->ctx->vtcm_size, spad_size); return HTP_STATUS_VTCM_TOO_SMALL; } @@ -2268,40 +2544,32 @@ int op_matmul(struct htp_ops_context * octx) { octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; - octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; - octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1); // round up to even - octx->src0_spad.stride = src0_row_size_padded; octx->src1_spad.stride = src1_row_size; if (need_quant) { - // Run quant jobs - const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); - octx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; - worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, octx, n_quant_jobs); + const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); + mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; + worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); } if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - // Run matmul jobs const uint32_t n_matmul_jobs = octx->n_threads; - worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, octx, n_matmul_jobs); + worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs); } return HTP_STATUS_OK; } -// ** main matmul-id entry point - int op_matmul_id(struct htp_ops_context * octx) { htp_matmul_tensors_preamble; + struct htp_matmul_context mmctx_struct = {0}; + struct htp_matmul_context * mmctx = &mmctx_struct; + mmctx->octx = octx; + struct htp_tensor * restrict ids = &octx->src2; - const char * op_type; - - worker_callback_t quant_job_func; - worker_callback_t matmul_id_job_func; - const size_t src0_row_size = nb01; const size_t dst_row_size = nb1; @@ -2310,6 +2578,13 @@ int op_matmul_id(struct htp_ops_context * octx) { const uint32_t src0_nrows = ne01; // per expert const uint32_t src1_nrows = ne11 * ne12 * ne13; + worker_callback_t quant_job_func; + worker_callback_t matmul_id_job_func = src1_nrows > 1 ? matmul_id : matvec_id; + + // Compute src0_nrows_per_thread + mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; + mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even + size_t src1_row_size; size_t src1_row_size_padded; @@ -2320,112 +2595,29 @@ int op_matmul_id(struct htp_ops_context * octx) { size_t matrix_row_counts_size = n_as * sizeof(uint32_t); size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping); - switch (src0->type) { - case HTP_TYPE_Q4_0: - op_type = "q4x2x2-f32"; - quant_job_func = htp_quantize_f32_q8x4x2; - src1_row_size = q8x4x2_row_size(ne10); // row size post quantization - if (src1_nrows > 1) { - matmul_id_job_func = htp_matmul_id_q4x4x2_q8x4x2; - } else { - matmul_id_job_func = htp_matvec_id_q4x4x2_q8x4x2; - } - - // Entire src1 tensor is placed into the VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); - octx->src2_spad.size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256); - - // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); - if (octx->src0_spad.size_per_thread < src1_row_size_padded) { - octx->src0_spad.size_per_thread = src1_row_size_padded; - } - - octx->src2_spad.size = octx->src2_spad.size_per_thread; - octx->src1_spad.size = octx->src1_spad.size_per_thread; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - break; - - case HTP_TYPE_Q8_0: - op_type = "q8x2x2-f32"; - quant_job_func = htp_quantize_f32_q8x4x2; - src1_row_size = q8x4x2_row_size(ne10); // row size post quantization - if (src1_nrows > 1) { - matmul_id_job_func = htp_matmul_id_q8x4x2_q8x4x2; - } else { - matmul_id_job_func = htp_matvec_id_q8x4x2_q8x4x2; - } - - // Entire src1 tensor is placed into the VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); - octx->src2_spad.size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256); - - // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); - if (octx->src0_spad.size_per_thread < src1_row_size_padded) { - octx->src0_spad.size_per_thread = src1_row_size_padded; - } - - octx->src2_spad.size = octx->src2_spad.size_per_thread; - octx->src1_spad.size = octx->src1_spad.size_per_thread; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - break; - - case HTP_TYPE_MXFP4: - op_type = "mxfp4x2x2-f32"; - quant_job_func = htp_quantize_f32_q8x4x2; - src1_row_size = q8x4x2_row_size(ne10); // row size post quantization - if (src1_nrows > 1) { - matmul_id_job_func = htp_matmul_id_mxfp4x4x2_q8x4x2; - } else { - matmul_id_job_func = htp_matvec_id_mxfp4x4x2_q8x4x2; - } - - // Entire src1 tensor is placed into the VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); - octx->src2_spad.size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256); - - // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); - if (octx->src0_spad.size_per_thread < src1_row_size_padded) { - octx->src0_spad.size_per_thread = src1_row_size_padded; - } - - octx->src2_spad.size = octx->src2_spad.size_per_thread; - octx->src1_spad.size = octx->src1_spad.size_per_thread; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - break; - - default: - return HTP_STATUS_NO_SUPPORT; + if (htp_mminit_vec_dot(mmctx, src0->type) != 0) { + return HTP_STATUS_NO_SUPPORT; } + quant_job_func = quantize_f32_q8x4x2; + src1_row_size = q8x4x2_row_size(ne10); + + const size_t src2_spad_size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256); + htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, src2_spad_size_per_thread); + size_t spad_size = octx->src2_spad.size + octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size; - FARF(HIGH, "matmul-id-%s : src0-spad-size %u src1-spad-size %u src2-spad-size %u dst-spad-size %u (%zu)\n", op_type, + FARF(HIGH, "matmul-id-%s : src0-spad-size %u src1-spad-size %u src2-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type, octx->src0_spad.size, octx->src1_spad.size, octx->src2_spad.size, octx->dst_spad.size, spad_size); - FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", op_type, + FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->data, src1->data, dst->data); // Make sure the reserved vtcm size is sufficient if (octx->ctx->vtcm_size < spad_size) { - FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, - octx->ctx->vtcm_size, spad_size); + FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, octx->ctx->vtcm_size, spad_size); return HTP_STATUS_VTCM_TOO_SMALL; } @@ -2434,8 +2626,8 @@ int op_matmul_id(struct htp_ops_context * octx) { octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->dst_spad.data = octx->src2_spad.data + octx->src2_spad.size; - octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; - octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1); // round up to even + octx->src0_spad.stride = src0_row_size_padded; + octx->src1_spad.stride = src1_row_size; if (src1_nrows > 1) { // initialize matrix_row_counts and map @@ -2447,8 +2639,7 @@ int op_matmul_id(struct htp_ops_context * octx) { // group rows by src0 matrix for (uint32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { // token idx for (uint32_t id = 0; id < n_ids; ++id) { // expert idx - const uint32_t i02 = - *(const uint32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]); + const uint32_t i02 = *(const uint32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]); assert(i02 >= 0 && i02 < n_as); @@ -2460,16 +2651,14 @@ int op_matmul_id(struct htp_ops_context * octx) { // Setup worker pool callbacks if (!(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE)) { - // Run quant jobs const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); - octx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; - worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, octx, n_quant_jobs); + mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; + worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); } if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - // Run matmul-id jobs const uint32_t n_matmul_jobs = octx->n_threads; - worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, octx, n_matmul_jobs); + worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs); } return HTP_STATUS_OK; diff --git a/ggml/src/ggml-hexagon/htp/sum-rows-ops.c b/ggml/src/ggml-hexagon/htp/sum-rows-ops.c new file mode 100644 index 0000000000..62e45da2b3 --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/sum-rows-ops.c @@ -0,0 +1,115 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#include +#include + +#include +#include + +#include "hex-dma.h" +#include "hvx-utils.h" + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-msg.h" +#include "htp-ops.h" + + +#define sum_rows_preamble \ + struct htp_tensor *src0 = &octx->src0;\ + struct htp_tensor *dst = &octx->dst; \ + \ + const uint32_t ne00 = src0->ne[0]; \ + const uint32_t ne01 = src0->ne[1]; \ + const uint32_t ne02 = src0->ne[2]; \ + const uint32_t ne03 = src0->ne[3]; \ + \ + const uint32_t nb00 = src0->nb[0]; \ + const uint32_t nb01 = src0->nb[1]; \ + const uint32_t nb02 = src0->nb[2]; \ + const uint32_t nb03 = src0->nb[3]; \ + \ + const uint32_t ne0 = dst->ne[0]; \ + const uint32_t ne1 = dst->ne[1]; \ + const uint32_t ne2 = dst->ne[2]; \ + const uint32_t ne3 = dst->ne[3]; \ + \ + const uint32_t nb0 = dst->nb[0]; \ + const uint32_t nb1 = dst->nb[1]; \ + const uint32_t nb2 = dst->nb[2]; \ + const uint32_t nb3 = dst->nb[3]; \ + +static int sum_rows_thread_f32(struct htp_ops_context * octx, const int nth, const int ith) { + sum_rows_preamble; + + const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; + const size_t src0_row_size = nb01; + const size_t dst_row_size = nb1; + + const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return HTP_STATUS_OK; + } + + int opt_path = 0; + if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) { + opt_path = 1; + } + + const uint8_t * restrict data_src = (const uint8_t *) src0->data; + uint8_t * restrict data_dst = (uint8_t *) dst->data; + + const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size)); + float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size)); + + for (uint32_t ir = 0; ir < src0_nrows_per_thread; ir++) { + const float * restrict src_local = src_th + (ir * ne00); + + if (ir + 1 < src0_nrows_per_thread) { + hex_l2fetch(src_local + ne00, src0_row_size, src0_row_size, 1); + } + + if (1 == opt_path) { + dst_th[ir] = hvx_reduce_sum_f32_a((const uint8_t *) src_local, ne00); + } else { + dst_th[ir] = hvx_reduce_sum_f32((const uint8_t *) src_local, ne00); + } + } + + return HTP_STATUS_OK; +} + +static void sum_rows_work_f32(unsigned int n, unsigned int i, void *data) { + sum_rows_thread_f32((struct htp_ops_context *) data, n, i); +} + +int op_sum_rows(struct htp_ops_context * octx) { + sum_rows_preamble; + + if (octx->src0.type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + const int n_threads = octx->n_threads; + const uint32_t src0_nrows = ne01 * ne02 * ne03; + + uint32_t n_jobs = MIN(n_threads, src0_nrows); + octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + + worker_pool_run_func(octx->ctx->worker_pool, sum_rows_work_f32, octx, n_jobs); + + return HTP_STATUS_OK; +} + diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index 1a27cb6e63..ce879bf037 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -132,6 +132,56 @@ static void rms_norm_htp_f32(const float * restrict src, } } +static void sqr_htp_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params, + int opt_path) { + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const float * restrict src_local = src + (ir * row_elems); + float * restrict dst_local = dst + (ir * row_elems); + + if (ir + 1 < num_rows) { + hex_l2fetch(src_local + row_elems, row_size, row_size, 1); + } + + if (1 == opt_path) { + hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); + } else { + hvx_sqr_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); + } + } +} + +static void sqrt_htp_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params, + int opt_path) { + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const float * restrict src_local = src + (ir * row_elems); + float * restrict dst_local = dst + (ir * row_elems); + + if (ir + 1 < num_rows) { + hex_l2fetch(src_local + row_elems, row_size, row_size, 1); + } + + if (1 == opt_path) { + hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); + } else { + hvx_sqrt_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); + } + } +} + static void unary_job_f32_per_thread(const struct htp_tensor * src, struct htp_tensor * dst, uint8_t * spad, @@ -181,6 +231,12 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src, case HTP_OP_SCALE: scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); break; + case HTP_OP_SQR: + sqr_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); + break; + case HTP_OP_SQRT: + sqrt_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); + break; default: break; @@ -218,6 +274,14 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { unary_op_func = unary_job_dispatcher_f32; op_type = "scale-f32"; break; + case HTP_OP_SQR: + unary_op_func = unary_job_dispatcher_f32; + op_type = "sqr-f32"; + break; + case HTP_OP_SQRT: + unary_op_func = unary_job_dispatcher_f32; + op_type = "sqrt-f32"; + break; default: FARF(ERROR, "Unsupported unary Op %u\n", octx->op); diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h index baadfe9a7b..e3714b38a6 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h @@ -98,6 +98,10 @@ static bool ggml_op_is_empty(enum ggml_op op) { } } +static inline bool ggml_impl_is_view(const struct ggml_tensor * t) { + return t->view_src != NULL; +} + static inline float ggml_compute_softplus_f32(float input) { return (input > 20.0f) ? input : logf(1 + expf(input)); } diff --git a/ggml/src/ggml-metal/ggml-metal-common.cpp b/ggml/src/ggml-metal/ggml-metal-common.cpp index 95627d3866..2eb9820bff 100644 --- a/ggml/src/ggml-metal/ggml-metal-common.cpp +++ b/ggml/src/ggml-metal/ggml-metal-common.cpp @@ -264,15 +264,26 @@ static std::vector ggml_metal_graph_optimize_reorder(const std::vector ggml_metal_graph_optimize_reorder(const std::vectorsrc[0])); - char base[256]; char name[256]; - const int64_t n = ggml_nelements(op); + int op_num = -1; - const char * op_str = "undefined"; switch (op->op) { - case GGML_OP_SCALE: op_str = "scale"; break; - case GGML_OP_FILL: op_str = "fill"; break; - case GGML_OP_CLAMP: op_str = "clamp"; break; - case GGML_OP_SQR: op_str = "sqr"; break; - case GGML_OP_SQRT: op_str = "sqrt"; break; - case GGML_OP_SIN: op_str = "sin"; break; - case GGML_OP_COS: op_str = "cos"; break; - case GGML_OP_LOG: op_str = "log"; break; - case GGML_OP_LEAKY_RELU: op_str = "leaky_relu"; break; + case GGML_OP_SCALE: op_num = OP_UNARY_NUM_SCALE; break; + case GGML_OP_FILL: op_num = OP_UNARY_NUM_FILL; break; + case GGML_OP_CLAMP: op_num = OP_UNARY_NUM_CLAMP; break; + case GGML_OP_SQR: op_num = OP_UNARY_NUM_SQR; break; + case GGML_OP_SQRT: op_num = OP_UNARY_NUM_SQRT; break; + case GGML_OP_SIN: op_num = OP_UNARY_NUM_SIN; break; + case GGML_OP_COS: op_num = OP_UNARY_NUM_COS; break; + case GGML_OP_LOG: op_num = OP_UNARY_NUM_LOG; break; + case GGML_OP_LEAKY_RELU: op_num = OP_UNARY_NUM_LEAKY_RELU; break; case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { - case GGML_UNARY_OP_TANH: op_str = "tanh"; break; - case GGML_UNARY_OP_RELU: op_str = "relu"; break; - case GGML_UNARY_OP_SIGMOID: op_str = "sigmoid"; break; - case GGML_UNARY_OP_GELU: op_str = "gelu"; break; - case GGML_UNARY_OP_GELU_ERF: op_str = "gelu_erf"; break; - case GGML_UNARY_OP_GELU_QUICK: op_str = "gelu_quick"; break; - case GGML_UNARY_OP_SILU: op_str = "silu"; break; - case GGML_UNARY_OP_ELU: op_str = "elu"; break; - case GGML_UNARY_OP_NEG: op_str = "neg"; break; - case GGML_UNARY_OP_ABS: op_str = "abs"; break; - case GGML_UNARY_OP_SGN: op_str = "sgn"; break; - case GGML_UNARY_OP_STEP: op_str = "step"; break; - case GGML_UNARY_OP_HARDSWISH: op_str = "hardswish"; break; - case GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break; - case GGML_UNARY_OP_EXP: op_str = "exp"; break; - case GGML_UNARY_OP_SOFTPLUS: op_str = "softplus"; break; - case GGML_UNARY_OP_EXPM1: op_str = "expm1"; break; + case GGML_UNARY_OP_TANH: op_num = OP_UNARY_NUM_TANH; break; + case GGML_UNARY_OP_RELU: op_num = OP_UNARY_NUM_RELU; break; + case GGML_UNARY_OP_SIGMOID: op_num = OP_UNARY_NUM_SIGMOID; break; + case GGML_UNARY_OP_GELU: op_num = OP_UNARY_NUM_GELU; break; + case GGML_UNARY_OP_GELU_ERF: op_num = OP_UNARY_NUM_GELU_ERF; break; + case GGML_UNARY_OP_GELU_QUICK: op_num = OP_UNARY_NUM_GELU_QUICK; break; + case GGML_UNARY_OP_SILU: op_num = OP_UNARY_NUM_SILU; break; + case GGML_UNARY_OP_ELU: op_num = OP_UNARY_NUM_ELU; break; + case GGML_UNARY_OP_NEG: op_num = OP_UNARY_NUM_NEG; break; + case GGML_UNARY_OP_ABS: op_num = OP_UNARY_NUM_ABS; break; + case GGML_UNARY_OP_SGN: op_num = OP_UNARY_NUM_SGN; break; + case GGML_UNARY_OP_STEP: op_num = OP_UNARY_NUM_STEP; break; + case GGML_UNARY_OP_HARDSWISH: op_num = OP_UNARY_NUM_HARDSWISH; break; + case GGML_UNARY_OP_HARDSIGMOID: op_num = OP_UNARY_NUM_HARDSIGMOID; break; + case GGML_UNARY_OP_EXP: op_num = OP_UNARY_NUM_EXP; break; + case GGML_UNARY_OP_SOFTPLUS: op_num = OP_UNARY_NUM_SOFTPLUS; break; + case GGML_UNARY_OP_EXPM1: op_num = OP_UNARY_NUM_EXPM1; break; default: GGML_ABORT("fatal error"); } break; default: GGML_ABORT("fatal error"); }; - const char * suffix = ""; - if (n % 4 == 0) { - suffix = "_4"; - } + const char * t0_str = ggml_type_name(op->src[0]->type); + const char * t_str = ggml_type_name(op->type); - snprintf(base, 256, "kernel_%s_%s%s", op_str, ggml_type_name(op->src[0]->type), suffix); - snprintf(name, 256, "%s", base); + const bool is_c4 = op->src[0]->ne[0] % 4 == 0; + const bool is_cnt = ggml_is_contiguous(op->src[0]) && ggml_nelements(op) < 32768; + + snprintf(base, 256, "kernel_unary_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : ""); + snprintf(name, 256, "%s_op=%d_cnt=%d", base, op_num, is_cnt); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, op_num, FC_UNARY + 0); + ggml_metal_cv_set_bool (cv, is_cnt, FC_UNARY + 1); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } + res.c4 = is_c4; + res.cnt = is_cnt; + return res; } @@ -320,31 +328,46 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_l } ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) { - GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); char base[256]; char name[256]; - const char * op_str = "undefined"; + int op_num = -1; + switch (op->op) { - case GGML_OP_SUM_ROWS: - op_str = "sum_rows"; break; - case GGML_OP_MEAN: - op_str = "mean"; break; + case GGML_OP_SUM_ROWS: op_num = OP_SUM_ROWS_NUM_SUM_ROWS; break; + case GGML_OP_MEAN: op_num = OP_SUM_ROWS_NUM_MEAN; break; default: GGML_ABORT("fatal error"); }; - snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type)); + const char * t0_str = ggml_type_name(op->src[0]->type); + const char * t_str = ggml_type_name(op->type); - snprintf(name, 256, "%s", base); + const bool is_c4 = op->src[0]->ne[0] % 4 == 0; + + snprintf(base, 256, "kernel_sum_rows_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : ""); + snprintf(name, 256, "%s_op=%d", base, op_num); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, op_num, FC_SUM_ROWS + 0); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } res.smem = 32*sizeof(float); + if (is_c4) { + res.smem *= 4; + } + + res.c4 = is_c4; + return res; } @@ -1392,34 +1415,78 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_v GGML_UNUSED(op); } -ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin( - ggml_metal_library_t lib, - ggml_op op, - int32_t n_fuse, - bool row) { +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) { char base[256]; char name[256]; - const char * op_str = "undefined"; - switch (op) { - case GGML_OP_ADD: op_str = "add"; break; - case GGML_OP_SUB: op_str = "sub"; break; - case GGML_OP_MUL: op_str = "mul"; break; - case GGML_OP_DIV: op_str = "div"; break; + int op_num = -1; + + switch (op->op) { + case GGML_OP_ADD: op_num = 0; break; + case GGML_OP_SUB: op_num = 1; break; + case GGML_OP_MUL: op_num = 2; break; + case GGML_OP_DIV: op_num = 3; break; default: GGML_ABORT("fatal error"); }; - if (row) { - snprintf(base, 256, "kernel_%s_row_c4_fuse_%d", op_str, n_fuse); - } else { - snprintf(base, 256, "kernel_%s_fuse_%d", op_str, n_fuse); - } + const char * t0_str = ggml_type_name(op->src[0]->type); + const char * t1_str = ggml_type_name(op->src[1]->type); + const char * t_str = ggml_type_name(op->type); - snprintf(name, 256, "%s", base); + const bool is_c4 = (op->src[0]->ne[0] % 4 == 0) && (op->src[1]->ne[0] % 4 == 0); + + const bool is_rb = ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && (ggml_nrows(op->src[1]) == 1) && ggml_nelements(op) < 65536; + + snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s%s", t0_str, t1_str, t_str, is_c4 ? "_4" : ""); + snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d", base, op_num, n_fuse, is_rb); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0); + ggml_metal_cv_set_int16(cv, n_fuse, FC_BIN + 1); + ggml_metal_cv_set_bool (cv, is_rb, FC_BIN + 2); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + } + + res.c4 = is_c4; + res.cnt = is_rb; + + return res; +} + +ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_metal_library_t lib, ggml_op op) { + char base[256]; + char name[256]; + + int op_num = -1; + + switch (op) { + case GGML_OP_ADD: op_num = 0; break; + case GGML_OP_SUB: op_num = 1; break; + case GGML_OP_MUL: op_num = 2; break; + case GGML_OP_DIV: op_num = 3; break; + default: GGML_ABORT("fatal error"); + }; + + snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s", "f32", "f32", "f32"); + snprintf(name, 256, "%s_op=%d_nf=%d", base, op_num, 1); + + ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); + if (!res.pipeline) { + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0); + ggml_metal_cv_set_int16(cv, 1, FC_BIN + 1); + ggml_metal_cv_set_bool (cv, false, FC_BIN + 2); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } return res; @@ -1428,13 +1495,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin( ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) { assert(op->op == GGML_OP_L2_NORM); - GGML_ASSERT(op->src[0]->ne[0] % 4 == 0); - GGML_ASSERT(ggml_is_contiguous_1(op->src[0])); - char base[256]; char name[256]; - snprintf(base, 256, "kernel_l2_norm_f32"); + const bool is_c4 = op->src[0]->ne[0] % 4 == 0; + + const char * t0_str = ggml_type_name(op->src[0]->type); + const char * t_str = ggml_type_name(op->type); + + snprintf(base, 256, "kernel_l2_norm_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : ""); snprintf(name, 256, "%s", base); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); @@ -1442,6 +1511,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); } + res.c4 = is_c4; res.smem = 32*sizeof(float); return res; diff --git a/ggml/src/ggml-metal/ggml-metal-device.h b/ggml/src/ggml-metal/ggml-metal-device.h index 84dcec3083..93d7f6a216 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ggml/src/ggml-metal/ggml-metal-device.h @@ -53,6 +53,9 @@ struct ggml_metal_pipeline_with_params { int nr1; size_t smem; + + bool c4; + bool cnt; }; int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline); @@ -134,7 +137,8 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op); -struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse ); +struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one (ggml_metal_library_t lib, enum ggml_op op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse); diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index c8e737d418..3db7f12629 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -346,10 +346,12 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_meta struct ggml_metal_pipeline_with_params res = { /*.pipeline =*/ nil, + /*.nsg =*/ 0, /*.nr0 =*/ 0, /*.nr1 =*/ 0, - /*.nsg =*/ 0, /*.smem =*/ 0, + /*.c4 =*/ false, + /*.cnt =*/ false, }; res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name); @@ -362,10 +364,12 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_meta struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) { struct ggml_metal_pipeline_with_params res = { /*.pipeline =*/ nil, + /*.nsg =*/ 0, /*.nr0 =*/ 0, /*.nr1 =*/ 0, - /*.nsg =*/ 0, /*.smem =*/ 0, + /*.c4 =*/ false, + /*.cnt =*/ false, }; [lib->lock lock]; @@ -1007,6 +1011,15 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te } switch (op->op) { + case GGML_OP_SCALE: + case GGML_OP_FILL: + case GGML_OP_CLAMP: + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_LOG: + return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16); case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { case GGML_UNARY_OP_TANH: @@ -1026,7 +1039,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_SOFTPLUS: case GGML_UNARY_OP_EXPM1: - return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16); default: return false; } @@ -1054,11 +1067,9 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_ADD_ID: - return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ACC: + return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_REPEAT: - case GGML_OP_SCALE: - case GGML_OP_FILL: case GGML_OP_CONV_TRANSPOSE_1D: return true; case GGML_OP_CONV_TRANSPOSE_2D: @@ -1066,14 +1077,6 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te (op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; - case GGML_OP_CLAMP: - return op->src[0]->type == GGML_TYPE_F32; - case GGML_OP_SQR: - case GGML_OP_SQRT: - case GGML_OP_SIN: - case GGML_OP_COS: - case GGML_OP_LOG: - return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_SUM: return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]); case GGML_OP_TRI: @@ -1083,9 +1086,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_MEAN: case GGML_OP_SOFT_MAX: case GGML_OP_GROUP_NORM: - return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]); case GGML_OP_L2_NORM: - return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); + return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]); case GGML_OP_COUNT_EQUAL: return has_simdgroup_reduction && op->src[0]->type == GGML_TYPE_I32 && @@ -1157,6 +1159,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: return has_simdgroup_reduction; + case GGML_OP_SET: case GGML_OP_CPY: case GGML_OP_DUP: case GGML_OP_CONT: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 7f73cb97bb..383e0d6e93 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -80,6 +80,9 @@ #define FC_SSM_CONV 900 #define FC_SOLVE_TRI 1000 #define FC_COUNT_EQUAL 1100 +#define FC_UNARY 1200 +#define FC_BIN 1300 +#define FC_SUM_ROWS 1400 // op-specific constants #define OP_FLASH_ATTN_EXT_NQPSG 8 @@ -88,6 +91,37 @@ #define OP_FLASH_ATTN_EXT_VEC_NQPSG 1 #define OP_FLASH_ATTN_EXT_VEC_NCPSG 32 +#define OP_UNARY_NUM_SCALE 10 +#define OP_UNARY_NUM_FILL 11 +#define OP_UNARY_NUM_CLAMP 12 +#define OP_UNARY_NUM_SQR 13 +#define OP_UNARY_NUM_SQRT 14 +#define OP_UNARY_NUM_SIN 15 +#define OP_UNARY_NUM_COS 16 +#define OP_UNARY_NUM_LOG 17 +#define OP_UNARY_NUM_LEAKY_RELU 18 + +#define OP_UNARY_NUM_TANH 100 +#define OP_UNARY_NUM_RELU 101 +#define OP_UNARY_NUM_SIGMOID 102 +#define OP_UNARY_NUM_GELU 103 +#define OP_UNARY_NUM_GELU_ERF 104 +#define OP_UNARY_NUM_GELU_QUICK 105 +#define OP_UNARY_NUM_SILU 106 +#define OP_UNARY_NUM_ELU 107 +#define OP_UNARY_NUM_NEG 108 +#define OP_UNARY_NUM_ABS 109 +#define OP_UNARY_NUM_SGN 110 +#define OP_UNARY_NUM_STEP 111 +#define OP_UNARY_NUM_HARDSWISH 112 +#define OP_UNARY_NUM_HARDSIGMOID 113 +#define OP_UNARY_NUM_EXP 114 +#define OP_UNARY_NUM_SOFTPLUS 115 +#define OP_UNARY_NUM_EXPM1 116 + +#define OP_SUM_ROWS_NUM_SUM_ROWS 10 +#define OP_SUM_ROWS_NUM_MEAN 11 + // kernel argument structs // // - element counters (e.g. ne00) typically use int32_t to reduce register usage @@ -123,6 +157,31 @@ typedef struct { int32_t dim; } ggml_metal_kargs_concat; +typedef struct { + int32_t ne00; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; + uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; + float slope; + float scale; + float bias; + float val; + float min; + float max; +} ggml_metal_kargs_unary; + typedef struct { int32_t ne00; int32_t ne01; @@ -180,20 +239,6 @@ typedef struct { uint64_t nb3; } ggml_metal_kargs_repeat; -typedef struct { - float scale; - float bias; -} ggml_metal_kargs_scale; - -typedef struct { - float val; -} ggml_metal_kargs_fill; - -typedef struct { - float min; - float max; -} ggml_metal_kargs_clamp; - typedef struct { int64_t nk0; int64_t ne00; @@ -497,8 +542,21 @@ typedef struct { typedef struct { int32_t ne00; - int32_t ne00_4; + int32_t ne01; + int32_t ne02; + int32_t ne03; + uint64_t nb00; uint64_t nb01; + uint64_t nb02; + uint64_t nb03; + int32_t ne0; + int32_t ne1; + int32_t ne2; + int32_t ne3; + uint64_t nb0; + uint64_t nb1; + uint64_t nb2; + uint64_t nb3; float eps; } ggml_metal_kargs_l2_norm; @@ -880,10 +938,6 @@ typedef struct { int max_period; } ggml_metal_kargs_timestep_embedding; -typedef struct { - float slope; -} ggml_metal_kargs_leaky_relu; - typedef struct { int32_t ne00; int32_t ne01; diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index e0ed6c7805..3d5db0b79f 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -287,17 +287,9 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { n_fuse = ggml_metal_op_acc(ctx, idx); } break; case GGML_OP_SCALE: - { - n_fuse = ggml_metal_op_scale(ctx, idx); - } break; case GGML_OP_FILL: - { - n_fuse = ggml_metal_op_fill(ctx, idx); - } break; case GGML_OP_CLAMP: - { - n_fuse = ggml_metal_op_clamp(ctx, idx); - } break; + case GGML_OP_LEAKY_RELU: case GGML_OP_SQR: case GGML_OP_SQRT: case GGML_OP_SIN: @@ -426,10 +418,6 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_top_k(ctx, idx); } break; - case GGML_OP_LEAKY_RELU: - { - n_fuse = ggml_metal_op_leaky_relu(ctx, idx); - } break; case GGML_OP_TRI: { n_fuse = ggml_metal_op_tri(ctx, idx); @@ -438,6 +426,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx); } break; + case GGML_OP_SET: + { + n_fuse = ggml_metal_op_set(ctx, idx); + } break; case GGML_OP_DUP: case GGML_OP_CPY: case GGML_OP_CONT: @@ -628,8 +620,8 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); GGML_ASSERT(op->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous(op->src[0])); - GGML_ASSERT(ggml_is_contiguous(op->src[1])); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[1])); const size_t pnb1 = ((const int32_t *) op->op_params)[0]; const size_t pnb2 = ((const int32_t *) op->op_params)[1]; @@ -679,10 +671,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { } ggml_metal_kargs_bin args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, + /*.ne00 =*/ ne10, + /*.ne01 =*/ ne11, + /*.ne02 =*/ ne12, + /*.ne03 =*/ ne13, /*.nb00 =*/ nb00, /*.nb01 =*/ pnb1, /*.nb02 =*/ pnb2, @@ -695,10 +687,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { /*.nb11 =*/ nb11, /*.nb12 =*/ nb12, /*.nb13 =*/ nb13, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.ne3 =*/ ne3, + /*.ne0 =*/ ne10, + /*.ne1 =*/ ne11, + /*.ne2 =*/ ne12, + /*.ne3 =*/ ne13, /*.nb0 =*/ nb0, /*.nb1 =*/ pnb1, /*.nb2 =*/ pnb2, @@ -707,7 +699,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { /*.o1 =*/ { 0 }, }; - auto pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false); + auto pipeline = ggml_metal_library_get_pipeline_bin_one(lib, GGML_OP_ADD); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); @@ -715,126 +707,19 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); - const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00); + const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + int nth = 1; + + while (2*nth < args.ne0 && nth < nth_max) { + nth *= 2; + } ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1); return 1; } -int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) { - ggml_tensor * op = ctx->node(idx); - - ggml_metal_library_t lib = ctx->lib; - ggml_metal_encoder_t enc = ctx->enc; - - GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); - GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); - GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - - float scale; - float bias; - memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(float)); - memcpy(&bias, ((const int32_t *) op->op_params) + 1, sizeof(float)); - - ggml_metal_kargs_scale args = { - /*.scale =*/ scale, - /*.bias =*/ bias, - }; - - int64_t n = ggml_nelements(op); - - if (n % 4 == 0) { - n /= 4; - } - - auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); - - ggml_metal_encoder_set_pipeline(enc, pipeline); - ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - - ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); - - return 1; -} - -int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) { - ggml_tensor * op = ctx->node(idx); - - ggml_metal_library_t lib = ctx->lib; - ggml_metal_encoder_t enc = ctx->enc; - - GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); - GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); - GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - - const float val = ggml_get_op_params_f32(op, 0); - - ggml_metal_kargs_fill args = { - /*.val =*/ val - }; - - int64_t n = ggml_nelements(op); - - if (n % 4 == 0) { - n /= 4; - } - - auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); - - ggml_metal_encoder_set_pipeline(enc, pipeline); - ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - - ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); - - return 1; -} - -int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) { - ggml_tensor * op = ctx->node(idx); - - ggml_metal_library_t lib = ctx->lib; - ggml_metal_encoder_t enc = ctx->enc; - - GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); - GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); - GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - - float min; - float max; - memcpy(&min, ((const int32_t *) op->op_params) + 0, sizeof(float)); - memcpy(&max, ((const int32_t *) op->op_params) + 1, sizeof(float)); - - ggml_metal_kargs_clamp args = { - /*.min =*/ min, - /*.max =*/ max, - }; - - int64_t n = ggml_nelements(op); - - if (n % 4 == 0) { - n /= 4; - } - - auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); - - ggml_metal_encoder_set_pipeline(enc, pipeline); - ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - - ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); - - return 1; -} - int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); @@ -846,19 +731,79 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - int64_t n = ggml_nelements(op); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); - if (n % 4 == 0) { - n /= 4; + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + ggml_metal_kargs_unary args = { + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.slope =*/ 0.0, + /*.scale =*/ 0.0, + /*.bias =*/ 0.0, + /*.val =*/ 0.0, + /*.min =*/ 0.0, + /*.max =*/ 0.0, + }; + + if (op->op == GGML_OP_LEAKY_RELU) { + args.slope = ggml_get_op_params_f32(op, 0); + } + + if (op->op == GGML_OP_SCALE) { + args.scale = ggml_get_op_params_f32(op, 0); + args.bias = ggml_get_op_params_f32(op, 1); + } + + if (op->op == GGML_OP_FILL) { + args.val = ggml_get_op_params_f32(op, 0); + } + + if (op->op == GGML_OP_CLAMP) { + args.min = ggml_get_op_params_f32(op, 0); + args.max = ggml_get_op_params_f32(op, 1); } auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); - ggml_metal_encoder_set_pipeline(enc, pipeline); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1); + if (pipeline.c4) { + args.ne00 = ne00/4; + args.ne0 = ne0/4; + } - ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); + + if (pipeline.cnt) { + const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op); + + ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); + } else { + const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + const int nth = MIN(args.ne00, nth_max); + + const int nk0 = (args.ne00 + nth - 1)/nth; + + ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1); + } return 1; } @@ -969,6 +914,11 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + ggml_metal_kargs_sum_rows args = { /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, @@ -990,21 +940,26 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op); + if (pipeline.c4) { + args.ne00 = ne00/4; + args.ne0 = ne0/4; + } + int nth = 32; // SIMD width - while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + while (nth < args.ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { nth *= 2; } nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); - nth = std::min(nth, ne00); + nth = std::min(nth, (int) args.ne00); const size_t smem = pipeline.smem; ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); @@ -1664,6 +1619,134 @@ int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + const size_t pnb1 = ((const int32_t *) op->op_params)[0]; + const size_t pnb2 = ((const int32_t *) op->op_params)[1]; + const size_t pnb3 = ((const int32_t *) op->op_params)[2]; + const size_t offs = ((const int32_t *) op->op_params)[3]; + + const bool inplace = (bool) ((const int32_t *) op->op_params)[4]; + + if (!inplace) { + // run a separete kernel to cpy src->dst + // not sure how to avoid this + // TODO: make a simpler cpy_bytes kernel + + //const id pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj; + auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); + + ggml_metal_kargs_cpy args = { + /*.nk0 =*/ ne00, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + + ggml_metal_op_concurrency_reset(ctx); + } + + auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[1]->type, op->type); + + GGML_ASSERT(ne10 % ggml_blck_size(op->src[1]->type) == 0); + + int64_t nk0 = ne10; + if (ggml_is_quantized(op->src[1]->type)) { + nk0 = ne10/16; + } else if (ggml_is_quantized(op->type)) { + nk0 = ne10/ggml_blck_size(op->type); + } + + int nth = std::min(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + // when rows are small, we can batch them together in a single threadgroup + int nrptg = 1; + + // TODO: relax this constraint in the future + if (ggml_blck_size(op->src[1]->type) == 1 && ggml_blck_size(op->type) == 1) { + if (nth > nk0) { + nrptg = (nth + nk0 - 1)/nk0; + nth = nk0; + + if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nrptg--; + } + } + } + + nth = std::min(nth, nk0); + + ggml_metal_kargs_cpy args = { + /*.nk0 =*/ nk0, + /*.ne00 =*/ ne10, + /*.ne01 =*/ ne11, + /*.ne02 =*/ ne12, + /*.ne03 =*/ ne13, + /*.nb00 =*/ nb10, + /*.nb01 =*/ nb11, + /*.nb02 =*/ nb12, + /*.nb03 =*/ nb13, + /*.ne0 =*/ ne10, + /*.ne1 =*/ ne11, + /*.ne2 =*/ ne12, + /*.ne3 =*/ ne13, + /*.nb0 =*/ ggml_element_size(op), + /*.nb1 =*/ pnb1, + /*.nb2 =*/ pnb2, + /*.nb3 =*/ pnb3, + }; + + const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1; + + bid_dst.offs += offs; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src1, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne11 + nrptg - 1)/nrptg, ne12, ne13, nth, nrptg, 1); + + return 1; +} + int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); @@ -2895,8 +2978,6 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); GGML_ASSERT(ggml_is_contiguous_rows(op->src[1])); - bool bcast_row = false; - ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]); ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); @@ -2990,18 +3071,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { struct ggml_metal_pipeline_with_params pipeline; - if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) { - GGML_ASSERT(ggml_is_contiguous(op->src[0])); - - // src1 is a row - GGML_ASSERT(ne11 == 1); - - pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, true); - - bcast_row = true; - } else { - pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, false); - } + pipeline = ggml_metal_library_get_pipeline_bin(lib, op, n_fuse); if (n_fuse > 1) { bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1)); @@ -3015,20 +3085,28 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) { } } + if (pipeline.c4) { + args.ne00 = ne00/4; + args.ne10 = ne10/4; + args.ne0 = ne0/4; + } + ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, bid_src0, 1); ggml_metal_encoder_set_buffer (enc, bid_src1, 2); ggml_metal_encoder_set_buffer (enc, bid_dst, 3); - if (bcast_row) { - const int64_t n = ggml_nelements(op)/4; + if (pipeline.cnt) { + const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op); ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); } else { - int nth = 32; + const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); - while (16*nth < ne0 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + int nth = 1; + + while (2*nth < args.ne0 && nth < nth_max) { nth *= 2; } @@ -3049,39 +3127,59 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + float eps; memcpy(&eps, op->op_params, sizeof(float)); - int nth = 32; // SIMD width - ggml_metal_kargs_l2_norm args = { - /*.ne00 =*/ ne00, - /*.ne00_4 =*/ ne00/4, - /*.nb01 =*/ nb01, - /*.eps =*/ eps, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + /*.eps =*/ eps, }; auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op); - while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + if (pipeline.c4) { + args.ne00 = ne00/4; + args.ne0 = ne0/4; + } + + int nth = 32; // SIMD width + + while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { nth *= 2; } nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); - nth = std::min(nth, ne00/4); const size_t smem = pipeline.smem; - const int64_t nrows = ggml_nrows(op->src[0]); - ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); - ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); return 1; } @@ -4089,42 +4187,6 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) { return 1; } -int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) { - ggml_tensor * op = ctx->node(idx); - - ggml_metal_library_t lib = ctx->lib; - ggml_metal_encoder_t enc = ctx->enc; - - GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); - GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); - GGML_TENSOR_LOCALS( int32_t, ne, op, ne); - GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); - - float slope; - memcpy(&slope, op->op_params, sizeof(float)); - - ggml_metal_kargs_leaky_relu args = { - /*.slope =*/ slope - }; - - auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op); - - int64_t n = ggml_nelements(op); - - if (n % 4 == 0) { - n /= 4; - } - - ggml_metal_encoder_set_pipeline(enc, pipeline); - ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - - ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1); - - return 1; -} - int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 3c64e4f600..f3e38c7aa9 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -46,9 +46,6 @@ size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op); int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx); int ggml_metal_op_repeat (ggml_metal_op_t ctx, int idx); int ggml_metal_op_acc (ggml_metal_op_t ctx, int idx); -int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx); -int ggml_metal_op_fill (ggml_metal_op_t ctx, int idx); -int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx); int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx); int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx); int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx); @@ -62,6 +59,7 @@ int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx); int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx); int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_set (ggml_metal_op_t ctx, int idx); int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pool_1d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx); @@ -86,7 +84,6 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx); int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx); int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx); int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx); -int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx); int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 612a42a1ea..6c349aa0c9 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -77,6 +77,14 @@ static inline float dot(float x, float y) { return x*y; } +static inline float sum(float x) { + return x; +} + +static inline float sum(float4 x) { + return x[0] + x[1] + x[2] + x[3]; +} + // NOTE: this is not dequantizing - we are simply fitting the template template void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { @@ -895,60 +903,217 @@ enum ggml_sort_order { GGML_SORT_ORDER_DESC, }; -// general-purpose kernel for addition, subtraction, multiplication and division of two tensors -// pros: works for non-contiguous tensors, supports broadcast across all dims -// cons: not very efficient -template -kernel void kernel_add_fuse_impl( - constant ggml_metal_kargs_bin & args, +constant float GELU_COEF_A = 0.044715f; +constant float GELU_QUICK_COEF = -1.702f; +constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; +constant float SQRT_2_INV = 0.70710678118654752440084436210484f; + +// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation +// ref: https://www.johndcook.com/blog/python_erf/ +constant float p_erf = 0.3275911f; +constant float a1_erf = 0.254829592f; +constant float a2_erf = -0.284496736f; +constant float a3_erf = 1.421413741f; +constant float a4_erf = -1.453152027f; +constant float a5_erf = 1.061405429f; + +template +inline T erf_approx(T x) { + T sign_x = sign(x); + x = fabs(x); + T t = 1.0f / (1.0f + p_erf * x); + T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x); + return sign_x * y; +} + +template T elu_approx(T x); + +template<> inline float elu_approx(float x) { + return (x > 0.f) ? x : (exp(x) - 1); +} + +template<> inline float4 elu_approx(float4 x) { + float4 res; + + res[0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f); + res[1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f); + res[2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f); + res[3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f); + + return res; +} + +constant short FC_unary_op [[function_constant(FC_UNARY + 0)]]; +constant bool FC_unary_cnt[[function_constant(FC_UNARY + 1)]]; + +template +kernel void kernel_unary_impl( + constant ggml_metal_kargs_unary & args, device const char * src0, - device const char * src1, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig.z; - const int i02 = tgpig.y; - const int i01 = tgpig.x; +#define FC_OP FC_unary_op +#define FC_CNT FC_unary_cnt - const int i13 = i03%args.ne13; - const int i12 = i02%args.ne12; - const int i11 = i01%args.ne11; + device const T0 * src0_ptr; + device T * dst_ptr; - device const float * src0_ptr = (device const float *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs); - device float * dst_ptr = (device float *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs); + int i0; - device const float * src1_ptr[F]; - for (short j = 0; j < F; ++j) { - src1_ptr[j] = (device const float *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); + if (FC_CNT) { + i0 = tgpig.x; + + src0_ptr = (device const T0 *) (src0); + dst_ptr = (device T *) (dst); + } else { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int k0 = tgpig.x/args.ne01; + const int i01 = tgpig.x - k0*args.ne01; + + i0 = k0*ntg.x + tpitg.x; + + src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); + dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 ); } - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; + { + //threadgroup_barrier(mem_flags::mem_none); - float res = src0_ptr[i0]; - -#pragma unroll - for (short j = 0; j < F; ++j) { - res += src1_ptr[j][i10]; + if (!FC_CNT) { + if (i0 >= args.ne0) { + return; + } } - dst_ptr[i0] = res; + const TC x = (TC) src0_ptr[i0]; + + if (FC_OP == OP_UNARY_NUM_SCALE) { + dst_ptr[i0] = (T) (args.scale * x + args.bias); + } + + if (FC_OP == OP_UNARY_NUM_FILL) { + dst_ptr[i0] = (T) args.val; + } + + if (FC_OP == OP_UNARY_NUM_CLAMP) { + dst_ptr[i0] = (T) clamp(x, args.min, args.max); + } + + if (FC_OP == OP_UNARY_NUM_SQR) { + dst_ptr[i0] = (T) (x * x); + } + + if (FC_OP == OP_UNARY_NUM_SQRT) { + dst_ptr[i0] = (T) sqrt(x); + } + + if (FC_OP == OP_UNARY_NUM_SIN) { + dst_ptr[i0] = (T) sin(x); + } + + if (FC_OP == OP_UNARY_NUM_COS) { + dst_ptr[i0] = (T) cos(x); + } + + if (FC_OP == OP_UNARY_NUM_LOG) { + dst_ptr[i0] = (T) log(x); + } + + if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) { + dst_ptr[i0] = (T) (TC(x > 0)*x + TC(x <= 0)*(x * args.slope)); + } + + if (FC_OP == OP_UNARY_NUM_TANH) { + dst_ptr[i0] = (T) precise::tanh(x); + } + + if (FC_OP == OP_UNARY_NUM_RELU) { + dst_ptr[i0] = (T) fmax(0, x); + } + + if (FC_OP == OP_UNARY_NUM_SIGMOID) { + dst_ptr[i0] = (T) (1 / (1 + exp(-x))); + } + + if (FC_OP == OP_UNARY_NUM_GELU) { + dst_ptr[i0] = (T) (0.5*x*(1 + precise::tanh(SQRT_2_OVER_PI*x*(1 + GELU_COEF_A*x*x)))); + } + + if (FC_OP == OP_UNARY_NUM_GELU_ERF) { + dst_ptr[i0] = (T) (0.5*x*(1 + erf_approx(SQRT_2_INV*x))); + } + + if (FC_OP == OP_UNARY_NUM_GELU_QUICK) { + dst_ptr[i0] = (T) (x * (1/(1 + exp(GELU_QUICK_COEF*x)))); + } + + if (FC_OP == OP_UNARY_NUM_SILU) { + dst_ptr[i0] = (T) (x / (1 + exp(-x))); + } + + if (FC_OP == OP_UNARY_NUM_ELU) { + dst_ptr[i0] = (T) elu_approx(x); + } + + if (FC_OP == OP_UNARY_NUM_NEG) { + dst_ptr[i0] = (T) -x; + } + + if (FC_OP == OP_UNARY_NUM_ABS) { + dst_ptr[i0] = (T) fabs(x); + } + + if (FC_OP == OP_UNARY_NUM_SGN) { + dst_ptr[i0] = T(x > 0) - T(x < 0); + } + + if (FC_OP == OP_UNARY_NUM_STEP) { + dst_ptr[i0] = T(x > 0); + } + + if (FC_OP == OP_UNARY_NUM_HARDSWISH) { + dst_ptr[i0] = (T) (x * fmax(0, fmin(1, x/6 + 0.5))); + } + + if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) { + dst_ptr[i0] = (T) fmax(0, fmin(1, x/6 + 0.5)); + } + + if (FC_OP == OP_UNARY_NUM_EXP) { + dst_ptr[i0] = (T) exp(x); + } + + if (FC_OP == OP_UNARY_NUM_SOFTPLUS) { + dst_ptr[i0] = (T) select(log(1 + exp(x)), x, x > 20); + } + + if (FC_OP == OP_UNARY_NUM_EXPM1) { + // TODO: precise implementation + dst_ptr[i0] = (T) (exp(x) - 1); + } } + +#undef FC_OP +#undef FC_CNT } -typedef decltype(kernel_add_fuse_impl<2>) kernel_add_fuse_t; +typedef decltype(kernel_unary_impl) kernel_unary_t; -template [[host_name("kernel_add_fuse_1")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<1>; -template [[host_name("kernel_add_fuse_2")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<2>; -template [[host_name("kernel_add_fuse_3")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<3>; -template [[host_name("kernel_add_fuse_4")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<4>; -template [[host_name("kernel_add_fuse_5")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<5>; -template [[host_name("kernel_add_fuse_6")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<6>; -template [[host_name("kernel_add_fuse_7")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<7>; -template [[host_name("kernel_add_fuse_8")]] kernel kernel_add_fuse_t kernel_add_fuse_impl<8>; +template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl; +template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl; +template [[host_name("kernel_unary_f16_f16")]] kernel kernel_unary_t kernel_unary_impl; +template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_unary_impl; -kernel void kernel_sub_fuse_1( +// OP: 0 - add, 1 - sub, 2 - mul, 3 - div +constant short FC_bin_op [[function_constant(FC_BIN + 0)]]; +constant short FC_bin_f [[function_constant(FC_BIN + 1)]]; +constant bool FC_bin_rb [[function_constant(FC_BIN + 2)]]; + +template +kernel void kernel_bin_fuse_impl( constant ggml_metal_kargs_bin & args, device const char * src0, device const char * src1, @@ -956,89 +1121,152 @@ kernel void kernel_sub_fuse_1( uint3 tgpig[[threadgroup_position_in_grid]], ushort3 tpitg[[thread_position_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig.z; - const int i02 = tgpig.y; - const int i01 = tgpig.x; +#define FC_OP FC_bin_op +#define FC_F FC_bin_f +#define FC_RB FC_bin_rb - const int i13 = i03%args.ne13; - const int i12 = i02%args.ne12; - const int i11 = i01%args.ne11; + if (FC_RB) { + // row broadcast + const uint i0 = tgpig.x; + const uint i1 = i0%args.ne10; - device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; - device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; - device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; + device const T0 * src0_row = (device const T0 *) (src0); + device T * dst_row = (device T *) (dst); - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) - *((device float *)(src1_ptr + i10*args.nb10)); - } -} + if (FC_F == 1) { + device const T1 * src1_row = (device const T1 *) (src1 + args.o1[0]); -kernel void kernel_mul_fuse_1( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig.z; - const int i02 = tgpig.y; - const int i01 = tgpig.x; + if (FC_OP == 0) { + dst_row[i0] = src0_row[i0] + src1_row[i1]; + } - const int i13 = i03%args.ne13; - const int i12 = i02%args.ne12; - const int i11 = i01%args.ne11; + if (FC_OP == 1) { + dst_row[i0] = src0_row[i0] - src1_row[i1]; + } - device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; - device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; - device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; + if (FC_OP == 2) { + dst_row[i0] = src0_row[i0] * src1_row[i1]; + } - if (args.ne10 == 1) { - const float x = *((device float *)(src1_ptr)); - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x; + if (FC_OP == 3) { + dst_row[i0] = src0_row[i0] / src1_row[i1]; + } + } else { + T0 res = src0_row[i0]; + + if (FC_OP == 0) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res += ((device const T1 *) (src1 + args.o1[j]))[i1]; + } + } + + if (FC_OP == 1) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res -= ((device const T1 *) (src1 + args.o1[j]))[i1]; + } + } + + if (FC_OP == 2) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res *= ((device const T1 *) (src1 + args.o1[j]))[i1]; + } + } + + if (FC_OP == 3) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res /= ((device const T1 *) (src1 + args.o1[j]))[i1]; + } + } + + dst_row[i0] = res; } } else { - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * *((device float *)(src1_ptr + i10*args.nb10)); + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; + + if (i01 >= args.ne01) { + return; + } + + const int i13 = i03%args.ne13; + const int i12 = i02%args.ne12; + const int i11 = i01%args.ne11; + + device const T0 * src0_ptr = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs); + device T * dst_ptr = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs); + + if (FC_F == 1) { + device const T1 * src1_ptr = (device const T1 *) (src1 + args.o1[0] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + + if (FC_OP == 0) { + dst_ptr[i0] = src0_ptr[i0] + src1_ptr[i10]; + } + + if (FC_OP == 1) { + dst_ptr[i0] = src0_ptr[i0] - src1_ptr[i10]; + } + + if (FC_OP == 2) { + dst_ptr[i0] = src0_ptr[i0] * src1_ptr[i10]; + } + + if (FC_OP == 3) { + dst_ptr[i0] = src0_ptr[i0] / src1_ptr[i10]; + } + } + } else { + device const T1 * src1_ptr[8]; + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + src1_ptr[j] = (device const T1 *) (src1 + args.o1[j] + i13*args.nb13 + i12*args.nb12 + i11*args.nb11); + } + + for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { + const int i10 = i0%args.ne10; + + T res = src0_ptr[i0]; + + if (FC_OP == 0) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res += src1_ptr[j][i10]; + } + } + + if (FC_OP == 1) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res -= src1_ptr[j][i10]; + } + } + + if (FC_OP == 2) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res *= src1_ptr[j][i10]; + } + } + + if (FC_OP == 3) { + FOR_UNROLL (short j = 0; j < FC_F; ++j) { + res /= src1_ptr[j][i10]; + } + } + + dst_ptr[i0] = res; + } } } + +#undef FC_OP +#undef FC_F +#undef FC_RB } -kernel void kernel_div_fuse_1( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig.z; - const int i02 = tgpig.y; - const int i01 = tgpig.x; +typedef decltype(kernel_bin_fuse_impl) kernel_bin_fuse_t; - const int i13 = i03%args.ne13; - const int i12 = i02%args.ne12; - const int i11 = i01%args.ne11; - - device const char * src0_ptr = src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + args.offs; - device const char * src1_ptr = src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + args.o1[0]; - device char * dst_ptr = dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1 + args.offs; - - if (args.ne10 == 1) { - const float x = 1.0f / *((device float *)(src1_ptr)); - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) * x; - } - } else { - for (int i0 = tpitg.x; i0 < args.ne0; i0 += ntg.x) { - const int i10 = i0%args.ne10; - *((device float *)(dst_ptr + i0*args.nb0)) = *((device float *)(src0_ptr + i0*args.nb00)) / *((device float *)(src1_ptr + i10*args.nb10)); - } - } -} +template [[host_name("kernel_bin_fuse_f32_f32_f32")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl; +template [[host_name("kernel_bin_fuse_f32_f32_f32_4")]] kernel kernel_bin_fuse_t kernel_bin_fuse_impl; kernel void kernel_add_id( constant ggml_metal_kargs_add_id & args, @@ -1057,7 +1285,7 @@ kernel void kernel_add_id( const size_t nb1 = args.ne0 * sizeof(float); const size_t nb2 = args.ne1 * nb1; - device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2); + device float * dst_row = (device float *)((device char *)dst + i1*nb1 + i2*nb2); device const float * src0_row = (device const float *)((device char *)src0 + i1*args.nb01 + i2*args.nb02); device const float * src1_row = (device const float *)((device char *)src1 + i11*args.nb11); @@ -1098,549 +1326,6 @@ template [[host_name("kernel_repeat_f16")]] kernel kernel_repeat_t kernel_repeat template [[host_name("kernel_repeat_i32")]] kernel kernel_repeat_t kernel_repeat; template [[host_name("kernel_repeat_i16")]] kernel kernel_repeat_t kernel_repeat; -// assumption: src1 is a row -// broadcast src1 into src0 -template -kernel void kernel_add_row_c4_fuse_impl( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint tpig[[thread_position_in_grid]]) { - const uint nb = args.ne00/4; - const uint i = tpig % nb; - - device const float4 * src0_row = (device const float4 *) (src0); - device float4 * dst_row = (device float4 *) (dst); - - float4 res = src0_row[tpig]; - -#pragma unroll(F) - for (short j = 0; j < F; ++j) { - res += ((device const float4 *) (src1 + args.o1[j]))[i]; - } - - dst_row[tpig] = res; -} - -typedef decltype(kernel_add_row_c4_fuse_impl<1>) kernel_add_row_c4_fuse_t; - -template [[host_name("kernel_add_row_c4_fuse_1")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<1>; -template [[host_name("kernel_add_row_c4_fuse_2")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<2>; -template [[host_name("kernel_add_row_c4_fuse_3")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<3>; -template [[host_name("kernel_add_row_c4_fuse_4")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<4>; -template [[host_name("kernel_add_row_c4_fuse_5")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<5>; -template [[host_name("kernel_add_row_c4_fuse_6")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<6>; -template [[host_name("kernel_add_row_c4_fuse_7")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<7>; -template [[host_name("kernel_add_row_c4_fuse_8")]] kernel kernel_add_row_c4_fuse_t kernel_add_row_c4_fuse_impl<8>; - -template -kernel void kernel_sub_row_c4_fuse_impl( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint tpig[[thread_position_in_grid]]) { - - const uint nb = args.ne00/4; - const uint i = tpig % nb; - - device const float4 * src0_row = (device const float4 *) (src0); - device float4 * dst_row = (device float4 *) (dst); - - device const float4 * src1_row[F]; - for (short j = 0; j < F; ++j) { - src1_row[j] = (device const float4 *) (src1 + args.o1[j]); - } - - float4 res = src0_row[tpig]; - -#pragma unroll(F) - for (short j = 0; j < F; ++j) { - res -= src1_row[j][i]; - } - - dst_row[tpig] = res; -} - -typedef decltype(kernel_sub_row_c4_fuse_impl<1>) kernel_sub_row_c4_fuse_t; - -template [[host_name("kernel_sub_row_c4_fuse_1")]] kernel kernel_sub_row_c4_fuse_t kernel_sub_row_c4_fuse_impl<1>; - -template -kernel void kernel_mul_row_c4_fuse_impl( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint tpig[[thread_position_in_grid]]) { - - const uint nb = args.ne00/4; - const uint i = tpig % nb; - - device const float4 * src0_row = (device const float4 *) (src0); - device float4 * dst_row = (device float4 *) (dst); - - device const float4 * src1_row[F]; - for (short j = 0; j < F; ++j) { - src1_row[j] = (device const float4 *) (src1 + args.o1[j]); - } - - float4 res = src0_row[tpig]; - -#pragma unroll(F) - for (short j = 0; j < F; ++j) { - res *= src1_row[j][i]; - } - - dst_row[tpig] = res; -} - -typedef decltype(kernel_mul_row_c4_fuse_impl<1>) kernel_mul_row_c4_fuse_t; - -template [[host_name("kernel_mul_row_c4_fuse_1")]] kernel kernel_mul_row_c4_fuse_t kernel_mul_row_c4_fuse_impl<1>; - -template -kernel void kernel_div_row_c4_fuse_impl( - constant ggml_metal_kargs_bin & args, - device const char * src0, - device const char * src1, - device char * dst, - uint tpig[[thread_position_in_grid]]) { - - const uint nb = args.ne00/4; - const uint i = tpig % nb; - - device const float4 * src0_row = (device const float4 *) (src0); - device float4 * dst_row = (device float4 *) (dst); - - device const float4 * src1_row[F]; - for (short j = 0; j < F; ++j) { - src1_row[j] = (device const float4 *) (src1 + args.o1[j]); - } - - float4 res = src0_row[tpig]; - -#pragma unroll(F) - for (short j = 0; j < F; ++j) { - res /= src1_row[j][i]; - } - - dst_row[tpig] = res; -} - -typedef decltype(kernel_div_row_c4_fuse_impl<1>) kernel_div_row_c4_fuse_t; - -template [[host_name("kernel_div_row_c4_fuse_1")]] kernel kernel_div_row_c4_fuse_t kernel_div_row_c4_fuse_impl<1>; - -kernel void kernel_scale_f32( - constant ggml_metal_kargs_scale & args, - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * args.scale + args.bias; -} - -kernel void kernel_scale_f32_4( - constant ggml_metal_kargs_scale & args, - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * args.scale + args.bias; -} - -kernel void kernel_fill_f32( - constant ggml_metal_kargs_fill & args, - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = args.val; -} - -kernel void kernel_fill_f32_4( - constant ggml_metal_kargs_fill & args, - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = args.val; -} - -kernel void kernel_clamp_f32( - constant ggml_metal_kargs_clamp & args, - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = clamp(src0[tpig], args.min, args.max); -} - -kernel void kernel_clamp_f32_4( - constant ggml_metal_kargs_clamp & args, - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = clamp(src0[tpig], args.min, args.max); -} - -kernel void kernel_relu_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = max(0.0f, src0[tpig]); -} - -kernel void kernel_relu_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = max(0.0f, src0[tpig]); -} - -kernel void kernel_sigmoid_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig])); -} - -kernel void kernel_sigmoid_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = 1.0f / (1.0f + exp(-src0[tpig])); -} - -kernel void kernel_tanh_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = precise::tanh(src0[tpig]); -} - -kernel void kernel_tanh_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = precise::tanh(src0[tpig]); -} - -constant float GELU_COEF_A = 0.044715f; -constant float GELU_QUICK_COEF = -1.702f; -constant float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; -constant float SQRT_2_INV = 0.70710678118654752440084436210484f; - -kernel void kernel_gelu_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - - dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - -kernel void kernel_gelu_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - - // BEWARE !!! - // Simply using "tanh" instead of "precise::tanh" will sometimes results in NaNs! - // This was observed with Falcon 7B and 40B models - // - dst[tpig] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); -} - -kernel void kernel_gelu_quick_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - - dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); -} - -kernel void kernel_gelu_quick_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - - dst[tpig] = x*(1.0f/(1.0f+exp(GELU_QUICK_COEF*x))); -} - -// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation -// ref: https://www.johndcook.com/blog/python_erf/ -constant float p_erf = 0.3275911f; -constant float a1_erf = 0.254829592f; -constant float a2_erf = -0.284496736f; -constant float a3_erf = 1.421413741f; -constant float a4_erf = -1.453152027f; -constant float a5_erf = 1.061405429f; - -template -T erf_approx(T x) { - T sign_x = sign(x); - x = fabs(x); - T t = 1.0f / (1.0f + p_erf * x); - T y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x); - return sign_x * y; -} - -kernel void kernel_gelu_erf_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - - dst[tpig] = 0.5f*x*(1.0f+erf_approx(x*SQRT_2_INV)); -} - -kernel void kernel_gelu_erf_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - - dst[tpig] = 0.5f*x*(1.0f+erf_approx(x*SQRT_2_INV)); -} - -kernel void kernel_silu_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - dst[tpig] = x / (1.0f + exp(-x)); -} - -kernel void kernel_silu_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - dst[tpig] = x / (1.0f + exp(-x)); -} - -kernel void kernel_elu_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - const float x = src0[tpig]; - dst[tpig] = (x > 0.0f) ? x : (exp(x) - 1.0f); -} - -kernel void kernel_elu_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - const float4 x = src0[tpig]; - dst[tpig][0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f); - dst[tpig][1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f); - dst[tpig][2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f); - dst[tpig][3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f); -} - -kernel void kernel_sqr_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * src0[tpig]; -} - -kernel void kernel_sqr_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = src0[tpig] * src0[tpig]; -} - -kernel void kernel_sqrt_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = sqrt(src0[tpig]); -} - -kernel void kernel_sqrt_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = sqrt(src0[tpig]); -} - -kernel void kernel_sin_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = sin(src0[tpig]); -} - -kernel void kernel_sin_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = sin(src0[tpig]); -} - -kernel void kernel_cos_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = cos(src0[tpig]); -} - -kernel void kernel_cos_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = cos(src0[tpig]); -} - -kernel void kernel_log_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = log(src0[tpig]); -} - -kernel void kernel_log_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = log(src0[tpig]); -} - -kernel void kernel_neg_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = -src0[tpig]; -} - -kernel void kernel_neg_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = -src0[tpig]; -} - -kernel void kernel_abs_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = fabs(src0[tpig]); -} - -kernel void kernel_abs_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = fabs(src0[tpig]); -} - -kernel void kernel_sgn_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = sign(src0[tpig]); -} - -kernel void kernel_sgn_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = sign(src0[tpig]); -} - -kernel void kernel_step_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = step(0.0f, src0[tpig]); -} - -kernel void kernel_step_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = step(0.0f, src0[tpig]); -} - -kernel void kernel_hardswish_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - const float x = src0[tpig]; - dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); -} - -kernel void kernel_hardswish_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - const float4 x = src0[tpig]; - dst[tpig] = x * fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); -} - -kernel void kernel_hardsigmoid_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - const float x = src0[tpig]; - dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); -} - -kernel void kernel_hardsigmoid_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - const float4 x = src0[tpig]; - dst[tpig] = fmin(1.0f, fmax(0.0f, (x + 3.0f) / 6.0f)); -} - -kernel void kernel_exp_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = exp(src0[tpig]); -} - -kernel void kernel_exp_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = exp(src0[tpig]); -} - -kernel void kernel_softplus_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - device const float & x = src0[tpig]; - dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f); -} - -kernel void kernel_softplus_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - device const float4 & x = src0[tpig]; - dst[tpig] = select(log(1.0f + exp(x)), x, x > 20.0f); -} - -kernel void kernel_expm1_f32( - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = exp(src0[tpig]) - 1.0f; -} - -kernel void kernel_expm1_f32_4( - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - dst[tpig] = exp(src0[tpig]) - 1.0f; -} - kernel void kernel_reglu_f32( constant ggml_metal_kargs_glu & args, device const char * src0, @@ -1824,33 +1509,35 @@ kernel void kernel_op_sum_f32( } } -template -kernel void kernel_sum_rows( +constant short FC_sum_rows_op [[function_constant(FC_SUM_ROWS + 0)]]; + +template +kernel void kernel_sum_rows_impl( constant ggml_metal_kargs_sum_rows & args, - device const float * src0, - device float * dst, - threadgroup float * shmem_f32 [[threadgroup(0)]], + device const char * src0, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort3 tpitg[[thread_position_in_threadgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - int64_t i3 = tgpig.z; - int64_t i2 = tgpig.y; - int64_t i1 = tgpig.x; +#define FC_OP FC_sum_rows_op - if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) { - return; - } + const int i3 = tgpig.z; + const int i2 = tgpig.y; + const int i1 = tgpig.x; + + threadgroup T0 * shmem_t = (threadgroup T0 *) shmem; if (sgitg == 0) { - shmem_f32[tiisg] = 0.0f; + shmem_t[tiisg] = 0.0f; } - device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); - device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); + device const T0 * src_row = (device const T0 *) (src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); + device T * dst_row = (device T *) (dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); - float sumf = 0; + T0 sumf = T0(0.0f); for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { sumf += src_row[i0]; @@ -1861,23 +1548,33 @@ kernel void kernel_sum_rows( threadgroup_barrier(mem_flags::mem_threadgroup); if (tiisg == 0) { - shmem_f32[sgitg] = sumf; + shmem_t[sgitg] = sumf; } threadgroup_barrier(mem_flags::mem_threadgroup); - sumf = shmem_f32[tiisg]; + sumf = shmem_t[tiisg]; sumf = simd_sum(sumf); if (tpitg.x == 0) { - dst_row[0] = norm ? sumf / args.ne00 : sumf; + if (FC_OP == OP_SUM_ROWS_NUM_MEAN) { + if (is_same::value) { + dst_row[0] = sum(sumf) / (4*args.ne00); + } else { + dst_row[0] = sum(sumf) / args.ne00; + } + } else { + dst_row[0] = sum(sumf); + } } + +#undef FC_OP } -typedef decltype(kernel_sum_rows) kernel_sum_rows_t; +typedef decltype(kernel_sum_rows_impl) kernel_sum_rows_t; -template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows; -template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows; +template [[host_name("kernel_sum_rows_f32_f32")]] kernel kernel_sum_rows_t kernel_sum_rows_impl; +template [[host_name("kernel_sum_rows_f32_f32_4")]] kernel kernel_sum_rows_t kernel_sum_rows_impl; template kernel void kernel_cumsum_blk( @@ -2758,9 +2455,6 @@ kernel void kernel_solve_tri_f32( const short K = FC_solve_tri_k; const short NP = PAD2(N, NW); - const int32_t ne02 = args.ne02; - const int32_t ne03 = args.ne03; - const int32_t i03 = tgpig.z; const int32_t i02 = tgpig.y; const int32_t i01 = tgpig.x*NSG + sgitg; @@ -3047,26 +2741,32 @@ template [[host_name("kernel_rms_norm_f32_4")]] kernel kernel_rms_norm_f template [[host_name("kernel_rms_norm_mul_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl; template [[host_name("kernel_rms_norm_mul_add_f32_4")]] kernel kernel_rms_norm_fuse_t kernel_rms_norm_fuse_impl; -kernel void kernel_l2_norm_f32( +template +kernel void kernel_l2_norm_impl( constant ggml_metal_kargs_l2_norm & args, device const char * src0, device char * dst, threadgroup float * shmem_f32 [[threadgroup(0)]], - uint tgpig[[threadgroup_position_in_grid]], - ushort tpitg[[thread_position_in_threadgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort ntg[[threads_per_threadgroup]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int i03 = tgpig.z; + const int i02 = tgpig.y; + const int i01 = tgpig.x; + if (sgitg == 0) { shmem_f32[tiisg] = 0.0f; } - device const float4 * x = (device const float4 *) (src0 + tgpig*args.nb01); + device const T0 * x = (device const T0 *) (src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); + device T * y = (device T *) (dst + i03*args.nb3 + i02*args.nb2 + i01*args.nb1); float sumf = 0.0f; // parallel sum - for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) { sumf += dot(x[i00], x[i00]); } sumf = simd_sum(sumf); @@ -3084,12 +2784,16 @@ kernel void kernel_l2_norm_f32( const float scale = 1.0f/sqrt(max(sumf, args.eps)); - device float4 * y = (device float4 *) dst + tgpig*args.ne00_4; - for (int i00 = tpitg; i00 < args.ne00_4; i00 += ntg) { + for (int i00 = tpitg.x; i00 < args.ne00; i00 += ntg.x) { y[i00] = x[i00] * scale; } } +typedef decltype(kernel_l2_norm_impl) kernel_l2_norm_t; + +template [[host_name("kernel_l2_norm_f32_f32")]] kernel kernel_l2_norm_t kernel_l2_norm_impl; +template [[host_name("kernel_l2_norm_f32_f32_4")]] kernel kernel_l2_norm_t kernel_l2_norm_impl; + kernel void kernel_group_norm_f32( constant ggml_metal_kargs_group_norm & args, device const float * src0, @@ -5191,24 +4895,6 @@ kernel void kernel_argsort_merge_f32_i32( template [[host_name("kernel_argsort_merge_f32_i32_asc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32; template [[host_name("kernel_argsort_merge_f32_i32_desc")]] kernel argsort_merge_t kernel_argsort_merge_f32_i32; -kernel void kernel_leaky_relu_f32( - constant ggml_metal_kargs_leaky_relu & args, - device const float * src0, - device float * dst, - uint tpig[[thread_position_in_grid]]) { - const float x = src0[tpig]; - dst[tpig] = x > 0.0f ? x : x * args.slope; -} - -kernel void kernel_leaky_relu_f32_4( - constant ggml_metal_kargs_leaky_relu & args, - device const float4 * src0, - device float4 * dst, - uint tpig[[thread_position_in_grid]]) { - const float4 x = src0[tpig]; - dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope); -} - constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]]; constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]]; @@ -6280,7 +5966,7 @@ kernel void kernel_flash_attn_ext_vec( static_assert(DK4 % NL == 0, "DK4 must be divisible by NL"); static_assert(DV4 % NL == 0, "DV4 must be divisible by NL"); - const short T = PK + NSG*SH; // shared memory size per query in (half) + //const short T = PK + NSG*SH; // shared memory size per query in (half) //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t @@ -8868,7 +8554,9 @@ kernel void kernel_mul_mm( threadgroup S0 * sa = (threadgroup S0 *)(shmem); threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); +#ifdef GGML_METAL_HAS_TENSOR threadgroup float * sc = (threadgroup float *)(shmem); +#endif constexpr int NR0 = 64; constexpr int NR1 = 32; @@ -8991,8 +8679,8 @@ kernel void kernel_mul_mm( const short sx = (tiitg%NL1); const short sy = (tiitg/NL1)/8; - const short dx = sx; - const short dy = sy; + //const short dx = sx; + //const short dy = sy; const short ly = (tiitg/NL1)%8; @@ -9241,7 +8929,9 @@ kernel void kernel_mul_mm_id( threadgroup S0 * sa = (threadgroup S0 *)(shmem); threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); +#ifdef GGML_METAL_HAS_TENSOR threadgroup float * sc = (threadgroup float *)(shmem); +#endif constexpr int NR0 = 64; constexpr int NR1 = 32; @@ -9376,8 +9066,8 @@ kernel void kernel_mul_mm_id( const short sx = (tiitg%NL1); const short sy = (tiitg/NL1)/8; - const short dx = sx; - const short dy = sy; + //const short dx = sx; + //const short dy = sy; const short ly = (tiitg/NL1)%8; @@ -10058,7 +9748,7 @@ kernel void kernel_opt_step_sgd_f32( template kernel void kernel_memset( - constant ggml_metal_kargs_fill & args, + constant ggml_metal_kargs_memset & args, device T * dst, uint tpig[[thread_position_in_grid]]) { dst[tpig] = args.val; diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index fa5fadd112..f389193691 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -85,6 +85,9 @@ set(GGML_OPENCL_KERNELS mul_mv_q4_0_f32_8x_flat mul_mv_q4_0_f32_1d_8x_flat mul_mv_q4_0_f32_1d_16x_flat + mul_mv_q4_1_f32 + mul_mv_q4_1_f32_flat + mul_mv_q4_k_f32 mul_mv_q6_k_f32 mul_mv_q6_k_f32_flat mul_mv_q8_0_f32 @@ -100,7 +103,10 @@ set(GGML_OPENCL_KERNELS gemv_moe_mxfp4_f32 mul_mm_f32_f32_l4_lm mul_mm_f16_f32_l4_lm + mul_mm_q4_0_f32_l4_lm + mul_mm_q4_1_f32_l4_lm mul_mm_q8_0_f32_l4_lm + mul_mm_q6_k_f32_l4_lm mul_mm_q8_0_f32_8x4 gemv_noshuffle_general_q8_0_f32 mul diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 508b2b8f03..3da022ed86 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -484,7 +484,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_scale_f32, kernel_scale_f32_4; cl_kernel kernel_sqr_cont_f32, kernel_sqr_cont_f32_4, kernel_sqr_cont_f16, kernel_sqr_cont_f16_4; cl_kernel kernel_sqrt_cont_f32, kernel_sqrt_cont_f32_4, kernel_sqrt_cont_f16, kernel_sqrt_cont_f16_4; - cl_kernel kernel_mean_f32; + cl_kernel kernel_mean_f32, kernel_mean_f32_4; cl_kernel kernel_silu, kernel_silu_4; cl_kernel kernel_gelu, kernel_gelu_4; cl_kernel kernel_gelu_erf, kernel_gelu_erf_4; @@ -525,6 +525,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mm_f16_f32_kq; cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v; cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0; + cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1; cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans; cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans; cl_kernel kernel_mul_mat_q4_0_f32_8x_flat; @@ -532,6 +533,9 @@ struct ggml_backend_opencl_context { cl_kernel kernel_restore_block_q4_0_noshuffle; cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K; cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; + cl_kernel kernel_mul_mv_q4_1_f32; + cl_kernel kernel_mul_mv_q4_1_f32_flat; + cl_kernel kernel_mul_mv_q4_K_f32; cl_kernel kernel_mul_mv_q6_K_f32; cl_kernel kernel_mul_mv_q6_K_f32_flat; cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat; @@ -539,15 +543,15 @@ struct ggml_backend_opencl_context { cl_kernel kernel_solve_tri_f32; cl_kernel kernel_im2col_f32, kernel_im2col_f16; cl_kernel kernel_argsort_f32_i32; - cl_kernel kernel_sum_rows_f32; + cl_kernel kernel_sum_rows_f32, kernel_sum_rows_f32_4; cl_kernel kernel_repeat_f32; cl_kernel kernel_pad; cl_kernel kernel_tanh_f32, kernel_tanh_f32_4, kernel_tanh_f32_nc; cl_kernel kernel_tanh_f16, kernel_tanh_f16_4, kernel_tanh_f16_nc; - cl_kernel kernel_expm1_f32_nd; - cl_kernel kernel_expm1_f16_nd; - cl_kernel kernel_softplus_f32_nd; - cl_kernel kernel_softplus_f16_nd; + cl_kernel kernel_expm1_f32, kernel_expm1_f32_4, kernel_expm1_f32_nc; + cl_kernel kernel_expm1_f16, kernel_expm1_f16_4, kernel_expm1_f16_nc; + cl_kernel kernel_softplus_f32, kernel_softplus_f32_4, kernel_softplus_f32_nc; + cl_kernel kernel_softplus_f16, kernel_softplus_f16_4, kernel_softplus_f16_nc; cl_kernel kernel_upscale; cl_kernel kernel_upscale_bilinear; cl_kernel kernel_concat_f32; @@ -563,7 +567,10 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mv_id_mxfp4_f32_flat; cl_kernel kernel_mul_mm_f32_f32_l4_lm; cl_kernel kernel_mul_mm_f16_f32_l4_lm; + cl_kernel kernel_mul_mm_q4_0_f32_l4_lm; + cl_kernel kernel_mul_mm_q4_1_f32_l4_lm; cl_kernel kernel_mul_mm_q8_0_f32_l4_lm; + cl_kernel kernel_mul_mm_q6_k_f32_l4_lm; std::vector profiling_info; @@ -886,6 +893,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4_trans", &err), err)); @@ -1117,6 +1126,57 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mv_q4_1_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_1_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_1_f32.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q4_1_f32 = clCreateKernel(prog, "kernel_mul_mv_q4_1_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // mul_mv_q4_1_f32_flat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_1_f32_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_1_f32_flat.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q4_1_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q4_1_f32_flat", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // mul_mv_q4_k_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_k_f32.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q4_K_f32 = clCreateKernel(prog, "kernel_mul_mv_q4_K_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mv_q6_k_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -1342,6 +1402,38 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mm_q4_0_f32_l4_lm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_q4_0_f32_l4_lm.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_q4_0_f32_l4_lm.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mm_q4_0_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q4_0_f32_l4_lm", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mm_q4_1_f32_l4_lm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_q4_1_f32_l4_lm.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_q4_1_f32_l4_lm.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mm_q4_1_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q4_1_f32_l4_lm", &err), err)); + GGML_LOG_CONT("."); + } + // mul_mm_q8_0_f32_l4_lm { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -1358,6 +1450,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mm_q6_k_f32_l4_lm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_q6_k_f32_l4_lm.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_q6_k_f32_l4_lm.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mm_q6_k_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q6_k_f32_l4_lm", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mm_f16_f32_kq_kqv { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -1728,6 +1837,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); CL_CHECK((backend_ctx->kernel_mean_f32 = clCreateKernel(prog, "kernel_mean_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_mean_f32_4 = clCreateKernel(prog, "kernel_mean_f32_4", &err), err)); CL_CHECK(clReleaseProgram(prog)); GGML_LOG_CONT("."); @@ -1765,6 +1875,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); CL_CHECK((backend_ctx->kernel_sum_rows_f32 = clCreateKernel(backend_ctx->program_sum_rows_f32, "kernel_sum_rows_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_sum_rows_f32_4 = clCreateKernel(backend_ctx->program_sum_rows_f32, "kernel_sum_rows_f32_4", &err), err)); GGML_LOG_CONT("."); } @@ -1869,20 +1980,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve #else const std::string kernel_src = read_file("expm1.cl"); #endif - cl_program prog; - if (!kernel_src.empty()) { - prog = - build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_expm1_f32_nd = clCreateKernel(prog, "kernel_expm1_f32_nd", &err), err)); - CL_CHECK((backend_ctx->kernel_expm1_f16_nd = clCreateKernel(prog, "kernel_expm1_f16_nd", &err), err)); - GGML_LOG_CONT("."); - } else { - GGML_LOG_WARN("ggml_opencl: expm1 kernel source not found or empty. Expm1 operation will not be available.\n"); - prog = nullptr; - backend_ctx->kernel_expm1_f32_nd = nullptr; - backend_ctx->kernel_expm1_f16_nd = nullptr; - } + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_expm1_f32 = clCreateKernel(prog, "kernel_expm1_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_expm1_f32_4 = clCreateKernel(prog, "kernel_expm1_f32_4", &err), err)); + CL_CHECK((backend_ctx->kernel_expm1_f32_nc = clCreateKernel(prog, "kernel_expm1_f32_nc", &err), err)); + CL_CHECK((backend_ctx->kernel_expm1_f16 = clCreateKernel(prog, "kernel_expm1_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_expm1_f16_4 = clCreateKernel(prog, "kernel_expm1_f16_4", &err), err)); + CL_CHECK((backend_ctx->kernel_expm1_f16_nc = clCreateKernel(prog, "kernel_expm1_f16_nc", &err), err)); CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); } // softplus @@ -1894,20 +2001,16 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve #else const std::string kernel_src = read_file("softplus.cl"); #endif - cl_program prog; - if (!kernel_src.empty()) { - prog = - build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); - CL_CHECK((backend_ctx->kernel_softplus_f32_nd = clCreateKernel(prog, "kernel_softplus_f32_nd", &err), err)); - CL_CHECK((backend_ctx->kernel_softplus_f16_nd = clCreateKernel(prog, "kernel_softplus_f16_nd", &err), err)); - GGML_LOG_CONT("."); - } else { - GGML_LOG_WARN("ggml_opencl: softplus kernel source not found or empty. Softplus operation will not be available.\n"); - prog = nullptr; - backend_ctx->kernel_softplus_f32_nd = nullptr; - backend_ctx->kernel_softplus_f16_nd = nullptr; - } + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + CL_CHECK((backend_ctx->kernel_softplus_f32 = clCreateKernel(prog, "kernel_softplus_f32", &err), err)); + CL_CHECK((backend_ctx->kernel_softplus_f32_4 = clCreateKernel(prog, "kernel_softplus_f32_4", &err), err)); + CL_CHECK((backend_ctx->kernel_softplus_f32_nc = clCreateKernel(prog, "kernel_softplus_f32_nc", &err), err)); + CL_CHECK((backend_ctx->kernel_softplus_f16 = clCreateKernel(prog, "kernel_softplus_f16", &err), err)); + CL_CHECK((backend_ctx->kernel_softplus_f16_4 = clCreateKernel(prog, "kernel_softplus_f16_4", &err), err)); + CL_CHECK((backend_ctx->kernel_softplus_f16_nc = clCreateKernel(prog, "kernel_softplus_f16_nc", &err), err)); CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); } // upscale @@ -2887,6 +2990,59 @@ struct ggml_tensor_extra_cl_q4_0 { } }; +struct ggml_tensor_extra_cl_q4_1 { + // Quantized values. + cl_mem q = nullptr; + // Quantized values in image1d_buffer_t. + cl_mem q_img = nullptr; + // Scales. + cl_mem d = nullptr; + // Scales in image1d_buffer_t. + cl_mem d_img = nullptr; + // Min + cl_mem m = nullptr; + // Min in image1d_buffer_t. + cl_mem m_img = nullptr; + // Size of quantized values. + size_t size_q = 0; + // Size of scales. + size_t size_d = 0; + // Size of min values. + size_t size_m = 0; + + ~ggml_tensor_extra_cl_q4_1() { + reset(); + } + + void reset() { + // q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer. + // They must be properly released so that the original buffer can be + // properly released to avoid memory leak. + if (q != nullptr) { + CL_CHECK(clReleaseMemObject(q)); + q = nullptr; + } + if (d != nullptr) { + CL_CHECK(clReleaseMemObject(d)); + d = nullptr; + } + if (m != nullptr) { + CL_CHECK(clReleaseMemObject(m)); + m = nullptr; + } + // Currently, q_img and d_img are only initialized when SMALL_ALLOC is + // enabled. They point to the images in ggml_backend_opencl_buffer_context. + // So, there is no need to release them here. + // TODO: initialize them for non SMALL_PATH path, or remove them. + q_img = nullptr; + d_img = nullptr; + m_img = nullptr; + size_q = 0; + size_d = 0; + size_m = 0; + } +}; + struct ggml_tensor_extra_cl_mxfp4 { // Quantized values. cl_mem q = nullptr; @@ -3301,11 +3457,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_UNARY_OP_TANH: return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; case GGML_UNARY_OP_EXPM1: - return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) || - (op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16); + return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; case GGML_UNARY_OP_SOFTPLUS: - return (op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) || - (op->src[0]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16); + return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16; default: return false; } @@ -3363,7 +3517,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return true; } else if (op->src[0]->type == GGML_TYPE_F32) { return op->src[1]->type == GGML_TYPE_F32; - } else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_MXFP4 || + } else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 || + op->src[0]->type == GGML_TYPE_MXFP4 || + op->src[0]->type == GGML_TYPE_Q4_K || op->src[0]->type == GGML_TYPE_Q6_K) { return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); } else if (op->src[0]->type == GGML_TYPE_Q8_0) { @@ -3423,7 +3579,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te } case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: - return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]); + return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_FLASH_ATTN_EXT: { const ggml_tensor * q = op->src[0]; @@ -3592,6 +3748,21 @@ struct ggml_backend_opencl_buffer_context { return extra; } + ggml_tensor_extra_cl_q4_1 * ggml_opencl_alloc_temp_tensor_extra_q4_1() { + ggml_tensor_extra_cl_q4_1 * extra; + if (temp_tensor_extras_q4_1.empty()) { + extra = new ggml_tensor_extra_cl_q4_1(); + } else { + extra = temp_tensor_extras_q4_1.back(); + temp_tensor_extras_q4_1.pop_back(); + } + + temp_tensor_extras_q4_1_in_use.push_back(extra); + + extra->reset(); + return extra; + } + ggml_tensor_extra_cl_mxfp4 * ggml_opencl_alloc_temp_tensor_extra_mxfp4() { ggml_tensor_extra_cl_mxfp4 * extra; if (temp_tensor_extras_mxfp4.empty()) { @@ -3648,6 +3819,11 @@ struct ggml_backend_opencl_buffer_context { } temp_tensor_extras_q4_0_in_use.clear(); + for (ggml_tensor_extra_cl_q4_1 * e : temp_tensor_extras_q4_1_in_use) { + temp_tensor_extras_q4_1.push_back(e); + } + temp_tensor_extras_q4_1_in_use.clear(); + for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4_in_use) { temp_tensor_extras_mxfp4.push_back(e); } @@ -3673,6 +3849,8 @@ struct ggml_backend_opencl_buffer_context { std::vector temp_tensor_extras_in_use; std::vector temp_tensor_extras_q4_0; std::vector temp_tensor_extras_q4_0_in_use; + std::vector temp_tensor_extras_q4_1; + std::vector temp_tensor_extras_q4_1_in_use; std::vector temp_tensor_extras_mxfp4; std::vector temp_tensor_extras_mxfp4_in_use; std::vector temp_tensor_extras_q8_0; @@ -4042,6 +4220,75 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, return; } + if (tensor->type == GGML_TYPE_Q4_1) { + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); + + // Allocate the new extra and create aliases from the original. + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_q4_1 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q4_1(); + + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_m = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + GGML_ASSERT(size_d + size_m + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + CL_CHECK(clEnqueueWriteBuffer( + queue, data_device, CL_TRUE, 0, + ggml_nbytes(tensor), data, 0, NULL, NULL)); + + cl_buffer_region region; + + // The original tensor memory is divided into scales and quants, i.e., + // we first store scales, mins, then quants. + // Create subbuffer for scales. + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = size_d; + extra->d = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + auto previous_origin = region.origin; + + // Create subbuffer for mins. + region.origin = align_to(previous_origin + size_d, backend_ctx->alignment); + region.size = size_m; + extra->m = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; + + // Create subbuffer for quants. + region.origin = align_to(previous_origin + size_m, backend_ctx->alignment); + region.size = size_q; + extra->q = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->m)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; + + return; + } if (tensor->type == GGML_TYPE_MXFP4) { ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); @@ -4544,7 +4791,35 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, size, data, 0, NULL, NULL)); CL_CHECK(clReleaseMemObject(data_device)); return; - } else if (tensor->type == GGML_TYPE_MXFP4) { + } + if (tensor->type == GGML_TYPE_Q4_1) { + ggml_tensor_extra_cl_q4_1 * extra = (ggml_tensor_extra_cl_q4_1 *)tensor->extra; + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_1; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->m)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &data_device)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } + if (tensor->type == GGML_TYPE_MXFP4) { ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *)tensor->extra; cl_int err; @@ -6117,7 +6392,6 @@ static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const GGML_UNUSED(src1); GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); - GGML_ASSERT(ggml_is_contiguous(src0)); ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -6140,7 +6414,14 @@ static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const const cl_ulong nb2 = dst->nb[2]; const cl_ulong nb3 = dst->nb[3]; - cl_kernel kernel = backend_ctx->kernel_mean_f32; + cl_kernel kernel; + + const bool is_c4 = ne00 % 4 == 0; + if (is_c4) { + kernel = backend_ctx->kernel_mean_f32_4; + } else { + kernel = backend_ctx->kernel_mean_f32; + } CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); @@ -6157,7 +6438,7 @@ static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb2)); CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb3)); - size_t global_work_size[] = {(size_t)ne01, (size_t)ne02, (size_t)ne03}; + size_t global_work_size[] = {64 * (size_t)ne01, (size_t)ne02, (size_t)ne03}; size_t local_work_size[] = {(size_t)64, 1, 1}; backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); @@ -7105,18 +7386,8 @@ static void ggml_cl_expm1(ggml_backend_t backend, const ggml_tensor * src0, cons ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; - cl_ulong offset0_abs = extra0->offset + src0->view_offs; - cl_ulong offsetd_abs = extrad->offset + dst->view_offs; - - cl_kernel kernel; - if (dst->type == GGML_TYPE_F32) { - kernel = backend_ctx->kernel_expm1_f32_nd; - } else if (dst->type == GGML_TYPE_F16) { - kernel = backend_ctx->kernel_expm1_f16_nd; - } else { - GGML_ASSERT(false && "Unsupported type for ggml_cl_expm1"); - } - GGML_ASSERT(kernel != nullptr); + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; const int ne00 = src0->ne[0]; const int ne01 = src0->ne[1]; @@ -7128,70 +7399,74 @@ static void ggml_cl_expm1(ggml_backend_t backend, const ggml_tensor * src0, cons const cl_ulong nb02 = src0->nb[2]; const cl_ulong nb03 = src0->nb[3]; - const int ne10 = dst->ne[0]; - const int ne11 = dst->ne[1]; - const int ne12 = dst->ne[2]; - const int ne13 = dst->ne[3]; + const cl_ulong nb0 = dst->nb[0]; + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; - const cl_ulong nb10 = dst->nb[0]; - const cl_ulong nb11 = dst->nb[1]; - const cl_ulong nb12 = dst->nb[2]; - const cl_ulong nb13 = dst->nb[3]; + cl_kernel kernel; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_abs)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd_abs)); - - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03)); - - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb10)); - CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb11)); - CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb12)); - CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb13)); - - size_t global_work_size[3]; - if (ne10 == 0 || ne11 == 0 || ne12 == 0 || ne13 == 0) { // Handle case of 0 elements - return; - } - global_work_size[0] = (size_t)ne10; - global_work_size[1] = (size_t)ne11; - global_work_size[2] = (size_t)ne12; - - size_t lws0 = 16, lws1 = 4, lws2 = 1; - if (ne10 < 16) lws0 = ne10; - if (ne11 < 4) lws1 = ne11; - if (ne12 < 1) lws2 = ne12 > 0 ? ne12 : 1; - - while (lws0 * lws1 * lws2 > 256 && lws0 > 1) lws0 /= 2; - while (lws0 * lws1 * lws2 > 256 && lws1 > 1) lws1 /= 2; - while (lws0 * lws1 * lws2 > 256 && lws2 > 1) lws2 /= 2; - - - size_t local_work_size[] = {lws0, lws1, lws2}; - - size_t* local_work_size_ptr = local_work_size; - if (!backend_ctx->non_uniform_workgroups) { - if (global_work_size[0] % local_work_size[0] != 0 || - global_work_size[1] % local_work_size[1] != 0 || - global_work_size[2] % local_work_size[2] != 0) { - local_work_size_ptr = NULL; + if (ggml_is_contiguous(src0)) { + // Handle contiguous input + int n = ggml_nelements(dst); + if (n % 4 == 0) { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_expm1_f32_4; + } else { + kernel = backend_ctx->kernel_expm1_f16_4; + } + n /= 4; + } else { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_expm1_f32; + } else { + kernel = backend_ctx->kernel_expm1_f16; + } } - } - if (global_work_size[0] == 0 || global_work_size[1] == 0 || global_work_size[2] == 0) return; - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; + } + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); + } else { + // Handle non-contiguous input + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_expm1_f32_nc; + } else { + kernel = backend_ctx->kernel_expm1_f16_nc; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb3)); + + int nth = 64; + + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + } } static void ggml_cl_softplus(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) { @@ -7207,18 +7482,8 @@ static void ggml_cl_softplus(ggml_backend_t backend, const ggml_tensor * src0, c ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra; ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra; - cl_ulong offset0_abs = extra0->offset + src0->view_offs; - cl_ulong offsetd_abs = extrad->offset + dst->view_offs; - - cl_kernel kernel; - if (dst->type == GGML_TYPE_F32) { - kernel = backend_ctx->kernel_softplus_f32_nd; - } else if (dst->type == GGML_TYPE_F16) { - kernel = backend_ctx->kernel_softplus_f16_nd; - } else { - GGML_ASSERT(false && "Unsupported type for ggml_cl_softplus"); - } - GGML_ASSERT(kernel != nullptr); + cl_ulong offset0 = extra0->offset + src0->view_offs; + cl_ulong offsetd = extrad->offset + dst->view_offs; const int ne00 = src0->ne[0]; const int ne01 = src0->ne[1]; @@ -7230,70 +7495,74 @@ static void ggml_cl_softplus(ggml_backend_t backend, const ggml_tensor * src0, c const cl_ulong nb02 = src0->nb[2]; const cl_ulong nb03 = src0->nb[3]; - const int ne10 = dst->ne[0]; - const int ne11 = dst->ne[1]; - const int ne12 = dst->ne[2]; - const int ne13 = dst->ne[3]; + const cl_ulong nb0 = dst->nb[0]; + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; - const cl_ulong nb10 = dst->nb[0]; - const cl_ulong nb11 = dst->nb[1]; - const cl_ulong nb12 = dst->nb[2]; - const cl_ulong nb13 = dst->nb[3]; + cl_kernel kernel; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0_abs)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd_abs)); - - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb00)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb01)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong),&nb02)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong),&nb03)); - - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne11)); - CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne12)); - CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne13)); - CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong),&nb10)); - CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong),&nb11)); - CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong),&nb12)); - CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong),&nb13)); - - size_t global_work_size[3]; - if (ne10 == 0 || ne11 == 0 || ne12 == 0 || ne13 == 0) { // Handle case of 0 elements - return; - } - global_work_size[0] = (size_t)ne10; - global_work_size[1] = (size_t)ne11; - global_work_size[2] = (size_t)ne12; - - size_t lws0 = 16, lws1 = 4, lws2 = 1; - if (ne10 < 16) lws0 = ne10; - if (ne11 < 4) lws1 = ne11; - if (ne12 < 1) lws2 = ne12 > 0 ? ne12 : 1; - - while (lws0 * lws1 * lws2 > 256 && lws0 > 1) lws0 /= 2; - while (lws0 * lws1 * lws2 > 256 && lws1 > 1) lws1 /= 2; - while (lws0 * lws1 * lws2 > 256 && lws2 > 1) lws2 /= 2; - - - size_t local_work_size[] = {lws0, lws1, lws2}; - - size_t* local_work_size_ptr = local_work_size; - if (!backend_ctx->non_uniform_workgroups) { - if (global_work_size[0] % local_work_size[0] != 0 || - global_work_size[1] % local_work_size[1] != 0 || - global_work_size[2] % local_work_size[2] != 0) { - local_work_size_ptr = NULL; + if (ggml_is_contiguous(src0)) { + // Handle contiguous input + int n = ggml_nelements(dst); + if (n % 4 == 0) { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_softplus_f32_4; + } else { + kernel = backend_ctx->kernel_softplus_f16_4; + } + n /= 4; + } else { + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_softplus_f32; + } else { + kernel = backend_ctx->kernel_softplus_f16; + } } - } - if (global_work_size[0] == 0 || global_work_size[1] == 0 || global_work_size[2] == 0) return; - backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + + size_t global_work_size[] = {(size_t)n, 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + size_t * local_work_size_ptr = local_work_size; + if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) { + local_work_size_ptr = nullptr; + } + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst); + } else { + // Handle non-contiguous input + if (src0->type == GGML_TYPE_F32) { + kernel = backend_ctx->kernel_softplus_f32_nc; + } else { + kernel = backend_ctx->kernel_softplus_f16_nc; + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &nb00)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb0)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb3)); + + int nth = 64; + + size_t global_work_size[] = {(size_t)ne01*nth, (size_t)ne02, (size_t)ne03}; + size_t local_work_size[] = {(size_t)nth, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + } } static void ggml_cl_repeat(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1_shape_def, ggml_tensor * dst) { @@ -8372,6 +8641,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co #ifdef GGML_OPENCL_SOA_Q ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra; + ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra; ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra; @@ -8885,6 +9155,91 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); return; } + case GGML_TYPE_Q4_0: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_q4_0_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_0->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } + case GGML_TYPE_Q4_1: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_q4_1_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_1->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_1->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_1->m)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } case GGML_TYPE_Q8_0: { if (ne11 < 32) { break; @@ -8927,6 +9282,50 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); return; } + case GGML_TYPE_Q6_K: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_q6_k_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q6_K->ql)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q6_K->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q6_K->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q6_K->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } default: break; } @@ -9181,7 +9580,71 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); #endif // GGML_OPENCL_SOA_Q break; - case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1: { +#ifdef GGML_OPENCL_SOA_Q + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + kernel = backend_ctx->kernel_mul_mv_q4_1_f32_flat; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_1->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_1->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_1->m)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &r3)); +#else + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + kernel = backend_ctx->kernel_mul_mv_q4_1_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); +#endif // GGML_OPENCL_SOA_Q + break; + } case GGML_TYPE_Q8_0: { #ifdef GGML_OPENCL_SOA_Q kernel = backend_ctx->kernel_mul_mv_q8_0_f32_flat; @@ -9262,7 +9725,42 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co } case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: - case GGML_TYPE_Q4_K: + case GGML_TYPE_Q4_K: { + kernel = backend_ctx->kernel_mul_mv_q4_K_f32; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(int), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3)); + break; + } case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: #ifdef GGML_OPENCL_SOA_Q @@ -9424,7 +9922,10 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } else if (src0t == GGML_TYPE_Q4_K) { - GGML_ASSERT(false && "not implemented"); + size_t global_work_size[] = {(size_t)(ne01+ndst*nth1-1)/(ndst*nth1)*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } else if (src0t == GGML_TYPE_Q3_K) { GGML_ASSERT(false && "not implemented"); } else if (src0t == GGML_TYPE_Q5_K) { @@ -10573,7 +11074,6 @@ static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, c GGML_UNUSED(src1); GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type)); - GGML_ASSERT(ggml_is_contiguous(src0)); ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -10596,7 +11096,14 @@ static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, c const cl_ulong nb2 = dst->nb[2]; const cl_ulong nb3 = dst->nb[3]; - cl_kernel kernel = backend_ctx->kernel_sum_rows_f32; + cl_kernel kernel; + + const bool is_c4 = ne00 % 4 == 0; + if (is_c4) { + kernel = backend_ctx->kernel_sum_rows_f32_4; + } else { + kernel = backend_ctx->kernel_sum_rows_f32; + } CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); @@ -10613,7 +11120,7 @@ static void ggml_cl_sum_rows(ggml_backend_t backend, const ggml_tensor * src0, c CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb2)); CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb3)); - size_t global_work_size[] = {(size_t)ne01, (size_t)ne02, (size_t)ne03}; + size_t global_work_size[] = {64 * (size_t)ne01, (size_t)ne02, (size_t)ne03}; size_t local_work_size[] = {(size_t)64, 1, 1}; backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 9fb434713d..2c244ce321 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -46,6 +46,15 @@ struct block_q4_0 uint8_t qs[QK4_0 / 2]; }; +//------------------------------------------------------------------------------ +// block_q4_1 +//------------------------------------------------------------------------------ +struct block_q4_1 { + half d; // delta + half m; // min + uchar qs[QK4_1 / 2]; // nibbles / quants +}; + //------------------------------------------------------------------------------ // block_q6_K //------------------------------------------------------------------------------ @@ -148,6 +157,48 @@ kernel void kernel_restore_block_q4_0_noshuffle( } } +//------------------------------------------------------------------------------ +// kernel_convert_block_q4_1 +// Convert the block_q4_1 format to 2 separate arrays (AOS -> SOA). +// This kernel does not deshuffle the bits. +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_q4_1( + global struct block_q4_1 * src0, + global uchar * dst_q, + global half * dst_d, + global half * dst_m +) { + global struct block_q4_1 * b = (global struct block_q4_1 *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK4_1/2*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + global half * m = (global half *) dst_m + get_global_id(0); + + *d = b->d; + *m = b->m; + + for (int i = 0; i < QK4_1/2; ++i) { + q[i] = b->qs[i]; + } +} + +kernel void kernel_restore_block_q4_1( + global uchar * src_q, + global half * src_d, + global half * src_m, + global struct block_q4_1 * dst +) { + global struct block_q4_1 * b = (global struct block_q4_1 *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK4_1/2*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + global half * m = (global half *) src_m + get_global_id(0); + + b->d = *d; + b->m = *m; + for (int i = 0; i < QK4_1/2; ++i) { + b->qs[i] = q[i]; + } +} + //------------------------------------------------------------------------------ // block_mxfp4 //------------------------------------------------------------------------------ diff --git a/ggml/src/ggml-opencl/kernels/expm1.cl b/ggml/src/ggml-opencl/kernels/expm1.cl index 126298a2cd..05442ac204 100644 --- a/ggml/src/ggml-opencl/kernels/expm1.cl +++ b/ggml/src/ggml-opencl/kernels/expm1.cl @@ -3,80 +3,111 @@ //------------------------------------------------------------------------------ // expm1 //------------------------------------------------------------------------------ -kernel void kernel_expm1_f32_nd( - global void * p_src0_base, - ulong off_src0_abs, - global void * p_dst_base, - ulong off_dst_abs, - int ne00, - int ne01, - int ne02, - int ne03, + +kernel void kernel_expm1_f32( + global const float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0f; +} + +kernel void kernel_expm1_f32_4( + global const float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0f; +} + +kernel void kernel_expm1_f16( + global const half * src0, + ulong offset0, + global half * dst, + ulong offsetd +) { + src0 = (global half*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0h; +} + +kernel void kernel_expm1_f16_4( + global const half4 * src0, + ulong offset0, + global half4 * dst, + ulong offsetd +) { + src0 = (global half4*)((global char*)src0 + offset0); + dst = (global half4*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = exp(src0[get_global_id(0)]) - 1.0h; +} + +kernel void kernel_expm1_f32_nc( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, ulong nb00, ulong nb01, ulong nb02, ulong nb03, - int ne10, - int ne11, - int ne12, - int ne13, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13 + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 ) { - int i0 = get_global_id(0); - int i1 = get_global_id(1); - int i2 = get_global_id(2); + src0 = src0 + offset0; + dst = dst + offsetd; - if (i0 < ne10 && i1 < ne11 && i2 < ne12) { - for (int i3 = 0; i3 < ne13; ++i3) { - ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03; - global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor); + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); - ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13; - global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor); + for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) { + global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - *dst_val_ptr = exp(*src_val_ptr) - 1; - } + *y = exp(*x) - 1.0f; } } -kernel void kernel_expm1_f16_nd( - global void * p_src0_base, - ulong off_src0_abs, - global void * p_dst_base, - ulong off_dst_abs, - int ne00, - int ne01, - int ne02, - int ne03, +kernel void kernel_expm1_f16_nc( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, ulong nb00, ulong nb01, ulong nb02, ulong nb03, - int ne10, - int ne11, - int ne12, - int ne13, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13 + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 ) { - int i0 = get_global_id(0); - int i1 = get_global_id(1); - int i2 = get_global_id(2); + src0 = src0 + offset0; + dst = dst + offsetd; - if (i0 < ne10 && i1 < ne11 && i2 < ne12) { - for (int i3 = 0; i3 < ne13; ++i3) { - ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03; - global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor); + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); - ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13; - global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor); + for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) { + global const half * x = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * y = (global half *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - *dst_val_ptr = exp(*src_val_ptr) - 1; - } + *y = exp(*x) - 1.0f; } } diff --git a/ggml/src/ggml-opencl/kernels/mean.cl b/ggml/src/ggml-opencl/kernels/mean.cl index 5c3e8bcd86..7c7e0a587e 100644 --- a/ggml/src/ggml-opencl/kernels/mean.cl +++ b/ggml/src/ggml-opencl/kernels/mean.cl @@ -1,8 +1,13 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +// Most devices have max workgroup size of 1024, so this is enough for subgroup +// sizes of 16, 32, 64 and 128. Increase this value for smaller subgroups sizes +#define MAX_SUBGROUPS 64 kernel void kernel_mean_f32( - global float * src0, + global char * src0, ulong offset0, - global float * dst, + global char * dst, ulong offsetd, int ne00, int ne01, @@ -15,25 +20,121 @@ kernel void kernel_mean_f32( ulong nb2, ulong nb3 ) { - src0 = (global float *)((global char *)src0 + offset0); - dst = (global float *)((global char *)dst + offsetd); + src0 = src0 + offset0; + dst = dst + offsetd; - int i3 = get_global_id(2); - int i2 = get_global_id(1); - int i1 = get_global_id(0); + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + const int lid = get_local_id(0); + const int lsize = get_local_size(0); + + const uint sg_size = get_sub_group_size(); + const uint sg_id = get_sub_group_id(); + const uint sg_lid = get_sub_group_local_id(); + + __local float lmem[MAX_SUBGROUPS]; if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { return; } - global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03); - global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3); - - float row_sum = 0; - - for (int i0 = 0; i0 < ne00; i0++) { - row_sum += src_row[i0]; + if(sg_id == 0){ + lmem[sg_lid] = 0.0f; } - dst_row[0] = row_sum / ne00; + global float * src_row = (global float *) (src0 + i1*nb01 + i2*nb02 + i3*nb03); + global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3); + + float sumf = 0.0f; + + for (int i0 = lid; i0 < ne00; i0 += lsize) { + sumf += src_row[i0]; + } + + sumf = sub_group_reduce_add(sumf); + + barrier(CLK_LOCAL_MEM_FENCE); + + if(sg_lid == 0){ + lmem[sg_id] = sumf; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + sumf = lmem[sg_lid]; + sumf = sub_group_reduce_add(sumf); + + if (lid == 0) { + dst_row[0] = sumf / ne00; + } +} + +kernel void kernel_mean_f32_4( + global char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + dst = dst + offsetd; + + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + const int lid = get_local_id(0); + const int lsize = get_local_size(0); + + const uint sg_size = get_sub_group_size(); + const uint sg_id = get_sub_group_id(); + const uint sg_lid = get_sub_group_local_id(); + + __local float lmem[MAX_SUBGROUPS]; + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + if(sg_id == 0){ + lmem[sg_lid] = 0.0f; + } + + global float4 * src_row = (global float4 *) (src0 + i1*nb01 + i2*nb02 + i3*nb03); + global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3); + + float4 sum_vec = (float4)0.0f; + + for (int i0 = lid; i0 < ne00 / 4; i0 += lsize) { + sum_vec += src_row[i0]; + } + + float sumf = dot(sum_vec, (float4)(1.0f)); + sumf = sub_group_reduce_add(sumf); + + barrier(CLK_LOCAL_MEM_FENCE); + + if(sg_lid == 0){ + lmem[sg_id] = sumf; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + sumf = lmem[sg_lid]; + sumf = sub_group_reduce_add(sumf); + + if (lid == 0) { + dst_row[0] = sumf / ne00; + } } diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl new file mode 100644 index 0000000000..4100e3080a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl @@ -0,0 +1,163 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 8 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q4_0_f32_l4_lm( + global uchar4 * src0_q, + global half * src0_d, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + int ib = idx / 4; + int iqs = idx % 4; + + float d = (float)src0_d[ib]; + global uchar4 * qs = src0_q + ib*4 + iqs; + uchar4 q = *qs; + float4 v1 = (convert_float4((uchar4)((q.s0 )&0x0F, (q.s1 )&0x0F, (q.s2 )&0x0F, (q.s3 )&0x0F)) - 8.0f)*d; + float4 v2 = (convert_float4((uchar4)((q.s0>>4)&0x0F, (q.s1>>4)&0x0F, (q.s2>>4)&0x0F, (q.s3>>4)&0x0F)) - 8.0f)*d; + + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3; + } else { + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl new file mode 100644 index 0000000000..d0d2f08361 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl @@ -0,0 +1,165 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 8 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q4_1_f32_l4_lm( + global uchar4 * src0_q, + global half * src0_d, + global half * src0_m, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + int ib = idx / 4; + int iqs = idx % 4; + + float d = (float)src0_d[ib]; + float m = (float)src0_m[ib]; + global uchar4 * qs = src0_q + ib*4 + iqs; + uchar4 q = *qs; + float4 v1 = (convert_float4((uchar4)((q.s0 )&0x0F, (q.s1 )&0x0F, (q.s2 )&0x0F, (q.s3 )&0x0F)))*d + m; + float4 v2 = (convert_float4((uchar4)((q.s0>>4)&0x0F, (q.s1>>4)&0x0F, (q.s2>>4)&0x0F, (q.s3>>4)&0x0F)))*d + m; + + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3; + } else { + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl new file mode 100644 index 0000000000..3602c92fef --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl @@ -0,0 +1,158 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 2 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q6_k_f32_l4_lm( + global uchar * src0_ql, + global uchar * src0_qh, + global char * src0_s, + global half * src0_d, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + + int ib = idx / 128; // 2 values per idx + int iqs = idx % 128; // 0..127 + + int n = iqs / 64; // 0,1 + int b = (iqs % 64) / 32; // 0,1 + int is_b = (iqs % 16) / 8; // 0,1 + int qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 + int is = 8 * n + qhshift + is_b; // 0..15 + int qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126 + int qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 + + float dscale = (float)src0_d[ib] * (float)src0_s[ib*16 + is]; + + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 0] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 0] >> qhshift) & 3) << 4)) - 32); + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 1] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 1] >> qhshift) & 3) << 4)) - 32); + } else { + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl new file mode 100644 index 0000000000..6fe828f20e --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl @@ -0,0 +1,219 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_1 32 + +struct block_q4_1 { + half d; // delta + half m; // min + uchar qs[QK4_1 / 2]; // nibbles / quants +}; + +inline float block_q4_1_dot_y( + global const struct block_q4_1 * qb_curr, + float sumy, + float16 yl, + int il +) { + float d = qb_curr->d; + float m = qb_curr->m; + + float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f); + + global const ushort * qs = ((global const ushort *) qb_curr + 2 + il/2); + + acc.s0 += yl.s0 * (qs[0] & 0x000F); + acc.s0 += yl.s1 * (qs[0] & 0x0F00); + acc.s0 += yl.s8 * (qs[0] & 0x00F0); + acc.s3 += yl.s9 * (qs[0] & 0xF000); + + acc.s0 += yl.s2 * (qs[1] & 0x000F); + acc.s1 += yl.s3 * (qs[1] & 0x0F00); + acc.s2 += yl.sa * (qs[1] & 0x00F0); + acc.s3 += yl.sb * (qs[1] & 0xF000); + + acc.s0 += yl.s4 * (qs[2] & 0x000F); + acc.s1 += yl.s5 * (qs[2] & 0x0F00); + acc.s2 += yl.sc * (qs[2] & 0x00F0); + acc.s3 += yl.sd * (qs[2] & 0xF000); + + acc.s0 += yl.s6 * (qs[3] & 0x000F); + acc.s1 += yl.s7 * (qs[3] & 0x0F00); + acc.s2 += yl.se * (qs[3] & 0x00F0); + acc.s3 += yl.sf * (qs[3] & 0xF000); + + return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3) + sumy * m; +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each subgroup works on 4 rows +#define N_SIMDGROUP 1 // number of subgroups in a thread group +#define N_SIMDWIDTH 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32( + global void * src0, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK4_1; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global struct block_q4_1 * x = (global struct block_q4_1 *) src0 + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK4_1 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q4_1_dot_y(x+ib+0*nb, sumy, yl, il); + sumf.s1 += block_q4_1_dot_y(x+ib+1*nb, sumy, yl, il); + sumf.s2 += block_q4_1_dot_y(x+ib+2*nb, sumy, yl, il); + sumf.s3 += block_q4_1_dot_y(x+ib+3*nb, sumy, yl, il); + + yb += QK4_1 * (N_SIMDWIDTH/2); + } + + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q4_1_f32( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl new file mode 100644 index 0000000000..d7c4645d67 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl @@ -0,0 +1,229 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_1 32 + +struct block_q4_1 { + half d; // delta + half m; // min + uchar qs[QK4_1 / 2]; // nibbles / quants +}; + +inline float block_q4_1_dot_y_flat( + global const uchar * x, + global const half * dh, + global const half * mh, + float sumy, + float16 yl, + int il +) { + float d = *dh; + float m = *mh; + global const ushort * qs = ((global const ushort *) x + il/2); + + float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f); + + acc.s0 += yl.s0 * (qs[0] & 0x000F); + acc.s0 += yl.s1 * (qs[0] & 0x0F00); + acc.s0 += yl.s8 * (qs[0] & 0x00F0); + acc.s3 += yl.s9 * (qs[0] & 0xF000); + + acc.s0 += yl.s2 * (qs[1] & 0x000F); + acc.s1 += yl.s3 * (qs[1] & 0x0F00); + acc.s2 += yl.sa * (qs[1] & 0x00F0); + acc.s3 += yl.sb * (qs[1] & 0xF000); + + acc.s0 += yl.s4 * (qs[2] & 0x000F); + acc.s1 += yl.s5 * (qs[2] & 0x0F00); + acc.s2 += yl.sc * (qs[2] & 0x00F0); + acc.s3 += yl.sd * (qs[2] & 0xF000); + + acc.s0 += yl.s6 * (qs[3] & 0x000F); + acc.s1 += yl.s7 * (qs[3] & 0x0F00); + acc.s2 += yl.se * (qs[3] & 0x00F0); + acc.s3 += yl.sf * (qs[3] & 0xF000); + + return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3) + sumy * m; +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each subgroup works on 4 rows +#define N_SIMDGROUP 1 // number of subgroups in a thread group +#define N_SIMDWIDTH 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32_flat( + global void * src0_q, + global void * src0_d, + global void * src0_m, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK4_1; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + // The number of scales/mins is the same as the number of blocks. + ulong offset0_dm = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)); + // Each block contains QK4_1/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_1/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_dm; + global half * m = (global half *) src0_m + offset0_dm; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK4_1 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 0*nb*QK4_1/2, d + ib + 0*nb, m + ib + 0*nb, sumy, yl, il); + sumf.s1 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 1*nb*QK4_1/2, d + ib + 1*nb, m + ib + 1*nb, sumy, yl, il); + sumf.s2 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 2*nb*QK4_1/2, d + ib + 2*nb, m + ib + 2*nb, sumy, yl, il); + sumf.s3 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 3*nb*QK4_1/2, d + ib + 3*nb, m + ib + 3*nb, sumy, yl, il); + + yb += QK4_1 * (N_SIMDWIDTH/2); + } + + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q4_1_f32_flat( + global void * src0_q, + global void * src0_d, + global void * src0_m, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32_flat(src0_q, src0_d, src0_m, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl new file mode 100644 index 0000000000..71ab989821 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl @@ -0,0 +1,180 @@ +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +//------------------------------------------------------------------------------ +// block_q4_K +//------------------------------------------------------------------------------ +#define QK_K 256 +#define K_SCALE_SIZE 12 + +// 8 blocks of 32 elements each +// weight is represented as x = a * q + b +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + + uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uchar qs[QK_K/2]; // 4-bit quants +} block_q4_K; + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // number of rows each SIMD group works on +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // SIMD group size +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +#undef BLOCK_STRIDE +// number of (super) blocks each subgroup processes +// each thread in a subgroup processes a block (32 weights) +#define BLOCK_STRIDE (N_SIMDWIDTH/8) + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q4_K_f32( + global char * src0, + int offset0, + global char * src1, + int offset1, + global char * dst, + int offsetd, + int ne00, + int ne01, + ulong nb01, + ulong nb02, + ulong nb03, + int ne12, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + ushort kmask1 = 0x3f3f; + ushort kmask2 = 0x0f0f; + ushort kmask3 = 0xc0c0; + + int ix = get_sub_group_local_id()/8; // super block index + int it = get_sub_group_local_id()%8; // block index (inside super block) + int iq = it/4; // 0 or 1 - first or second half of the super block + int ir = it%4; // 0...3 - block index in the half super block + + int nb = ne00/QK_K; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + int offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + int offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global block_q4_K * x = (global block_q4_K *) (src0 + offset_src0); + global float * y = (global float *) (src1 + offset_src1); + + float yl[16]; + float yh[16]; + float sumf[N_DST] = {0.f}; + float all_sum; + + global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; + + ushort sc16[4]; + uchar * sc8 = (uchar *)sc16; + + for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) { + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+0] = y4[i+0]; + sumy.s0 += yl[i+0]; + + yl[i+8] = y4[i+32]; + sumy.s1 += yl[i+8]; + + yh[i+0] = y4[i+128]; + sumy.s2 += yh[i+0]; + + yh[i+8] = y4[i+160]; + sumy.s3 += yh[i+8]; + } + + global ushort * sc = (global ushort *)x[ib].scales + iq; + global ushort * q1 = (global ushort *)x[ib].qs + 16 * iq + 4 * ir; + global half * dh = &x[ib].d; + + for (int row = 0; row < N_DST; row++) { + sc16[0] = sc[0] & kmask1; + sc16[1] = sc[2] & kmask1; + sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); + sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); + + global ushort * q2 = q1 + 32; + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1.s0 += yl[i+0] * (q1[i/2] & 0x000F); + acc1.s1 += yl[i+1] * (q1[i/2] & 0x0F00); + acc1.s2 += yl[i+8] * (q1[i/2] & 0x00F0); + acc1.s3 += yl[i+9] * (q1[i/2] & 0xF000); + acc2.s0 += yh[i+0] * (q2[i/2] & 0x000F); + acc2.s1 += yh[i+1] * (q2[i/2] & 0x0F00); + acc2.s2 += yh[i+8] * (q2[i/2] & 0x00F0); + acc2.s3 += yh[i+9] * (q2[i/2] & 0xF000); + } + + float dall = dh[0]; + float dmin = dh[1]; + sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] + + (acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f + + (acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] + + (acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) - + dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]); + + q1 += nb01/2; + sc += nb01/2; + dh += nb01/2; + } + + y4 += BLOCK_STRIDE * QK_K; + } + + global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0; + + for (int row = 0; row < N_DST; ++row) { + all_sum = sub_group_reduce_add(sumf[row]); + if (first_row + row < ne01) { + if (get_sub_group_local_id() == 0) { + dst_f32[first_row + row] = all_sum; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/softplus.cl b/ggml/src/ggml-opencl/kernels/softplus.cl index 033766e2e0..6f8b747416 100644 --- a/ggml/src/ggml-opencl/kernels/softplus.cl +++ b/ggml/src/ggml-opencl/kernels/softplus.cl @@ -3,86 +3,114 @@ //------------------------------------------------------------------------------ // softplus //------------------------------------------------------------------------------ -inline float softplus_f32(float x){ - float ax = fabs(x); - float m = fmax(x, 0.0f); - return log1p(exp(-ax)) + m; + +kernel void kernel_softplus_f32( + global const float * src0, + ulong offset0, + global float * dst, + ulong offsetd +) { + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = (src0[get_global_id(0)] > 20.0f) ? src0[get_global_id(0)] : log(1.0f + exp(src0[get_global_id(0)])); } -kernel void kernel_softplus_f32_nd( - global void * p_src0_base, - ulong off_src0_abs, - global void * p_dst_base, - ulong off_dst_abs, - int ne00, - int ne01, - int ne02, - int ne03, +kernel void kernel_softplus_f32_4( + global const float4 * src0, + ulong offset0, + global float4 * dst, + ulong offsetd +) { + src0 = (global float4*)((global char*)src0 + offset0); + dst = (global float4*)((global char*)dst + offsetd); + + dst[get_global_id(0)] = (src0[get_global_id(0)] > 20.0f) ? src0[get_global_id(0)] : log(1.0f + exp(src0[get_global_id(0)])); +} + +kernel void kernel_softplus_f16( + global const half * src0, + ulong offset0, + global half * dst, + ulong offsetd +) { + src0 = (global half*)((global char*)src0 + offset0); + dst = (global half*)((global char*)dst + offsetd); + + const float x = convert_float(src0[get_global_id(0)]); + dst[get_global_id(0)] = convert_half_rte((x > 20.0f) ? x : log(1.0f + exp(x))); +} + +kernel void kernel_softplus_f16_4( + global const half4 * src0, + ulong offset0, + global half4 * dst, + ulong offsetd +) { + src0 = (global half4*)((global char*)src0 + offset0); + dst = (global half4*)((global char*)dst + offsetd); + + const float4 x = convert_float4(src0[get_global_id(0)]); + dst[get_global_id(0)] = convert_half4_rte((x > 20.0f) ? x : log(1.0f + exp(x))); +} + +kernel void kernel_softplus_f32_nc( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, ulong nb00, ulong nb01, ulong nb02, ulong nb03, - int ne10, - int ne11, - int ne12, - int ne13, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13 + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 ) { - int i0 = get_global_id(0); - int i1 = get_global_id(1); - int i2 = get_global_id(2); + src0 = src0 + offset0; + dst = dst + offsetd; - if (i0 < ne10 && i1 < ne11 && i2 < ne12) { - for (int i3 = 0; i3 < ne13; ++i3) { - ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03; - global const float *src_val_ptr = (global const float *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor); + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); - ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13; - global float *dst_val_ptr = (global float *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor); + for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) { + global const float * x = (global const float *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global float * y = (global float *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - *dst_val_ptr = softplus_f32(*src_val_ptr); - } + *y = (*x > 20.0f) ? *x : log(1.0f + exp(*x)); } } -kernel void kernel_softplus_f16_nd( - global void * p_src0_base, - ulong off_src0_abs, - global void * p_dst_base, - ulong off_dst_abs, - int ne00, - int ne01, - int ne02, - int ne03, +kernel void kernel_softplus_f16_nc( + global const char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, ulong nb00, ulong nb01, ulong nb02, ulong nb03, - int ne10, - int ne11, - int ne12, - int ne13, - ulong nb10, - ulong nb11, - ulong nb12, - ulong nb13 + ulong nb0, + ulong nb1, + ulong nb2, + ulong nb3 ) { - int i0 = get_global_id(0); - int i1 = get_global_id(1); - int i2 = get_global_id(2); + src0 = src0 + offset0; + dst = dst + offsetd; - if (i0 < ne10 && i1 < ne11 && i2 < ne12) { - for (int i3 = 0; i3 < ne13; ++i3) { - ulong src_offset_in_tensor = (ulong)i0*nb00 + (ulong)i1*nb01 + (ulong)i2*nb02 + (ulong)i3*nb03; - global const half *src_val_ptr = (global const half *)((global char *)p_src0_base + off_src0_abs + src_offset_in_tensor); + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); - ulong dst_offset_in_tensor = (ulong)i0*nb10 + (ulong)i1*nb11 + (ulong)i2*nb12 + (ulong)i3*nb13; - global half *dst_val_ptr = (global half *)((global char *)p_dst_base + off_dst_abs + dst_offset_in_tensor); + for (int i0 = get_local_id(0); i0 < ne00; i0 += get_local_size(0)) { + global const half * hx = (global const half *)(src0 + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); + global half * hy = (global half *)(dst + i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0); - *dst_val_ptr = (half)(softplus_f32((float)(*src_val_ptr))); - } + const float x = convert_float(*hx); + *hy = convert_half_rte((x > 20.0f) ? x : log(1.0f + exp(x))); } } diff --git a/ggml/src/ggml-opencl/kernels/sum_rows.cl b/ggml/src/ggml-opencl/kernels/sum_rows.cl index c5f7c570f9..84630aa8a3 100644 --- a/ggml/src/ggml-opencl/kernels/sum_rows.cl +++ b/ggml/src/ggml-opencl/kernels/sum_rows.cl @@ -1,8 +1,13 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +// Most devices have max workgroup size of 1024, so this is enough for subgroup +// sizes of 16, 32, 64 and 128. Increase this value for smaller subgroups sizes +#define MAX_SUBGROUPS 64 kernel void kernel_sum_rows_f32( - global float * src0, + global char * src0, ulong offset0, - global float * dst, + global char * dst, ulong offsetd, int ne00, int ne01, @@ -15,25 +20,121 @@ kernel void kernel_sum_rows_f32( ulong nb2, ulong nb3 ) { - src0 = (global float *)((global char *)src0 + offset0); - dst = (global float *)((global char *)dst + offsetd); + src0 = src0 + offset0; + dst = dst + offsetd; - int i3 = get_global_id(2); - int i2 = get_global_id(1); - int i1 = get_global_id(0); + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + const int lid = get_local_id(0); + const int lsize = get_local_size(0); + + const uint sg_size = get_sub_group_size(); + const uint sg_id = get_sub_group_id(); + const uint sg_lid = get_sub_group_local_id(); + + __local float lmem[MAX_SUBGROUPS]; if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { return; } - global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03); - global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3); - - float row_sum = 0; - - for (int i0 = 0; i0 < ne00; i0++) { - row_sum += src_row[i0]; + if(sg_id == 0){ + lmem[sg_lid] = 0.0f; } - dst_row[0] = row_sum; + global float * src_row = (global float *) (src0 + i1*nb01 + i2*nb02 + i3*nb03); + global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3); + + float sumf = 0.0f; + + for (int i0 = lid; i0 < ne00; i0 += lsize) { + sumf += src_row[i0]; + } + + sumf = sub_group_reduce_add(sumf); + + barrier(CLK_LOCAL_MEM_FENCE); + + if(sg_lid == 0){ + lmem[sg_id] = sumf; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + sumf = lmem[sg_lid]; + sumf = sub_group_reduce_add(sumf); + + if (lid == 0) { + dst_row[0] = sumf; + } +} + +kernel void kernel_sum_rows_f32_4( + global char * src0, + ulong offset0, + global char * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne03, + ulong nb01, + ulong nb02, + ulong nb03, + ulong nb1, + ulong nb2, + ulong nb3 +) { + src0 = src0 + offset0; + dst = dst + offsetd; + + const int i3 = get_group_id(2); + const int i2 = get_group_id(1); + const int i1 = get_group_id(0); + + const int lid = get_local_id(0); + const int lsize = get_local_size(0); + + const uint sg_size = get_sub_group_size(); + const uint sg_id = get_sub_group_id(); + const uint sg_lid = get_sub_group_local_id(); + + __local float lmem[MAX_SUBGROUPS]; + + if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) { + return; + } + + if(sg_id == 0){ + lmem[sg_lid] = 0.0f; + } + + global float4 * src_row = (global float4 *) (src0 + i1*nb01 + i2*nb02 + i3*nb03); + global float * dst_row = (global float *) (dst + i1*nb1 + i2*nb2 + i3*nb3); + + float4 sum_vec = (float4)0.0f; + + for (int i0 = lid; i0 < ne00 / 4; i0 += lsize) { + sum_vec += src_row[i0]; + } + + float sumf = dot(sum_vec, (float4)(1.0f)); + sumf = sub_group_reduce_add(sumf); + + barrier(CLK_LOCAL_MEM_FENCE); + + if(sg_lid == 0){ + lmem[sg_id] = sumf; + } + + barrier(CLK_LOCAL_MEM_FENCE); + + sumf = lmem[sg_lid]; + sumf = sub_group_reduce_add(sumf); + + if (lid == 0) { + dst_row[0] = sumf; + } } diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 72097ffd0f..a8840a0773 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -92,6 +92,7 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } #define VK_VENDOR_ID_APPLE 0x106b #define VK_VENDOR_ID_INTEL 0x8086 #define VK_VENDOR_ID_NVIDIA 0x10de +#define VK_VENDOR_ID_QUALCOMM 0x5143 #define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256 @@ -687,6 +688,7 @@ struct vk_device_struct { vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT]; vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_acc_f32; + vk_pipeline pipeline_set_f32; // [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16] vk_pipeline pipeline_add[2][2][2]; @@ -942,6 +944,7 @@ struct vk_mat_mat_push_constants { uint32_t M; uint32_t N; uint32_t K; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; + uint32_t base_work_group_z; uint32_t num_batches; uint32_t k_split; uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3; uint32_t padded_N; @@ -961,6 +964,7 @@ struct vk_mat_vec_push_constants { uint32_t batch_stride_b; uint32_t batch_stride_d; uint32_t fusion_flags; + uint32_t base_work_group_y; uint32_t ne02; uint32_t ne12; uint32_t broadcast2; @@ -4080,7 +4084,7 @@ static void ggml_vk_load_shaders(vk_device& device) { } ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -4181,7 +4185,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_set_f32, "set_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 0}, 1); ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); @@ -5641,6 +5646,10 @@ static void ggml_vk_instance_init() { driver_priorities[vk::DriverId::eMesaNvk] = 2; #endif break; + case VK_VENDOR_ID_QUALCOMM: + driver_priorities[vk::DriverId::eQualcommProprietary] = 1; + driver_priorities[vk::DriverId::eMesaTurnip] = 2; + break; } driver_priorities[vk::DriverId::eMesaDozen] = 100; @@ -6766,8 +6775,16 @@ static void ggml_vk_matmul( uint32_t padded_n) { VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")"); if (split_k == 1) { - const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n }; - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch }); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2])); + + uint32_t base_work_group_z = 0; + while (base_work_group_z < batch) { + uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]); + + const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k, ne02, ne12, broadcast2, broadcast3, padded_n }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, groups_z }); + base_work_group_z += groups_z; + } return; } @@ -6781,9 +6798,17 @@ static void ggml_vk_matmul( uint32_t k_split = CEIL_DIV(k, split_k); k_split = ROUNDUP_POW2(k_split, 256); - const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n }; - // Make sure enough workgroups get assigned for split k to work - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch }); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, CEIL_DIV(batch, ctx->device->properties.limits.maxComputeWorkGroupCount[2])); + + uint32_t base_work_group_z = 0; + while (base_work_group_z < batch) { + uint32_t groups_z = std::min(batch - base_work_group_z, ctx->device->properties.limits.maxComputeWorkGroupCount[2]); + + const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, base_work_group_z, batch, k_split, ne02, ne12, broadcast2, broadcast3, padded_n }; + // Make sure enough workgroups get assigned for split k to work + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, groups_z }); + base_work_group_z += groups_z; + } ggml_vk_sync_buffers(ctx, subctx); const std::array pc2 = { (uint32_t)(m * n * batch), split_k }; ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 }); @@ -7179,7 +7204,6 @@ static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& sub } // Request descriptor sets - ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); if (qx_needs_dequant) { ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1); } @@ -7477,7 +7501,6 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& if (quantize_y) { ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1); } - ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1); } vk_subbuffer d_D = ggml_vk_tensor_subbuffer(ctx, cgraph->nodes[node_idx + ctx->num_additional_fused_ops]); @@ -7572,22 +7595,29 @@ static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& fusion_flags |= MAT_VEC_FUSION_FLAGS_BIAS1; } - // compute - const vk_mat_vec_push_constants pc = { - (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, - stride_batch_x, stride_batch_y, stride_batch_d, - fusion_flags, - (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3, - }; - ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, - { - d_X, - d_Y, - d_D, - d_F0, - d_F1, - }, - pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z }); + ggml_pipeline_request_descriptor_sets(ctx, dmmv, CEIL_DIV(ne12 * ne13, ctx->device->properties.limits.maxComputeWorkGroupCount[1])); + + uint32_t base_work_group_y = 0; + while (base_work_group_y < ne12 * ne13) { + + uint32_t groups_y = std::min((uint32_t)(ne12 * ne13) - base_work_group_y, ctx->device->properties.limits.maxComputeWorkGroupCount[1]); + const vk_mat_vec_push_constants pc = { + (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, + stride_batch_x, stride_batch_y, stride_batch_d, + fusion_flags, base_work_group_y, + (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3, + }; + ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, + { + d_X, + d_Y, + d_D, + d_F0, + d_F1, + }, + pc, { groups_x, groups_y, groups_z }); + base_work_group_y += groups_y; + } if (x_non_contig) { ctx->prealloc_x_need_sync = true; @@ -7825,10 +7855,15 @@ static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, c src1->nb[2] <= src1->nb[1] && src1->nb[1] <= src1->nb[3] && src0->ne[3] == 1 && - src1->ne[3] == 1) { + src1->ne[3] == 1 && + src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] && + src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) { ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, cgraph, node_idx); } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 && - !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) { + !ggml_is_permuted(src0) && !ggml_is_permuted(src1) && + src0->ne[3] <= ctx->device->properties.limits.maxComputeWorkGroupCount[0] && + src0->ne[1] <= ctx->device->properties.limits.maxComputeWorkGroupCount[1] && + src1->ne[2] <= ctx->device->properties.limits.maxComputeWorkGroupCount[2]) { ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, cgraph, node_idx); // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four) // when ne12 and ne13 are one. @@ -8422,6 +8457,8 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co const uint32_t acctype = f32acc ? 4 : 2; const uint32_t f16vec4 = 8; + const uint32_t tmpsh = (Bc / MatBc) * sizeof(float); + const uint32_t qstride = hsk_pad / 4 + 2; const uint32_t Qf = Br * qstride * f16vec4; @@ -8438,7 +8475,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co const uint32_t slope = Br * acctype; - const uint32_t total_size = Qf + Psh + sfsh + ksh + slope; + const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + slope; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", kv_type=" << kv_type << ", total_size=" << total_size << ", supported=" << supported); @@ -8815,6 +8852,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_acc_f32; } return nullptr; + case GGML_OP_SET: + if (src0->type == src1->type && src0->type == dst->type && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32)) { + return ctx->device->pipeline_set_f32; + } + return nullptr; case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_MUL: @@ -9801,16 +9844,16 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t dst_type_size = ggml_type_size(dst->type); - int nb1 = dst->op_params[0] / 4; // 4 bytes of float32 - int nb2 = dst->op_params[1] / 4; // 4 bytes of float32 - // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused - int offset = dst->op_params[3] / 4; // offset in bytes + int nb1 = dst->op_params[0] / src0_type_size; // 4 bytes of float32 + int nb2 = dst->op_params[1] / src0_type_size; // 4 bytes of float32 + int nb3 = dst->op_params[2] / src0_type_size; // 4 bytes of float32 + int offset = dst->op_params[3] / src0_type_size; // offset in bytes - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ACC, { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, { (uint32_t)ggml_nelements(src0), - (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3, (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, - (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] / dst_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3, 0, 0.0f, 0.0f, offset, }); @@ -10624,8 +10667,10 @@ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& sub } static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { - float * op_params = (float *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f }); + const float * op_params = (const float *)dst->op_params; + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); + p.param1 = op_params[0]; + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, std::move(p)); } static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { @@ -11543,7 +11588,6 @@ static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t } } - ggml_pipeline_request_descriptor_sets(ctx, p, num_it); if (split_k > 1) { ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it); @@ -12052,7 +12096,6 @@ static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, // y[i] = i % k; } - ggml_pipeline_request_descriptor_sets(ctx, p, num_it); if (split_k > 1) { ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it); @@ -12500,6 +12543,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_ACC: + case GGML_OP_SET: ggml_vk_acc(ctx, compute_ctx, src0, src1, node); break; @@ -14896,8 +14940,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return true; case GGML_OP_NORM: case GGML_OP_GROUP_NORM: - case GGML_OP_L2_NORM: return ggml_is_contiguous(op->src[0]); + case GGML_OP_L2_NORM: + return ggml_is_contiguous_rows(op->src[0]) && + op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_MUL: @@ -14960,7 +15006,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ACC: - return op->src[0]->type == GGML_TYPE_F32; + return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + case GGML_OP_SET: + return op->src[0]->type == op->src[1]->type && op->src[0]->type == op->type && + (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_I32); case GGML_OP_CONCAT: return ggml_type_size(op->src[0]->type) == ggml_type_size(GGML_TYPE_F32); case GGML_OP_ADD1: @@ -15611,6 +15660,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_ACC) { tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]); + } else if (tensor->op == GGML_OP_SET) { + tensor_clone = ggml_set(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]); } else if (tensor->op == GGML_OP_NORM) { tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params); } else if (tensor->op == GGML_OP_GROUP_NORM) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp b/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp index 5084a70ed4..6ba3d1d89e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp @@ -3,6 +3,9 @@ #include "types.glsl" #include "generic_binary_head.glsl" +// false for SET, true for ACC +layout(constant_id = 1) const bool ACC = true; + layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; void main() { @@ -13,17 +16,22 @@ void main() { const uint offset = p.param3; const uint src1_i = idx - offset; - const uint oz = src1_i / p.nb02; - const uint oy = (src1_i - (oz * p.nb02)) / p.nb01; - const uint ox = src1_i % p.nb01; + const uint i3 = src1_i / p.nb03; + const uint rem2 = src1_i - i3 * p.nb03; + const uint i2 = rem2 / p.nb02; + const uint rem1 = rem2 - i2 * p.nb02; + const uint i1 = rem1 / p.nb01; + const uint i0 = rem1 % p.nb01; uint i00, i01, i02, i03; - get_indices(idx, i00, i01, i02, i03); - if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) { - data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + ox + oy * p.ne10 + oz * p.ne10 * p.ne11])); + if (i0 < p.ne10 && i1 < p.ne11 && i2 < p.ne12 && i3 < p.ne13) { + if (ACC) { + data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)])); + } else { + data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)])); + } } else { - data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)])); + data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx])); } } - diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 914f131c96..0735f67854 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -130,6 +130,7 @@ void main() { if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) { bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; + float max_mask = NEG_FLT_MAX_OVER_2; [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { uint32_t c = (idx + tid) % Bc; uint32_t r = (idx + tid) / Bc; @@ -137,12 +138,25 @@ void main() { if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); masksh[c][r] = m; + max_mask = max(max_mask, m); } else { masksh[c][r] = float(0); } } } + // skip the block if the mask is entirely -inf + bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); barrier(); + if (gl_SubgroupInvocationID == 0) { + tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; + } + barrier(); + [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { + max_mask = max(max_mask, tmpsh[s]); + } + if (max_mask <= NEG_FLT_MAX_OVER_2) { + continue; + } } float Sf[Br][cols_per_thread]; @@ -260,6 +274,9 @@ void main() { barrier(); } + // prevent race on tmpsh + barrier(); + // reduce across threads [[unroll]] for (uint32_t r = 0; r < Br; ++r) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index b317773823..19630972da 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -42,6 +42,8 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY return elem; } +shared float tmpsh[row_split]; + const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4 shared f16vec4 Qf[Br * qstride]; @@ -213,6 +215,19 @@ void main() { } } } + // skip the block if the mask is entirely -inf + bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); + barrier(); + if (gl_SubgroupInvocationID == 0) { + tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; + } + barrier(); + [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { + max_mask = max(max_mask, tmpsh[s]); + } + if (max_mask <= NEG_FLT_MAX_OVER_2) { + continue; + } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 39f0c4d23b..853f17fa16 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -176,7 +176,14 @@ void main() { tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t + coopmat mvmax; + coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + // skip the block if the mask is entirely -inf + coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16); + if (mvmax[0] <= NEG_FLT_MAX_OVER_2) { + continue; + } } else { tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); // Don't clamp against nem1 when GQA is enabled @@ -184,7 +191,14 @@ void main() { tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV); tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); + coopmat mvmax; + coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + // skip the block if the mask is entirely -inf + coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16); + if (mvmax[0] <= NEG_FLT_MAX_OVER_2) { + continue; + } } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp index 83ef2f8795..7d0a1de0df 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp @@ -1,6 +1,6 @@ #version 450 -#include "generic_head.glsl" +#include "generic_unary_head.glsl" #include "types.glsl" #extension GL_EXT_control_flow_attributes : enable @@ -8,19 +8,22 @@ layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; - shared FLOAT_TYPE sum[BLOCK_SIZE]; void main() { const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; const uint tid = gl_LocalInvocationID.x; + const uint i3 = row / (p.ne11 * p.ne12); + const uint i3_offset = i3 * p.ne12 * p.ne11; + const uint i2 = (row - i3_offset) / p.ne11; + const uint i2_offset = i2 * p.ne11; + const uint i1 = row - i3_offset - i2_offset; + sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp - [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { - const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]); + [[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) { + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]); sum[tid] += xi * xi; } @@ -35,7 +38,7 @@ void main() { const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1))); - [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { - data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col])); + [[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) { + data_d[i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0] = D_TYPE(scale * FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0])); } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl index 4f2c700306..4aeda68c7f 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl @@ -32,6 +32,7 @@ layout (push_constant) uniform parameter uint expert_i1; uint nbi1; #else + uint base_work_group_y; uint ne02; uint ne12; uint broadcast2; @@ -45,9 +46,9 @@ uint expert_id; void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) { #ifdef MUL_MAT_ID - const uint expert_i0 = gl_GlobalInvocationID.y; + const uint expert_i0 = gl_WorkGroupID.y; #else - const uint batch_idx = gl_GlobalInvocationID.y; + const uint batch_idx = gl_WorkGroupID.y + p.base_work_group_y; #endif #ifndef MUL_MAT_ID diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp index 775e9a70f6..79344d3300 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -90,6 +90,8 @@ layout (push_constant) uniform parameter uint nbi1; uint ne11; #else + uint base_work_group_z; + uint num_batches; uint k_split; uint ne02; uint ne12; @@ -139,7 +141,7 @@ void main() { const uint ic = gl_WorkGroupID.y; #ifdef MUL_MAT_ID - const uint expert_idx = gl_GlobalInvocationID.z; + const uint expert_idx = gl_WorkGroupID.z; if (ic * BN >= data_expert_count[expert_idx]) { return; } @@ -149,7 +151,7 @@ void main() { #endif #ifndef MUL_MAT_ID - const uint batch_idx = gl_GlobalInvocationID.z; + const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z; const uint i13 = batch_idx / p.ne12; const uint i12 = batch_idx % p.ne12; @@ -366,7 +368,7 @@ void main() { const uint dc = ic * BN + warp_c * WN; #ifndef MUL_MAT_ID - const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; + const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches; #endif #ifdef COOPMAT diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index b6614d2fc5..717d124e01 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -53,6 +53,8 @@ layout (push_constant) uniform parameter uint nbi1; uint ne11; #else + uint base_work_group_z; + uint num_batches; uint k_split; uint ne02; uint ne12; @@ -197,7 +199,7 @@ void main() { const uint ic = gl_WorkGroupID.y; #ifdef MUL_MAT_ID - const uint expert_idx = gl_GlobalInvocationID.z; + const uint expert_idx = gl_WorkGroupID.z; if (ic * BN >= data_expert_count[expert_idx]) { return; } @@ -215,7 +217,7 @@ void main() { #endif #ifndef MUL_MAT_ID - const uint batch_idx = gl_GlobalInvocationID.z; + const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z; const uint i13 = batch_idx / p.ne12; const uint i12 = batch_idx % p.ne12; @@ -255,7 +257,7 @@ void main() { #else uint pos_a = batch_idx_a * (p.batch_stride_a / QUANT_K); uint pos_b = batch_idx * p.batch_stride_b; - uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; + uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches; #endif uint stride_a = p.stride_a / QUANT_K; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp index 335d7f6a68..aae1c2e8ae 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -57,6 +57,8 @@ layout (push_constant) uniform parameter uint nbi1; uint ne11; #else + uint base_work_group_z; + uint num_batches; uint k_split; uint ne02; uint ne12; @@ -108,7 +110,7 @@ void main() { const uint ic = gl_WorkGroupID.y; #ifdef MUL_MAT_ID - const uint expert_idx = gl_GlobalInvocationID.z; + const uint expert_idx = gl_WorkGroupID.z; if (ic * BN >= data_expert_count[expert_idx]) { return; } @@ -118,7 +120,7 @@ void main() { #endif #ifndef MUL_MAT_ID - const uint batch_idx = gl_GlobalInvocationID.z; + const uint batch_idx = gl_WorkGroupID.z + p.base_work_group_z; const uint i13 = batch_idx / p.ne12; const uint i12 = batch_idx % p.ne12; @@ -276,7 +278,7 @@ void main() { const uint dc = ic * BN + warp_c * WN; #ifndef MUL_MAT_ID - const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; + const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * p.num_batches; #endif [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { diff --git a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp index 6997f6bdd3..0d5a818dac 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu-shader-lib.hpp @@ -1,10 +1,16 @@ #ifndef GGML_WEBGPU_SHADER_LIB_HPP #define GGML_WEBGPU_SHADER_LIB_HPP +#include "ggml-wgsl-shaders.hpp" #include "ggml.h" #include "pre_wgsl.hpp" +#include + +#include +#include #include +#include #include #define GGML_WEBGPU_F16_SIZE_BYTES 2 @@ -17,17 +23,203 @@ #define GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE 512u -struct ggml_webgpu_processed_shader { - std::string wgsl; - std::string variant; - void * decisions; -}; +// Matrix multiplication parameters + +// Register tiling parameters +#define WEBGPU_MUL_MAT_TILE_M 8 +#define WEBGPU_MUL_MAT_TILE_N 8 +#define WEBGPU_MUL_MAT_WG_SIZE_M 8 +#define WEBGPU_MUL_MAT_WG_SIZE_N 8 +#define WEBGPU_MUL_MAT_TILE_K 32 + +// Subgroup matrix parameters +// The number of subgroups in the M dimension +#define WEBGPU_MUL_MAT_SUBGROUP_M 2 +// The number of subgroups in the N dimension +#define WEBGPU_MUL_MAT_SUBGROUP_N 2 +// The number of subgroup matrices each subgroup accumulates over +#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4 +#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2 + +// Matrix-vector multiplication parameters +#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256 +// Must be multiple of 4 to work with vectorized paths, and must divide +// mul_mat_vec wg size +#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64 +#define WEBGPU_MUL_MAT_VEC_TILE_K 256 + +// default size for legacy matrix multiplication +#define WEBGPU_MUL_MAT_WG_SIZE 256 // Same hash combine function as in boost template inline void ggml_webgpu_hash_combine(size_t & seed, const T & value) { seed ^= std::hash{}(value) + 0x9e3779b9 + (seed << 6) + (seed >> 2); } +struct ggml_webgpu_shader_lib_context { + ggml_tensor * src0; + ggml_tensor * src1; + ggml_tensor * src2; + ggml_tensor * src3; + ggml_tensor * src4; + ggml_tensor * dst; + + uint32_t max_wg_size; + size_t wg_mem_limit_bytes = 0; + bool inplace = false; + bool overlap = false; + bool supports_subgroup_matrix = false; + uint32_t sg_mat_m = 0; + uint32_t sg_mat_n = 0; + uint32_t sg_mat_k = 0; + uint32_t max_subgroup_size = 0; +}; + +struct webgpu_pipeline { + wgpu::ComputePipeline pipeline; + std::string name; + std::shared_ptr context = nullptr; +}; + +struct ggml_webgpu_generic_shader_decisions { + uint32_t wg_size = 0; +}; + +/** Argsort **/ + +struct ggml_webgpu_argsort_shader_lib_context { + uint32_t max_wg_size; + size_t wg_mem_limit_bytes; + int32_t order; +}; + +/** Set Rows **/ + +struct ggml_webgpu_set_rows_pipeline_key { + int dst_type; + int vec4; + int i64_idx; + + bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const { + return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx; + } +}; + +struct ggml_webgpu_set_rows_pipeline_key_hash { + size_t operator()(const ggml_webgpu_set_rows_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.dst_type); + ggml_webgpu_hash_combine(seed, key.vec4); + ggml_webgpu_hash_combine(seed, key.i64_idx); + return seed; + } +}; + +struct ggml_webgpu_set_rows_shader_decisions { + bool vec4; + bool i64_idx; + uint32_t wg_size; +}; + +/** Get Rows **/ + +struct ggml_webgpu_get_rows_pipeline_key { + ggml_type src_type; + int vectorized; + + bool operator==(const ggml_webgpu_get_rows_pipeline_key & other) const { + return src_type == other.src_type && vectorized == other.vectorized; + } +}; + +struct ggml_webgpu_get_rows_pipeline_key_hash { + size_t operator()(const ggml_webgpu_get_rows_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.src_type); + ggml_webgpu_hash_combine(seed, key.vectorized); + return seed; + } +}; + +/** Pad **/ +struct ggml_webgpu_pad_pipeline_key { + bool circular; + + bool operator==(const ggml_webgpu_pad_pipeline_key & other) const { return circular == other.circular; } +}; + +struct ggml_webgpu_pad_pipeline_key_hash { + size_t operator()(const ggml_webgpu_pad_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.circular); + return seed; + } +}; + +/** Scale **/ + +struct ggml_webgpu_scale_pipeline_key { + int inplace; + + bool operator==(const ggml_webgpu_scale_pipeline_key & other) const { return inplace == other.inplace; } +}; + +struct ggml_webgpu_scale_pipeline_key_hash { + size_t operator()(const ggml_webgpu_scale_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.inplace); + return seed; + } +}; + +/** Binary **/ + +struct ggml_webgpu_binary_pipeline_key { + int type; + int op; + bool inplace; + bool overlap; + + bool operator==(const ggml_webgpu_binary_pipeline_key & other) const { + return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap; + } +}; + +struct ggml_webgpu_binary_pipeline_key_hash { + size_t operator()(const ggml_webgpu_binary_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.op); + ggml_webgpu_hash_combine(seed, key.inplace); + ggml_webgpu_hash_combine(seed, key.overlap); + return seed; + } +}; + +/** Unary **/ + +struct ggml_webgpu_unary_pipeline_key { + int type; + int op; + bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella + bool inplace; + + bool operator==(const ggml_webgpu_unary_pipeline_key & other) const { + return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace; + } +}; + +struct ggml_webgpu_unary_pipeline_key_hash { + size_t operator()(const ggml_webgpu_unary_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.type); + ggml_webgpu_hash_combine(seed, key.op); + ggml_webgpu_hash_combine(seed, key.is_unary); + ggml_webgpu_hash_combine(seed, key.inplace); + return seed; + } +}; + /** FlashAttention */ struct ggml_webgpu_flash_attn_pipeline_key { @@ -99,439 +291,941 @@ inline size_t ggml_webgpu_flash_attn_wg_mem_bytes(uint32_t q_tile, return f16_elems * GGML_WEBGPU_F16_SIZE_BYTES + f32_elems * GGML_WEBGPU_F32_SIZE_BYTES; } -static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) { - const size_t limit_bytes = context.wg_mem_limit_bytes; - const size_t q_tile = context.sg_mat_m; - const size_t base_q_bytes = - (context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + - 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; - size_t bytes_per_kv = 0; - if (!context.key.kv_direct) { - bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v); +/** Matrix Multiplication **/ + +struct ggml_webgpu_legacy_mul_mat_pipeline_key { + ggml_type src0_type; + ggml_type src1_type; + + bool operator==(const ggml_webgpu_legacy_mul_mat_pipeline_key & other) const { + return src0_type == other.src0_type && src1_type == other.src1_type; } - if (context.key.has_mask) { - bytes_per_kv += q_tile; - } - bytes_per_kv += q_tile; - bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; - const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; - return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n; -} - -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_flash_attn_shader_lib_context & context) { - std::vector defines; - std::string variant = "flash_attn"; - - switch (context.key.kv_type) { - case GGML_TYPE_F32: - defines.push_back("KV_F32"); - break; - case GGML_TYPE_F16: - defines.push_back("KV_F16"); - break; - case GGML_TYPE_Q4_0: - defines.push_back("KV_Q4_0"); - break; - case GGML_TYPE_Q8_0: - defines.push_back("KV_Q8_0"); - break; - default: - GGML_ABORT("Unsupported KV type for flash attention shader"); - } - variant += std::string("_") + ggml_type_name(context.key.kv_type); - - if (context.key.has_mask) { - defines.push_back("MASK"); - variant += "_mask"; - } - if (context.key.has_sinks) { - defines.push_back("SINKS"); - variant += "_sinks"; - } - if (context.key.uses_logit_softcap) { - defines.push_back("LOGIT_SOFTCAP"); - variant += "_lgsc"; - } - - if (context.key.kv_direct) { - defines.push_back("KV_DIRECT"); - variant += "_kvdirect"; - } - - defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(context.key.head_dim_qk)); - variant += std::string("_hsqk") + std::to_string(context.key.head_dim_qk); - - defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(context.key.head_dim_v)); - variant += std::string("_hsv") + std::to_string(context.key.head_dim_v); - // For now these are not part of the variant name - defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); - defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); - defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); - - // Add chosen Q/KV tile sizes - uint32_t q_tile = context.sg_mat_m; - uint32_t kv_tile = std::min(ggml_webgpu_flash_attn_max_kv_tile(context), - context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); - if (context.key.kv_direct) { - GGML_ASSERT(kv_tile <= GGML_WEBGPU_KV_SEQ_PAD); - // Avoids having to use bounds-checks and decreasing performance for direct KV loads - while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { - kv_tile -= context.sg_mat_n; - } - } - - defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); - defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); - - // workgroup size - uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); - - defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - ggml_webgpu_flash_attn_shader_decisions * decisions = new ggml_webgpu_flash_attn_shader_decisions(); - decisions->q_tile = q_tile; - decisions->kv_tile = kv_tile; - decisions->wg_size = wg_size; - result.decisions = decisions; - return result; -} - -/** Generic **/ - -struct ggml_webgpu_generic_shader_lib_context { - int vec4; - uint32_t max_wg_size; }; -struct ggml_webgpu_generic_shader_decisions { +struct ggml_webgpu_legacy_mul_mat_pipeline_key_hash { + size_t operator()(const ggml_webgpu_legacy_mul_mat_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.src0_type); + ggml_webgpu_hash_combine(seed, key.src1_type); + return seed; + } +}; + +struct ggml_webgpu_mul_mat_vec_pipeline_key { + ggml_type src0_type; + ggml_type src1_type; + int vectorized; + + bool operator==(const ggml_webgpu_mul_mat_vec_pipeline_key & other) const { + return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized; + } +}; + +struct ggml_webgpu_mul_mat_vec_pipeline_key_hash { + size_t operator()(const ggml_webgpu_mul_mat_vec_pipeline_key & key) const { + size_t seed = 0; + ggml_webgpu_hash_combine(seed, key.src0_type); + ggml_webgpu_hash_combine(seed, key.src1_type); + ggml_webgpu_hash_combine(seed, key.vectorized); + return seed; + } +}; + +struct ggml_webgpu_mul_mat_vec_shader_decisions { uint32_t wg_size; + uint32_t tile_k; + uint32_t outputs_per_wg; + uint32_t vec_size; }; -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_generic_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_generic_shader_lib_context & context, - const std::string & base_variant) { - std::vector defines; - std::string variant = base_variant; +struct ggml_webgpu_mul_mat_pipeline_key { + ggml_type src0_type; + ggml_type src1_type; + int vectorized; + int use_subgroup_matrix; - if (context.vec4) { - defines.push_back("VEC4"); - variant += "_vec"; + bool operator==(const ggml_webgpu_mul_mat_pipeline_key & other) const { + return src0_type == other.src0_type && src1_type == other.src1_type && vectorized == other.vectorized && + use_subgroup_matrix == other.use_subgroup_matrix; } - - defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - return result; -} - -/** Pad **/ - -struct ggml_webgpu_pad_pipeline_key { - bool circular; - - bool operator==(const ggml_webgpu_pad_pipeline_key & other) const { return circular == other.circular; } }; -struct ggml_webgpu_pad_pipeline_key_hash { - size_t operator()(const ggml_webgpu_pad_pipeline_key & key) const { +struct ggml_webgpu_mul_mat_pipeline_key_hash { + size_t operator()(const ggml_webgpu_mul_mat_pipeline_key & key) const { size_t seed = 0; - ggml_webgpu_hash_combine(seed, key.circular); + ggml_webgpu_hash_combine(seed, key.src0_type); + ggml_webgpu_hash_combine(seed, key.src1_type); + ggml_webgpu_hash_combine(seed, key.vectorized); + ggml_webgpu_hash_combine(seed, key.use_subgroup_matrix); return seed; } }; -struct ggml_webgpu_pad_shader_lib_context { - ggml_webgpu_pad_pipeline_key key; - uint32_t max_wg_size; +struct ggml_webgpu_mul_mat_shader_decisions { + uint32_t tile_k; + uint32_t wg_size_m; + uint32_t wg_size_n; + uint32_t wg_size; + uint32_t outputs_per_wg; + int use_subgroup_matrix; + + uint32_t tile_m; + uint32_t tile_n; + + // Subgroup matrix parameters + uint32_t subgroup_m; + uint32_t subgroup_n; + uint32_t subgroup_matrix_m; + uint32_t subgroup_matrix_n; + + uint32_t mul_mat_wg_size; }; -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_pad_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_pad_shader_lib_context & context) { - std::vector defines; - std::string variant = "pad"; +class ggml_webgpu_shader_lib { + wgpu::Device device; + pre_wgsl::Preprocessor preprocessor; - if (context.key.circular) { - defines.push_back("CIRCULAR"); - variant += "_circular"; + std::unordered_map sum_rows_pipelines; // key is fixed, no variants yet + std::unordered_map argmax_pipelines; // key is vec4 + std::unordered_map argsort_pipelines; // key is order + std::unordered_map argsort_merge_pipelines; // key is order + std::unordered_map cumsum_pipelines; // key is fixed, no variants yet + std::unordered_map + get_rows_pipelines; // src_type, vectorized + std::unordered_map + unary_pipelines; // type/op/inplace + std::unordered_map + scale_pipelines; // inplace + std::unordered_map + pad_pipelines; // circular/non-circular + std::unordered_map + binary_pipelines; // type/op/inplace/overlap + std::unordered_map + flash_attn_pipelines; + std::unordered_map + mul_mat_legacy_pipelines; // legacy mul_mat (non-subgroup/non-regtile/non-vec) + std::unordered_map + mul_mat_vec_pipelines; // fast mat-vec (n==1) + std::unordered_map + mul_mat_fast_pipelines; // fast mat-mat (reg-tile or subgroup) + + std::unordered_map + set_rows_pipelines; + + public: + ggml_webgpu_shader_lib(wgpu::Device device) { this->device = device; } + + webgpu_pipeline get_sum_rows_pipeline(const ggml_webgpu_shader_lib_context & context) { + auto it = sum_rows_pipelines.find(1); + if (it != sum_rows_pipelines.end()) { + return it->second; + } + std::vector defines; + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_sum_rows, defines); + sum_rows_pipelines[1] = ggml_webgpu_create_pipeline(device, processed, "sum_rows"); + return sum_rows_pipelines[1]; } - defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + webgpu_pipeline get_argmax_pipeline(const ggml_webgpu_shader_lib_context & context) { + bool vec4 = context.src0->ne[0] % 4 == 0; - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions(); - decisions->wg_size = context.max_wg_size; - result.decisions = decisions; - return result; -} + auto it = argmax_pipelines.find(vec4); + if (it != argmax_pipelines.end()) { + return it->second; + } + std::string variant = "argmax"; + std::vector defines; + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + if (vec4) { + defines.push_back("VEC4"); + variant += "_vec4"; + } -/** Argsort **/ - -struct ggml_webgpu_argsort_shader_lib_context { - uint32_t max_wg_size; - size_t wg_mem_limit_bytes; - int32_t order; -}; - -struct ggml_webgpu_argsort_shader_decisions { - uint32_t wg_size = 0; -}; - -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_argsort_shader_lib_context & context) { - std::vector defines; - std::string variant = "argsort"; - defines.push_back(std::string("ORDER=") + std::to_string(context.order)); - variant += std::string("_order") + std::to_string(context.order); - uint32_t wg_size = 1; - while (wg_size * 2 <= context.max_wg_size && - wg_size * GGML_WEBGPU_I32_SIZE_BYTES <= context.wg_mem_limit_bytes / 2) { - wg_size *= 2; + auto processed = preprocessor.preprocess(wgsl_argmax, defines); + argmax_pipelines[vec4] = ggml_webgpu_create_pipeline(device, processed, variant); + return argmax_pipelines.at(vec4); } - defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - ggml_webgpu_argsort_shader_decisions * decisions = new ggml_webgpu_argsort_shader_decisions(); - decisions->wg_size = wg_size; - result.decisions = decisions; - return result; -} -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_merge_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_argsort_shader_lib_context & context) { - std::vector defines; - std::string variant = "argsort_merge"; - defines.push_back(std::string("ORDER=") + std::to_string(context.order)); - variant += std::string("_order") + std::to_string(context.order); - uint32_t wg_size = std::min(GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE, context.max_wg_size); - defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - ggml_webgpu_argsort_shader_decisions * decisions = new ggml_webgpu_argsort_shader_decisions(); - decisions->wg_size = wg_size; - result.decisions = decisions; - return result; -} + webgpu_pipeline get_set_rows_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_set_rows_pipeline_key key = { .dst_type = context.dst->type, + .vec4 = context.src0->ne[0] % 4 == 0, + .i64_idx = context.src1->type == GGML_TYPE_I64 }; -/** Set Rows **/ + auto it = set_rows_pipelines.find(key); + if (it != set_rows_pipelines.end()) { + return it->second; + } -struct ggml_webgpu_set_rows_pipeline_key { - int dst_type; - int vec4; - int i64_idx; + std::vector defines; + std::string variant = "set_rows"; - bool operator==(const ggml_webgpu_set_rows_pipeline_key & other) const { - return dst_type == other.dst_type && vec4 == other.vec4 && i64_idx == other.i64_idx; + switch (context.dst->type) { + case GGML_TYPE_F32: + defines.push_back("DST_F32"); + variant += "_dstf32"; + break; + case GGML_TYPE_F16: + defines.push_back("DST_F16"); + variant += "_dstf16"; + break; + default: + GGML_ABORT("Unsupported dst type for set_rows shader"); + } + + if (key.vec4) { + defines.push_back("VEC4"); + variant += "_vec4"; + } + if (key.i64_idx) { + defines.push_back("I64_IDX"); + variant += "_i64idx"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_set_rows, defines); + auto decisions = std::make_shared(); + decisions->vec4 = key.vec4; + decisions->i64_idx = key.i64_idx; + decisions->wg_size = context.max_wg_size; + set_rows_pipelines[key] = ggml_webgpu_create_pipeline(device, processed, variant); + set_rows_pipelines[key].context = decisions; + return set_rows_pipelines[key]; + } + + webgpu_pipeline get_cumsum_pipeline(const ggml_webgpu_shader_lib_context & context) { + auto it = cumsum_pipelines.find(1); + if (it != cumsum_pipelines.end()) { + return it->second; + } + + std::vector defines; + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_cumsum, defines); + cumsum_pipelines[1] = ggml_webgpu_create_pipeline(device, processed, "cumsum"); + return cumsum_pipelines[1]; + } + + webgpu_pipeline get_argsort_pipeline(const ggml_webgpu_shader_lib_context & context) { + bool is_top_k = context.dst->op == GGML_OP_TOP_K; + // ascending order is 0, descending order is 1 + const int32_t order = + is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(context.dst, 0); + + auto it = argsort_pipelines.find(order); + if (it != argsort_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "argsort"; + defines.push_back(std::string("ORDER=") + std::to_string(order)); + variant += std::string("_order") + std::to_string(order); + uint32_t wg_size = 1; + while (wg_size * 2 <= context.max_wg_size && + wg_size * GGML_WEBGPU_I32_SIZE_BYTES <= context.wg_mem_limit_bytes / 2) { + wg_size *= 2; + } + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + auto processed = preprocessor.preprocess(wgsl_argsort, defines); + auto decisions = std::make_shared(); + decisions->wg_size = wg_size; + argsort_pipelines[order] = ggml_webgpu_create_pipeline(device, processed, variant); + argsort_pipelines[order].context = decisions; + return argsort_pipelines[order]; + } + + webgpu_pipeline get_argsort_merge_pipeline(const ggml_webgpu_shader_lib_context & context) { + bool is_top_k = context.dst->op == GGML_OP_TOP_K; + // ascending order is 0, descending order is 1 + const int32_t order = + is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(context.dst, 0); + + auto it = argsort_merge_pipelines.find(order); + if (it != argsort_merge_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "argsort_merge"; + defines.push_back(std::string("ORDER=") + std::to_string(order)); + variant += std::string("_order") + std::to_string(order); + uint32_t wg_size = std::min(GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE, context.max_wg_size); + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + + auto processed = preprocessor.preprocess(wgsl_argsort_merge, defines); + argsort_merge_pipelines[order] = ggml_webgpu_create_pipeline(device, processed, variant); + return argsort_merge_pipelines[order]; + } + + webgpu_pipeline get_get_rows_pipeline(const ggml_webgpu_shader_lib_context & context) { + const bool vectorized = context.src0->type == GGML_TYPE_F32 && context.dst->ne[0] % 4 == 0; + ggml_webgpu_get_rows_pipeline_key key = { + .src_type = context.src0->type, + .vectorized = (int) vectorized, + }; + + auto it = get_rows_pipelines.find(key); + if (it != get_rows_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "get_rows"; + + const struct ggml_type_traits * type_traits = ggml_get_type_traits(key.src_type); + const char * type_str = type_traits->type_name; + + switch (key.src_type) { + case GGML_TYPE_F32: + if (key.vectorized) { + defines.push_back("F32_VEC"); + defines.push_back("SRC_TYPE=vec4"); + defines.push_back("DST_TYPE=vec4"); + defines.push_back("BLOCK_SIZE=4u"); + } else { + defines.push_back("F32"); + defines.push_back("SRC_TYPE=f32"); + defines.push_back("DST_TYPE=f32"); + defines.push_back("BLOCK_SIZE=1u"); + } + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("F16"); + defines.push_back("SRC_TYPE=f16"); + defines.push_back("DST_TYPE=f32"); + defines.push_back("BLOCK_SIZE=1u"); + variant += "_f16"; + break; + case GGML_TYPE_I32: + defines.push_back("I32"); + defines.push_back("SRC_TYPE=i32"); + defines.push_back("DST_TYPE=i32"); + defines.push_back("BLOCK_SIZE=1u"); + variant += "_i32"; + break; + default: + { + std::string type_upper = type_str; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + defines.push_back("BYTE_HELPERS"); + defines.push_back(type_upper + "_T"); + defines.push_back(type_upper); + defines.push_back(type_upper + "_SCALE_MIN"); + defines.push_back(type_upper + "_TABLES"); + defines.push_back(type_upper + "_GRID"); + + variant += "_"; + variant += type_str; + + defines.push_back(std::string("SRC_TYPE=") + type_str); + defines.push_back("DST_TYPE=f32"); + + if ((key.src_type >= GGML_TYPE_Q4_0 && key.src_type <= GGML_TYPE_Q8_1) || + key.src_type == GGML_TYPE_IQ4_NL) { + defines.push_back("BLOCK_SIZE=32u"); + } else if (key.src_type >= GGML_TYPE_Q2_K) { + defines.push_back("BLOCK_SIZE=256u"); + } else { + defines.push_back("BLOCK_SIZE=1u"); + } + break; + } + } + + if (key.vectorized) { + variant += "_vec"; + } + + defines.push_back("WG_SIZE=" + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_get_rows, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + get_rows_pipelines[key] = pipeline; + return get_rows_pipelines[key]; + } + + webgpu_pipeline get_scale_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_scale_pipeline_key key = { .inplace = context.inplace }; + + auto it = scale_pipelines.find(key); + if (it != scale_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "scale"; + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_scale, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + scale_pipelines[key] = pipeline; + return scale_pipelines[key]; + } + + webgpu_pipeline get_pad_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_pad_pipeline_key key = { .circular = ggml_get_op_params_i32(context.dst, 8) != 0 }; + + auto it = pad_pipelines.find(key); + if (it != pad_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "pad"; + + if (key.circular) { + defines.push_back("CIRCULAR"); + variant += "_circular"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_pad, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + pad_pipelines[key] = pipeline; + return pad_pipelines[key]; + } + + webgpu_pipeline get_mul_mat_vec_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_mul_mat_vec_pipeline_key key = { + .src0_type = context.src0->type, + .src1_type = context.src1->type, + // Quantized mat-vec path currently runs scalar; only allow vectorization when both inputs are float + .vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0, + }; + + auto it = mul_mat_vec_pipelines.find(key); + if (it != mul_mat_vec_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "mul_mat_vec"; + + // src1 type (vector) + switch (context.src1->type) { + case GGML_TYPE_F32: + defines.push_back("SRC1_INNER_TYPE=f32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("SRC1_INNER_TYPE=f16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported src1 type for mul_mat_vec shader"); + } + + // src0 type (matrix row) + switch (context.src0->type) { + case GGML_TYPE_F32: + defines.push_back("SRC0_INNER_TYPE=f32"); + defines.push_back("MUL_ACC_FLOAT"); + break; + case GGML_TYPE_F16: + defines.push_back("SRC0_INNER_TYPE=f16"); + defines.push_back("MUL_ACC_FLOAT"); + break; + default: + { + // Quantized types: use helpers but accumulate in f16 + const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); + std::string src0_name = src0_traits->type_name; + std::string type_upper = src0_name; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + defines.push_back("BYTE_HELPERS"); + defines.push_back("MUL_ACC_" + type_upper); + + // For fast path we always dequantize from f16 inside the shader + defines.push_back("SRC0_INNER_TYPE=f16"); + break; + } + } + + // VEC/SCALAR controls + defines.push_back(key.vectorized ? "VEC" : "SCALAR"); + + uint32_t wg_size = WEBGPU_MUL_MAT_VEC_WG_SIZE; + uint32_t tile_k = WEBGPU_MUL_MAT_VEC_TILE_K; + uint32_t outputs_per_wg = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + defines.push_back(std::string("TILE_K=") + std::to_string(tile_k)); + defines.push_back(std::string("OUTPUTS_PER_WG=") + std::to_string(outputs_per_wg)); + + auto processed = preprocessor.preprocess(wgsl_mul_mat_vec, defines); + auto decisions = std::make_shared(); + decisions->wg_size = wg_size; + decisions->tile_k = tile_k; + decisions->outputs_per_wg = outputs_per_wg; + decisions->vec_size = key.vectorized ? 4 : 1; + + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + mul_mat_vec_pipelines[key] = pipeline; + return mul_mat_vec_pipelines[key]; + } + + webgpu_pipeline get_mul_mat_fast_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_mul_mat_pipeline_key key = { + .src0_type = context.src0->type, + .src1_type = context.src1->type, + .vectorized = (context.src0->ne[0] % 4 == 0 && context.dst->ne[0] % 4 == 0 && context.dst->ne[1] % 4 == 0 && + (context.src0->type == GGML_TYPE_F32 || context.src0->type == GGML_TYPE_F16)) ? + 1 : + 0, + .use_subgroup_matrix = context.supports_subgroup_matrix + }; + + auto it = mul_mat_fast_pipelines.find(key); + if (it != mul_mat_fast_pipelines.end()) { + return it->second; + } + + const char * shader_src = key.use_subgroup_matrix ? wgsl_mul_mat_subgroup_matrix : wgsl_mul_mat_reg_tile; + std::vector defines; + std::string variant = key.use_subgroup_matrix ? "mul_mat_subgroup_matrix" : "mul_mat_reg_tile"; + + // src1 type + switch (context.src1->type) { + case GGML_TYPE_F32: + defines.push_back("SRC1_INNER_TYPE=f32"); + break; + case GGML_TYPE_F16: + defines.push_back("SRC1_INNER_TYPE=f16"); + break; + default: + GGML_ABORT("Unsupported src1 type for mul_mat fast shader"); + } + + // src0 type + const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); + const char * src0_name = src0_traits->type_name; + + switch (context.src0->type) { + case GGML_TYPE_F32: + defines.push_back("SRC0_INNER_TYPE=f32"); + defines.push_back("FLOAT"); + defines.push_back("MUL_ACC_FLOAT"); + defines.push_back("INIT_SRC0_SHMEM_FLOAT"); + defines.push_back("INIT_SRC1_SHMEM_FLOAT"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("SRC0_INNER_TYPE=f16"); + defines.push_back("FLOAT"); + defines.push_back("MUL_ACC_FLOAT"); + defines.push_back("INIT_SRC0_SHMEM_FLOAT"); + defines.push_back("INIT_SRC1_SHMEM_FLOAT"); + variant += "_f16"; + break; + default: + { + std::string type_upper = src0_name; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + defines.push_back("BYTE_HELPERS"); + defines.push_back("MUL_ACC_" + type_upper); + defines.push_back("INIT_SRC0_SHMEM_" + type_upper); + defines.push_back("INIT_SRC1_SHMEM_FLOAT"); + + // Use f16 inside the shader for quantized types + defines.push_back("SRC0_INNER_TYPE=f16"); + + variant += std::string("_") + src0_name; + break; + } + } + + // VEC/SCALAR controls + defines.push_back(key.vectorized ? "VEC" : "SCALAR"); + + // Tiles + defines.push_back("TILE_M=" + std::to_string(WEBGPU_MUL_MAT_TILE_M) + "u"); + defines.push_back("TILE_N=" + std::to_string(WEBGPU_MUL_MAT_TILE_N) + "u"); + defines.push_back("TILE_K=" + std::to_string(WEBGPU_MUL_MAT_TILE_K) + "u"); + + // Subgroup matrix specifics + if (key.use_subgroup_matrix) { + defines.push_back("MAX_SUBGROUP_SIZE=" + std::to_string(context.max_subgroup_size) + "u"); + defines.push_back("SUBGROUP_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M) + "u"); + defines.push_back("SUBGROUP_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N) + "u"); + defines.push_back("SUBGROUP_MATRIX_M=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M) + "u"); + defines.push_back("SUBGROUP_MATRIX_N=" + std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N) + "u"); + defines.push_back("SUBGROUP_MATRIX_M_SIZE=" + std::to_string(context.sg_mat_m) + "u"); + defines.push_back("SUBGROUP_MATRIX_N_SIZE=" + std::to_string(context.sg_mat_n) + "u"); + defines.push_back("SUBGROUP_MATRIX_K_SIZE=" + std::to_string(context.sg_mat_k) + "u"); + } + + // variant suffix for src1 type + variant += std::string("_") + (context.src1->type == GGML_TYPE_F32 ? "f32" : "f16"); + if (key.vectorized) { + variant += "_vectorized"; + } + + if (!key.use_subgroup_matrix) { + defines.push_back("WORKGROUP_SIZE_M=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_M) + "u"); + defines.push_back("WORKGROUP_SIZE_N=" + std::to_string(WEBGPU_MUL_MAT_WG_SIZE_N) + "u"); + } + + auto processed = preprocessor.preprocess(shader_src, defines); + + auto decisions = std::make_shared(); + decisions->tile_k = WEBGPU_MUL_MAT_TILE_K; + decisions->tile_m = WEBGPU_MUL_MAT_TILE_M; + decisions->tile_n = WEBGPU_MUL_MAT_TILE_N; + decisions->use_subgroup_matrix = key.use_subgroup_matrix; + if (key.use_subgroup_matrix) { + decisions->subgroup_m = WEBGPU_MUL_MAT_SUBGROUP_M; + decisions->subgroup_n = WEBGPU_MUL_MAT_SUBGROUP_N; + decisions->subgroup_matrix_m = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M; + decisions->subgroup_matrix_n = WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N; + decisions->wg_size = context.max_subgroup_size; + } else { + decisions->wg_size_m = WEBGPU_MUL_MAT_WG_SIZE_M; + decisions->wg_size_n = WEBGPU_MUL_MAT_WG_SIZE_N; + decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE_M * WEBGPU_MUL_MAT_WG_SIZE_N; + decisions->mul_mat_wg_size = WEBGPU_MUL_MAT_WG_SIZE; + } + + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + mul_mat_fast_pipelines[key] = pipeline; + return mul_mat_fast_pipelines[key]; + } + + webgpu_pipeline get_mul_mat_legacy_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_legacy_mul_mat_pipeline_key key = { .src0_type = context.src0->type, + .src1_type = context.src1->type }; + + auto it = mul_mat_legacy_pipelines.find(key); + if (it != mul_mat_legacy_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "mul_mat"; + + switch (context.src1->type) { + case GGML_TYPE_F32: + defines.push_back("SRC1_TYPE=f32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("SRC1_TYPE=f16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported src1 type for mul_mat legacy shader"); + } + + const struct ggml_type_traits * src0_traits = ggml_get_type_traits(context.src0->type); + const char * src0_name = src0_traits->type_name; + + switch (context.src0->type) { + case GGML_TYPE_F32: + defines.push_back("SRC0_TYPE=f32"); + defines.push_back("FLOAT"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("SRC0_TYPE=f16"); + defines.push_back("FLOAT"); + variant += "_f16"; + break; + default: + { + // quantized types + std::string type_upper = src0_name; + std::transform(type_upper.begin(), type_upper.end(), type_upper.begin(), ::toupper); + + defines.push_back(std::string("SRC0_TYPE=") + src0_name); + defines.push_back("BYTE_HELPERS"); + defines.push_back(type_upper + "_T"); + defines.push_back(type_upper); + defines.push_back(type_upper + "_SCALE_MIN"); + defines.push_back(type_upper + "_TABLES"); + defines.push_back(type_upper + "_GRID"); + + variant += std::string("_") + src0_name; + break; + } + } + + auto processed = preprocessor.preprocess(wgsl_mul_mat, defines); + + auto decisions = std::make_shared(); + decisions->wg_size = WEBGPU_MUL_MAT_WG_SIZE; + + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + mul_mat_legacy_pipelines[key] = pipeline; + return mul_mat_legacy_pipelines[key]; + } + + webgpu_pipeline get_unary_pipeline(const ggml_webgpu_shader_lib_context & context) { + const bool is_unary = context.dst->op == GGML_OP_UNARY; + const int op = is_unary ? (int) ggml_get_unary_op(context.dst) : context.dst->op; + ggml_webgpu_unary_pipeline_key key = { + .type = context.dst->type, + .op = op, + .is_unary = is_unary, + .inplace = context.inplace, + }; + + auto it = unary_pipelines.find(key); + if (it != unary_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = + key.is_unary ? ggml_unary_op_name((ggml_unary_op) key.op) : ggml_op_name((ggml_op) key.op); + defines.push_back(variant); + + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("TYPE_F16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported type for unary shader"); + } + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_unary, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + unary_pipelines[key] = pipeline; + return unary_pipelines[key]; + } + + webgpu_pipeline get_binary_pipeline(const ggml_webgpu_shader_lib_context & context) { + ggml_webgpu_binary_pipeline_key key = { + .type = context.dst->type, + .op = context.dst->op, + .inplace = context.inplace, + .overlap = context.overlap, + }; + + auto it = binary_pipelines.find(key); + if (it != binary_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string op_name = ggml_op_name((ggml_op) key.op); + std::string variant = op_name; + + defines.push_back(std::string("OP_") + op_name); + + switch (key.type) { + case GGML_TYPE_F32: + defines.push_back("TYPE_F32"); + variant += "_f32"; + break; + case GGML_TYPE_F16: + defines.push_back("TYPE_F16"); + variant += "_f16"; + break; + default: + GGML_ABORT("Unsupported type for binary shader"); + } + + if (key.inplace) { + defines.push_back("INPLACE"); + variant += "_inplace"; + } else if (key.overlap) { + defines.push_back("OVERLAP"); + variant += "_overlap"; + } + + defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); + + auto processed = preprocessor.preprocess(wgsl_binary, defines); + auto decisions = std::make_shared(); + decisions->wg_size = context.max_wg_size; + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + binary_pipelines[key] = pipeline; + return binary_pipelines[key]; + } + + webgpu_pipeline get_flash_attn_pipeline(const ggml_webgpu_shader_lib_context & context) { + const bool has_mask = context.src3 != nullptr; + const bool has_sinks = context.src4 != nullptr; + + bool kv_direct = (context.src1->type == GGML_TYPE_F16) && (context.src0->ne[0] % context.sg_mat_k == 0) && + (context.src1->ne[1] % context.sg_mat_n == 0); + + ggml_webgpu_flash_attn_pipeline_key key = { + .kv_type = context.src1->type, + .head_dim_qk = (uint32_t) context.src0->ne[0], + .head_dim_v = (uint32_t) context.src2->ne[0], + .kv_direct = kv_direct, + .has_mask = has_mask, + .has_sinks = has_sinks, + .uses_logit_softcap = (*(float *) &context.dst->op_params[2]) != 0.0f, + }; + + auto it = flash_attn_pipelines.find(key); + if (it != flash_attn_pipelines.end()) { + return it->second; + } + + std::vector defines; + std::string variant = "flash_attn"; + + switch (key.kv_type) { + case GGML_TYPE_F32: + defines.push_back("KV_F32"); + break; + case GGML_TYPE_F16: + defines.push_back("KV_F16"); + break; + case GGML_TYPE_Q4_0: + defines.push_back("KV_Q4_0"); + break; + case GGML_TYPE_Q8_0: + defines.push_back("KV_Q8_0"); + break; + default: + GGML_ABORT("Unsupported KV type for flash attention shader"); + } + variant += std::string("_") + ggml_type_name(key.kv_type); + + if (key.has_mask) { + defines.push_back("MASK"); + variant += "_mask"; + } + if (key.has_sinks) { + defines.push_back("SINKS"); + variant += "_sinks"; + } + if (key.uses_logit_softcap) { + defines.push_back("LOGIT_SOFTCAP"); + variant += "_lgsc"; + } + if (key.kv_direct) { + defines.push_back("KV_DIRECT"); + variant += "_kvdirect"; + } + + defines.push_back(std::string("HEAD_DIM_QK=") + std::to_string(key.head_dim_qk)); + variant += std::string("_hsqk") + std::to_string(key.head_dim_qk); + + defines.push_back(std::string("HEAD_DIM_V=") + std::to_string(key.head_dim_v)); + variant += std::string("_hsv") + std::to_string(key.head_dim_v); + + defines.push_back(std::string("SG_MAT_M=") + std::to_string(context.sg_mat_m)); + defines.push_back(std::string("SG_MAT_N=") + std::to_string(context.sg_mat_n)); + defines.push_back(std::string("SG_MAT_K=") + std::to_string(context.sg_mat_k)); + + uint32_t q_tile = context.sg_mat_m; + uint32_t kv_tile = + std::min(ggml_webgpu_flash_attn_max_kv_tile({ key, context.sg_mat_m, context.sg_mat_n, context.sg_mat_k, + context.wg_mem_limit_bytes, context.max_subgroup_size }), + context.sg_mat_n * GGML_WEBGPU_FLASH_ATTN_PREFERRED_KV_SG_TILES); + if (key.kv_direct) { + while (GGML_WEBGPU_KV_SEQ_PAD % kv_tile != 0) { + kv_tile -= context.sg_mat_n; + } + } + + defines.push_back(std::string("Q_TILE=") + std::to_string(q_tile)); + defines.push_back(std::string("KV_TILE=") + std::to_string(kv_tile)); + + uint32_t wg_size = std::max(context.max_subgroup_size, GGML_WEBGPU_FLASH_ATTN_PREFERRED_WG_SIZE); + defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size)); + + auto processed = preprocessor.preprocess(wgsl_flash_attn, defines); + auto decisions = std::make_shared(); + decisions->q_tile = q_tile; + decisions->kv_tile = kv_tile; + decisions->wg_size = wg_size; + + webgpu_pipeline pipeline = ggml_webgpu_create_pipeline(device, processed, variant); + pipeline.context = decisions; + flash_attn_pipelines[key] = pipeline; + return flash_attn_pipelines[key]; + } + + private: + static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, + std::string shader_code, + std::string label) { + wgpu::ShaderSourceWGSL shader_source; + shader_source.code = shader_code.c_str(); + + wgpu::ShaderModuleDescriptor shader_desc; + shader_desc.nextInChain = &shader_source; + + wgpu::ShaderModule shader_module = device.CreateShaderModule(&shader_desc); + + wgpu::ComputePipelineDescriptor pipeline_desc; + pipeline_desc.label = label.c_str(); + pipeline_desc.compute.module = shader_module; + pipeline_desc.compute.entryPoint = "main"; // Entry point in the WGSL code + pipeline_desc.layout = nullptr; // nullptr means auto layout + return { device.CreateComputePipeline(&pipeline_desc), label }; + } + + static uint32_t ggml_webgpu_flash_attn_max_kv_tile(const ggml_webgpu_flash_attn_shader_lib_context & context) { + const size_t limit_bytes = context.wg_mem_limit_bytes; + const size_t q_tile = context.sg_mat_m; + const size_t base_q_bytes = + (context.key.head_dim_qk + context.key.head_dim_v) * q_tile * GGML_WEBGPU_F16_SIZE_BYTES + + 2 * q_tile * GGML_WEBGPU_F32_SIZE_BYTES; + size_t bytes_per_kv = 0; + if (!context.key.kv_direct) { + bytes_per_kv += std::max(context.key.head_dim_qk, context.key.head_dim_v); + } + if (context.key.has_mask) { + bytes_per_kv += q_tile; + } + bytes_per_kv += q_tile; + bytes_per_kv *= GGML_WEBGPU_F16_SIZE_BYTES; + const uint32_t max_kv_tile = (limit_bytes - base_q_bytes) / bytes_per_kv; + return (max_kv_tile / context.sg_mat_n) * context.sg_mat_n; } }; -struct ggml_webgpu_set_rows_pipeline_key_hash { - size_t operator()(const ggml_webgpu_set_rows_pipeline_key & key) const { - size_t seed = 0; - ggml_webgpu_hash_combine(seed, key.dst_type); - ggml_webgpu_hash_combine(seed, key.vec4); - ggml_webgpu_hash_combine(seed, key.i64_idx); - return seed; - } -}; - -struct ggml_webgpu_set_rows_shader_lib_context { - ggml_webgpu_set_rows_pipeline_key key; - uint32_t max_wg_size; -}; - -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_set_rows_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_set_rows_shader_lib_context & context) { - std::vector defines; - std::string variant = "set_rows"; - - switch (context.key.dst_type) { - case GGML_TYPE_F32: - defines.push_back("DST_F32"); - variant += "_dstf32"; - break; - case GGML_TYPE_F16: - defines.push_back("DST_F16"); - variant += "_dstf16"; - break; - default: - GGML_ABORT("Unsupported dst type for set_rows shader"); - } - - if (context.key.vec4) { - defines.push_back("VEC4"); - variant += "_vec"; - } - if (context.key.i64_idx) { - defines.push_back("I64_IDX"); - variant += "_i64idx"; - } - - defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions(); - decisions->wg_size = context.max_wg_size; - result.decisions = decisions; - return result; -} - -struct ggml_webgpu_unary_pipeline_key { - int type; - int op; - bool is_unary; // many unary operators fall under the GGML_OP_UNARY umbrella - bool inplace; - - bool operator==(const ggml_webgpu_unary_pipeline_key & other) const { - return type == other.type && op == other.op && is_unary == other.is_unary && inplace == other.inplace; - } -}; - -struct ggml_webgpu_unary_pipeline_key_hash { - size_t operator()(const ggml_webgpu_unary_pipeline_key & key) const { - size_t seed = 0; - ggml_webgpu_hash_combine(seed, key.type); - ggml_webgpu_hash_combine(seed, key.op); - ggml_webgpu_hash_combine(seed, key.is_unary); - ggml_webgpu_hash_combine(seed, key.inplace); - return seed; - } -}; - -struct ggml_webgpu_unary_shader_lib_context { - ggml_webgpu_unary_pipeline_key key; - uint32_t max_wg_size; -}; - -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_unary_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_unary_shader_lib_context & context) { - std::vector defines; - std::string variant = context.key.is_unary ? ggml_unary_op_name((ggml_unary_op) context.key.op) : - ggml_op_name((ggml_op) context.key.op); - // Operation-specific behavior - defines.push_back(variant); - - switch (context.key.type) { - case GGML_TYPE_F32: - defines.push_back("TYPE_F32"); - variant += "_f32"; - break; - case GGML_TYPE_F16: - defines.push_back("TYPE_F16"); - variant += "_f16"; - break; - default: - GGML_ABORT("Unsupported type for unary shader"); - } - - if (context.key.inplace) { - defines.push_back("INPLACE"); - variant += "_inplace"; - } - - defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions(); - decisions->wg_size = context.max_wg_size; - result.decisions = decisions; - return result; -} - -/** Binary **/ - -struct ggml_webgpu_binary_pipeline_key { - int type; - int op; - bool inplace; - bool overlap; - - bool operator==(const ggml_webgpu_binary_pipeline_key & other) const { - return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap; - } -}; - -struct ggml_webgpu_binary_pipeline_key_hash { - size_t operator()(const ggml_webgpu_binary_pipeline_key & key) const { - size_t seed = 0; - ggml_webgpu_hash_combine(seed, key.type); - ggml_webgpu_hash_combine(seed, key.op); - ggml_webgpu_hash_combine(seed, key.inplace); - ggml_webgpu_hash_combine(seed, key.overlap); - return seed; - } -}; - -struct ggml_webgpu_binary_shader_lib_context { - ggml_webgpu_binary_pipeline_key key; - uint32_t max_wg_size; -}; - -inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_binary_shader( - pre_wgsl::Preprocessor & preprocessor, - const char * shader_src, - const ggml_webgpu_binary_shader_lib_context & context) { - std::vector defines; - std::string op_name = ggml_op_name((ggml_op) context.key.op); - std::string variant = op_name; - - defines.push_back(std::string("OP_") + op_name); - - switch (context.key.type) { - case GGML_TYPE_F32: - defines.push_back("TYPE_F32"); - variant += "_f32"; - break; - case GGML_TYPE_F16: - defines.push_back("TYPE_F16"); - variant += "_f16"; - break; - default: - GGML_ABORT("Unsupported type for binary shader"); - } - - if (context.key.inplace) { - defines.push_back("INPLACE"); - variant += "_inplace"; - } else if (context.key.overlap) { - defines.push_back("OVERLAP"); - variant += "_overlap"; - } - - defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size)); - ggml_webgpu_processed_shader result; - result.wgsl = preprocessor.preprocess(shader_src, defines); - result.variant = variant; - ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions(); - decisions->wg_size = context.max_wg_size; - result.decisions = decisions; - return result; -} #endif // GGML_WEBGPU_SHADER_LIB_HPP diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index f7ceca1121..b5fee48056 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -8,7 +8,6 @@ #include "ggml-backend-impl.h" #include "ggml-impl.h" #include "ggml-webgpu-shader-lib.hpp" -#include "ggml-wgsl-shaders.hpp" #include "pre_wgsl.hpp" #ifdef __EMSCRIPTEN__ @@ -23,6 +22,7 @@ #include #include #include +#include #include #include #include @@ -69,50 +69,29 @@ /* Constants */ -// Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to implementations so this can be removed. -#define WEBGPU_MAX_WG_SIZE 288 - -#define WEBGPU_MUL_MAT_WG_SIZE 256 #define WEBGPU_NUM_PARAM_BUFS 16u #define WEBGPU_COMMAND_SUBMIT_BATCH_SIZE 8u #define WEBGPU_WAIT_ANY_TIMEOUT_MS 0 -// Maximum number of in-flight submissions per-thread, to avoid exhausting the parameter buffer pool +// Maximum number of in-flight submissions per-thread, to avoid exhausting the +// parameter buffer pool #define WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD WEBGPU_NUM_PARAM_BUFS / WEBGPU_COMMAND_SUBMIT_BATCH_SIZE #define WEBGPU_PARAMS_BUF_SIZE_BYTES 128 // enough for 32 parameters #define WEBGPU_NUM_SET_ROWS_ERROR_BUFS 16 #define WEBGPU_SET_ROWS_ERROR_BUF_SIZE_BYTES 4 #define WEBGPU_STORAGE_BUF_BINDING_MULT 4 // a storage buffer binding size must be a multiple of 4 -// For operations which process a row in parallel, this seems like a reasonable default +// For operations which process a row in parallel, this seems like a reasonable +// default #define WEBGPU_ROW_SPLIT_WG_SIZE 64 -// Matrix multiplication parameters - -// Register tiling parameters -#define WEBGPU_MUL_MAT_TILE_M 8 -#define WEBGPU_MUL_MAT_TILE_N 8 -#define WEBGPU_MUL_MAT_WG_SIZE_M 8 -#define WEBGPU_MUL_MAT_WG_SIZE_N 8 -#define WEBGPU_MUL_MAT_TILE_K 32 - -// Subgroup matrix parameters -// The number of subgroups in the M dimension -#define WEBGPU_MUL_MAT_SUBGROUP_M 2 -// The number of subgroups in the N dimension -#define WEBGPU_MUL_MAT_SUBGROUP_N 2 -// The number of subgroup matrices each subgroup accumulates over -#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M 4 -#define WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N 2 - -// Matrix-vector multiplication parameters -#define WEBGPU_MUL_MAT_VEC_WG_SIZE 256 -// Must be multiple of 4 to work with vectorized paths, and must divide mul_mat_vec wg size -#define WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG 64 -#define WEBGPU_MUL_MAT_VEC_TILE_K 256 +// Track https://github.com/gpuweb/gpuweb/issues/5315 for fixes to +// implementations so this can be removed, necessary only for get_rows right now +#define WEBGPU_MAX_WG_SIZE 288 /* End Constants */ -// This is a "fake" base pointer, since WebGPU buffers do not have pointers to their locations. +// This is a "fake" base pointer, since WebGPU buffers do not have pointers to +// their locations. static void * const webgpu_ptr_base = (void *) (uintptr_t) 0x1000; // NOLINT // Always returns the base offset of a tensor, regardless of views. @@ -186,11 +165,17 @@ struct webgpu_buf_pool { void cleanup() { std::lock_guard lock(mutex); for (auto & bufs : free) { - bufs.host_buf.Destroy(); - bufs.dev_buf.Destroy(); + if (bufs.host_buf) { + bufs.host_buf.Destroy(); + } + if (bufs.dev_buf) { + bufs.dev_buf.Destroy(); + } } free.clear(); } + + ~webgpu_buf_pool() { this->cleanup(); } }; #ifdef GGML_WEBGPU_GPU_PROFILE @@ -252,15 +237,11 @@ struct webgpu_gpu_profile_buf_pool { } free.clear(); } + + ~webgpu_gpu_profile_buf_pool() { this->cleanup(); } }; #endif -struct webgpu_pipeline { - wgpu::ComputePipeline pipeline; - std::string name; - void * context = nullptr; -}; - struct webgpu_command { wgpu::CommandBuffer commands; std::vector params_bufs; @@ -319,6 +300,23 @@ struct webgpu_global_context_struct { wgpu::Buffer debug_host_buf; wgpu::Buffer debug_dev_buf; #endif + + ~webgpu_global_context_struct() { + if (this->get_tensor_staging_buf) { + this->get_tensor_staging_buf.Destroy(); + this->get_tensor_staging_buf = nullptr; + } +#ifdef GGML_WEBGPU_DEBUG + if (this->debug_host_buf) { + this->debug_host_buf.Destroy(); + this->debug_host_buf = nullptr; + } + if (this->debug_dev_buf) { + this->debug_dev_buf.Destroy(); + this->debug_dev_buf = nullptr; + } +#endif + } }; typedef std::shared_ptr webgpu_global_context; @@ -328,41 +326,18 @@ struct webgpu_context_struct { // Points to global instances owned by ggml_backend_webgpu_reg_context webgpu_global_context global_ctx; - pre_wgsl::Preprocessor p; + std::unique_ptr shader_lib; webgpu_buf_pool param_buf_pool; webgpu_buf_pool set_rows_error_buf_pool; - std::map>> mul_mat_pipelines; // src0_type, src1_type, vectorized - std::map>> - mul_mat_vec_pipelines; // src0_type, src1_type, vectorized - - std::unordered_map - flash_attn_pipelines; - - std::unordered_map argmax_pipelines; // key is vec4 - std::unordered_map argsort_pipelines; // key is order (asc/desc) - std::unordered_map argsort_merge_pipelines; // key is order (asc/desc) - std::unordered_map cumsum_pipelines; // key is fixed, no variants yet - std::unordered_map sum_rows_pipelines; // key is fixed, no variants yet - - std::unordered_map - set_rows_pipelines; - std::map> get_rows_pipelines; // src_type, vectorized - - std::map> cpy_pipelines; // src_type, dst_type - - std::unordered_map - binary_pipelines; + std::map> cpy_pipelines; // src_type, dst_type std::map rms_norm_pipelines; // inplace std::map>> rope_pipelines; // type, ff, inplace std::map>> glu_pipelines; // glu_op, type, split - std::map scale_pipelines; // inplace + std::map>> soft_max_pipelines; // mask_type, has_sink, inplace - std::unordered_map - unary_pipelines; - std::unordered_map pad_pipelines; size_t memset_bytes_per_thread; }; @@ -404,25 +379,6 @@ struct ggml_backend_webgpu_buffer_context { /* WebGPU object initializations */ -// Process a WGSL shader string, replacing tokens of the form {{KEY}} with -// the corresponding values provided in `repls`. -static std::string ggml_webgpu_process_shader_repls(const char * src, - const std::map & repls) { - if (!src) { - return std::string(); - } - std::string s = src; - for (const auto & kv : repls) { - std::string token = "{{" + kv.first + "}}"; - size_t pos = 0; - while ((pos = s.find(token, pos)) != std::string::npos) { - s.replace(pos, token.length(), kv.second); - pos += kv.second.length(); - } - } - return s; -} - static webgpu_pipeline ggml_webgpu_create_pipeline(wgpu::Device & device, const char * shader_code, const char * label, @@ -470,8 +426,9 @@ static void ggml_webgpu_create_buffer(wgpu::Device & device, static void ggml_backend_webgpu_wait(webgpu_global_context & ctx, std::vector & futures, bool block = true) { - // If we have too many in-flight submissions, wait on the oldest one first. If there are many threads, - // inflight_max may be 0, meaning that we must wait on all futures. + // If we have too many in-flight submissions, wait on the oldest one first. If + // there are many threads, inflight_max may be 0, meaning that we must wait on + // all futures. uint64_t timeout_ms = block ? UINT64_MAX : 0; uint32_t inflight_threads = ctx->inflight_threads; uint32_t inflight_max = WEBGPU_MAX_INFLIGHT_SUBS_PER_THREAD / std::max(inflight_threads, 1u); @@ -656,7 +613,8 @@ static webgpu_command ggml_backend_webgpu_build_multi( encoder.CopyBufferToBuffer(params_bufs.host_buf, 0, params_bufs.dev_buf, 0, params_bufs.dev_buf.GetSize()); } - // If there are SET_ROWS operations in this submission, copy their error buffers to the host. + // If there are SET_ROWS operations in this submission, copy their error + // buffers to the host. if (set_rows_error_bufs) { encoder.CopyBufferToBuffer(set_rows_error_bufs->dev_buf, 0, set_rows_error_bufs->host_buf, 0, set_rows_error_bufs->host_buf.GetSize()); @@ -744,7 +702,6 @@ static const char * ggml_backend_webgpu_name(ggml_backend_t backend) { return ctx->name.c_str(); } -// TODO: implement proper cleanup static void ggml_backend_webgpu_free(ggml_backend_t backend) { ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context; WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")"); @@ -788,9 +745,8 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) { std::cout << "ggml_webgpu: gpu/cpu ratio: " << (total_cpu > 0.0 ? total_gpu / total_cpu : 0.0) << "\n"; #endif -#if !defined(GGML_WEBGPU_CPU_PROFILE) && !defined(GGML_WEBGPU_GPU_PROFILE) - GGML_UNUSED(ctx); -#endif + delete ctx; + delete backend; } static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) { @@ -877,27 +833,13 @@ static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, g } static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - const bool circular = ggml_get_op_params_i32(dst, 8) != 0; - - ggml_webgpu_pad_pipeline_key pipeline_key = { .circular = circular }; - ggml_webgpu_pad_shader_lib_context shader_lib_ctx = { - .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup }; - webgpu_pipeline pipeline; - auto it = ctx->pad_pipelines.find(pipeline_key); - if (it != ctx->pad_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = ggml_webgpu_preprocess_pad_shader(ctx->p, wgsl_pad, shader_lib_ctx); - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - pipeline.context = processed.decisions; - ctx->pad_pipelines.emplace(pipeline_key, pipeline); - } + webgpu_pipeline pipeline = ctx->shader_lib->get_pad_pipeline(shader_lib_ctx); - ggml_webgpu_generic_shader_decisions decisions = - *static_cast(pipeline.context); + auto * decisions = static_cast(pipeline.context.get()); const uint32_t ne = (uint32_t) ggml_nelements(dst); @@ -941,7 +883,7 @@ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, g .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - uint32_t wg_x = CEIL_DIV(ne, decisions.wg_size); + uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } @@ -949,37 +891,25 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) { - // For set rows specifically, we need to check if src and idx are empty tensors. + // For set rows specifically, we need to check if src and idx are empty + // tensors. if (ggml_is_empty(src) || ggml_is_empty(idx)) { return std::nullopt; } - ggml_webgpu_set_rows_pipeline_key key = { .dst_type = dst->type, - .vec4 = src->ne[0] % 4 == 0, - .i64_idx = idx->type == GGML_TYPE_I64 }; - - ggml_webgpu_set_rows_shader_lib_context shader_lib_ctx = { - .key = key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, + .src1 = idx, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup }; - webgpu_pipeline pipeline; - auto it = ctx->set_rows_pipelines.find(key); - if (it != ctx->set_rows_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_set_rows_shader(ctx->p, wgsl_set_rows, shader_lib_ctx); - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - pipeline.context = processed.decisions; - ctx->set_rows_pipelines.emplace(key, pipeline); - } + webgpu_pipeline pipeline = ctx->shader_lib->get_set_rows_pipeline(shader_lib_ctx); - ggml_webgpu_generic_shader_decisions decisions = - *static_cast(pipeline.context); + auto * decisions = static_cast(pipeline.context.get()); std::optional error_bufs = std::nullopt; - if (key.i64_idx) { + if (decisions->i64_idx) { error_bufs = ctx->set_rows_error_buf_pool.alloc_bufs(); if (error_bufs->host_buf.GetMapState() == wgpu::BufferMapState::Mapped) { error_bufs->host_buf.Unmap(); @@ -1017,42 +947,63 @@ static std::optional ggml_webgpu_set_rows(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - if (key.i64_idx) { + if (decisions->i64_idx) { entries.push_back( { .binding = 3, .buffer = error_bufs->dev_buf, .offset = 0, .size = error_bufs->dev_buf.GetSize() }); } uint32_t threads; - if (key.vec4) { + if (decisions->vec4) { threads = (src->ne[1] * src->ne[2] * src->ne[3]) * (src->ne[0] / 4); } else { threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3]; } - uint32_t wg_x = CEIL_DIV(threads, decisions.wg_size); + uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1, error_bufs); } +// Workgroup size is a common constant +static std::vector ggml_webgpu_wg_size_entry(uint32_t wg_size) { + std::vector constants(1); + constants[0].key = "wg_size"; + constants[0].value = wg_size; + return constants; +} + static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) { - std::vector params = { - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), - // Convert byte-strides to element-strides - (uint32_t) (src->nb[1] / ggml_type_size(src->type)), (uint32_t) (src->nb[2] / ggml_type_size(src->type)), - (uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)), - (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)), - (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), - (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), - // Shape of dst - (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2], (uint32_t) dst->ne[3], - // Shape of idx - (uint32_t) (idx->ne[1]), (uint32_t) (idx->ne[2]) + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, + .src1 = nullptr, + .dst = dst, + .max_wg_size = WEBGPU_MAX_WG_SIZE, }; + webgpu_pipeline pipeline = ctx->shader_lib->get_get_rows_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + + std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, idx) / ggml_type_size(idx->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), + (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + (uint32_t) (idx->nb[0] / ggml_type_size(idx->type)), + (uint32_t) (idx->nb[1] / ggml_type_size(idx->type)), + (uint32_t) (idx->nb[2] / ggml_type_size(idx->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) dst->ne[3], + (uint32_t) (idx->ne[1]), + (uint32_t) (idx->ne[2]) }; + std::vector entries = { { .binding = 0, .buffer = ggml_webgpu_tensor_buf(src), @@ -1068,10 +1019,8 @@ static webgpu_command ggml_webgpu_get_rows(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MAX_WG_SIZE); + uint32_t wg_x = CEIL_DIV(dst->ne[1] * dst->ne[2] * dst->ne[3], decisions->wg_size); - uint32_t vectorized = src->type == GGML_TYPE_F32 && dst->ne[0] % 4 == 0; - webgpu_pipeline pipeline = ctx->get_rows_pipelines[src->type][vectorized]; return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } @@ -1079,45 +1028,10 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { - std::vector params = { - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), - (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), - (uint32_t) dst->ne[0], // number of rows in result (M, transposed) - (uint32_t) dst->ne[1], // number of columns in result (N) - (uint32_t) src0->ne[0], // number of columns in src0/src1 (K) - (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 1 - (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 1 - (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 2 - (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 2 - (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), // stride (elements/blocks) of src0 in dimension 3 - (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), // stride (elements/blocks) of src1 in dimension 3 - (uint32_t) src0->ne[2], // batch size in dimension 2 - (uint32_t) src0->ne[3], // batch size in dimension 3 - (uint32_t) (src1->ne[2] / src0->ne[2]), // broadcast in dimension 2 - (uint32_t) (src1->ne[3] / src0->ne[3]) // broadcast in dimension 3 - }; - - std::vector entries = { - { .binding = 0, - .buffer = ggml_webgpu_tensor_buf(src0), - .offset = ggml_webgpu_tensor_align_offset(ctx, src0), - .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, - { .binding = 1, - .buffer = ggml_webgpu_tensor_buf(src1), - .offset = ggml_webgpu_tensor_align_offset(ctx, src1), - .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, - { .binding = 2, - .buffer = ggml_webgpu_tensor_buf(dst), - .offset = ggml_webgpu_tensor_align_offset(ctx, dst), - .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, - }; - - webgpu_pipeline pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][0]; - - uint32_t wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], WEBGPU_MUL_MAT_WG_SIZE); - uint32_t wg_y = 1; + // Determine if this is a mat-vec operation + bool is_vec = (dst->ne[1] == 1); + // Determine if we should use fast path bool use_fast = false; switch (src1->type) { case GGML_TYPE_F16: @@ -1138,43 +1052,104 @@ static webgpu_command ggml_webgpu_mul_mat(webgpu_context & ctx, break; } - if (use_fast) { - int vectorized = src0->ne[0] % 4 == 0 && dst->ne[0] % 4 == 0 && dst->ne[1] % 4 == 0; - if (dst->ne[1] == 1) { - // We don't support vectorized mul_mat_vec for quantized types - vectorized = vectorized && (src0->type < 2); - pipeline = ctx->mul_mat_vec_pipelines[src0->type][src1->type][vectorized]; - uint32_t batches = dst->ne[2] * dst->ne[3]; - uint32_t output_groups = CEIL_DIV(dst->ne[0], WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG); - uint32_t total_wg = output_groups * batches; - wg_x = total_wg % ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; - wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension); - } else { - pipeline = ctx->mul_mat_pipelines[src0->type][src1->type][vectorized]; - uint32_t wg_m; - uint32_t wg_n; -#ifndef __EMSCRIPTEN__ - if (ctx->global_ctx->capabilities.supports_subgroup_matrix) { - // The total number of subgroups/workgroups needed per matrix. - uint32_t wg_m_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_M * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M * - ctx->global_ctx->capabilities.sg_mat_m; - wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile); - uint32_t wg_n_sg_tile = WEBGPU_MUL_MAT_SUBGROUP_N * WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N * - ctx->global_ctx->capabilities.sg_mat_n; - wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile); - } else { -#endif - uint32_t tile_m_s = WEBGPU_MUL_MAT_TILE_M * WEBGPU_MUL_MAT_WG_SIZE_M; - uint32_t tile_n_s = WEBGPU_MUL_MAT_TILE_N * WEBGPU_MUL_MAT_WG_SIZE_N; - wg_m = CEIL_DIV(dst->ne[0], tile_m_s); - wg_n = CEIL_DIV(dst->ne[1], tile_n_s); -#ifndef __EMSCRIPTEN__ - } -#endif + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .supports_subgroup_matrix = ctx->global_ctx->capabilities.supports_subgroup_matrix, + .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, + .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, + .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, + .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size, + }; - wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3]; - } + // Get or create pipeline + webgpu_pipeline pipeline; + + if (use_fast && is_vec) { + pipeline = ctx->shader_lib->get_mul_mat_vec_pipeline(shader_lib_ctx); + } else if (use_fast) { + pipeline = ctx->shader_lib->get_mul_mat_fast_pipeline(shader_lib_ctx); + } else { + pipeline = ctx->shader_lib->get_mul_mat_legacy_pipeline(shader_lib_ctx); } + + // Build params + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) src0->ne[0], + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)), + (uint32_t) src0->ne[2], + (uint32_t) src0->ne[3], + (uint32_t) (src1->ne[2] / src0->ne[2]), + (uint32_t) (src1->ne[3] / src0->ne[3]) + }; + + // Build bind group entries + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) }, + { .binding = 2, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }, + }; + + // Calculate workgroup dimensions + uint32_t wg_x = 1; + uint32_t wg_y = 1; + + if (use_fast && is_vec) { + auto decisions = static_cast(pipeline.context.get()); + + uint32_t batches = dst->ne[2] * dst->ne[3]; + uint32_t output_groups = CEIL_DIV(dst->ne[0], decisions->outputs_per_wg); + uint32_t total_wg = output_groups * batches; + // TODO: split large sizes into multiple batches to avoid way over-provisioning workgroups + wg_x = std::min(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension); + wg_y = CEIL_DIV(total_wg, ctx->global_ctx->capabilities.limits.maxComputeWorkgroupsPerDimension); + } else if (use_fast) { + auto decisions = static_cast(pipeline.context.get()); + + // Fast-path tiled/subgroup calculations + uint32_t wg_m, wg_n; + if (decisions->use_subgroup_matrix) { + uint32_t wg_m_sg_tile = + decisions->subgroup_m * decisions->subgroup_matrix_m * ctx->global_ctx->capabilities.sg_mat_m; + wg_m = CEIL_DIV(dst->ne[0], wg_m_sg_tile); + uint32_t wg_n_sg_tile = + decisions->subgroup_n * decisions->subgroup_matrix_n * ctx->global_ctx->capabilities.sg_mat_n; + wg_n = CEIL_DIV(dst->ne[1], wg_n_sg_tile); + } else { + uint32_t tile_m_s = decisions->tile_m * decisions->wg_size_m; + uint32_t tile_n_s = decisions->tile_n * decisions->wg_size_n; + wg_m = CEIL_DIV(dst->ne[0], tile_m_s); + wg_n = CEIL_DIV(dst->ne[1], tile_n_s); + } + wg_x = wg_m * wg_n * dst->ne[2] * dst->ne[3]; + } else { // legacy + auto decisions = static_cast(pipeline.context.get()); + uint32_t wg_size = decisions->wg_size; + wg_x = CEIL_DIV(dst->ne[0] * dst->ne[1] * dst->ne[2] * dst->ne[3], wg_size); + wg_y = 1; + } + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, wg_y); } @@ -1262,45 +1237,26 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); - bool kv_direct = (K->type == GGML_TYPE_F16) && (Q->ne[0] % ctx->global_ctx->capabilities.sg_mat_k == 0) && - (K->ne[1] % GGML_WEBGPU_KV_SEQ_PAD == 0); - - ggml_webgpu_flash_attn_pipeline_key key = { - .kv_type = K->type, - .head_dim_qk = (uint32_t) Q->ne[0], - .head_dim_v = (uint32_t) V->ne[0], - .kv_direct = kv_direct, - .has_mask = static_cast(has_mask), - .has_sinks = static_cast(has_sinks), - .uses_logit_softcap = logit_softcap != 0.0f, + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = Q, + .src1 = K, + .src2 = V, + .src3 = mask, + .src4 = sinks, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, + .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, + .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, + .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, + .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size, }; - webgpu_pipeline pipeline; - auto it = ctx->flash_attn_pipelines.find(key); - if (it != ctx->flash_attn_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_flash_attn_shader_lib_context shader_lib_ctx = { - .key = key, - .sg_mat_m = ctx->global_ctx->capabilities.sg_mat_m, - .sg_mat_n = ctx->global_ctx->capabilities.sg_mat_n, - .sg_mat_k = ctx->global_ctx->capabilities.sg_mat_k, - .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, - .max_subgroup_size = ctx->global_ctx->capabilities.max_subgroup_size - }; + webgpu_pipeline pipeline = ctx->shader_lib->get_flash_attn_pipeline(shader_lib_ctx); - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_flash_attn_shader(ctx->p, wgsl_flash_attn, shader_lib_ctx); - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - pipeline.context = processed.decisions; - ctx->flash_attn_pipelines.emplace(key, pipeline); - } + auto * decisions = static_cast(pipeline.context.get()); - ggml_webgpu_flash_attn_shader_decisions decisions = - *static_cast(pipeline.context); - - uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions.q_tile); + uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile); uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } @@ -1309,30 +1265,18 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx, static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { bool is_unary = dst->op == GGML_OP_UNARY; bool inplace = ggml_webgpu_tensor_equal(src, dst) || (dst->op == GGML_OP_FILL); - int op = is_unary ? (int) ggml_get_unary_op(dst) : dst->op; - ggml_webgpu_unary_pipeline_key pipeline_key = { - .type = dst->type, .op = op, .is_unary = is_unary, .inplace = inplace - }; - ggml_webgpu_unary_shader_lib_context shader_lib_ctx = { - .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, + .src1 = nullptr, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .inplace = inplace, }; - webgpu_pipeline pipeline; - auto it = ctx->unary_pipelines.find(pipeline_key); - if (it != ctx->unary_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_unary_shader(ctx->p, wgsl_unary, shader_lib_ctx); - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - pipeline.context = processed.decisions; - ctx->unary_pipelines.emplace(pipeline_key, pipeline); - } + webgpu_pipeline pipeline = ctx->shader_lib->get_unary_pipeline(shader_lib_ctx); - ggml_webgpu_generic_shader_decisions decisions = - *static_cast(pipeline.context); + auto * decisions = static_cast(pipeline.context.get()); uint32_t ne = (uint32_t) ggml_nelements(dst); @@ -1392,7 +1336,7 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - uint32_t wg_x = CEIL_DIV(ne, decisions.wg_size); + uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } @@ -1402,31 +1346,18 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, ggml_tensor * dst) { binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst); - ggml_webgpu_binary_pipeline_key pipeline_key = { - .type = dst->type, - .op = dst->op, - .inplace = flags.inplace, - .overlap = flags.overlap, - }; - ggml_webgpu_binary_shader_lib_context shader_lib_ctx = { - .key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src0, + .src1 = src1, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .inplace = flags.inplace, + .overlap = flags.overlap, }; - webgpu_pipeline pipeline; - auto it = ctx->binary_pipelines.find(pipeline_key); - if (it != ctx->binary_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_binary_shader(ctx->p, wgsl_binary, shader_lib_ctx); - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - pipeline.context = processed.decisions; - ctx->binary_pipelines.emplace(pipeline_key, pipeline); - } + webgpu_pipeline pipeline = ctx->shader_lib->get_binary_pipeline(shader_lib_ctx); - ggml_webgpu_generic_shader_decisions decisions = - *static_cast(pipeline.context); + auto * decisions = static_cast(pipeline.context.get()); uint32_t ne = (uint32_t) ggml_nelements(dst); @@ -1471,7 +1402,7 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - uint32_t wg_x = CEIL_DIV(ne, decisions.wg_size); + uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } @@ -1651,8 +1582,20 @@ static webgpu_command ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, } static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - int inplace = ggml_webgpu_tensor_equal(src, dst); + bool inplace = ggml_webgpu_tensor_equal(src, dst); + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, + .src1 = nullptr, + .dst = dst, + .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + .inplace = inplace, + }; + + webgpu_pipeline pipeline = ctx->shader_lib->get_scale_pipeline(shader_lib_ctx); + auto * decisions = static_cast(pipeline.context.get()); + + // params unchanged std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), @@ -1670,12 +1613,14 @@ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, *(uint32_t *) &dst->op_params[1] // bias }; + // bindgroups unchanged std::vector entries = { { .binding = 0, .buffer = ggml_webgpu_tensor_buf(src), .offset = ggml_webgpu_tensor_align_offset(ctx, src), .size = ggml_webgpu_tensor_binding_size(ctx, src) } }; + if (!inplace) { entries.push_back({ .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), @@ -1683,9 +1628,8 @@ static webgpu_command ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE); - return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, ctx->scale_pipelines[inplace], params, - entries, wg_x); + uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), decisions->wg_size); + return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } static webgpu_command ggml_webgpu_soft_max(webgpu_context & ctx, @@ -1778,74 +1722,40 @@ static webgpu_command ggml_webgpu_argmax(webgpu_context & ctx, ggml_tensor * src .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - ggml_webgpu_generic_shader_lib_context shader_lib_ctx = { - .vec4 = src->ne[0] % 4 == 0, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup }; - webgpu_pipeline pipeline; - auto it = ctx->argmax_pipelines.find(shader_lib_ctx.vec4); - if (it != ctx->argmax_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_argmax, shader_lib_ctx, "argmax"); - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - ctx->argmax_pipelines.emplace(shader_lib_ctx.vec4, pipeline); - } - uint32_t wg_x = ggml_nelements(dst); + webgpu_pipeline pipeline = ctx->shader_lib->get_argmax_pipeline(shader_lib_ctx); + uint32_t wg_x = ggml_nelements(dst); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - bool is_top_k = dst->op == GGML_OP_TOP_K; - // ascending order is 0, descending order is 1 - const int32_t order = is_top_k ? (int32_t) GGML_SORT_ORDER_DESC : (int32_t) ggml_get_op_params_i32(dst, 0); + bool is_top_k = dst->op == GGML_OP_TOP_K; - ggml_webgpu_argsort_shader_lib_context shader_lib_ctx = { + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, + .src1 = nullptr, + .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, .wg_mem_limit_bytes = ctx->global_ctx->capabilities.limits.maxComputeWorkgroupStorageSize, - .order = order }; - webgpu_pipeline argsort_pipeline; - auto it = ctx->argsort_pipelines.find(order); - if (it != ctx->argsort_pipelines.end()) { - argsort_pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_argsort_shader(ctx->p, wgsl_argsort, shader_lib_ctx); - argsort_pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - argsort_pipeline.context = processed.decisions; - ctx->argsort_pipelines.emplace(order, argsort_pipeline); - } - ggml_webgpu_argsort_shader_decisions argsort_decisions = - *static_cast(argsort_pipeline.context); + webgpu_pipeline argsort_pipeline = ctx->shader_lib->get_argsort_pipeline(shader_lib_ctx); + auto * argsort_decisions = static_cast(argsort_pipeline.context.get()); - webgpu_pipeline argsort_merge_pipeline; - it = ctx->argsort_merge_pipelines.find(order); - if (it != ctx->argsort_merge_pipelines.end()) { - argsort_merge_pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_argsort_merge_shader(ctx->p, wgsl_argsort_merge, shader_lib_ctx); - argsort_merge_pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - argsort_merge_pipeline.context = processed.decisions; - ctx->argsort_merge_pipelines.emplace(order, argsort_merge_pipeline); - } + webgpu_pipeline argsort_merge_pipeline = ctx->shader_lib->get_argsort_merge_pipeline(shader_lib_ctx); const uint32_t src_ne0 = (uint32_t) src->ne[0]; const uint32_t nrows = (uint32_t) ggml_nrows(src); - const uint32_t npr = CEIL_DIV(src_ne0, argsort_decisions.wg_size); + const uint32_t npr = CEIL_DIV(src_ne0, argsort_decisions->wg_size); const uint32_t block_size = - is_top_k ? std::min(argsort_decisions.wg_size, (uint32_t) dst->ne[0]) : argsort_decisions.wg_size; + is_top_k ? std::min(argsort_decisions->wg_size, (uint32_t) dst->ne[0]) : argsort_decisions->wg_size; uint32_t out_ne0 = src_ne0; if (is_top_k) { if (npr > 1) { - const uint32_t last_tile = src_ne0 - (npr - 1) * argsort_decisions.wg_size; + const uint32_t last_tile = src_ne0 - (npr - 1) * argsort_decisions->wg_size; out_ne0 = (npr - 1) * block_size + std::min(last_tile, block_size); } else { out_ne0 = block_size; @@ -1994,22 +1904,15 @@ static webgpu_command ggml_webgpu_cumsum(webgpu_context & ctx, ggml_tensor * src .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - ggml_webgpu_generic_shader_lib_context shader_lib_ctx = { - .vec4 = false, + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, + .src1 = nullptr, + .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, }; - webgpu_pipeline pipeline; - auto it = ctx->cumsum_pipelines.find(1); - if (it != ctx->cumsum_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_cumsum, shader_lib_ctx, "cumsum"); - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - ctx->cumsum_pipelines.emplace(1, pipeline); - } - uint32_t wg_x = ggml_nrows(dst); + + webgpu_pipeline pipeline = ctx->shader_lib->get_cumsum_pipeline(shader_lib_ctx); + uint32_t wg_x = ggml_nrows(dst); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } @@ -2035,22 +1938,12 @@ static webgpu_command ggml_webgpu_sum_rows(webgpu_context & ctx, ggml_tensor * s .size = ggml_webgpu_tensor_binding_size(ctx, dst) } }; - ggml_webgpu_generic_shader_lib_context shader_lib_ctx = { - .vec4 = false, - .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup, + ggml_webgpu_shader_lib_context shader_lib_ctx = { + .src0 = src, .dst = dst, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup }; - webgpu_pipeline pipeline; - auto it = ctx->sum_rows_pipelines.find(1); - if (it != ctx->sum_rows_pipelines.end()) { - pipeline = it->second; - } else { - ggml_webgpu_processed_shader processed = - ggml_webgpu_preprocess_generic_shader(ctx->p, wgsl_sum_rows, shader_lib_ctx, "sum_rows"); - pipeline = - ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str()); - ctx->sum_rows_pipelines.emplace(1, pipeline); - } + webgpu_pipeline pipeline = ctx->shader_lib->get_sum_rows_pipeline(shader_lib_ctx); + uint32_t wg_x = total_sum ? 1 : ggml_nrows(dst); return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x); } @@ -2198,7 +2091,10 @@ static ggml_backend_i ggml_backend_webgpu_i = { static void ggml_backend_webgpu_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_webgpu_buffer_context * ctx = static_cast(buffer->context); - ctx->buffer.Destroy(); + if (ctx != nullptr && ctx->buffer != nullptr) { + ctx->buffer.Destroy(); + delete ctx; + } } // Returns the "fake" base pointer. @@ -2213,7 +2109,9 @@ static void ggml_backend_webgpu_buffer_memset_tensor(ggml_backend_buffer_t buffe size_t offset, size_t size) { if (size == 0) { - WEBGPU_LOG_DEBUG("ggml_backend_webgpu_buffer_memset_tensor: size is zero, nothing to do."); + WEBGPU_LOG_DEBUG( + "ggml_backend_webgpu_buffer_memset_tensor: size is zero, " + "nothing to do."); return; } @@ -2290,7 +2188,8 @@ static void ggml_backend_webgpu_buffer_get_tensor(ggml_backend_buffer_t buffer, size_t final_size = size; if (size % 4 != 0) { - // If size is not a multiple of 4, we need to round it up to the next multiple of 4 + // If size is not a multiple of 4, we need to round it up to the next + // multiple of 4 final_size = size + (4 - (size % 4)); } @@ -2344,7 +2243,8 @@ static ggml_backend_buffer_i ggml_backend_webgpu_buffer_interface = { /* .get_tensor = */ ggml_backend_webgpu_buffer_get_tensor, /* .cpy_tensor = */ NULL, // TODO: optional, implement this /* .clear = */ ggml_backend_webgpu_buffer_clear, - /* .reset = */ NULL, // TODO: optional, think it coordinates with .init_tensor + /* .reset = */ NULL, // TODO: optional, think it coordinates with + // .init_tensor }; /* End GGML Backend Buffer Interface */ @@ -2381,7 +2281,8 @@ static size_t ggml_backend_webgpu_buffer_type_get_alignment(ggml_backend_buffer_ return dev_ctx->webgpu_global_ctx->capabilities.limits.minStorageBufferOffsetAlignment; } -// maxBufferSize might be larger, but you can't bind more than maxStorageBufferBindingSize to a single binding. +// maxBufferSize might be larger, but you can't bind more than +// maxStorageBufferBindingSize to a single binding. static size_t ggml_backend_webgpu_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { ggml_backend_webgpu_device_context * dev_ctx = static_cast(buft->device->context); @@ -2467,14 +2368,6 @@ static ggml_guid_t ggml_backend_webgpu_guid(void) { return reinterpret_cast((void *) guid_str); } -// Workgroup size is a common constant -static std::vector ggml_webgpu_wg_size_entry(uint32_t wg_size) { - std::vector constants(1); - constants[0].key = "wg_size"; - constants[0].value = wg_size; - return constants; -} - static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { // we use the maximum workgroup size for the memset pipeline size_t max_threads = WEBGPU_MAX_WG_SIZE * ctx->capabilities.limits.maxComputeWorkgroupsPerDimension; @@ -2489,207 +2382,6 @@ static void ggml_webgpu_init_memset_pipeline(webgpu_global_context & ctx) { ctx->memset_pipelines[0] = ggml_webgpu_create_pipeline(ctx->device, wgsl_memset, "memset", constants); } -static void ggml_webgpu_init_mul_mat_pipeline(webgpu_context & webgpu_ctx) { - // Q4/Q5/Q8 classic quantizations - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_0_f32, "mul_mat_q4_0_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_1][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_1_f32, "mul_mat_q4_1_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_0][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_0_f32, "mul_mat_q5_0_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_1][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_1_f32, "mul_mat_q5_1_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q8_0][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q8_0_f32, "mul_mat_q8_0_f32"); - - // K-quantizations - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q2_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q2_k_f32, "mul_mat_q2_k_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q3_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q3_k_f32, "mul_mat_q3_k_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q4_k_f32, "mul_mat_q4_k_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q5_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q5_k_f32, "mul_mat_q5_k_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q6_K][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_q6_k_f32, "mul_mat_q6_k_f32"); - - // IQ quantizations (2-, 3-, 4-bit variants) - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XXS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_xxs_f32, "mul_mat_iq2_xxs_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_XS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_xs_f32, "mul_mat_iq2_xs_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ2_S][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq2_s_f32, "mul_mat_iq2_s_f32"); - - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_XXS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq3_xxs_f32, "mul_mat_iq3_xxs_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ3_S][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq3_s_f32, "mul_mat_iq3_s_f32"); - - // 1-bit and 4-bit IQ variants - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_S][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq1_s_f32, "mul_mat_iq1_s_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ1_M][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq1_m_f32, "mul_mat_iq1_m_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_NL][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq4_nl_f32, "mul_mat_iq4_nl_f32"); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_IQ4_XS][GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_mat_iq4_xs_f32, "mul_mat_iq4_xs_f32"); - - std::string proc_mul_mat_f32_f32; - std::string proc_mul_mat_f32_f32_vec; - std::string proc_mul_mat_f16_f32; - std::string proc_mul_mat_f16_f32_vec; - std::string proc_mul_mat_f16_f16; - std::string proc_mul_mat_f16_f16_vec; - std::string proc_mul_mat_q4_0_f32; - std::string proc_mul_mat_q4_0_f32_vec; - - std::vector mul_mat_constants; -#ifndef __EMSCRIPTEN__ - if (webgpu_ctx->global_ctx->capabilities.supports_subgroup_matrix) { - std::map sg_matrix_repls; - sg_matrix_repls["WEBGPU_MAX_SUBGROUP_SIZE"] = - std::to_string(webgpu_ctx->global_ctx->capabilities.max_subgroup_size); - sg_matrix_repls["WEBGPU_TILE_K"] = std::to_string(WEBGPU_MUL_MAT_TILE_K); - sg_matrix_repls["WEBGPU_SUBGROUP_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_M); - sg_matrix_repls["WEBGPU_SUBGROUP_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_N); - sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_M"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_M); - sg_matrix_repls["WEBGPU_SUBGROUP_MATRIX_N"] = std::to_string(WEBGPU_MUL_MAT_SUBGROUP_MATRIX_N); - sg_matrix_repls["WEBGPU_SG_MAT_M_SIZE"] = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_m); - sg_matrix_repls["WEBGPU_SG_MAT_N_SIZE"] = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_n); - sg_matrix_repls["WEBGPU_SG_MAT_K_SIZE"] = std::to_string(webgpu_ctx->global_ctx->capabilities.sg_mat_k); - proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32, sg_matrix_repls); - proc_mul_mat_f32_f32_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f32_f32_vec, sg_matrix_repls); - proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32, sg_matrix_repls); - proc_mul_mat_f16_f32_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f32_vec, sg_matrix_repls); - proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16, sg_matrix_repls); - proc_mul_mat_f16_f16_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_f16_f16_vec, sg_matrix_repls); - proc_mul_mat_q4_0_f32 = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32, sg_matrix_repls); - proc_mul_mat_q4_0_f32_vec = - ggml_webgpu_process_shader_repls(wgsl_mul_mat_subgroup_matrix_q4_0_f32_vec, sg_matrix_repls); - } else { -#endif - mul_mat_constants.push_back({ .key = "TILE_K", .value = WEBGPU_MUL_MAT_TILE_K }); - mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_M", .value = WEBGPU_MUL_MAT_WG_SIZE_M }); - mul_mat_constants.push_back({ .key = "WORKGROUP_SIZE_N", .value = WEBGPU_MUL_MAT_WG_SIZE_N }); - - std::map reg_repls; - reg_repls["WEBGPU_TILE_M"] = std::to_string(WEBGPU_MUL_MAT_TILE_M); - reg_repls["WEBGPU_TILE_N"] = std::to_string(WEBGPU_MUL_MAT_TILE_N); - - proc_mul_mat_f32_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32, reg_repls); - proc_mul_mat_f32_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f32_f32_vec, reg_repls); - proc_mul_mat_f16_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32, reg_repls); - proc_mul_mat_f16_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f32_vec, reg_repls); - proc_mul_mat_f16_f16 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16, reg_repls); - proc_mul_mat_f16_f16_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_f16_f16_vec, reg_repls); - proc_mul_mat_q4_0_f32 = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32, reg_repls); - proc_mul_mat_q4_0_f32_vec = ggml_webgpu_process_shader_repls(wgsl_mul_mat_reg_tile_q4_0_f32_vec, reg_repls); -#ifndef __EMSCRIPTEN__ - } -#endif - - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, proc_mul_mat_f32_f32.c_str(), "mul_mat_f32_f32", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, proc_mul_mat_f32_f32_vec.c_str(), "mul_mat_f32_f32_vec", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f32.c_str(), "mul_mat_f16_f32", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f32_vec.c_str(), "mul_mat_f16_f32_vec", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f16.c_str(), "mul_mat_f16_f16", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, proc_mul_mat_f16_f16_vec.c_str(), "mul_mat_f16_f16_vec", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, proc_mul_mat_q4_0_f32.c_str(), "mul_mat_q4_0_f32", mul_mat_constants); - webgpu_ctx->mul_mat_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, proc_mul_mat_q4_0_f32_vec.c_str(), "mul_mat_q4_0_f32_vec", mul_mat_constants); - - std::vector mul_mat_vec_constants(3); - mul_mat_vec_constants[0].key = "WORKGROUP_SIZE"; - mul_mat_vec_constants[0].value = WEBGPU_MUL_MAT_VEC_WG_SIZE; - mul_mat_vec_constants[1].key = "TILE_K"; - mul_mat_vec_constants[1].value = WEBGPU_MUL_MAT_VEC_TILE_K; - mul_mat_vec_constants[2].key = "OUTPUTS_PER_WG"; - mul_mat_vec_constants[2].value = WEBGPU_MUL_MAT_VEC_OUTPUTS_PER_WG; - - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f32_f32, "mul_mat_vec_f32_f32", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F32][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f32_f32_vec, "mul_mat_vec_f32_f32_vec", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f32, "mul_mat_vec_f16_f32", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f32_vec, "mul_mat_vec_f16_f32_vec", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f16, "mul_mat_vec_f16_f16", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_F16][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_f16_f16_vec, "mul_mat_vec_f16_f16_vec", mul_mat_vec_constants); - webgpu_ctx->mul_mat_vec_pipelines[GGML_TYPE_Q4_0][GGML_TYPE_F32][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_mul_mat_vec_q4_0_f32, "mul_mat_vec_q4_0_f32", mul_mat_vec_constants); -} - -static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_f32, "get_rows_f32", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_F32][1] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_get_rows_f32_vec, "get_rows_f32_vec", constants); - - webgpu_ctx->get_rows_pipelines[GGML_TYPE_F16][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_f16, "get_rows_f16", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_I32][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_i32, "get_rows_i32", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_0, "get_rows_q4_0", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_1, "get_rows_q4_1", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_0, "get_rows_q5_0", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_1][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_1, "get_rows_q5_1", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q8_0][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q8_0, "get_rows_q8_0", constants); - - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q2_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q2_k, "get_rows_q2_k", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q3_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q3_k, "get_rows_q3_k", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q4_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q4_k, "get_rows_q4_k", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q5_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q5_k, "get_rows_q5_k", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_Q6_K][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_q6_k, "get_rows_q6_k", constants); - - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XXS][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_xxs, "get_rows_iq2_xxs", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_XS][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_xs, "get_rows_iq2_xs", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ2_S][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq2_s, "get_rows_iq2_s", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_XXS][0] = ggml_webgpu_create_pipeline( - webgpu_ctx->global_ctx->device, wgsl_get_rows_iq3_xxs, "get_rows_iq3_xxs", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ3_S][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq3_s, "get_rows_iq3_s", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_S][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq1_s, "get_rows_iq1_s", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ1_M][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq1_m, "get_rows_iq1_m", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_NL][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq4_nl, "get_rows_iq4_nl", constants); - webgpu_ctx->get_rows_pipelines[GGML_TYPE_IQ4_XS][0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_get_rows_iq4_xs, "get_rows_iq4_xs", constants); -} - static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); @@ -2796,15 +2488,6 @@ static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) { webgpu_ctx->global_ctx->device, wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants); } -static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) { - std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE); - - webgpu_ctx->scale_pipelines[0] = - ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_scale_f32, "scale_f32", constants); - webgpu_ctx->scale_pipelines[1] = ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_scale_f32_inplace, - "scale_f32_inplace", constants); -} - static void ggml_webgpu_init_soft_max_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE); @@ -2926,12 +2609,12 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) { dev_desc.SetDeviceLostCallback( wgpu::CallbackMode::AllowSpontaneous, [](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) { + if (reason == wgpu::DeviceLostReason::Destroyed) { + return; + } GGML_UNUSED(device); - GGML_UNUSED(reason); - GGML_UNUSED(message); - //TODO: uncomment once proper free logic is in place - //GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast(reason), - //std::string(message).c_str()); + GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast(reason), + std::string(message).c_str()); }); dev_desc.SetUncapturedErrorCallback( [](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) { @@ -2995,6 +2678,7 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { ggml_backend_webgpu_device_context * dev_ctx = (ggml_backend_webgpu_device_context *) dev->context; webgpu_context webgpu_ctx = std::make_shared(); webgpu_ctx->global_ctx = dev_ctx->webgpu_global_ctx; + webgpu_ctx->shader_lib = std::make_unique(dev_ctx->webgpu_global_ctx->device); webgpu_ctx->param_buf_pool.init(webgpu_ctx->global_ctx->device, WEBGPU_NUM_PARAM_BUFS, WEBGPU_PARAMS_BUF_SIZE_BYTES, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::Uniform, wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::MapWrite); @@ -3003,13 +2687,10 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) { wgpu::BufferUsage::CopySrc | wgpu::BufferUsage::Storage, wgpu::BufferUsage::CopyDst | wgpu::BufferUsage::MapRead); - ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx); - ggml_webgpu_init_get_rows_pipeline(webgpu_ctx); ggml_webgpu_init_cpy_pipeline(webgpu_ctx); ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx); ggml_webgpu_init_rope_pipeline(webgpu_ctx); ggml_webgpu_init_glu_pipeline(webgpu_ctx); - ggml_webgpu_init_scale_pipeline(webgpu_ctx); ggml_webgpu_init_soft_max_pipeline(webgpu_ctx); #ifdef GGML_WEBGPU_DEBUG // Initialize debug buffers @@ -3051,11 +2732,11 @@ static ggml_backend_buffer_type_t ggml_backend_webgpu_device_get_buffer_type(ggm static struct ggml_backend_buffer_type ggml_backend_webgpu_buffer_type = { /* .iface = */ { /* .get_name = */ ggml_backend_webgpu_buffer_type_get_name, - /* .alloc_buffer = */ ggml_backend_webgpu_buffer_type_alloc_buffer, - /* .get_alignment = */ ggml_backend_webgpu_buffer_type_get_alignment, - /* .get_max_size = */ ggml_backend_webgpu_buffer_type_get_max_size, - /* .get_alloc_size = */ ggml_backend_webgpu_buffer_type_get_alloc_size, - /* .is_host = */ NULL, // defaults to false + /* .alloc_buffer = */ + ggml_backend_webgpu_buffer_type_alloc_buffer, /* .get_alignment = */ + ggml_backend_webgpu_buffer_type_get_alignment, /* .get_max_size = */ + ggml_backend_webgpu_buffer_type_get_max_size, /* .get_alloc_size = */ + ggml_backend_webgpu_buffer_type_get_alloc_size, /* .is_host = */ NULL, // defaults to false }, /* .device = */ dev, @@ -3365,10 +3046,7 @@ static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) { return ctx->device_count; } -// TODO: Does this need to be thread safe? Is it only called once? -// TODO: move most logic to device_init function so backend can be freed/initialized properly // Only one device is supported for now - static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) { GGML_ASSERT(index == 0); WEBGPU_LOG_DEBUG("ggml_backend_reg_get_device()"); diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl index 389c97bb51..9a5b18ebc0 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/common_decls.tmpl @@ -1,5 +1,4 @@ -#decl(BYTE_HELPERS) - +#ifdef BYTE_HELPERS fn get_byte(value: u32, index: u32) -> u32 { return (value >> (index * 8)) & 0xFF; } @@ -7,76 +6,74 @@ fn get_byte(value: u32, index: u32) -> u32 { fn get_byte_i32(value: u32, index: u32) -> i32 { return bitcast(((value >> (index * 8)) & 0xFF) << 24) >> 24; } +#endif -#enddecl(BYTE_HELPERS) - -#decl(Q4_0_T) +#ifdef Q4_0_T struct q4_0 { d: f16, qs: array }; -#enddecl(Q4_0_T) +#endif -#decl(Q4_1_T) +#ifdef Q4_1_T struct q4_1 { d: f16, m: f16, qs: array }; -#enddecl(Q4_1_T) +#endif -#decl(Q5_0_T) +#ifdef Q5_0_T struct q5_0 { d: f16, qh: array, qs: array }; -#enddecl(Q5_0_T) +#endif -#decl(Q5_1_T) +#ifdef Q5_1_T struct q5_1 { d: f16, m: f16, qh: u32, qs: array }; -#enddecl(Q5_1_T) +#endif -#decl(Q8_0_T) +#ifdef Q8_0_T struct q8_0 { d: f16, qs: array }; -#enddecl(Q8_0_T) +#endif -#decl(Q8_1_T) +#ifdef Q8_1_T struct q8_1 { d: f16, m: f16, qs: array }; -#enddecl(Q8_1_T) +#endif -#decl(Q2_K_T) -struct q2_k { +#ifdef Q2_K_T +struct q2_K { scales: array, qs: array, d: f16, dmin: f16 }; -#enddecl(Q2_K_T) +#endif -#decl(Q3_K_T) -struct q3_k { +#ifdef Q3_K_T +struct q3_K { hmask: array, qs: array, scales: array, d: f16 }; -#enddecl(Q3_K_T) - -#decl(Q45_K_SCALE_MIN) +#endif +#if defined(Q4_K_SCALE_MIN) || defined(Q5_K_SCALE_MIN) fn get_scale_min(is: u32, scales: array) -> vec2 { if (is < 4) { let sc_byte = get_byte(scales[is / 4], is % 4); @@ -91,69 +88,67 @@ fn get_scale_min(is: u32, scales: array) -> vec2 { return vec2(f32(sc), f32(m)); } } - -#enddecl(Q45_K_SCALE_MIN) - -#decl(Q4_K_T) -struct q4_k { +#endif +#ifdef Q4_K_T +struct q4_K { d: f16, dmin: f16, scales: array, qs: array }; -#enddecl(Q4_K_T) +#endif -#decl(Q5_K_T) -struct q5_k { +#ifdef Q5_K_T +struct q5_K { d: f16, dmin: f16, scales: array, qh: array, qs: array }; -#enddecl(Q5_K_T) +#endif -#decl(Q6_K_T) -struct q6_k { +#ifdef Q6_K_T +struct q6_K { ql: array, qh: array, scales: array, d: f16 }; -#enddecl(Q6_K_T) +#endif -#decl(IQ2_XXS_T) +#ifdef IQ2_XXS_T struct iq2_xxs { d: f16, qs: array }; -#enddecl(IQ2_XXS_T) +#endif -#decl(IQ2_XS_T) +#ifdef IQ2_XS_T struct iq2_xs { d: f16, qs: array, scales: array }; -#enddecl(IQ2_XS_T) +#endif -#decl(IQ2_S_T) +#ifdef IQ2_S_T struct iq2_s { d: f16, qs: array, qh: array, scales: array }; -#enddecl(IQ2_S_T) +#endif -#decl(IQ3_XSS_T) +#ifdef IQ3_XXS_T struct iq3_xxs { d: f16, qs: array }; -#enddecl(IQ3_XSS_T) +#endif -#decl(IQ3_S_T) +#ifdef IQ3_S_T struct iq3_s { d: f16, qs: array, @@ -161,41 +156,41 @@ struct iq3_s { signs: array, scales: array }; -#enddecl(IQ3_S_T) +#endif -#decl(IQ1_S_T) +#ifdef IQ1_S_T struct iq1_s { d: f16, qs: array, qh: array }; -#enddecl(IQ1_S_T) +#endif -#decl(IQ1_M_T) +#ifdef IQ1_M_T struct iq1_m { qs: array, qh: array, scales: array }; -#enddecl(IQ1_M_T) +#endif -#decl(IQ4_NL_T) +#ifdef IQ4_NL_T struct iq4_nl { d: f16, qs: array, }; -#enddecl(IQ4_NL_T) +#endif -#decl(IQ4_XS_T) +#ifdef IQ4_XS_T struct iq4_xs { d: f16, scales_h: f16, scales_l: u32, qs: array }; -#enddecl(IQ4_XS_T) +#endif -#decl(IQ23_TABLES) +#if defined(IQ2_XXS_TABLES) || defined(IQ2_XS_TABLES) || defined(IQ2_S_TABLES) || defined(IQ3_XXS_TABLES) || defined(IQ3_S_TABLES) const kmask_iq2xs : array = array( 0x08040201u, // 1, 2, 4, 8 0x80402010u // 16, 32, 64, 128 @@ -211,9 +206,9 @@ const ksigns_iq2xs: array = array( 0x63e2e160,0xe76665e4,0xeb6a69e8,0x6feeed6c, 0xf37271f0,0x77f6f574,0x7bfaf978,0xff7e7dfc ); -#enddecl(IQ23_TABLES) +#endif -#decl(IQ2_XXS_GRID) +#ifdef IQ2_XXS_GRID const iq2xxs_grid = array( 0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808, 0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x082b0808, 0x08080808, @@ -280,9 +275,9 @@ const iq2xxs_grid = array( 0x0808082b, 0x2b2b0808, 0x19190808, 0x2b2b0808, 0x2b081919, 0x2b2b0808, 0x08082b19, 0x2b2b0819, 0x08080808, 0x2b2b082b, 0x08192b08, 0x2b2b1908, 0x19190808, 0x2b2b2b08, 0x08081908, 0x2b2b2b19 ); -#enddecl(IQ2_XXS_GRID) +#endif -#decl(IQ2_XS_GRID) +#ifdef IQ2_XS_GRID const iq2xs_grid = array( 0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808, 0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808, @@ -413,9 +408,9 @@ const iq2xs_grid = array( 0x2b2b2b08, 0x2b2b2b08, 0x08081908, 0x2b2b2b19, 0x2b081908, 0x2b2b2b19, 0x2b08192b, 0x2b2b2b19, 0x082b2b08, 0x2b2b2b2b, 0x082b2b2b, 0x2b2b2b2b, 0x2b190819, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b ); -#enddecl(IQ2_XS_GRID) +#endif -#decl(IQ2_S_GRID) +#ifdef IQ2_S_GRID const iq2s_grid = array( 0x08080808, 0x08080808, 0x0808082b, 0x08080808, 0x08081919, 0x08080808, 0x08082b08, 0x08080808, 0x08082b2b, 0x08080808, 0x08190819, 0x08080808, 0x08191908, 0x08080808, 0x0819192b, 0x08080808, @@ -674,10 +669,9 @@ const iq2s_grid = array( 0x2b08192b, 0x2b2b2b19, 0x08082b08, 0x2b2b2b2b, 0x08082b2b, 0x2b2b2b2b, 0x082b0808, 0x2b2b2b2b, 0x082b082b, 0x2b2b2b2b, 0x082b2b08, 0x2b2b2b2b, 0x2b082b08, 0x2b2b2b2b, 0x2b2b2b2b, 0x2b2b2b2b ); -#enddecl(IQ2_S_GRID) - -#decl(IQ3_XSS_GRID) +#endif +#ifdef IQ3_XXS_GRID const iq3xxs_grid = array( 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414, 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14, @@ -712,10 +706,9 @@ const iq3xxs_grid = array( 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c, 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04 ); -#enddecl(IQ3_XSS_GRID) - -#decl(IQ3_S_GRID) +#endif +#ifdef IQ3_S_GRID const iq3s_grid = array( 0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305, 0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905, @@ -782,9 +775,9 @@ const iq3s_grid = array( 0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b, 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101 ); -#enddecl(IQ3_S_GRID) +#endif -#decl(IQ1_GRID) +#if defined(IQ1_S_GRID) || defined(IQ1_M_GRID) const IQ1_DELTA: f32 = 0.125; @@ -919,12 +912,12 @@ const iq1_grid = array( 0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557 ); -#enddecl(IQ1_GRID) +#endif -#decl(IQ4_GRID) +#if defined(IQ4_NL_GRID) || defined(IQ4_XS_GRID) const kvalues_iq4nl = array( -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113 ); -#enddecl(IQ4_GRID) +#endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py index d61df5bb9e..8b5cfe715e 100755 --- a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +++ b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py @@ -56,12 +56,46 @@ def expand_includes(shader, input_dir): return include_pattern.sub(replacer, shader) -def write_shader(shader_name, shader_code, output_dir, outfile): +def chunk_shader(shader_code, max_chunk_len=60000): + """Split shader_code into safe raw-string sized chunks.""" + return [shader_code[i : i + max_chunk_len] for i in range(0, len(shader_code), max_chunk_len)] + + +def raw_delim(shader_code): + """Pick a raw-string delimiter that does not appear in the shader.""" + delim = "wgsl" + while f"){delim}\"" in shader_code: + delim += "_x" + return delim + + +def write_shader(shader_name, shader_code, output_dir, outfile, input_dir): + shader_code = expand_includes(shader_code, input_dir) + if output_dir: wgsl_filename = os.path.join(output_dir, f"{shader_name}.wgsl") with open(wgsl_filename, "w", encoding="utf-8") as f_out: f_out.write(shader_code) - outfile.write(f'const char* wgsl_{shader_name} = R"({shader_code})";\n\n') + + delim = raw_delim(shader_code) + chunks = chunk_shader(shader_code) + + if len(chunks) == 1: + outfile.write(f'const char* wgsl_{shader_name} = R"{delim}({shader_code}){delim}";\n\n') + else: + for idx, chunk in enumerate(chunks): + outfile.write(f'static const char wgsl_{shader_name}_part{idx}[] = R"{delim}({chunk}){delim}";\n\n') + outfile.write(f'static const std::string& wgsl_{shader_name}_str() {{\n') + outfile.write(' static const std::string s = []{\n') + outfile.write(' std::string tmp;\n') + outfile.write(f' tmp.reserve({len(shader_code)});\n') + for idx in range(len(chunks)): + outfile.write(f' tmp.append(wgsl_{shader_name}_part{idx});\n') + outfile.write(' return tmp;\n') + outfile.write(' }();\n') + outfile.write(' return s;\n') + outfile.write('}\n') + outfile.write(f'const char* wgsl_{shader_name} = wgsl_{shader_name}_str().c_str();\n\n') def generate_variants(fname, input_dir, output_dir, outfile): @@ -74,7 +108,7 @@ def generate_variants(fname, input_dir, output_dir, outfile): try: variants = ast.literal_eval(extract_block(text, "VARIANTS")) except ValueError: - write_shader(shader_base_name, text, output_dir, outfile) + write_shader(shader_base_name, text, output_dir, outfile, input_dir) else: try: decls_map = parse_decls(extract_block(text, "DECLS")) @@ -123,7 +157,7 @@ def generate_variants(fname, input_dir, output_dir, outfile): output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"] else: output_name = shader_base_name - write_shader(output_name, final_shader, output_dir, outfile) + write_shader(output_name, final_shader, output_dir, outfile, input_dir) def main(): @@ -137,7 +171,8 @@ def main(): os.makedirs(args.output_dir, exist_ok=True) with open(args.output_file, "w", encoding="utf-8") as out: - out.write("// Auto-generated shader embedding\n\n") + out.write("// Auto-generated shader embedding\n") + out.write("#include \n\n") for fname in sorted(os.listdir(args.input_dir)): if fname.endswith(".wgsl"): generate_variants(fname, args.input_dir, args.output_dir, out) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl similarity index 83% rename from ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl index f80ce1fc55..b10800e36d 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.wgsl @@ -1,222 +1,31 @@ -#define(VARIANTS) +enable f16; +#include "common_decls.tmpl" -[ - { - "SHADER_SUFFIX": "f32_vec", - "REPLS": { - "TYPE" : "vec4", - "DST_TYPE": "vec4", - "BLOCK_SIZE": 4 - }, - "DECLS": ["F32_VEC"] - }, - { - "REPLS": { - "TYPE" : "f32", - "DST_TYPE": "f32", - "BLOCK_SIZE": 1 - }, - "DECLS": ["F32"] - }, - { - "REPLS": { - "TYPE" : "f16", - "DST_TYPE": "f32", - "BLOCK_SIZE": 1 - }, - "DECLS": ["F16"] - }, - { - "REPLS": { - "TYPE" : "i32", - "DST_TYPE": "i32", - "BLOCK_SIZE": 1 - }, - "DECLS": ["I32"] - }, - { - "REPLS": { - "TYPE" : "q4_0", - "DST_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"] - }, - { - "REPLS": { - "TYPE" : "q4_1", - "DST_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"] - }, - { - "REPLS": { - "TYPE" : "q5_0", - "DST_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"] - }, - { - "REPLS": { - "TYPE" : "q5_1", - "DST_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"] - }, - { - "REPLS": { - "TYPE" : "q8_0", - "DST_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"] - }, - { - "REPLS": { - "TYPE" : "q2_k", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"] - }, - { - "REPLS": { - "TYPE" : "q3_k", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"] - }, - { - "REPLS": { - "TYPE" : "q4_k", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"] - }, - { - "REPLS": { - "TYPE" : "q5_k", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"] - }, - { - "REPLS": { - "TYPE" : "q6_k", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"] - }, - { - "REPLS": { - "TYPE" : "iq2_xxs", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"] - }, - { - "REPLS": { - "TYPE" : "iq2_xs", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"] - }, - { - "REPLS": { - "TYPE": "iq2_s", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"] - }, - { - "REPLS": { - "TYPE": "iq3_xxs", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"] - }, - { - "REPLS": { - "TYPE": "iq3_s", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"] - }, - { - "REPLS": { - "TYPE": "iq1_s", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"] - }, - { - "REPLS": { - "TYPE": "iq1_m", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"] - }, - { - "REPLS": { - "TYPE": "iq4_nl", - "DST_TYPE": "f32", - "BLOCK_SIZE": 32, - }, - "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"] - }, - { - "REPLS": { - "TYPE": "iq4_xs", - "DST_TYPE": "f32", - "BLOCK_SIZE": 256, - }, - "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(F32_VEC) +#ifdef F32_VEC fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { dst[(dst_base / 4) + offset] = src[(src_base / 4) + offset]; } -#enddecl(F32_VEC) +#endif -#decl(F32) +#ifdef F32 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { dst[dst_base + offset] = src[src_base + offset]; } -#enddecl(F32) +#endif -#decl(F16) +#ifdef F16 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { dst[dst_base + offset] = f32(src[src_base + offset]); } -#enddecl(F16) +#endif -#decl(I32) +#ifdef I32 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { dst[dst_base + offset] = src[src_base + offset]; } -#enddecl(I32) +#endif -#decl(Q4_0) +#ifdef Q4_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_q4_0 = src[src_base + offset]; let d = f32(block_q4_0.d); @@ -232,9 +41,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(Q4_0) +#endif -#decl(Q4_1) +#ifdef Q4_1 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_q4_1 = src[src_base + offset]; let d = f32(block_q4_1.d); @@ -251,9 +60,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(Q4_1) +#endif -#decl(Q5_0) +#ifdef Q5_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_q5_0 = src[src_base + offset]; let d = f32(block_q5_0.d); @@ -272,10 +81,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } +#endif -#enddecl(Q5_0) - -#decl(Q5_1) +#ifdef Q5_1 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_q5_1 = src[src_base + offset]; let d = f32(block_q5_1.d); @@ -294,9 +102,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(Q5_1) +#endif -#decl(Q8_0) +#ifdef Q8_0 fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block_q8_0 = src[src_base + offset]; let d = f32(block_q8_0.d); @@ -310,9 +118,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(Q8_0) +#endif -#decl(Q2_K) +#ifdef Q2_K fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -340,9 +148,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(Q2_K) +#endif -#decl(Q3_K) +#ifdef Q3_K fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -398,9 +206,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(Q3_K) +#endif -#decl(Q4_K) +#ifdef Q4_K // 8 blocks of 32 elements each fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; @@ -425,9 +233,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(Q4_K) +#endif -#decl(Q5_K) +#ifdef Q5_K fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -455,9 +263,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(Q5_K) +#endif -#decl(Q6_K) +#ifdef Q6_K // 16 blocks of 16 elements each fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; @@ -511,10 +319,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { sc_b_idx += 8; } } +#endif -#enddecl(Q6_K) - -#decl(IQ2_XXS) +#ifdef IQ2_XXS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -536,9 +343,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(IQ2_XXS) +#endif -#decl(IQ2_XS) +#ifdef IQ2_XS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -568,9 +375,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(IQ2_XS) +#endif -#decl(IQ2_S) +#ifdef IQ2_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -608,10 +415,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } +#endif -#enddecl(IQ2_S) - -#decl(IQ3_XSS) +#ifdef IQ3_XXS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -638,9 +444,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(IQ3_XSS) +#endif -#decl(IQ3_S) +#ifdef IQ3_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -683,9 +489,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } -#enddecl(IQ3_S) +#endif -#decl(IQ1_S) +#ifdef IQ1_S fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -707,10 +513,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } +#endif -#enddecl(IQ1_S) - -#decl(IQ1_M) +#ifdef IQ1_M fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; @@ -751,10 +556,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { } } } +#endif -#enddecl(IQ1_M) - -#decl(IQ4_NL) +#ifdef IQ4_NL fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -770,9 +574,9 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { dst_i++; } } -#enddecl(IQ4_NL) +#endif -#decl(IQ4_XS) +#ifdef IQ4_XS fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { let block = src[src_base + offset]; let d = f32(block.d); @@ -791,24 +595,16 @@ fn copy_elements(src_base: u32, dst_base: u32, offset: u32) { dst_i += 16; } } -#enddecl(IQ4_XS) - -#end(DECLS) - -#define(SHADER) - -enable f16; - -DECLS +#endif @group(0) @binding(0) -var src: array<{{TYPE}}>; +var src: array; @group(0) @binding(1) var idx: array; @group(0) @binding(2) -var dst: array<{{DST_TYPE}}>; +var dst: array; struct Params { offset_src: u32, // in elements @@ -842,8 +638,7 @@ struct Params { @group(0) @binding(3) var params: Params; -override wg_size: u32; -@compute @workgroup_size(wg_size) +@compute @workgroup_size(WG_SIZE) fn main(@builtin(global_invocation_id) gid: vec3) { if (gid.x >= params.n_rows * params.ne2 * params.ne3) { return; @@ -866,9 +661,8 @@ fn main(@builtin(global_invocation_id) gid: vec3) { let i_src_row = params.offset_src + idx_val * params.stride_src1 + i_dst2 * params.stride_src2 + i_dst3 * params.stride_src3; let i_dst_row = params.offset_dst + i_dst1 * params.stride_dst1 + i_dst2 * params.stride_dst2 + i_dst3 * params.stride_dst3; - for (var i: u32 = 0; i < params.ne0/{{BLOCK_SIZE}}; i++) { + for (var i: u32 = 0; i < params.ne0/BLOCK_SIZE; i++) { copy_elements(i_src_row, i_dst_row, i); } } -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl similarity index 84% rename from ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl index 0f8e6e5ac3..6aba47317c 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat.wgsl @@ -1,195 +1,24 @@ -#define(VARIANTS) +enable f16; -[ - { - "REPLS": { - "SRC0_TYPE" : "f32", - "SRC1_TYPE" : "f32", - "BLOCK_SIZE" : 1 - }, - "DECLS" : ["FLOAT"] - }, - { - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f16", - "BLOCK_SIZE" : 1 - }, - "DECLS" : ["FLOAT"] - }, - { - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "BLOCK_SIZE" : 1 - }, - "DECLS" : ["FLOAT"] - }, - { - "REPLS": { - "SRC0_TYPE": "q4_0", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q4_0_T", "Q4_0"] - }, - { - "REPLS": { - "SRC0_TYPE": "q4_1", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q4_1_T", "Q4_1"] - }, - { - "REPLS": { - "SRC0_TYPE": "q5_0", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q5_0_T", "Q5_0"] - }, - { - "REPLS": { - "SRC0_TYPE": "q5_1", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q5_1_T", "Q5_1"] - }, - { - "REPLS": { - "SRC0_TYPE": "q8_0", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 32 - }, - "DECLS": ["BYTE_HELPERS", "Q8_0_T", "Q8_0"] - }, - { - "REPLS": { - "SRC0_TYPE": "q2_k", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "Q2_K_T", "Q2_K"] - }, - { - "REPLS": { - "SRC0_TYPE": "q3_k", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "Q3_K_T", "Q3_K"] - }, - { - "REPLS": { - "SRC0_TYPE": "q4_k", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q4_K_T", "Q4_K"] - }, - { - "REPLS": { - "SRC0_TYPE": "q5_k", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["Q45_K_SCALE_MIN", "BYTE_HELPERS", "Q5_K_T", "Q5_K"] - }, - { - "REPLS": { - "SRC0_TYPE": "q6_k", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "Q6_K_T", "Q6_K"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq2_xxs", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XXS_GRID", "IQ2_XXS_T", "IQ2_XXS"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq2_xs", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_XS_GRID", "IQ2_XS_T", "IQ2_XS"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq2_s", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ2_S_GRID", "IQ2_S_T", "IQ2_S"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq3_xxs", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_XSS_GRID", "IQ3_XSS_T", "IQ3_XSS"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq3_s", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ23_TABLES", "IQ3_S_GRID", "IQ3_S_T", "IQ3_S"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq1_s", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_S_T", "IQ1_S"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq1_m", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256 - }, - "DECLS": ["BYTE_HELPERS", "IQ1_GRID", "IQ1_M_T", "IQ1_M"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq4_nl", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 32, - }, - "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_NL_T", "IQ4_NL"] - }, - { - "REPLS": { - "SRC0_TYPE": "iq4_xs", - "SRC1_TYPE": "f32", - "BLOCK_SIZE": 256, - }, - "DECLS": ["BYTE_HELPERS", "IQ4_GRID", "IQ4_XS_T", "IQ4_XS"] - } -] +#include "common_decls.tmpl" -#end(VARIANTS) +#ifdef FLOAT +const BLOCK_SIZE = 1u; -#define(DECLS) +#elif defined(Q4_0) || defined(Q4_1) || defined(Q5_0) || defined(Q5_1) || defined(Q8_0) || defined(Q8_1) || defined(IQ4_NL) +const BLOCK_SIZE = 32u; -#decl(FLOAT) +#elif defined(Q2_K) || defined(Q3_K) || defined(Q4_K) || defined(Q5_K) || defined(Q6_K) || defined(IQ2_XXS) || defined(IQ2_XS) || defined(IQ2_S) || defined(IQ3_XXS) || defined(IQ3_S) || defined(IQ1_S) || defined(IQ1_M) || defined(IQ4_XS) +const BLOCK_SIZE = 256u; +#endif + +#ifdef FLOAT fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { return f32(src0[src0_idx_base + offset]) * f32(src1[src1_idx_base + offset]); } -#enddecl(FLOAT) +#endif -#decl(Q4_0) +#ifdef Q4_0 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_q4_0 = src0[src0_idx_base + offset]; let d = f32(block_q4_0.d); @@ -207,9 +36,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } -#enddecl(Q4_0) +#endif -#decl(Q4_1) +#ifdef Q4_1 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_q4_1 = src0[src0_idx_base + offset]; let d = f32(block_q4_1.d); @@ -228,9 +57,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } -#enddecl(Q4_1) +#endif -#decl(Q5_0) +#ifdef Q5_0 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_q5_0 = src0[src0_idx_base + offset]; let d = f32(block_q5_0.d); @@ -251,9 +80,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } -#enddecl(Q5_0) +#endif -#decl(Q5_1) +#ifdef Q5_1 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_q5_1 = src0[src0_idx_base + offset]; let d = f32(block_q5_1.d); @@ -274,9 +103,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } -#enddecl(Q5_1) +#endif -#decl(Q8_0) +#ifdef Q8_0 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_q8_0 = src0[src0_idx_base + offset]; let d = f32(block_q8_0.d); @@ -292,9 +121,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } -#enddecl(Q8_0) +#endif -#decl(Q8_1) +#ifdef Q8_1 fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block_q8_1 = src0[src0_idx_base + offset]; let d = f32(block_q8_1.d); @@ -311,9 +140,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } -#enddecl(Q8_1) +#endif -#decl(Q2_K) +#ifdef Q2_K // 16 blocks of 16 elements each fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; @@ -344,10 +173,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif -#enddecl(Q2_K) - -#decl(Q3_K) +#ifdef Q3_K // 16 blocks of 16 elements each fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; @@ -406,10 +234,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif -#enddecl(Q3_K) - -#decl(Q4_K) +#ifdef Q4_K // 8 blocks of 32 elements each fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; @@ -436,10 +263,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif -#enddecl(Q4_K) - -#decl(Q5_K) +#ifdef Q5_K // 8 blocks of 32 elements each fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; @@ -470,10 +296,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif -#enddecl(Q5_K) - -#decl(Q6_K) +#ifdef Q6_K // 16 blocks of 16 elements each fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; @@ -529,10 +354,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif -#enddecl(Q6_K) - -#decl(IQ2_XXS) +#ifdef IQ2_XXS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; let d = f32(block.d); @@ -556,10 +380,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif -#enddecl(IQ2_XXS) - -#decl(IQ2_XS) +#ifdef IQ2_XS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; let d = f32(block.d); @@ -591,10 +414,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif -#enddecl(IQ2_XS) - -#decl(IQ2_S) +#ifdef IQ2_S fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; let d = f32(block.d); @@ -634,11 +456,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif - -#enddecl(IQ2_S) - -#decl(IQ3_XSS) +#ifdef IQ3_XXS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; let d = f32(block.d); @@ -667,10 +487,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif -#enddecl(IQ3_XSS) - -#decl(IQ3_S) +#ifdef IQ3_S fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; let d = f32(block.d); @@ -715,9 +534,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } -#enddecl(IQ3_S) +#endif -#decl(IQ1_S) +#ifdef IQ1_S fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; let d = f32(block.d); @@ -741,10 +560,10 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif -#enddecl(IQ1_S) -#decl(IQ1_M) +#ifdef IQ1_M fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; @@ -787,10 +606,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif -#enddecl(IQ1_M) - -#decl(IQ4_NL) +#ifdef IQ4_NL fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; let d = f32(block.d); @@ -808,10 +626,9 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } +#endif -#enddecl(IQ4_NL) - -#decl(IQ4_XS) +#ifdef IQ4_XS fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { let block = src0[src0_idx_base + offset]; let d = f32(block.d); @@ -832,16 +649,7 @@ fn multiply_add(src0_idx_base: u32, src1_idx_base: u32, offset: u32) -> f32 { } return sum; } - -#enddecl(IQ4_XS) - -#end(DECLS) - -#define(SHADER) - -enable f16; - -DECLS +#endif struct MulMatParams { offset_src0: u32, // in elements/blocks @@ -864,8 +672,8 @@ struct MulMatParams { broadcast3: u32 }; -@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // M rows, K columns -@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed) +@group(0) @binding(0) var src0: array; // M rows, K columns +@group(0) @binding(1) var src1: array; // K rows, N columns (transposed) @group(0) @binding(2) var dst: array; // M rows, N columns @group(0) @binding(3) var params: MulMatParams; @@ -898,10 +706,8 @@ fn main(@builtin(global_invocation_id) global_id: vec3) { let src1_idx_base = params.offset_src1 + src13_idx * params.stride_13 + src12_idx * params.stride_12 + row * params.stride_11; var sum = 0.0; - for (var i: u32 = 0u; i < params.k/{{BLOCK_SIZE}}; i = i + 1u) { + for (var i: u32 = 0u; i < params.k/BLOCK_SIZE; i = i + 1u) { sum += multiply_add(src0_idx_base, src1_idx_base, i); } dst[params.offset_dst + dst3_idx * dst3_stride + dst2_idx * dst2_stride + row * params.m + col] = sum; } - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl index 109ff8d615..5c1074ebc1 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_decls.tmpl @@ -1,58 +1,65 @@ -#decl(SHMEM_VEC) +#ifdef VEC +#define VEC_SIZE 4 +#define SHMEM_TYPE vec4 +#define DST_TYPE vec4 +#define SRC0_TYPE vec4 +#define SRC1_TYPE vec4 + fn store_shmem(val: vec4, idx: u32) { shmem[idx] = val.x; shmem[idx + 1] = val.y; shmem[idx + 2] = val.z; shmem[idx + 3] = val.w; } -#enddecl(SHMEM_VEC) +#endif + +#ifdef SCALAR +#define VEC_SIZE 1 +#define SHMEM_TYPE f16 +#define DST_TYPE f32 +#define SRC0_TYPE SRC0_INNER_TYPE +#define SRC1_TYPE SRC1_INNER_TYPE -#decl(SHMEM_SCALAR) fn store_shmem(val: f16, idx: u32) { shmem[idx] = val; } -#enddecl(SHMEM_SCALAR) - -#decl(INIT_SRC0_SHMEM_FLOAT) +#endif +#ifdef INIT_SRC0_SHMEM_FLOAT fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u32) { - for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { + for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC0_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) { let tile_m = elem_idx / TILE_K; let tile_k = elem_idx % TILE_K; let global_m = offset_m + tile_m; let global_k = k_outer + tile_k; let src0_idx = batch_offset + global_m * params.stride_01 + global_k; let src0_val = select( // taking a slight performance hit to avoid oob - {{SRC0_TYPE}}(0.0), - src0[src0_idx/{{VEC_SIZE}}], + SRC0_TYPE(0.0), + src0[src0_idx/VEC_SIZE], global_m < params.m && global_k < params.k); - store_shmem({{SHMEM_TYPE}}(src0_val), elem_idx); + store_shmem(SHMEM_TYPE(src0_val), elem_idx); } } +#endif -#enddecl(INIT_SRC0_SHMEM_FLOAT) - -#decl(INIT_SRC1_SHMEM) - +#ifdef INIT_SRC1_SHMEM_FLOAT fn init_shmem_src1(thread_id: u32, batch_offset: u32, offset_n: u32, k_outer: u32) { - for (var elem_idx = thread_id * {{VEC_SIZE}}; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { + for (var elem_idx = thread_id * VEC_SIZE; elem_idx < TILE_SRC1_SHMEM; elem_idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) { let tile_n = elem_idx / TILE_K; let tile_k = elem_idx % TILE_K; let global_n = offset_n + tile_n; let global_k = k_outer + tile_k; let src1_idx = batch_offset + global_n * params.stride_11 + global_k; let src1_val = select( - {{SRC1_TYPE}}(0.0), - src1[src1_idx/{{VEC_SIZE}}], + SRC1_TYPE(0.0), + src1[src1_idx/VEC_SIZE], global_n < params.n && global_k < params.k); - store_shmem({{SHMEM_TYPE}}(src1_val), TILE_SRC0_SHMEM + elem_idx); + store_shmem(SHMEM_TYPE(src1_val), TILE_SRC0_SHMEM + elem_idx); } } +#endif -#enddecl(INIT_SRC1_SHMEM) - -#decl(INIT_SRC0_SHMEM_Q4_0) - +#ifdef INIT_SRC0_SHMEM_Q4_0 const BLOCK_SIZE = 32u; // the number of blocks per k-tile. Note that this currently only works if TILE_K is a multiple of BLOCK_SIZE, which may need to be rethought for larger quantized types. override BLOCKS_K = TILE_K/BLOCK_SIZE; @@ -93,5 +100,4 @@ fn init_shmem_src0(thread_id: u32, batch_offset: u32, offset_m: u32, k_outer: u3 } } } - -#enddecl(INIT_SRC0_SHMEM_Q4_0) +#endif diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl similarity index 55% rename from ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl index 6b1dd26cd9..771e5cd1ee 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_reg_tile.wgsl @@ -1,115 +1,19 @@ -#define(VARIANTS) -[ - { - "SHADER_SUFFIX": "f32_f32_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", - "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f32_f32", - "REPLS": { - "SRC0_TYPE" : "f32", - "SRC1_TYPE" : "f32", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f32_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", - "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f32", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f16_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", - "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f16", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f16", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "q4_0_f32_vec", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", - "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "q4_0_f32", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] - } -] +enable f16; -#end(VARIANTS) +#include "common_decls.tmpl" +#include "mul_mat_decls.tmpl" -#define(DECLS) - -#decl(VEC) +#ifdef VEC fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> vec4 { return vec4(f32(acc[tm][tn]), f32(acc[tm + 1][tn]), f32(acc[tm + 2][tn]), f32(acc[tm + 3][tn])); } -#enddecl(VEC) +#endif -#decl(SCALAR) +#ifdef SCALAR fn store_val(acc: array, TILE_M>, tn: u32, tm: u32) -> f32 { return f32(acc[tm][tn]); } -#enddecl(SCALAR) - -#end(DECLS) - -#define(SHADER) -enable f16; +#endif struct MulMatParams { offset_src0: u32, @@ -130,14 +34,12 @@ struct MulMatParams { broadcast3: u32 }; -@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // M rows, K columns -@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed) -@group(0) @binding(2) var dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed) +@group(0) @binding(0) var src0: array; // M rows, K columns +@group(0) @binding(1) var src1: array; // K rows, N columns (transposed) +@group(0) @binding(2) var dst: array; // M rows, N columns (transposed) @group(0) @binding(3) var params: MulMatParams; -DECLS - fn get_local_n(thread_id: u32) -> u32 { return thread_id / WORKGROUP_SIZE_M; } @@ -145,18 +47,9 @@ fn get_local_m(thread_id: u32) -> u32 { return thread_id % WORKGROUP_SIZE_M; } -// TILE_M must be multiple of 4 for vec4 loads -const TILE_M = {{WEBGPU_TILE_M}}u; -const TILE_N = {{WEBGPU_TILE_N}}u; - -override WORKGROUP_SIZE_M: u32; -override WORKGROUP_SIZE_N: u32; -override TILE_K: u32; - -override TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N; -override TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M; -override TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N; - +const TOTAL_WORKGROUP_SIZE = WORKGROUP_SIZE_M * WORKGROUP_SIZE_N; +const TILE_SRC0_SHMEM = TILE_K * WORKGROUP_SIZE_M * TILE_M; +const TILE_SRC1_SHMEM = TILE_K * WORKGROUP_SIZE_N * TILE_N; var shmem: array; @compute @workgroup_size(TOTAL_WORKGROUP_SIZE) @@ -233,15 +126,13 @@ fn main(@builtin(workgroup_id) wg_id: vec3, for (var tn = 0u; tn < TILE_N; tn++) { let global_col = output_col_base + tn; if (global_col < params.n) { - for (var tm = 0u; tm < TILE_M; tm += {{VEC_SIZE}}) { + for (var tm = 0u; tm < TILE_M; tm += VEC_SIZE) { let global_row = output_row_base + tm; if (global_row < params.m) { let dst_idx = dst_batch_offset + global_col * params.m + global_row; - dst[dst_idx/{{VEC_SIZE}}] = store_val(acc, tn, tm); + dst[dst_idx/VEC_SIZE] = store_val(acc, tn, tm); } } } } } - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl similarity index 66% rename from ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl index 47c8ce36ab..64529e03cd 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_subgroup_matrix.wgsl @@ -1,100 +1,12 @@ -#define(VARIANTS) -[ - { - "SHADER_SUFFIX": "f32_f32_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", - "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f32_f32", - "REPLS": { - "SRC0_TYPE" : "f32", - "SRC1_TYPE" : "f32", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f32_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", - "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f32", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f16_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", - "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "f16_f16", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f16", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_FLOAT", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "q4_0_f32_vec", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "vec4", - "DST_TYPE" : "vec4", - "SHMEM_TYPE" : "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["BYTE_HELPERS", "VEC", "SHMEM_VEC", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] - }, - { - "SHADER_SUFFIX": "q4_0_f32", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "DST_TYPE" : "f32", - "SHMEM_TYPE" : "f16", - "VEC_SIZE" : 1, - }, - "DECLS": ["BYTE_HELPERS", "SCALAR", "SHMEM_SCALAR", "INIT_SRC0_SHMEM_Q4_0", "INIT_SRC1_SHMEM"] - } -] +diagnostic(off, chromium.subgroup_matrix_uniformity); +enable f16; +enable subgroups; +enable chromium_experimental_subgroup_matrix; -#end(VARIANTS) +#include "common_decls.tmpl" +#include "mul_mat_decls.tmpl" -#define(DECLS) - -#decl(VEC) +#ifdef VEC fn store_dst(shmem_idx: u32, dst_idx: u32) { dst[dst_idx] = vec4( f32(shmem[shmem_idx]), @@ -103,21 +15,13 @@ fn store_dst(shmem_idx: u32, dst_idx: u32) { f32(shmem[shmem_idx + 3]) ); } -#enddecl(VEC) +#endif -#decl(SCALAR) +#ifdef SCALAR fn store_dst(shmem_idx: u32, dst_idx: u32) { dst[dst_idx] = f32(shmem[shmem_idx]); } -#enddecl(SCALAR) - -#end(DECLS) - -#define(SHADER) -diagnostic(off, chromium.subgroup_matrix_uniformity); -enable f16; -enable subgroups; -enable chromium_experimental_subgroup_matrix; +#endif struct MulMatParams { offset_src0: u32, @@ -138,36 +42,19 @@ struct MulMatParams { broadcast3: u32 }; -@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // M rows, K columns -@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // K rows, N columns (transposed) -@group(0) @binding(2) var dst: array<{{DST_TYPE}}>; // M rows, N columns (transposed) +// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included +@group(0) @binding(0) var src0: array; // M rows, K columns +@group(0) @binding(1) var src1: array; // K rows, N columns (transposed) +@group(0) @binding(2) var dst: array; // M rows, N columns (transposed) @group(0) @binding(3) var params: MulMatParams; -DECLS - -// Note: These are string interpolated at build time, cannot use override constants due to limitations in -// current Dawn version type definitions/matrix load requirements for constant memory sizes. -const SUBGROUP_M = {{WEBGPU_SUBGROUP_M}}u; -const SUBGROUP_N = {{WEBGPU_SUBGROUP_N}}u; -// For portability we assume the max subgroup size, meaning some subgroups will be masked out if the -// runtime subgroup size is smaller. -const MAX_SUBGROUP_SIZE = {{WEBGPU_MAX_SUBGROUP_SIZE}}u; - -const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N; - -const SUBGROUP_MATRIX_M_SIZE = {{WEBGPU_SG_MAT_M_SIZE}}u; -const SUBGROUP_MATRIX_N_SIZE = {{WEBGPU_SG_MAT_N_SIZE}}u; -const SUBGROUP_MATRIX_K_SIZE = {{WEBGPU_SG_MAT_K_SIZE}}u; - -const SUBGROUP_MATRIX_M = {{WEBGPU_SUBGROUP_MATRIX_M}}u; -const SUBGROUP_MATRIX_N = {{WEBGPU_SUBGROUP_MATRIX_N}}u; - -const TILE_K = {{WEBGPU_TILE_K}}u; - const WG_M_SG_TILE_SIZE = SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; const WG_N_SG_TILE_SIZE = SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; +// For portability we assume the max subgroup size, meaning some subgroups will be masked out if the +// runtime subgroup size is smaller. +const EXPECTED_SUBGROUPS = SUBGROUP_M * SUBGROUP_N; const TOTAL_WORKGROUP_SIZE = SUBGROUP_M * SUBGROUP_N * MAX_SUBGROUP_SIZE; const TILE_SRC0_SHMEM = TILE_K * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; const TILE_SRC1_SHMEM = TILE_K * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; @@ -285,7 +172,7 @@ fn main(@builtin(workgroup_id) wg_id: vec3, let tile_dst_row_base = wg_m * SUBGROUP_M * SUBGROUP_MATRIX_M * SUBGROUP_MATRIX_M_SIZE; let tile_dst_col_base = wg_n * SUBGROUP_N * SUBGROUP_MATRIX_N * SUBGROUP_MATRIX_N_SIZE; - for (var idx = thread_id * {{VEC_SIZE}}; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * {{VEC_SIZE}}) { + for (var idx = thread_id * VEC_SIZE; idx < total_tile_elems; idx += TOTAL_WORKGROUP_SIZE * VEC_SIZE) { let local_row = idx % WG_TILE_STRIDE; let local_col = idx / WG_TILE_STRIDE; @@ -294,9 +181,8 @@ fn main(@builtin(workgroup_id) wg_id: vec3, if (global_col < params.n && global_row < params.m) { let dst_idx = dst_batch_offset + global_col * params.m + global_row; - store_dst(idx, dst_idx/{{VEC_SIZE}}); + store_dst(idx, dst_idx/VEC_SIZE); } } } -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl similarity index 61% rename from ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl index ffbb640328..f9ea95e07b 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/mul_mat_vec.wgsl @@ -1,84 +1,17 @@ -#define(VARIANTS) -[ - { - "SHADER_SUFFIX": "f32_f32_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE": "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "f32_f32", - "REPLS": { - "SRC0_TYPE" : "f32", - "SRC1_TYPE" : "f32", - "DST_TYPE": "f32", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "f16_f32_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE": "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "f16_f32", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "DST_TYPE": "f32", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "f16_f16_vec", - "REPLS": { - "SRC0_TYPE" : "vec4", - "SRC1_TYPE" : "vec4", - "DST_TYPE": "vec4", - "VEC_SIZE" : 4, - }, - "DECLS": ["VEC", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "f16_f16", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f16", - "DST_TYPE": "f32", - "VEC_SIZE" : 1, - }, - "DECLS": ["SCALAR", "MUL_ACC_FLOAT"] - }, - { - "SHADER_SUFFIX": "q4_0_f32", - "REPLS": { - "SRC0_TYPE" : "f16", - "SRC1_TYPE" : "f32", - "DST_TYPE": "f32", - "VEC_SIZE" : 1, - }, - "DECLS": ["BYTE_HELPERS", "SCALAR", "MUL_ACC_Q4_0"] - } -] -#end(VARIANTS) +enable f16; -#define(DECLS) +#include "common_decls.tmpl" -#decl(VEC) -fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 { - return f32(dot({{SRC1_TYPE}}(src0_val), src1_val)); +#ifdef VEC + +#define VEC_SIZE 4 +#define DST_TYPE vec4 +#define SRC0_TYPE vec4 +#define SRC1_TYPE vec4 + +fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { + return f32(dot(SRC1_TYPE(src0_val), src1_val)); } fn store_val(group_base: u32) -> vec4 { @@ -87,33 +20,37 @@ fn store_val(group_base: u32) -> vec4 { partial_sums[group_base + THREADS_PER_OUTPUT * 2], partial_sums[group_base + THREADS_PER_OUTPUT * 3]); } -#enddecl(VEC) +#endif -#decl(SCALAR) -fn inner_dot(src0_val: {{SRC0_TYPE}}, src1_val: {{SRC1_TYPE}}) -> f32 { +#ifdef SCALAR + +#define VEC_SIZE 1 +#define DST_TYPE f32 +#define SRC0_TYPE SRC0_INNER_TYPE +#define SRC1_TYPE SRC1_INNER_TYPE + +fn inner_dot(src0_val: SRC0_TYPE, src1_val: SRC1_TYPE) -> f32 { return f32(src0_val) * f32(src1_val); } fn store_val(group_base: u32) -> f32 { return partial_sums[group_base]; } -#enddecl(SCALAR) - -#decl(MUL_ACC_FLOAT) +#endif +#ifdef MUL_ACC_FLOAT fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { var local_sum = 0.0; - for (var i = tig * {{VEC_SIZE}}; i < tile_size; i += THREADS_PER_OUTPUT * {{VEC_SIZE}}) { - let a = src0[(idx_base + k_outer + i) / {{VEC_SIZE}}]; - let b = shared_vector[i / {{VEC_SIZE}}]; + for (var i = tig * VEC_SIZE; i < tile_size; i += THREADS_PER_OUTPUT * VEC_SIZE) { + let a = src0[(idx_base + k_outer + i) / VEC_SIZE]; + let b = shared_vector[i / VEC_SIZE]; local_sum += inner_dot(a, b); } return local_sum; } +#endif -#enddecl(MUL_ACC_FLOAT) - -#decl(MUL_ACC_Q4_0) +#ifdef MUL_ACC_Q4_0 const BLOCK_SIZE = 32; const NQ = 16u; // number of weights per thread @@ -145,15 +82,7 @@ fn mul_acc(tig:u32, tile_size: u32, idx_base: u32, k_outer: u32) -> f32 { } return local_sum; } - -#enddecl(MUL_ACC_Q4_0) - -#end(DECLS) - -#define(SHADER) -enable f16; - -DECLS +#endif struct MulMatParams { offset_src0: u32, @@ -174,22 +103,20 @@ struct MulMatParams { broadcast3: u32 }; -@group(0) @binding(0) var src0: array<{{SRC0_TYPE}}>; // Matrix (M x K) -@group(0) @binding(1) var src1: array<{{SRC1_TYPE}}>; // Vector (K x 1, transposed) -@group(0) @binding(2) var dst: array<{{DST_TYPE}}>; // Result vector (transposed) +// SRC0_TYPE and SRC1_TYPE are defined in mul_mat_decls, which is included +@group(0) @binding(0) var src0: array; // M rows, K columns +@group(0) @binding(1) var src1: array; // K rows, N columns (transposed) +@group(0) @binding(2) var dst: array; // M rows, N columns (transposed) @group(0) @binding(3) var params: MulMatParams; -override WORKGROUP_SIZE: u32; -override TILE_K: u32; -override OUTPUTS_PER_WG: u32; -override THREADS_PER_OUTPUT = WORKGROUP_SIZE / OUTPUTS_PER_WG; +const THREADS_PER_OUTPUT = WG_SIZE / OUTPUTS_PER_WG; // Shared memory for collaborative loading and reduction -var shared_vector: array<{{SRC1_TYPE}}, TILE_K/{{VEC_SIZE}}>; // Cache vector tile -var partial_sums: array; // For reduction +var shared_vector: array; // Cache vector tile +var partial_sums: array; // For reduction -@compute @workgroup_size(WORKGROUP_SIZE) +@compute @workgroup_size(WG_SIZE) fn main( @builtin(local_invocation_id) local_id: vec3, @builtin(workgroup_id) wg_id: vec3, @@ -232,8 +159,8 @@ fn main( let tile_size = min(TILE_K, params.k - k_tile); // Cooperatively load vector tile into shared memory (all threads) - for (var i = thread_id * {{VEC_SIZE}}; i < tile_size; i += WORKGROUP_SIZE * {{VEC_SIZE}}) { - shared_vector[i / {{VEC_SIZE}}] = src1[(src1_idx_base + k_tile + i) / {{VEC_SIZE}}]; + for (var i = thread_id * VEC_SIZE; i < tile_size; i += WG_SIZE * VEC_SIZE) { + shared_vector[i / VEC_SIZE] = src1[(src1_idx_base + k_tile + i) / VEC_SIZE]; } workgroupBarrier(); @@ -250,7 +177,7 @@ fn main( workgroupBarrier(); let group_base = thread_group * THREADS_PER_OUTPUT; let thread_base = group_base + thread_in_group; - var offset = THREADS_PER_OUTPUT / 2; + var offset: u32 = THREADS_PER_OUTPUT / 2; while (offset > 0) { if (thread_in_group < offset) { partial_sums[thread_base] += partial_sums[thread_base + offset]; @@ -260,8 +187,8 @@ fn main( } // Store back to global memory - if (output_row < params.m && thread_group % {{VEC_SIZE}} == 0 && thread_in_group == 0) { - dst[dst_idx / {{VEC_SIZE}}] = store_val(group_base); + if (output_row < params.m && thread_group % VEC_SIZE == 0 && thread_in_group == 0) { + dst[dst_idx / VEC_SIZE] = store_val(group_base); } } -#end(SHADER) + diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl similarity index 78% rename from ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl rename to ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl index 040e80dfea..3b70a876d7 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/scale.wgsl @@ -1,21 +1,11 @@ -#define(VARIANTS) +#ifdef INPLACE +@group(0) @binding(1) +var params: Params; -[ - { - "SHADER_NAME": "scale_f32", - "DECLS": ["NOT_INPLACE"] - }, - { - "SHADER_NAME": "scale_f32_inplace", - "DECLS": ["INPLACE"] - } -] - -#end(VARIANTS) - -#define(DECLS) - -#decl(NOT_INPLACE) +fn store_scale(val: f32, offset: u32) { + src[offset] = val; +} +#else @group(0) @binding(1) var dst: array; @@ -25,20 +15,7 @@ var params: Params; fn store_scale(val: f32, offset: u32) { dst[offset] = val; } -#enddecl(NOT_INPLACE) - -#decl(INPLACE) -@group(0) @binding(1) -var params: Params; - -fn store_scale(val: f32, offset: u32) { - src[offset] = val; -} -#enddecl(INPLACE) - -#end(DECLS) - -#define(SHADER) +#endif struct Params { offset_src: u32, @@ -65,10 +42,7 @@ struct Params { @group(0) @binding(0) var src: array; -DECLS - -override wg_size: u32; -@compute @workgroup_size(wg_size) +@compute @workgroup_size(WG_SIZE) fn main(@builtin(global_invocation_id) gid: vec3) { if (gid.x >= params.ne) { return; @@ -87,4 +61,3 @@ fn main(@builtin(global_invocation_id) gid: vec3) { store_scale(src[i_src] * params.scale + params.bias, i_dst); } -#end(SHADER) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index 500cb6b72f..ed819eaa4c 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -1496,6 +1496,10 @@ bool ggml_are_same_stride(const struct ggml_tensor * t0, const struct ggml_tenso (t0->nb[3] == t1->nb[3]); } +bool ggml_is_view(const struct ggml_tensor * t) { + return ggml_impl_is_view(t); +} + // check if t1 can be represented as a repetition of t0 bool ggml_can_repeat(const struct ggml_tensor * t0, const struct ggml_tensor * t1) { static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); @@ -5749,7 +5753,7 @@ static struct ggml_tensor * ggml_unary_impl( struct ggml_tensor * a, enum ggml_unary_op op, bool inplace) { - GGML_ASSERT(ggml_is_contiguous_1(a)); + GGML_ASSERT(ggml_is_contiguous_rows(a)); struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 3af4fffe95..e90826dd1b 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -142,6 +142,7 @@ class Keys: EMBEDDING_SCALE = "{arch}.embedding_scale" TOKEN_SHIFT_COUNT = "{arch}.token_shift_count" INTERLEAVE_MOE_LAYER_STEP = "{arch}.interleave_moe_layer_step" + FULL_ATTENTION_INTERVAL = "{arch}.full_attention_interval" ACTIVATION_SPARSITY_SCALE = "{arch}.activation_sparsity_scale" ALTUP_ACTIVE_IDX = "{arch}.altup.active_idx" ALTUP_NUM_INPUTS = "{arch}.altup.num_inputs" @@ -180,6 +181,11 @@ class Keys: SLIDING_WINDOW_PATTERN = "{arch}.attention.sliding_window_pattern" TEMPERATURE_SCALE = "{arch}.attention.temperature_scale" + class Indexer: + HEAD_COUNT = "{arch}.attention.indexer.head_count" + KEY_LENGTH = "{arch}.attention.indexer.key_length" + TOP_K = "{arch}.attention.indexer.top_k" + class Rope: DIMENSION_COUNT = "{arch}.rope.dimension_count" DIMENSION_SECTIONS = "{arch}.rope.dimension_sections" @@ -384,6 +390,8 @@ class MODEL_ARCH(IntEnum): QWEN3NEXT = auto() QWEN3VL = auto() QWEN3VLMOE = auto() + QWEN35 = auto() + QWEN35MOE = auto() PHI2 = auto() PHI3 = auto() PHIMOE = auto() @@ -422,10 +430,12 @@ class MODEL_ARCH(IntEnum): CHATGLM = auto() GLM4 = auto() GLM4_MOE = auto() + GLM_DSA = auto() BITNET = auto() T5 = auto() T5ENCODER = auto() JAIS = auto() + JAIS2 = auto() NEMOTRON = auto() NEMOTRON_H = auto() NEMOTRON_H_MOE = auto() @@ -557,13 +567,14 @@ class MODEL_TENSOR(IntEnum): SSM_D = auto() SSM_NORM = auto() SSM_OUT = auto() + SSM_ALPHA = auto() # qwen3.5 SSM_BETA_ALPHA = auto() # qwen3next SSM_CONV1D_Q = auto() # Kimi Linear SSM_CONV1D_K = auto() # Kimi Linear SSM_CONV1D_V = auto() # Kimi Linear SSM_F_A = auto() # Kimi Linear SSM_F_B = auto() # Kimi Linear - SSM_BETA = auto() # Kimi Linear + SSM_BETA = auto() # Kimi Linear qwen3.5 SSM_G_A = auto() # Kimi Linear SSM_G_B = auto() # Kimi Linear TIME_MIX_W0 = auto() @@ -642,6 +653,7 @@ class MODEL_TENSOR(IntEnum): ENC_OUTPUT_NORM = auto() CLS = auto() # classifier CLS_OUT = auto() # classifier output projection + CLS_NORM = auto() CONV1D = auto() CONVNEXT_DW = auto() CONVNEXT_NORM = auto() @@ -666,6 +678,10 @@ class MODEL_TENSOR(IntEnum): VISEXP_GATE = auto() VISEXP_DOWN = auto() VISEXP_UP = auto() + INDEXER_K_NORM = auto() + INDEXER_PROJ = auto() + INDEXER_ATTN_K = auto() + INDEXER_ATTN_Q_B = auto() # vision V_MMPROJ = auto() V_MMPROJ_FC = auto() @@ -814,6 +830,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.QWEN3NEXT: "qwen3next", MODEL_ARCH.QWEN3VL: "qwen3vl", MODEL_ARCH.QWEN3VLMOE: "qwen3vlmoe", + MODEL_ARCH.QWEN35: "qwen35", + MODEL_ARCH.QWEN35MOE: "qwen35moe", MODEL_ARCH.PHI2: "phi2", MODEL_ARCH.PHI3: "phi3", MODEL_ARCH.PHIMOE: "phimoe", @@ -852,10 +870,12 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.CHATGLM: "chatglm", MODEL_ARCH.GLM4: "glm4", MODEL_ARCH.GLM4_MOE: "glm4moe", + MODEL_ARCH.GLM_DSA: "glm-dsa", MODEL_ARCH.BITNET: "bitnet", MODEL_ARCH.T5: "t5", MODEL_ARCH.T5ENCODER: "t5encoder", MODEL_ARCH.JAIS: "jais", + MODEL_ARCH.JAIS2: "jais2", MODEL_ARCH.NEMOTRON: "nemotron", MODEL_ARCH.NEMOTRON_H: "nemotron_h", MODEL_ARCH.NEMOTRON_H_MOE: "nemotron_h_moe", @@ -985,13 +1005,14 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d", MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm", MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out", + MODEL_TENSOR.SSM_ALPHA: "blk.{bid}.ssm_alpha", # qwen3.5 MODEL_TENSOR.SSM_BETA_ALPHA: "blk.{bid}.ssm_ba", MODEL_TENSOR.SSM_CONV1D_Q: "blk.{bid}.ssm_conv1d_q", # Kimi Linear MODEL_TENSOR.SSM_CONV1D_K: "blk.{bid}.ssm_conv1d_k", # Kimi Linear MODEL_TENSOR.SSM_CONV1D_V: "blk.{bid}.ssm_conv1d_v", # Kimi Linear MODEL_TENSOR.SSM_F_A: "blk.{bid}.ssm_f_a", # Kimi Linear MODEL_TENSOR.SSM_F_B: "blk.{bid}.ssm_f_b", # Kimi Linear - MODEL_TENSOR.SSM_BETA: "blk.{bid}.ssm_beta", # Kimi Linear + MODEL_TENSOR.SSM_BETA: "blk.{bid}.ssm_beta", # Kimi Linear qwen3.5 MODEL_TENSOR.SSM_G_A: "blk.{bid}.ssm_g_a", # Kimi Linear MODEL_TENSOR.SSM_G_B: "blk.{bid}.ssm_g_b", # Kimi Linear MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0", @@ -1070,6 +1091,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.ENC_OUTPUT_NORM: "enc.output_norm", MODEL_TENSOR.CLS: "cls", MODEL_TENSOR.CLS_OUT: "cls.output", + MODEL_TENSOR.CLS_NORM: "cls.norm", MODEL_TENSOR.CONV1D: "conv1d", MODEL_TENSOR.CONVNEXT_DW: "convnext.{bid}.dw", MODEL_TENSOR.CONVNEXT_NORM: "convnext.{bid}.norm", @@ -1094,6 +1116,10 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.VISEXP_GATE: "blk.{bid}.vis_gate", MODEL_TENSOR.VISEXP_DOWN: "blk.{bid}.vis_down", MODEL_TENSOR.VISEXP_UP: "blk.{bid}.vis_up", + MODEL_TENSOR.INDEXER_K_NORM: "blk.{bid}.indexer.k_norm", + MODEL_TENSOR.INDEXER_PROJ: "blk.{bid}.indexer.proj", + MODEL_TENSOR.INDEXER_ATTN_K: "blk.{bid}.indexer.attn_k", + MODEL_TENSOR.INDEXER_ATTN_Q_B: "blk.{bid}.indexer.attn_q_b", # vision MODEL_TENSOR.V_MMPROJ: "mm.{bid}", MODEL_TENSOR.V_MMPROJ_FC: "mm.model.fc", @@ -1485,6 +1511,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_NORM, MODEL_TENSOR.CLS, MODEL_TENSOR.CLS_OUT, + MODEL_TENSOR.CLS_NORM, ], MODEL_ARCH.NOMIC_BERT: [ MODEL_TENSOR.TOKEN_EMBD, @@ -1818,6 +1845,61 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], + MODEL_ARCH.QWEN35: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.ATTN_GATE, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_NORM, + MODEL_TENSOR.SSM_BETA, + MODEL_TENSOR.SSM_ALPHA, + MODEL_TENSOR.SSM_OUT + ], + MODEL_ARCH.QWEN35MOE: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_POST_NORM, + MODEL_TENSOR.ATTN_GATE, + MODEL_TENSOR.ATTN_QKV, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_GATE_INP_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.SSM_A, + MODEL_TENSOR.SSM_CONV1D, + MODEL_TENSOR.SSM_DT, + MODEL_TENSOR.SSM_NORM, + MODEL_TENSOR.SSM_BETA, + MODEL_TENSOR.SSM_ALPHA, + MODEL_TENSOR.SSM_OUT + ], MODEL_ARCH.PLAMO: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -2583,6 +2665,13 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_UP, MODEL_TENSOR.ATTN_POST_NORM, MODEL_TENSOR.FFN_POST_NORM, + # NextN/MTP tensors - preserved but unused + MODEL_TENSOR.NEXTN_EH_PROJ, + MODEL_TENSOR.NEXTN_EMBED_TOKENS, + MODEL_TENSOR.NEXTN_ENORM, + MODEL_TENSOR.NEXTN_HNORM, + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, ], MODEL_ARCH.GLM4_MOE: [ MODEL_TENSOR.TOKEN_EMBD, @@ -2615,6 +2704,47 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, ], + MODEL_ARCH.GLM_DSA: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_A, + MODEL_TENSOR.ATTN_Q_B, + MODEL_TENSOR.ATTN_KV_A_MQA, + MODEL_TENSOR.ATTN_KV_B, + MODEL_TENSOR.ATTN_K_B, + MODEL_TENSOR.ATTN_V_B, + MODEL_TENSOR.ATTN_Q_A_NORM, + MODEL_TENSOR.ATTN_KV_A_NORM, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_EXP_PROBS_B, + MODEL_TENSOR.INDEXER_K_NORM, + MODEL_TENSOR.INDEXER_PROJ, + MODEL_TENSOR.INDEXER_ATTN_K, + MODEL_TENSOR.INDEXER_ATTN_Q_B, + # NextN/MTP tensors - preserved but unused + MODEL_TENSOR.NEXTN_EH_PROJ, + MODEL_TENSOR.NEXTN_EMBED_TOKENS, + MODEL_TENSOR.NEXTN_ENORM, + MODEL_TENSOR.NEXTN_HNORM, + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, + ], MODEL_ARCH.BITNET: [ MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, @@ -2689,6 +2819,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_GATE, MODEL_TENSOR.FFN_UP, ], + MODEL_ARCH.JAIS2: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], MODEL_ARCH.NEMOTRON: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -3704,6 +3847,7 @@ class VisionProjectorType: VOXTRAL = "voxtral" LFM2 = "lfm2" KIMIVL = "kimivl" + KIMIK25 = "kimik25" LIGHTONOCR = "lightonocr" COGVLM = "cogvlm" JANUS_PRO = "janus_pro" @@ -3711,6 +3855,7 @@ class VisionProjectorType: MUSIC_FLAMINGO = "musicflamingo" # audio GLM4V = "glm4v" YOUTUVL = "youtuvl" + NEMOTRON_V2_VL = "nemotron_v2_vl" # Items here are (block size, type size) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 62172b24c3..4245d18bc4 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -708,6 +708,9 @@ class GGUFWriter: def add_leading_dense_block_count(self, length: int) -> None: self.add_uint32(Keys.LLM.LEADING_DENSE_BLOCK_COUNT.format(arch=self.arch), length) + def add_full_attention_interval(self, interval: int) -> None: + self.add_uint32(Keys.LLM.FULL_ATTENTION_INTERVAL.format(arch=self.arch), interval) + def add_feed_forward_length(self, length: int | Sequence[int]) -> None: if isinstance(length, int): self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length) @@ -768,6 +771,15 @@ class GGUFWriter: def add_value_length_mla(self, length: int) -> None: self.add_uint32(Keys.Attention.VALUE_LENGTH_MLA.format(arch=self.arch), length) + def add_indexer_head_count(self, count: int) -> None: + self.add_uint32(Keys.Attention.Indexer.HEAD_COUNT.format(arch=self.arch), count) + + def add_indexer_key_length(self, length: int) -> None: + self.add_uint32(Keys.Attention.Indexer.KEY_LENGTH.format(arch=self.arch), length) + + def add_indexer_top_k(self, top_k: int) -> None: + self.add_uint32(Keys.Attention.Indexer.TOP_K.format(arch=self.arch), top_k) + def add_max_alibi_bias(self, bias: float) -> None: self.add_float32(Keys.Attention.MAX_ALIBI_BIAS.format(arch=self.arch), bias) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 167ade7803..5fc75c52eb 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -228,6 +228,7 @@ class TensorNameMap: "transformer_encoder.{bid}.qkv", # neobert "layers.{bid}.attn.Wqkv", # modern-bert "model.layers.{bid}.self_attn.language_expert_query_key_value", # cogvlm + "model.layers.{bid}.linear_attn.in_proj_qkv", # qwen3.5 ), # Attention query @@ -359,6 +360,7 @@ class TensorNameMap: MODEL_TENSOR.ATTN_GATE: ( "model.layers.{bid}.self_attn.gate_proj", # afmoe + "model.layers.{bid}.linear_attn.in_proj_z", # qwen3.5 "model.layers.{bid}.self_attn.g_proj", # step3.5 head-wise attention gate ), @@ -823,6 +825,10 @@ class TensorNameMap: "model.layers.layers.{bid}.mixer.out_proj", # plamo2 ), + MODEL_TENSOR.SSM_ALPHA: ( + "model.layers.{bid}.linear_attn.in_proj_a", # qwen3.5 + ), + MODEL_TENSOR.SSM_BETA_ALPHA: ( "model.layers.{bid}.linear_attn.in_proj_ba", # qwen3next ), @@ -844,7 +850,8 @@ class TensorNameMap: "model.layers.{bid}.self_attn.f_b_proj", ), MODEL_TENSOR.SSM_BETA: ( - "model.layers.{bid}.self_attn.b_proj", + "model.layers.{bid}.linear_attn.in_proj_b", # qwen3.5 + "model.layers.{bid}.self_attn.b_proj", # Kimi Linear ), MODEL_TENSOR.SSM_G_A: ( "model.layers.{bid}.self_attn.g_a_proj", @@ -1199,6 +1206,22 @@ class TensorNameMap: "model.layers.{bid}.self_attn.vision_expert_query_key_value", # cogvlm ), + MODEL_TENSOR.INDEXER_K_NORM: ( + "model.layers.{bid}.self_attn.indexer.k_norm", # DSA + ), + + MODEL_TENSOR.INDEXER_PROJ: ( + "model.layers.{bid}.self_attn.indexer.weights_proj", # DSA + ), + + MODEL_TENSOR.INDEXER_ATTN_K: ( + "model.layers.{bid}.self_attn.indexer.wk", # DSA + ), + + MODEL_TENSOR.INDEXER_ATTN_Q_B: ( + "model.layers.{bid}.self_attn.indexer.wq_b", # DSA + ), + ############################################################################ # TODO: these do not belong to block_mappings_cfg - move them to mappings_cfg MODEL_TENSOR.ENC_OUTPUT_NORM: ( @@ -1217,6 +1240,10 @@ class TensorNameMap: MODEL_TENSOR.CLS_OUT: ( "classifier.out_proj", # roberta ), + + MODEL_TENSOR.CLS_NORM: ( + "head.norm", # modern-bert + ), ############################################################################# MODEL_TENSOR.CONVNEXT_DW: ( @@ -1296,6 +1323,7 @@ class TensorNameMap: MODEL_TENSOR.V_MMPROJ: ( "multi_modal_projector.linear_{bid}", + "mm_projector.proj.linear_{bid}", # Kimi-K2.5 "visual.merger.mlp.{bid}", # qwen2vl "merger.mlp.{bid}", ), @@ -1322,6 +1350,7 @@ class TensorNameMap: "model.vision_tower.embeddings.cls_token", # Intern-S1 "vision_model.class_embedding", # llama 4 "model.vision.patch_embedding.cls_embedding", # cogvlm + "vision_model.radio_model.model.patch_generator.cls_token.token", # Nemotron Nano v2 VL ), MODEL_TENSOR.V_ENC_EMBD_PATCH: ( @@ -1336,6 +1365,7 @@ class TensorNameMap: "vision_tower.patch_embed.proj", # kimi-vl "model.vision.patch_embedding.proj", # cogvlm "siglip2.vision_model.embeddings.patch_embedding", + "vision_model.radio_model.model.patch_generator.embedder", # Nemotron Nano v2 VL ), MODEL_TENSOR.V_ENC_EMBD_NORM: ( @@ -1352,11 +1382,14 @@ class TensorNameMap: "visual.pos_embed", # qwen3vl "model.vision.patch_embedding.position_embedding", # cogvlm "visual.embeddings.position_embedding", # glm4v + "vision_model.radio_model.model.patch_generator.pos_embed", # Nemotron Nano v2 VL ), MODEL_TENSOR.V_ENC_ATTN_QKV: ( "visual.blocks.{bid}.attn.qkv", # qwen3vl "model.vision.transformer.layers.{bid}.attention.query_key_value", # cogvlm + "vision_tower.encoder.blocks.{bid}.wqkv", # Kimi-K2.5 + "vision_model.radio_model.model.blocks.{bid}.attn.qkv", # Nemotron Nano v2 VL ), MODEL_TENSOR.V_ENC_ATTN_Q: ( @@ -1375,6 +1408,7 @@ class TensorNameMap: MODEL_TENSOR.V_ENC_ATTN_Q_NORM: ( "vision_tower.vision_model.encoder.layers.{bid}.attn.q_norm", # InternVL "model.vision_tower.encoder.layer.{bid}.attention.q_norm", # Intern-S1 + "visual.blocks.{bid}.attn.q_norm", # GLM-OCR ), MODEL_TENSOR.V_ENC_ATTN_K: ( @@ -1393,6 +1427,7 @@ class TensorNameMap: MODEL_TENSOR.V_ENC_ATTN_K_NORM: ( "vision_tower.vision_model.encoder.layers.{bid}.attn.k_norm", # InternVL "model.vision_tower.encoder.layer.{bid}.attention.k_norm", # Intern-S1 + "visual.blocks.{bid}.attn.k_norm", # GLM-OCR ), MODEL_TENSOR.V_ENC_ATTN_V: ( @@ -1421,6 +1456,7 @@ class TensorNameMap: "vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1) "model.vision.transformer.layers.{bid}.input_layernorm", # cogvlm "siglip2.vision_model.encoder.layers.{bid}.layer_norm1", + "vision_model.radio_model.model.blocks.{bid}.norm1", # Nemotron Nano v2 VL ), MODEL_TENSOR.V_ENC_ATTN_O: ( @@ -1437,6 +1473,7 @@ class TensorNameMap: "vision_tower.encoder.blocks.{bid}.wo", # kimi-vl "model.vision.transformer.layers.{bid}.attention.dense", # cogvlm "siglip2.vision_model.encoder.layers.{bid}.self_attn.out_proj", # youtuvl + "vision_model.radio_model.model.blocks.{bid}.attn.proj", # Nemotron Nano v2 VL ), MODEL_TENSOR.V_ENC_POST_ATTN_NORM: ( @@ -1452,6 +1489,7 @@ class TensorNameMap: "vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1) "model.vision.transformer.layers.{bid}.post_attention_layernorm", # cogvlm "siglip2.vision_model.encoder.layers.{bid}.layer_norm2", + "vision_model.radio_model.model.blocks.{bid}.norm2", # Nemotron Nano v2 VL ), MODEL_TENSOR.V_ENC_FFN_UP: ( @@ -1468,6 +1506,7 @@ class TensorNameMap: "vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1) "model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm "siglip2.vision_model.encoder.layers.{bid}.mlp.fc1", + "vision_model.radio_model.model.blocks.{bid}.mlp.fc1", # Nemotron Nano v2 VL ), MODEL_TENSOR.V_ENC_FFN_GATE: ( @@ -1490,6 +1529,7 @@ class TensorNameMap: "vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1) "model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm "siglip2.vision_model.encoder.layers.{bid}.mlp.fc2", + "vision_model.radio_model.model.blocks.{bid}.mlp.fc2", # Nemotron Nano v2 VL ), MODEL_TENSOR.V_LAYER_SCALE_1: ( @@ -1531,6 +1571,7 @@ class TensorNameMap: "multi_modal_projector.norm", "multi_modal_projector.layer_norm", "multi_modal_projector.pre_norm", + "mm_projector.pre_norm", # Kimi-K2.5 "pre_mm_projector_norm", "model.vision.linear_proj.norm1", # cogvlm "merger.ln_q", diff --git a/include/llama.h b/include/llama.h index fa1872add1..e17f16eed5 100644 --- a/include/llama.h +++ b/include/llama.h @@ -491,7 +491,7 @@ extern "C" { enum llama_params_fit_status { LLAMA_PARAMS_FIT_STATUS_SUCCESS = 0, // found allocations that are projected to fit LLAMA_PARAMS_FIT_STATUS_FAILURE = 1, // could not find allocations that are projected to fit - LLAMA_PARAMS_FIT_STATUS_ERROR = 2, // a hard error occured, e.g. because no model could be found at the specified path + LLAMA_PARAMS_FIT_STATUS_ERROR = 2, // a hard error occurred, e.g. because no model could be found at the specified path }; // fits mparams and cparams to free device memory (assumes system memory is unlimited) @@ -665,21 +665,12 @@ extern "C" { // The following functions operate on a llama_context, hence the naming: llama_verb_... - // Add a loaded LoRA adapter to given context - // This will not modify model's weight - LLAMA_API int32_t llama_set_adapter_lora( + // Set LoRa adapters on the context. Will only modify if the adapters currently in context are different. + LLAMA_API int32_t llama_set_adapters_lora( struct llama_context * ctx, - struct llama_adapter_lora * adapter, - float scale); - - // Remove a specific LoRA adapter from given context - // Return -1 if the adapter is not present in the context - LLAMA_API int32_t llama_rm_adapter_lora( - struct llama_context * ctx, - struct llama_adapter_lora * adapter); - - // Remove all LoRA adapters from given context - LLAMA_API void llama_clear_adapter_lora(struct llama_context * ctx); + struct llama_adapter_lora ** adapters, + size_t n_adapters, + float * scales); // Apply a loaded control vector to a llama_context, or if data is NULL, clear // the currently loaded vector. @@ -687,7 +678,7 @@ extern "C" { // to an n_embd x n_layers buffer starting from layer 1. // il_start and il_end are the layer range the vector should apply to (both inclusive) // See llama_control_vector_load in common to load a control vector. - LLAMA_API int32_t llama_apply_adapter_cvec( + LLAMA_API int32_t llama_set_adapter_cvec( struct llama_context * ctx, const float * data, size_t len, @@ -1163,9 +1154,9 @@ extern "C" { // /// Apply chat template. Inspired by hf apply_chat_template() on python. - /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model" + /// /// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template - /// @param tmpl A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead. + /// @param tmpl A Jinja template to use for this chat. /// @param chat Pointer to a list of multiple llama_chat_message /// @param n_msg Number of llama_chat_message in this chat /// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message. diff --git a/scripts/pr2wt.sh b/scripts/pr2wt.sh index bd635f3b9d..067f5d466b 100755 --- a/scripts/pr2wt.sh +++ b/scripts/pr2wt.sh @@ -30,12 +30,18 @@ fi PR=$1 [[ "$PR" =~ ^[0-9]+$ ]] || { echo "error: PR number must be numeric"; exit 1; } +url_origin=$(git config --get remote.upstream.url 2>/dev/null) || \ url_origin=$(git config --get remote.origin.url) || { - echo "error: no remote named 'origin' in this repository" + echo "error: no remote named 'upstream' or 'origin' in this repository" exit 1 } -org_repo=$(echo $url_origin | cut -d/ -f4-) +# Extract org/repo from either https or ssh format. +if [[ $url_origin =~ ^git@ ]]; then + org_repo=$(echo $url_origin | cut -d: -f2) +else + org_repo=$(echo $url_origin | cut -d/ -f4-) +fi org_repo=${org_repo%.git} echo "org/repo: $org_repo" diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 81e79a9470..02a096882e 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -a8db410a252c8c8f2d120c6f2e7133ebe032f35d +d6754f3d0e6d0acd21c12442353c9fd2f94188e7 diff --git a/scripts/sync_vendor.py b/scripts/sync_vendor.py index 1ff6a9a40f..fe1286d009 100755 --- a/scripts/sync_vendor.py +++ b/scripts/sync_vendor.py @@ -1,6 +1,11 @@ #!/usr/bin/env python3 import urllib.request +import os +import sys +import subprocess + +HTTPLIB_VERSION = "d4180e923f846b44a3d30acd938438d6e64fc9f6" vendor = { "https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp", @@ -12,8 +17,9 @@ vendor = { # "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.23/miniaudio.h": "vendor/miniaudio/miniaudio.h", "https://github.com/mackron/miniaudio/raw/669ed3e844524fcd883231b13095baee9f6de304/miniaudio.h": "vendor/miniaudio/miniaudio.h", - "https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.30.2/httplib.h": "vendor/cpp-httplib/httplib.h", - "https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.30.2/LICENSE": "vendor/cpp-httplib/LICENSE", + f"https://raw.githubusercontent.com/yhirose/cpp-httplib/{HTTPLIB_VERSION}/httplib.h": "httplib.h", + f"https://raw.githubusercontent.com/yhirose/cpp-httplib/{HTTPLIB_VERSION}/split.py": "split.py", + f"https://raw.githubusercontent.com/yhirose/cpp-httplib/{HTTPLIB_VERSION}/LICENSE": "vendor/cpp-httplib/LICENSE", "https://raw.githubusercontent.com/sheredom/subprocess.h/b49c56e9fe214488493021017bf3954b91c7c1f5/subprocess.h": "vendor/sheredom/subprocess.h", } @@ -22,19 +28,16 @@ for url, filename in vendor.items(): print(f"downloading {url} to {filename}") # noqa: NP100 urllib.request.urlretrieve(url, filename) - # split cpp/h files for httplib - # see: https://github.com/yhirose/cpp-httplib/blob/master/split.py - if 'httplib.h' in filename: - border = '// ----------------------------------------------------------------------------' - with open(filename, 'r') as f: - content = f.read() - header, implementation, footer = content.split(border, 2) - fname_cpp = filename.replace('.h', '.cpp') - with open(filename, 'w') as fh: - fh.write(header) - fh.write(footer) - with open(fname_cpp, 'w') as fc: - fc.write('#include "httplib.h"\n') - fc.write('namespace httplib {\n') - fc.write(implementation.replace('\ninline ', '\n')) - fc.write('} // namespace httplib\n') +print("Splitting httplib.h...") # noqa: NP100 +try: + subprocess.check_call([ + sys.executable, "split.py", + "--extension", "cpp", + "--out", "vendor/cpp-httplib" + ]) +except Exception as e: + print(f"Error: {e}") # noqa: NP100 + sys.exit(1) +finally: + os.remove("split.py") + os.remove("httplib.h") diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 2115fc4255..c10d5c70fb 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -57,13 +57,14 @@ add_library(llama models/deci.cpp models/deepseek.cpp models/deepseek2.cpp + models/delta-net-base.cpp models/dots1.cpp models/dream.cpp models/ernie4-5-moe.cpp models/ernie4-5.cpp + models/exaone-moe.cpp models/exaone.cpp models/exaone4.cpp - models/exaone-moe.cpp models/falcon-h1.cpp models/falcon.cpp models/gemma-embedding.cpp @@ -83,6 +84,7 @@ add_library(llama models/hunyuan-moe.cpp models/internlm2.cpp models/jais.cpp + models/jais2.cpp models/jamba.cpp models/kimi-linear.cpp models/lfm2.cpp @@ -91,10 +93,12 @@ add_library(llama models/llama-iswa.cpp models/llama.cpp models/maincoder.cpp + models/mamba-base.cpp models/mamba.cpp models/mimo2-iswa.cpp models/minicpm3.cpp models/minimax-m2.cpp + models/mistral3.cpp models/modern-bert.cpp models/mpt.cpp models/nemotron-h.cpp @@ -118,10 +122,12 @@ add_library(llama models/qwen2moe.cpp models/qwen2vl.cpp models/qwen3.cpp - models/qwen3vl.cpp - models/qwen3vl-moe.cpp + models/qwen35.cpp + models/qwen35moe.cpp models/qwen3moe.cpp models/qwen3next.cpp + models/qwen3vl-moe.cpp + models/qwen3vl.cpp models/refact.cpp models/rnd1.cpp models/rwkv6-base.cpp @@ -140,8 +146,6 @@ add_library(llama models/t5-enc.cpp models/wavtokenizer-dec.cpp models/xverse.cpp - models/mistral3.cpp - models/graph-context-mamba.cpp ) set_target_properties(llama PROPERTIES diff --git a/src/llama-adapter.h b/src/llama-adapter.h index d275d25425..aa3ab63ad7 100644 --- a/src/llama-adapter.h +++ b/src/llama-adapter.h @@ -39,6 +39,8 @@ private: std::vector tensors; // per layer }; +using llama_adapter_cvec_ptr = std::shared_ptr; + // // llama_adapter_lora // @@ -84,3 +86,4 @@ struct llama_adapter_lora { }; using llama_adapter_loras = std::unordered_map; +using llama_adapter_loras_ptr = std::unique_ptr; diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index bd78f1e556..3cb45b6922 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -37,6 +37,8 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_QWEN3NEXT, "qwen3next" }, { LLM_ARCH_QWEN3VL, "qwen3vl" }, { LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" }, + { LLM_ARCH_QWEN35, "qwen35" }, + { LLM_ARCH_QWEN35MOE, "qwen35moe" }, { LLM_ARCH_PHI2, "phi2" }, { LLM_ARCH_PHI3, "phi3" }, { LLM_ARCH_PHIMOE, "phimoe" }, @@ -72,10 +74,12 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_GLM4, "glm4" }, { LLM_ARCH_GLM4_MOE, "glm4moe" }, + { LLM_ARCH_GLM_DSA, "glm-dsa" }, { LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5ENCODER, "t5encoder" }, { LLM_ARCH_JAIS, "jais" }, + { LLM_ARCH_JAIS2, "jais2" }, { LLM_ARCH_NEMOTRON, "nemotron" }, { LLM_ARCH_NEMOTRON_H, "nemotron_h" }, { LLM_ARCH_NEMOTRON_H_MOE, "nemotron_h_moe" }, @@ -195,6 +199,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_EMBEDDING_SCALE, "%s.embedding_scale" }, { LLM_KV_TOKEN_SHIFT_COUNT, "%s.token_shift_count" }, { LLM_KV_INTERLEAVE_MOE_LAYER_STEP, "%s.interleave_moe_layer_step" }, + { LLM_KV_FULL_ATTENTION_INTERVAL, "%s.full_attention_interval" }, { LLM_KV_ATTENTION_HEAD_COUNT, "%s.attention.head_count" }, { LLM_KV_ATTENTION_HEAD_COUNT_KV, "%s.attention.head_count_kv" }, @@ -222,6 +227,9 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_TEMPERATURE_SCALE, "%s.attention.temperature_scale" }, { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, + { LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, "%s.attention.indexer.head_count" }, + { LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, "%s.attention.indexer.key_length" }, + { LLM_KV_ATTENTION_INDEXER_TOP_K, "%s.attention.indexer.top_k" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, @@ -360,12 +368,14 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_TOKEN_TYPES, "token_types" }, { LLM_TENSOR_CLS, "cls" }, { LLM_TENSOR_CLS_OUT, "cls.output" }, + { LLM_TENSOR_CLS_NORM, "cls.norm" }, { LLM_TENSOR_ENC_OUTPUT_NORM, "enc.output_norm" }, { LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" }, { LLM_TENSOR_SSM_A_NOSCAN, "blk.%d.ssm_a" }, { LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" }, { LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" }, { LLM_TENSOR_SSM_BETA_ALPHA, "blk.%d.ssm_ba" }, + { LLM_TENSOR_SSM_ALPHA, "blk.%d.ssm_alpha" }, { LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" }, { LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" }, { LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" }, @@ -512,6 +522,10 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_VISEXP_FFN_GATE, "blk.%d.vis_gate" }, { LLM_TENSOR_VISEXP_FFN_DOWN, "blk.%d.vis_down" }, { LLM_TENSOR_VISEXP_FFN_UP, "blk.%d.vis_up" }, + { LLM_TENSOR_INDEXER_K_NORM, "blk.%d.indexer.k_norm" }, + { LLM_TENSOR_INDEXER_PROJ, "blk.%d.indexer.proj" }, + { LLM_TENSOR_INDEXER_ATTN_K, "blk.%d.indexer.attn_k" }, + { LLM_TENSOR_INDEXER_ATTN_Q_B, "blk.%d.indexer.attn_q_b" }, }; static std::set llm_get_tensor_names(llm_arch arch) { @@ -816,6 +830,7 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_NORM, LLM_TENSOR_CLS, LLM_TENSOR_CLS_OUT, + LLM_TENSOR_CLS_NORM, }; case LLM_ARCH_JINA_BERT_V2: return { @@ -968,7 +983,6 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_ATTN_OUT, LLM_TENSOR_ATTN_QKV, LLM_TENSOR_ATTN_GATE, - LLM_TENSOR_FFN_NORM, LLM_TENSOR_FFN_GATE_INP, LLM_TENSOR_FFN_GATE_EXPS, LLM_TENSOR_FFN_DOWN_EXPS, @@ -985,6 +999,63 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, }; + case LLM_ARCH_QWEN35: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_GATE, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_SSM_A_NOSCAN, + LLM_TENSOR_SSM_CONV1D, + LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_BETA, + LLM_TENSOR_SSM_ALPHA, + LLM_TENSOR_SSM_NORM, + LLM_TENSOR_SSM_OUT, + }; + case LLM_ARCH_QWEN35MOE: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_POST_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_ATTN_QKV, + LLM_TENSOR_ATTN_GATE, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_INP_SHEXP, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_SSM_A_NOSCAN, + LLM_TENSOR_SSM_CONV1D, + LLM_TENSOR_SSM_DT, + LLM_TENSOR_SSM_BETA, + LLM_TENSOR_SSM_ALPHA, + LLM_TENSOR_SSM_NORM, + LLM_TENSOR_SSM_OUT, + }; case LLM_ARCH_QWEN3VL: case LLM_ARCH_CHAMELEON: case LLM_ARCH_HUNYUAN_DENSE: @@ -1565,6 +1636,12 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_DOWN, LLM_TENSOR_ATTN_POST_NORM, LLM_TENSOR_FFN_POST_NORM, + LLM_TENSOR_NEXTN_EH_PROJ, + LLM_TENSOR_NEXTN_EMBED_TOKENS, + LLM_TENSOR_NEXTN_ENORM, + LLM_TENSOR_NEXTN_HNORM, + LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, + LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, }; case LLM_ARCH_GLM4_MOE: return { @@ -1597,6 +1674,46 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, }; + case LLM_ARCH_GLM_DSA: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q_A_NORM, + LLM_TENSOR_ATTN_KV_A_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_A, + LLM_TENSOR_ATTN_Q_B, + LLM_TENSOR_ATTN_KV_A_MQA, + LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_K_B, + LLM_TENSOR_ATTN_V_B, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_INP_SHEXP, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_FFN_EXP_PROBS_B, + LLM_TENSOR_INDEXER_K_NORM, + LLM_TENSOR_INDEXER_PROJ, + LLM_TENSOR_INDEXER_ATTN_K, + LLM_TENSOR_INDEXER_ATTN_Q_B, + LLM_TENSOR_NEXTN_EH_PROJ, + LLM_TENSOR_NEXTN_EMBED_TOKENS, + LLM_TENSOR_NEXTN_ENORM, + LLM_TENSOR_NEXTN_HNORM, + LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, + LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, + }; case LLM_ARCH_BITNET: return { LLM_TENSOR_TOKEN_EMBD, @@ -1675,6 +1792,20 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_GATE, LLM_TENSOR_FFN_DOWN, }; + case LLM_ARCH_JAIS2: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_DOWN, + }; case LLM_ARCH_NEMOTRON_H: return { LLM_TENSOR_TOKEN_EMBD, @@ -2404,6 +2535,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_CLS_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output {LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output {LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, @@ -2456,6 +2588,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_SSM_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_DT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_SSM_ALPHA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_SSM_BETA_ALPHA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, @@ -2582,6 +2715,10 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_VISEXP_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_VISEXP_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_VISEXP_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_INDEXER_K_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, // NextN/MTP tensors are currently ignored (reserved for future MTP support) // These tensors only exist in the last layer(s) and are treated as output tensors {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, @@ -2675,6 +2812,8 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { case LLM_ARCH_NEMOTRON_H_MOE: case LLM_ARCH_QWEN3NEXT: case LLM_ARCH_KIMI_LINEAR: + case LLM_ARCH_QWEN35: + case LLM_ARCH_QWEN35MOE: return true; default: return false; diff --git a/src/llama-arch.h b/src/llama-arch.h index e8263369b8..43ca9a6a48 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -41,6 +41,8 @@ enum llm_arch { LLM_ARCH_QWEN3NEXT, LLM_ARCH_QWEN3VL, LLM_ARCH_QWEN3VLMOE, + LLM_ARCH_QWEN35, + LLM_ARCH_QWEN35MOE, LLM_ARCH_PHI2, LLM_ARCH_PHI3, LLM_ARCH_PHIMOE, @@ -76,10 +78,12 @@ enum llm_arch { LLM_ARCH_CHATGLM, LLM_ARCH_GLM4, LLM_ARCH_GLM4_MOE, + LLM_ARCH_GLM_DSA, LLM_ARCH_BITNET, LLM_ARCH_T5, LLM_ARCH_T5ENCODER, LLM_ARCH_JAIS, + LLM_ARCH_JAIS2, LLM_ARCH_NEMOTRON, LLM_ARCH_NEMOTRON_H, LLM_ARCH_NEMOTRON_H_MOE, @@ -199,6 +203,7 @@ enum llm_kv { LLM_KV_EMBEDDING_SCALE, LLM_KV_TOKEN_SHIFT_COUNT, LLM_KV_INTERLEAVE_MOE_LAYER_STEP, + LLM_KV_FULL_ATTENTION_INTERVAL, LLM_KV_ATTENTION_HEAD_COUNT, LLM_KV_ATTENTION_HEAD_COUNT_KV, @@ -226,6 +231,9 @@ enum llm_kv { LLM_KV_ATTENTION_TEMPERATURE_SCALE, LLM_KV_ATTENTION_KEY_LENGTH_MLA, LLM_KV_ATTENTION_VALUE_LENGTH_MLA, + LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, + LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, + LLM_KV_ATTENTION_INDEXER_TOP_K, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_SECTIONS, @@ -404,13 +412,14 @@ enum llm_tensor { LLM_TENSOR_SSM_NORM, LLM_TENSOR_SSM_OUT, LLM_TENSOR_SSM_BETA_ALPHA, // qwen3next + LLM_TENSOR_SSM_ALPHA, // qwen3.5 // Kimi Linear KDA (using SSM_ prefix for consistency) LLM_TENSOR_SSM_CONV1D_Q, // kimi: Q conv1d weight LLM_TENSOR_SSM_CONV1D_K, // kimi: K conv1d weight LLM_TENSOR_SSM_CONV1D_V, // kimi: V conv1d weight LLM_TENSOR_SSM_F_A, // kimi: forget gate projection A LLM_TENSOR_SSM_F_B, // kimi: forget gate projection B - LLM_TENSOR_SSM_BETA, // kimi: beta mixing coefficient + LLM_TENSOR_SSM_BETA, // kimi: beta mixing coefficient and qwen3.5 LLM_TENSOR_SSM_G_A, // kimi: output gate projection A LLM_TENSOR_SSM_G_B, // kimi: output gate projection B LLM_TENSOR_TIME_MIX_W0, @@ -489,6 +498,7 @@ enum llm_tensor { LLM_TENSOR_ENC_OUTPUT_NORM, LLM_TENSOR_CLS, LLM_TENSOR_CLS_OUT, + LLM_TENSOR_CLS_NORM, LLM_TENSOR_CONV1D, LLM_TENSOR_CONVNEXT_DW, LLM_TENSOR_CONVNEXT_NORM, @@ -513,6 +523,10 @@ enum llm_tensor { LLM_TENSOR_VISEXP_FFN_GATE, LLM_TENSOR_VISEXP_FFN_DOWN, LLM_TENSOR_VISEXP_FFN_UP, + LLM_TENSOR_INDEXER_K_NORM, + LLM_TENSOR_INDEXER_PROJ, + LLM_TENSOR_INDEXER_ATTN_K, + LLM_TENSOR_INDEXER_ATTN_Q_B, LLM_TENSOR_NEXTN_EH_PROJ, LLM_TENSOR_NEXTN_EMBED_TOKENS, LLM_TENSOR_NEXTN_ENORM, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 67947894f0..99119242be 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -22,6 +22,8 @@ llama_context::llama_context( const llama_model & model, llama_context_params params) : model(model), + cvec(std::make_unique()), + loras(std::make_unique()), balloc(std::make_unique(model.hparams.n_pos_per_embd())) { // TODO warning when creating llama_context with awkward ctx size that is not a power of 2, // may need to be backend-dependent @@ -694,7 +696,7 @@ enum llama_pooling_type llama_context::pooling_type() const { float * llama_context::get_logits() { output_reorder(); - return logits; + return logits.data; } int64_t llama_context::output_resolve_row(int32_t i) const { @@ -727,36 +729,15 @@ int64_t llama_context::output_resolve_row(int32_t i) const { } float * llama_context::get_logits_ith(int32_t i) { - int64_t j = -1; - output_reorder(); try { - if (logits == nullptr) { + if (logits.data == nullptr) { throw std::runtime_error("no logits"); } - // TODO: use output_resolve_row() - if (i < 0) { - j = n_outputs + i; - if (j < 0) { - throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs)); - } - } else if ((size_t) i >= output_ids.size()) { - throw std::runtime_error(format("out of range [0, %zu)", output_ids.size())); - } else { - j = output_ids[i]; - } - - if (j < 0) { - throw std::runtime_error(format("batch.logits[%d] != true", i)); - } - if (j >= n_outputs) { - // This should not happen - throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); - } - - return logits + j*model.vocab.n_tokens(); + const int64_t j = output_resolve_row(i); + return logits.data + j*model.vocab.n_tokens(); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid logits id %d, reason: %s\n", __func__, i, err.what()); #ifndef NDEBUG @@ -770,45 +751,24 @@ float * llama_context::get_logits_ith(int32_t i) { float * llama_context::get_embeddings() { output_reorder(); - return embd; + return embd.data; } llama_token * llama_context::get_sampled_tokens() const{ - return sampling.sampled; + return sampling.sampled.data; } float * llama_context::get_embeddings_ith(int32_t i) { - int64_t j = -1; - output_reorder(); try { - if (embd == nullptr) { + if (embd.data == nullptr) { throw std::runtime_error("no embeddings"); } - // TODO: use output_resolve_row() - if (i < 0) { - j = n_outputs + i; - if (j < 0) { - throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs)); - } - } else if ((size_t) i >= output_ids.size()) { - throw std::runtime_error(format("out of range [0, %zu)", output_ids.size())); - } else { - j = output_ids[i]; - } - - if (j < 0) { - throw std::runtime_error(format("batch.logits[%d] != true", i)); - } - if (j >= n_outputs) { - // This should not happen - throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); - } - + const int64_t j = output_resolve_row(i); const uint32_t n_embd_out = model.hparams.n_embd_out(); - return embd + j*n_embd_out; + return embd.data + j*n_embd_out; } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what()); #ifndef NDEBUG @@ -848,14 +808,14 @@ int32_t llama_context::cpy_mtp_state(llama_context & ctx_mtp) { llama_token llama_context::get_sampled_token_ith(int32_t idx) { output_reorder(); - if (sampling.sampled == nullptr) { + if (!sampling.sampled.has_data()) { return LLAMA_TOKEN_NULL; } try { const int64_t row = output_resolve_row(idx); - GGML_ASSERT(row < (int64_t) sampling.sampled_size); - return sampling.sampled[row]; + GGML_ASSERT(row < (int64_t) sampling.sampled.size); + return sampling.sampled.data[row]; } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what()); return LLAMA_TOKEN_NULL; @@ -865,7 +825,7 @@ llama_token llama_context::get_sampled_token_ith(int32_t idx) { float * llama_context::get_sampled_probs_ith(int32_t idx) { output_reorder(); - if (sampling.probs == nullptr) { + if (!sampling.probs.has_data()) { return nullptr; } @@ -874,7 +834,7 @@ float * llama_context::get_sampled_probs_ith(int32_t idx) { if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) { return nullptr; } - return sampling.probs + row*model.vocab.n_tokens(); + return sampling.probs.data + row*model.vocab.n_tokens(); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what()); return nullptr; @@ -884,7 +844,7 @@ float * llama_context::get_sampled_probs_ith(int32_t idx) { float * llama_context::get_sampled_logits_ith(int32_t idx) { output_reorder(); - if (sampling.logits == nullptr) { + if (!sampling.logits.has_data()) { return nullptr; } @@ -893,7 +853,7 @@ float * llama_context::get_sampled_logits_ith(int32_t idx) { if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) { return nullptr; } - return sampling.logits + row*model.vocab.n_tokens(); + return sampling.logits.data + row*model.vocab.n_tokens(); } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what()); return nullptr; @@ -905,13 +865,14 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) { try { const int64_t row = output_resolve_row(idx); - if (sampling.candidates != nullptr && + if (sampling.candidates.has_data() && (size_t) row < sampling.candidates_count.size() && sampling.candidates_count[row] > 0) { - return sampling.candidates + row*model.vocab.n_tokens(); + return sampling.candidates.data + row*model.vocab.n_tokens(); } } catch (const std::exception & err) { // fallback to full vocab list + GGML_UNUSED(err); } return sampling.token_ids_full_vocab.data(); @@ -920,7 +881,7 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) { size_t llama_context::get_sampled_candidates_count(int32_t idx) { output_reorder(); - if (sampling.candidates == nullptr) { + if (!sampling.candidates.has_data()) { return 0; } @@ -939,7 +900,7 @@ size_t llama_context::get_sampled_candidates_count(int32_t idx) { size_t llama_context::get_sampled_logits_count(int32_t idx) { output_reorder(); - if (sampling.logits == nullptr) { + if (!sampling.logits.has_data()) { return model.vocab.n_tokens(); } @@ -958,7 +919,7 @@ size_t llama_context::get_sampled_logits_count(int32_t idx) { size_t llama_context::get_sampled_probs_count(int32_t idx) { output_reorder(); - if (sampling.probs == nullptr) { + if (!sampling.probs.has_data()) { return 0; } @@ -1091,51 +1052,43 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { return true; } -void llama_context::set_adapter_lora( - llama_adapter_lora * adapter, - float scale) { - LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale); +void llama_context::set_adapters_lora(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) { + LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters); - if (auto it = loras.find(adapter); it != loras.end()) { - if (it->second == scale) { - return; - } - } - - loras[adapter] = scale; - - sched_need_reserve = true; -} - -bool llama_context::rm_adapter_lora( - llama_adapter_lora * adapter) { - LLAMA_LOG_DEBUG("%s: adapter = %p\n", __func__, (void *) adapter); - - auto it = loras.find(adapter); - if (it != loras.end()) { - loras.erase(it); - - sched_need_reserve = true; - - return true; - } - - return false; -} - -void llama_context::clear_adapter_lora() { - LLAMA_LOG_DEBUG("%s: call\n", __func__); - - if (loras.empty()) { + if (adapters_lora_are_same(adapters, n_adapters, scales)) { return; } - loras.clear(); + loras.reset(new llama_adapter_loras()); + + for (size_t i = 0; i < n_adapters; i ++) { + if (scales[i] != 0.0f) { + loras->insert({adapters[i], scales[i]}); + } + } sched_need_reserve = true; } -bool llama_context::apply_adapter_cvec( +bool llama_context::adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) { + LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters); + + if (n_adapters != loras->size()) { + return false; + } + + for (size_t i = 0; i < n_adapters; i ++) { + auto it = loras->find(adapters[i]); + + if (it == loras->end() || it->second != scales[i]) { + return false; + } + } + + return true; +} + +bool llama_context::set_adapter_cvec( const float * data, size_t len, int32_t n_embd, @@ -1145,7 +1098,7 @@ bool llama_context::apply_adapter_cvec( // TODO: should we reserve? - return cvec.apply(model, data, len, n_embd, il_start, il_end); + return cvec->apply(model, data, len, n_embd, il_start, il_end); } llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { @@ -1288,16 +1241,16 @@ int llama_context::encode(const llama_batch & batch_inp) { auto * t_embd = res->get_embd_pooled() ? res->get_embd_pooled() : res->get_embd(); // extract logits - if (logits && t_logits) { + if (logits.data && t_logits) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); GGML_ASSERT(backend_res != nullptr); - GGML_ASSERT(logits != nullptr); + GGML_ASSERT(logits.data != nullptr); - ggml_backend_tensor_get_async(backend_res, t_logits, logits, 0, n_tokens*n_vocab*sizeof(float)); + ggml_backend_tensor_get_async(backend_res, t_logits, logits.data, 0, n_tokens*n_vocab*sizeof(float)); } // extract embeddings - if (embd && t_embd) { + if (embd.data && t_embd) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); GGML_ASSERT(backend_embd != nullptr); @@ -1305,11 +1258,11 @@ int llama_context::encode(const llama_batch & batch_inp) { case LLAMA_POOLING_TYPE_NONE: { // extract token embeddings - GGML_ASSERT(embd != nullptr); + GGML_ASSERT(embd.data != nullptr); const uint32_t n_embd_out = hparams.n_embd_out(); - GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd_size); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd_out*sizeof(float)); + GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd.size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd.data, 0, n_tokens*n_embd_out*sizeof(float)); } break; case LLAMA_POOLING_TYPE_MEAN: case LLAMA_POOLING_TYPE_CLS: @@ -1357,7 +1310,7 @@ int llama_context::encode(const llama_batch & batch_inp) { cross.n_embd = t_embd->ne[0]; cross.n_enc = t_embd->ne[1]; cross.v_embd.resize(cross.n_embd*cross.n_enc); - memcpy(cross.v_embd.data(), embd, ggml_nbytes(t_embd)); + memcpy(cross.v_embd.data(), embd.data, ggml_nbytes(t_embd)); const auto & batch = balloc->get_batch(); @@ -1397,11 +1350,10 @@ static std::map build_seq_to_output_row(const llama_ubat static void copy_tensor_async_ints( const std::map & tensor_map, - llama_token * sampled, - size_t sampled_size, + const buffer_view & sampled, const std::map & seq_to_row, ggml_backend_sched_t sched) { - if (sampled == nullptr) { + if (!sampled.has_data()) { return; } @@ -1412,23 +1364,23 @@ static void copy_tensor_async_ints( } const uint32_t row = it->second; - GGML_ASSERT(row < sampled_size); + GGML_ASSERT(row < sampled.size); GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy"); ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); - ggml_backend_tensor_get_async(backend, tensor, sampled + row, 0, sizeof(sampled[row])); + ggml_backend_tensor_get_async(backend, tensor, sampled.data + row, 0, sizeof(sampled.data[row])); } } static void copy_tensor_async_floats( const std::map & tensor_map, - float * dst, + const buffer_view & dst, size_t stride, std::vector & counts, const std::map & seq_to_row, ggml_backend_sched_t sched) { - if (dst == nullptr) { + if (!dst.has_data()) { return; } @@ -1444,7 +1396,7 @@ static void copy_tensor_async_floats( GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy"); ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); - float * row_ptr = dst + (size_t) row * stride; + float * row_ptr = dst.data + (size_t) row * stride; ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); // Update the actual number of logits/probabilities that were written for this row. @@ -1454,12 +1406,12 @@ static void copy_tensor_async_floats( static void copy_tensor_async_candidates( const std::map & tensor_map, - llama_token * dst, + const buffer_view & dst, size_t stride, std::vector & counts, const std::map & seq_to_row, ggml_backend_sched_t sched) { - if (dst == nullptr) { + if (!dst.has_data()) { return; } @@ -1475,7 +1427,7 @@ static void copy_tensor_async_candidates( GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy"); ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); - llama_token * row_ptr = dst + (size_t) row * stride; + llama_token * row_ptr = dst.data + (size_t) row * stride; ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); // Update the actual number of candidates that were written. @@ -1726,22 +1678,22 @@ int llama_context::decode(const llama_batch & batch_inp) { } // extract logits - if (logits && t_logits && n_outputs > 0 && needs_raw_logits(ubatch, sampling.samplers)) { + if (logits.data && t_logits && n_outputs > 0 && needs_raw_logits(ubatch, sampling.samplers)) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); GGML_ASSERT(backend_res != nullptr); - GGML_ASSERT(logits != nullptr); + GGML_ASSERT(logits.data != nullptr); - float * logits_out = logits + n_outputs_prev*n_vocab; + float * logits_out = logits.data + n_outputs_prev*n_vocab; if (n_outputs) { GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits_size); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_vocab <= (int64_t) logits.size); ggml_backend_tensor_get_async(backend_res, t_logits, logits_out, 0, n_outputs*n_vocab*sizeof(float)); } } // extract embeddings - if (embd && t_embd && n_outputs > 0) { + if (embd.data && t_embd && n_outputs > 0) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); GGML_ASSERT(backend_embd != nullptr); @@ -1757,13 +1709,13 @@ int llama_context::decode(const llama_batch & batch_inp) { case LLAMA_POOLING_TYPE_NONE: { // extract token embeddings - GGML_ASSERT(embd != nullptr); + GGML_ASSERT(embd.data != nullptr); const uint32_t n_embd_out = hparams.n_embd_out(); - float * embd_out = embd + n_outputs_prev*n_embd_out; + float * embd_out = embd.data + n_outputs_prev*n_embd_out; if (n_outputs) { GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd_size); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd.size); ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd_out*sizeof(float)); } } break; @@ -1810,7 +1762,7 @@ int llama_context::decode(const llama_batch & batch_inp) { const auto stride = n_vocab; // async copy the sampling data from the backend to the host - copy_tensor_async_ints(res->t_sampled, sampling.sampled, sampling.sampled_size, seq_to_output_row, sched.get()); + copy_tensor_async_ints(res->t_sampled, sampling.sampled, seq_to_output_row, sched.get()); copy_tensor_async_floats (res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, seq_to_output_row, sched.get()); copy_tensor_async_floats (res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, seq_to_output_row, sched.get()); @@ -1881,7 +1833,6 @@ int llama_context::decode(const llama_batch & batch_inp) { // uint32_t llama_context::output_reserve(int32_t n_outputs) { - const auto & hparams = model.hparams; const auto & vocab = model.vocab; @@ -1904,19 +1855,14 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { size_t backend_float_count = 0; size_t backend_token_count = 0; - logits_size = has_logits ? n_vocab*n_outputs_max : 0; - embd_size = has_embd ? n_embd_out*n_outputs_max : 0; + logits.size = has_logits ? n_vocab*n_outputs_max : 0; + embd.size = has_embd ? n_embd_out*n_outputs_max : 0; // Allocate backend sampling output buffers if there are backend samplers configured. const bool has_sampling = !sampling.samplers.empty(); if (has_sampling) { - sampling.logits_size = n_vocab*n_outputs_max; - sampling.probs_size = n_vocab*n_outputs_max; - sampling.sampled_size = n_outputs_max; - sampling.candidates_size = n_vocab*n_outputs_max; - - backend_float_count = sampling.logits_size + sampling.probs_size; - backend_token_count = sampling.sampled_size + sampling.candidates_size; + backend_float_count = 2 * n_vocab * n_outputs_max; // logits + probs + backend_token_count = (1 + n_vocab) * n_outputs_max; // sampled + candidates } if (output_ids.empty()) { @@ -1926,7 +1872,7 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; const size_t new_size = - (logits_size + embd_size + backend_float_count) * sizeof(float) + + (logits.size + embd.size + backend_float_count) * sizeof(float) + ( backend_token_count) * sizeof(llama_token); // alloc only when more than the current capacity is required @@ -1941,8 +1887,8 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { // TODO: not needed? buf_output = nullptr; - logits = nullptr; - embd = nullptr; + logits.data = nullptr; + embd.data = nullptr; } auto * buft = ggml_backend_cpu_buffer_type(); @@ -1961,35 +1907,27 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get()); - logits = nullptr; - embd = nullptr; - size_t offset = 0; uint8_t * base = (uint8_t *) output_base; - logits = has_logits ? output_base : nullptr; - offset += logits_size * sizeof(float); + logits = has_logits ? buffer_view{output_base, logits.size} : buffer_view{nullptr, 0}; + offset += logits.size * sizeof(float); - embd = has_embd ? (float *) (base + offset) : nullptr; - offset += embd_size * sizeof(float); - - sampling.logits = nullptr; - sampling.probs = nullptr; - sampling.sampled = nullptr; - sampling.candidates = nullptr; + embd = has_embd ? buffer_view{(float *) (base + offset), embd.size} : buffer_view{nullptr, 0}; + offset += embd.size * sizeof(float); if (has_sampling) { - sampling.logits = (float *) (base + offset); - offset += sampling.logits_size * sizeof(float); + sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; + offset += sampling.logits.size * sizeof(float); - sampling.probs = (float *) (base + offset); - offset += sampling.probs_size * sizeof(float); + sampling.probs = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; + offset += sampling.probs.size * sizeof(float); - sampling.sampled = (llama_token *) (base + offset); - offset += sampling.sampled_size * sizeof(llama_token); + sampling.sampled = {(llama_token *) (base + offset), (size_t)n_outputs_max}; + offset += sampling.sampled.size * sizeof(llama_token); - sampling.candidates = (llama_token *) (base + offset); - offset += sampling.candidates_size * sizeof(llama_token); + sampling.candidates = {(llama_token *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; + offset += sampling.candidates.size * sizeof(llama_token); // The count vectors keep track of the actual number of logits/probs/candidates // copied from the backend for each output row. @@ -2002,7 +1940,16 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0); std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0); - std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL); + std::fill_n(sampling.sampled.data, sampling.sampled.size, LLAMA_TOKEN_NULL); + } else { + sampling.logits = {nullptr, 0}; + sampling.probs = {nullptr, 0}; + sampling.sampled = {nullptr, 0}; + sampling.candidates = {nullptr, 0}; + + sampling.logits_count.clear(); + sampling.probs_count.clear(); + sampling.candidates_count.clear(); } // set all ids as invalid (negative) @@ -2021,49 +1968,42 @@ void llama_context::output_reorder() { const uint64_t i0 = output_swaps[s].i0; const uint64_t i1 = output_swaps[s].i1; - if (logits_size > 0) { + if (logits.size > 0) { for (uint64_t k = 0; k < n_vocab; k++) { - std::swap(logits[i0*n_vocab + k], logits[i1*n_vocab + k]); + std::swap(logits.data[i0*n_vocab + k], logits.data[i1*n_vocab + k]); } } - if (embd_size > 0) { + if (embd.size > 0) { for (uint64_t k = 0; k < n_embd; k++) { - std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]); + std::swap(embd.data[i0*n_embd + k], embd.data[i1*n_embd + k]); } } - if (sampling.logits && sampling.logits_size > 0) { + if (!sampling.samplers.empty()) { + assert(sampling.logits.size > 0); + assert(sampling.probs.size > 0); + assert(sampling.candidates.size > 0); + assert(sampling.sampled.size > 0); + assert(sampling.logits_count.size() > 0); + assert(sampling.probs_count.size() > 0); + assert(sampling.candidates_count.size() > 0); + for (uint64_t k = 0; k < n_vocab; ++k) { - std::swap(sampling.logits[i0*n_vocab + k], sampling.logits[i1*n_vocab + k]); + std::swap(sampling.logits.data[i0*n_vocab + k], sampling.logits.data[i1*n_vocab + k]); } - } - if (sampling.probs && sampling.probs_size > 0) { for (uint64_t k = 0; k < n_vocab; ++k) { - std::swap(sampling.probs[i0*n_vocab + k], sampling.probs[i1*n_vocab + k]); + std::swap(sampling.probs.data[i0*n_vocab + k], sampling.probs.data[i1*n_vocab + k]); } - } - if (sampling.candidates && sampling.candidates_size > 0) { for (uint64_t k = 0; k < n_vocab; ++k) { - std::swap(sampling.candidates[i0*n_vocab + k], sampling.candidates[i1*n_vocab + k]); + std::swap(sampling.candidates.data[i0*n_vocab + k], sampling.candidates.data[i1*n_vocab + k]); } - } - if (sampling.sampled && sampling.sampled_size > 0) { - std::swap(sampling.sampled[i0], sampling.sampled[i1]); - } - - if (!sampling.logits_count.empty()) { - std::swap(sampling.logits_count[i0], sampling.logits_count[i1]); - } - - if (!sampling.probs_count.empty()) { - std::swap(sampling.probs_count[i0], sampling.probs_count[i1]); - } - - if (!sampling.candidates_count.empty()) { + std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]); + std::swap(sampling.logits_count[i0], sampling.logits_count[i1]); + std::swap(sampling.probs_count[i0], sampling.probs_count[i1]); std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]); } } @@ -2076,7 +2016,7 @@ void llama_context::output_reorder() { // uint32_t llama_context::graph_max_nodes(uint32_t n_tokens) const { - if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_KIMI_LINEAR) { + if (model.arch == LLM_ARCH_QWEN3NEXT || model.arch == LLM_ARCH_KIMI_LINEAR || model.arch == LLM_ARCH_QWEN35 || model.arch == LLM_ARCH_QWEN35MOE) { return std::max(n_tokens * 40, 32u * model.n_tensors()); } uint32_t res = std::max(1024u, 8u*model.n_tensors()); @@ -2164,8 +2104,8 @@ llm_graph_params llama_context::graph_params( /*.gtype =*/ gtype, /*.sched =*/ sched.get(), /*.backend_cpu =*/ backend_cpu, - /*.cvec =*/ &cvec, - /*.loras =*/ &loras, + /*.cvec =*/ cvec.get(), + /*.loras =*/ loras.get(), /*.mctx =*/ mctx, /*.cross =*/ &cross, /*.samplers =*/ sampling.samplers, @@ -2596,12 +2536,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { { LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__); - const uint64_t logits_size = std::min((uint64_t) this->logits_size, (uint64_t) n_outputs * model.vocab.n_tokens()); + const uint64_t logits_size = std::min((uint64_t) this->logits.size, (uint64_t) n_outputs * model.vocab.n_tokens()); io.write(&logits_size, sizeof(logits_size)); if (logits_size) { - io.write(logits, logits_size * sizeof(float)); + io.write(logits.data, logits_size * sizeof(float)); } } @@ -2609,12 +2549,12 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { { LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__); - const uint64_t embd_size = std::min((uint64_t) this->embd_size, (uint64_t) n_outputs * model.hparams.n_embd); + const uint64_t embd_size = std::min((uint64_t) this->embd.size, (uint64_t) n_outputs * model.hparams.n_embd); io.write(&embd_size, sizeof(embd_size)); if (embd_size) { - io.write(embd, embd_size * sizeof(float)); + io.write(embd.data, embd_size * sizeof(float)); } } @@ -2682,12 +2622,12 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { uint64_t logits_size; io.read_to(&logits_size, sizeof(logits_size)); - if (this->logits_size < logits_size) { + if (this->logits.size < logits_size) { throw std::runtime_error("logits buffer too small"); } if (logits_size) { - io.read_to(this->logits, logits_size * sizeof(float)); + io.read_to(this->logits.data, logits_size * sizeof(float)); } } @@ -2698,12 +2638,12 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { uint64_t embd_size; io.read_to(&embd_size, sizeof(embd_size)); - if (this->embd_size < embd_size) { + if (this->embd.size < embd_size) { throw std::runtime_error("embeddings buffer too small"); } if (embd_size) { - io.read_to(this->embd, embd_size * sizeof(float)); + io.read_to(this->embd.data, embd_size * sizeof(float)); } } @@ -2842,6 +2782,7 @@ void llama_context::opt_init(struct llama_model * model, struct llama_opt_params llama_set_param(model->cls_b, param_filter, param_filter_ud); llama_set_param(model->cls_out, param_filter, param_filter_ud); llama_set_param(model->cls_out_b, param_filter, param_filter_ud); + llama_set_param(model->cls_norm, param_filter, param_filter_ud); for (struct llama_layer & layer : model->layers) { for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) { @@ -3288,35 +3229,28 @@ uint32_t llama_get_sampled_probs_count_ith(llama_context * ctx, int32_t i) { // llama adapter API -int32_t llama_set_adapter_lora( +int32_t llama_set_adapters_lora( llama_context * ctx, - llama_adapter_lora * adapter, - float scale) { - ctx->set_adapter_lora(adapter, scale); + llama_adapter_lora ** adapters, + size_t n_adapters, + float * scales) { + if (adapters == nullptr || scales == nullptr) { + GGML_ASSERT(n_adapters == 0 && "invalid llama_set_adapters_lora call"); + } + + ctx->set_adapters_lora(adapters, n_adapters, scales); return 0; } -int32_t llama_rm_adapter_lora( - llama_context * ctx, - llama_adapter_lora * adapter) { - bool res = ctx->rm_adapter_lora(adapter); - - return res ? 0 : -1; -} - -void llama_clear_adapter_lora(llama_context * ctx) { - ctx->clear_adapter_lora(); -} - -int32_t llama_apply_adapter_cvec( +int32_t llama_set_adapter_cvec( llama_context * ctx, - const float * data, - size_t len, - int32_t n_embd, - int32_t il_start, - int32_t il_end) { - bool res = ctx->apply_adapter_cvec(data, len, n_embd, il_start, il_end); + const float * data, + size_t len, + int32_t n_embd, + int32_t il_start, + int32_t il_end) { + bool res = ctx->set_adapter_cvec(data, len, n_embd, il_start, il_end); return res ? 0 : -1; } diff --git a/src/llama-context.h b/src/llama-context.h index 72d305d11c..61351b780b 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -4,6 +4,7 @@ #include "llama-cparams.h" #include "llama-graph.h" #include "llama-adapter.h" +#include "llama-impl.h" #include "ggml-cpp.h" #include "ggml-opt.h" @@ -106,16 +107,11 @@ struct llama_context { void set_causal_attn(bool value); void set_warmup(bool value); - void set_adapter_lora( - llama_adapter_lora * adapter, - float scale); + void set_adapters_lora(llama_adapter_lora ** adapters, size_t n_adapters, float * scales); - bool rm_adapter_lora( - llama_adapter_lora * adapter); + bool adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales); - void clear_adapter_lora(); - - bool apply_adapter_cvec( + bool set_adapter_cvec( const float * data, size_t len, int32_t n_embd, @@ -262,43 +258,36 @@ private: const llama_model & model; - llama_cparams cparams; - llama_adapter_cvec cvec; - llama_adapter_loras loras; + llama_cparams cparams; + + llama_adapter_cvec_ptr cvec; + llama_adapter_loras_ptr loras; llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably std::unique_ptr memory; // decode output (2-dimensional array: [n_outputs][n_vocab]) - size_t logits_size = 0; // capacity (of floats) for logits - float * logits = nullptr; + buffer_view logits = {nullptr, 0}; // embeddings output (2-dimensional array: [n_outputs][n_embd]) // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE - size_t embd_size = 0; // capacity (of floats) for embeddings - float * embd = nullptr; + buffer_view embd = {nullptr, 0}; - // TODO: simplify struct sampling_info { + // !samplers.empty() to check if any samplers are active std::map samplers; - float * logits = nullptr; - size_t logits_size = 0; - - llama_token * sampled = nullptr; - size_t sampled_size = 0; - - float * probs = nullptr; - size_t probs_size = 0; - - llama_token * candidates = nullptr; - size_t candidates_size = 0; + buffer_view logits = {nullptr, 0}; + buffer_view sampled = {nullptr, 0}; + buffer_view probs = {nullptr, 0}; + buffer_view candidates = {nullptr, 0}; std::vector logits_count; std::vector probs_count; std::vector candidates_count; + // optimization std::vector token_ids_full_vocab; }; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index e0bb206ba7..d764b13bd0 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -17,6 +17,41 @@ #include #include +// dedup helpers + +static ggml_tensor * build_kq_mask( + ggml_context * ctx, + const llama_kv_cache_context * mctx, + const llama_ubatch & ubatch, + const llama_cparams & cparams) { + const auto n_kv = mctx->get_n_kv(); + const auto n_tokens = ubatch.n_tokens; + const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; + + return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); +} + +static bool can_reuse_kq_mask( + ggml_tensor * kq_mask, + const llama_kv_cache_context * mctx, + const llama_ubatch & ubatch, + const llama_cparams & cparams) { + const auto n_kv = mctx->get_n_kv(); + const auto n_tokens = ubatch.n_tokens; + const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; + + bool res = true; + + res &= (kq_mask->ne[0] == n_kv); + res &= (kq_mask->ne[1] == n_tokens/n_stream); + res &= (kq_mask->ne[2] == 1); + res &= (kq_mask->ne[3] == n_stream); + + return res; +} + +// impl + void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { if (ubatch->token) { const int64_t n_tokens = ubatch->n_tokens; @@ -150,7 +185,10 @@ bool llm_graph_input_out_ids::can_reuse(const llm_graph_params & params) { } void llm_graph_input_mean::set_input(const llama_ubatch * ubatch) { - if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { + if (cparams.embeddings && + (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN || + cparams.pooling_type == LLAMA_POOLING_TYPE_RANK )) { + const int64_t n_tokens = ubatch->n_tokens; const int64_t n_seq_tokens = ubatch->n_seq_tokens; const int64_t n_seqs_unq = ubatch->n_seqs_unq; @@ -414,8 +452,7 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) { res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - res &= self_kq_mask->ne[0] == mctx->get_n_kv(); - res &= self_kq_mask->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams); return res; } @@ -435,8 +472,7 @@ bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) { res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; - res &= self_kq_mask->ne[0] == mctx->get_n_kv(); - res &= self_kq_mask->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams); return res; } @@ -466,11 +502,8 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv(); - res &= self_kq_mask->ne[1] == params.ubatch.n_tokens; - - res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv(); - res &= self_kq_mask_swa->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); + res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); return res; } @@ -532,8 +565,7 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv(); - res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams); res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); @@ -576,8 +608,7 @@ bool llm_graph_input_mem_hybrid_k::can_reuse(const llm_graph_params & params) { res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; - res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv(); - res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams); res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); @@ -636,8 +667,7 @@ bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - res &= inp_attn->self_kq_mask->ne[0] == attn_ctx->get_base()->get_n_kv(); - res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams); } // swa tensors may not be allocated if there are no SWA attention layers @@ -645,8 +675,7 @@ bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; //res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - res &= inp_attn->self_kq_mask_swa->ne[0] == attn_ctx->get_swa()->get_n_kv(); - res &= inp_attn->self_kq_mask_swa->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams); } res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); @@ -1110,8 +1139,8 @@ ggml_tensor * llm_graph_context::build_ffn( if (down) { cur = build_lora_mm(down, cur); - if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { - // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators + if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) { + // GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators ggml_mul_mat_set_prec(cur, GGML_PREC_F32); } } @@ -1720,7 +1749,8 @@ ggml_tensor * llm_graph_context::build_attn_mha( ggml_tensor * cur; - if (cparams.flash_attn && kq_b == nullptr) { + const bool use_flash_attn = cparams.flash_attn && kq_b == nullptr; + if (use_flash_attn) { GGML_ASSERT(kq_b == nullptr && "Flash attention does not support KQ bias yet"); if (v_trans) { @@ -1916,14 +1946,11 @@ static std::unique_ptr build_attn_inp_kv_impl( { GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA"); - const auto n_kv = mctx_cur->get_n_kv(); - const auto n_tokens = ubatch.n_tokens; - const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; - inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams); + ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1983,8 +2010,8 @@ ggml_tensor * llm_graph_context::build_attn( if (wo) { cur = build_lora_mm(wo, cur); - if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE) { - // GLM4 and GLM4_MOE seem to have numerical issues with half-precision accumulators + if (arch == LLM_ARCH_GLM4 || arch == LLM_ARCH_GLM4_MOE || arch == LLM_ARCH_JAIS2) { + // GLM4, GLM4_MOE, and JAIS2 seem to have numerical issues with half-precision accumulators ggml_mul_mat_set_prec(cur, GGML_PREC_F32); } } @@ -2008,13 +2035,9 @@ static std::unique_ptr build_attn_inp_k_impl( { GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA"); - const auto n_kv = mctx_cur->get_n_kv(); - const auto n_tokens = ubatch.n_tokens; - const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; - inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -2213,15 +2236,11 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const auto inp = std::make_unique(hparams, cparams, mctx_cur); - const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; - { - const auto n_kv = mctx_cur->get_base()->get_n_kv(); - inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams); ggml_set_input(inp->self_kq_mask); ggml_set_name(inp->self_kq_mask, "self_kq_mask"); @@ -2232,12 +2251,10 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const { GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA"); - const auto n_kv = mctx_cur->get_swa()->get_n_kv(); - inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + inp->self_kq_mask_swa = build_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams); ggml_set_input(inp->self_kq_mask_swa); ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa"); @@ -2399,27 +2416,21 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() auto inp_attn = std::make_unique(hparams, cparams, attn_ctx); - const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; - { - const auto n_kv = attn_ctx->get_base()->get_n_kv(); - inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch); inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch); - inp_attn->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + inp_attn->self_kq_mask = build_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams); ggml_set_input(inp_attn->self_kq_mask); inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask; } { - const auto n_kv = attn_ctx->get_swa()->get_n_kv(); - inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch); inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch); - inp_attn->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + inp_attn->self_kq_mask_swa = build_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams); ggml_set_input(inp_attn->self_kq_mask_swa); inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa; @@ -2432,8 +2443,9 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() void llm_graph_context::build_dense_out( ggml_tensor * dense_2, + ggml_tensor * dense_2_b, ggml_tensor * dense_3) const { - if (!cparams.embeddings || !(dense_2 || dense_3)) { + if (!cparams.embeddings || !(dense_2 || dense_2_b || dense_3)) { return; } ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd; @@ -2442,6 +2454,9 @@ void llm_graph_context::build_dense_out( if (dense_2) { cur = ggml_mul_mat(ctx0, dense_2, cur); } + if (dense_2_b) { + cur = ggml_add(ctx0, cur, dense_2_b); + } if (dense_3) { cur = ggml_mul_mat(ctx0, dense_3, cur); } @@ -2455,7 +2470,8 @@ void llm_graph_context::build_pooling( ggml_tensor * cls, ggml_tensor * cls_b, ggml_tensor * cls_out, - ggml_tensor * cls_out_b) const { + ggml_tensor * cls_out_b, + ggml_tensor * cls_norm) const { if (!cparams.embeddings) { return; } @@ -2494,8 +2510,15 @@ void llm_graph_context::build_pooling( } break; case LLAMA_POOLING_TYPE_RANK: { - ggml_tensor * inp_cls = build_inp_cls(); - cur = ggml_get_rows(ctx0, inp, inp_cls); + if (arch == LLM_ARCH_MODERN_BERT) { + // modern bert gte reranker builds mean first then applies prediction head and classifier + // https://github.com/huggingface/transformers/blob/main/src/transformers/models/modernbert/modular_modernbert.py#L1404-1411 + ggml_tensor * inp_mean = build_inp_mean(); + cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean); + } else { + ggml_tensor * inp_cls = build_inp_cls(); + cur = ggml_get_rows(ctx0, inp, inp_cls); + } // classification head // https://github.com/huggingface/transformers/blob/5af7d41e49bbfc8319f462eb45253dcb3863dfb7/src/transformers/models/roberta/modeling_roberta.py#L1566 @@ -2504,7 +2527,15 @@ void llm_graph_context::build_pooling( if (cls_b) { cur = ggml_add(ctx0, cur, cls_b); } - cur = ggml_tanh(ctx0, cur); + if (arch == LLM_ARCH_MODERN_BERT) { + cur = ggml_gelu(ctx0, cur); + } else { + cur = ggml_tanh(ctx0, cur); + } + if (cls_norm) { + // head norm + cur = build_norm(cur, cls_norm, NULL, LLM_NORM, -1); + } } // some models don't have `cls_out`, for example: https://huggingface.co/jinaai/jina-reranker-v1-tiny-en diff --git a/src/llama-graph.h b/src/llama-graph.h index 6e0a817223..fb59246d21 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -1020,7 +1020,8 @@ struct llm_graph_context { ggml_tensor * cls, ggml_tensor * cls_b, ggml_tensor * cls_out, - ggml_tensor * cls_out_b) const; + ggml_tensor * cls_out_b, + ggml_tensor * cls_norm) const; // // sampling (backend sampling) @@ -1034,6 +1035,7 @@ struct llm_graph_context { void build_dense_out( ggml_tensor * dense_2, + ggml_tensor * dense_2_b, ggml_tensor * dense_3) const; }; diff --git a/src/llama-hparams.h b/src/llama-hparams.h index bf291469c9..aa1b5981c7 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -42,7 +42,6 @@ struct llama_hparams { uint32_t n_ctx_train; // context size the model was trained on uint32_t n_embd; - uint32_t n_embd_features = 0; uint32_t n_layer; int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache uint32_t n_rot; @@ -194,6 +193,11 @@ struct llama_hparams { std::array xielu_beta; std::array xielu_eps; + // DSA (deepseek sparse attention) + uint32_t indexer_n_head = 0; + uint32_t indexer_head_size = 0; + uint32_t indexer_top_k = 0; + // qwen3vl deepstack uint32_t n_deepstack_layers = 0; diff --git a/src/llama-impl.h b/src/llama-impl.h index c3391e79f5..dfd9fee9f4 100644 --- a/src/llama-impl.h +++ b/src/llama-impl.h @@ -49,6 +49,16 @@ struct time_meas { int64_t & t_acc; }; +template +struct buffer_view { + T * data; + size_t size = 0; + + bool has_data() const { + return data && size > 0; + } +}; + void replace_all(std::string & s, const std::string & search, const std::string & replace); // TODO: rename to llama_format ? diff --git a/src/llama-mmap.cpp b/src/llama-mmap.cpp index 0261e4c72c..c03228e9ce 100644 --- a/src/llama-mmap.cpp +++ b/src/llama-mmap.cpp @@ -504,6 +504,8 @@ struct llama_mmap::impl { } } #elif defined(_WIN32) + HANDLE hMapping = nullptr; + impl(struct llama_file * file, size_t prefetch, bool numa) { GGML_UNUSED(numa); @@ -511,7 +513,7 @@ struct llama_mmap::impl { HANDLE hFile = (HANDLE) _get_osfhandle(file->file_id()); - HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL); + hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL); if (hMapping == NULL) { DWORD error = GetLastError(); @@ -520,9 +522,9 @@ struct llama_mmap::impl { addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0); DWORD error = GetLastError(); - CloseHandle(hMapping); if (addr == NULL) { + CloseHandle(hMapping); throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str())); } @@ -554,9 +556,17 @@ struct llama_mmap::impl { } ~impl() { - if (!UnmapViewOfFile(addr)) { - LLAMA_LOG_WARN("warning: UnmapViewOfFile failed: %s\n", - llama_format_win_err(GetLastError()).c_str()); + if (hMapping) { + if (addr) { + if (!UnmapViewOfFile(addr)) { + LLAMA_LOG_WARN("warning: UnmapViewOfFile failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } + } + if (!CloseHandle(hMapping)) { + LLAMA_LOG_WARN("warning: CloseHandle failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } } } #else diff --git a/src/llama-model-saver.cpp b/src/llama-model-saver.cpp index 36e353074e..676efeda70 100644 --- a/src/llama-model-saver.cpp +++ b/src/llama-model-saver.cpp @@ -271,6 +271,7 @@ void llama_model_saver::add_tensors_from_model() { add_tensor(model.cls_b); add_tensor(model.cls_out); add_tensor(model.cls_out_b); + add_tensor(model.cls_norm); for (const struct llama_layer & layer : model.layers) { for (size_t i = 0; i < sizeof(layer)/sizeof(struct ggml_tensor *); ++i) { diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 27322c1a71..fa758609a7 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -125,6 +125,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; case LLM_TYPE_31B_A3_5B: return "31B.A3.5B"; + case LLM_TYPE_35B_A3B: return "35B.A3B"; case LLM_TYPE_48B_A3B: return "48B.A3B"; case LLM_TYPE_80B_A3B: return "80B.A3B"; case LLM_TYPE_100B_A6B: return "100B.A6B"; @@ -136,6 +137,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_300B_A47B: return "300B.A47B"; case LLM_TYPE_310B_A15B: return "310B.A15B"; case LLM_TYPE_355B_A32B: return "355B.A32B"; + case LLM_TYPE_744B_A40B: return "744B.A40B"; case LLM_TYPE_E2B: return "E2B"; case LLM_TYPE_E4B: return "E4B"; default: return "?B"; @@ -522,7 +524,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_EXPERT_GROUP_USED_COUNT, hparams.n_group_used, false); if (arch == LLM_ARCH_WAVTOKENIZER_DEC) { - ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd_features); + ml.get_key(LLM_KV_FEATURES_LENGTH, hparams.n_embd); + ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd_out_impl); ml.get_key(LLM_KV_POSNET_EMBEDDING_LENGTH, hparams.posnet.n_embd); ml.get_key(LLM_KV_POSNET_BLOCK_COUNT, hparams.posnet.n_layer); @@ -905,7 +908,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); - hparams.set_swa_pattern(swa_period); + hparams.set_swa_pattern(swa_period, true); } else { hparams.swa_type = LLAMA_SWA_TYPE_NONE; } @@ -1781,7 +1784,15 @@ void llama_model::load_hparams(llama_model_loader & ml) { { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + // NextN/MTP parameters (GLM-OCR) + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + + // TODO: when MTP is implemented, this should probably be updated if needed + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + switch (hparams.n_layer) { + case 17: type = LLM_TYPE_1B; break; // GLM-OCR case 40: type = LLM_TYPE_9B; break; case 61: type = LLM_TYPE_32B; break; default: type = LLM_TYPE_UNKNOWN; @@ -1820,6 +1831,50 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_GLM_DSA: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + // MoE parameters + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + // deepseek MLA parameters + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + + // DSA parameters + ml.get_key(LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, hparams.indexer_n_head); + ml.get_key(LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, hparams.indexer_head_size); + ml.get_key(LLM_KV_ATTENTION_INDEXER_TOP_K, hparams.indexer_top_k); + + // Expert gating function (GLM-4.5 uses sigmoid) + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } + + // NextN/MTP parameters + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + + // TODO: when MTP is implemented, this should probably be updated if needed + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + + switch (hparams.n_layer) { + case 79: type = LLM_TYPE_744B_A40B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_BITNET: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -1882,6 +1937,16 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_JAIS2: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_8B; break; + case 68: type = LLM_TYPE_70B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_NEMOTRON: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -2293,6 +2358,12 @@ void llama_model::load_hparams(llama_model_loader & ml) { case 10752: type = LLM_TYPE_2_6B; break; default: type = LLM_TYPE_UNKNOWN; } + if (const auto is_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); is_swa && hparams.n_swa > 0) { + hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; + for (uint32_t il = 0; il < hparams.n_layer; ++il) { + hparams.swa_layers[il] = !hparams.recurrent_layer_arr[il]; + } + } } break; case LLM_ARCH_LFM2MOE: { @@ -2403,8 +2474,12 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); // Mark recurrent layers (linear attention layers) - for (uint32_t i = 0; i < hparams.n_layer; ++i) { - hparams.recurrent_layer_arr[i] = ((i + 1) % 4 != 0); // TODO: extract the magic 4 from "full_attention_interval" + { + uint32_t full_attn_interval = 4; + ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + } } switch (hparams.n_layer) { @@ -2412,6 +2487,62 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_QWEN35: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + + // Load linear attention (gated delta net) parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // Mark recurrent layers (linear attention layers) + { + uint32_t full_attn_interval = 4; + ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + } + } + + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_2B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; + case LLM_ARCH_QWEN35MOE: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false); + ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, true); + + // Load linear attention (gated delta net) parameters + ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv); + ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner); + ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state); + ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank); + ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group); + + // Mark recurrent layers (linear attention layers) + { + uint32_t full_attn_interval = 4; + ml.get_key(LLM_KV_FULL_ATTENTION_INTERVAL, full_attn_interval, false); + for (uint32_t i = 0; i < hparams.n_layer; ++i) { + hparams.recurrent_layer_arr[i] = ((i + 1) % full_attn_interval != 0); + } + } + + switch (hparams.n_layer) { + case 28: type = LLM_TYPE_35B_A3B; break; + case 48: type = LLM_TYPE_80B_A3B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_MISTRAL3: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -3398,9 +3529,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); } - cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); - cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); - cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out = create_tensor(tn(LLM_TENSOR_CLS_OUT, "weight"), {n_embd, hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls_out_b = create_tensor(tn(LLM_TENSOR_CLS_OUT, "bias"), {hparams.n_cls_out}, TENSOR_NOT_REQUIRED); + cls = create_tensor(tn(LLM_TENSOR_CLS, "weight"), {n_embd, n_embd}, TENSOR_NOT_REQUIRED); + cls_norm = create_tensor(tn(LLM_TENSOR_CLS_NORM, "weight"), {n_embd}, TENSOR_NOT_REQUIRED); } break; case LLM_ARCH_NEO_BERT: @@ -5253,6 +5385,45 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); } } break; + case LLM_ARCH_JAIS2: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "bias", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + // attention biases - all have shape n_embd (output dimension of projections) + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, 0); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd}, 0); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd}, 0); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), {n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_norm_b = create_tensor(tn(LLM_TENSOR_FFN_NORM, "bias", i), {n_embd}, 0); + + // Jais-2 uses simple MLP (no gate) with biases + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); + } + } break; case LLM_ARCH_CHATGLM: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -5303,30 +5474,48 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } for (int i = 0; i < n_layer; ++i) { - auto & layer = layers[i]; - - layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); - layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, TENSOR_NOT_REQUIRED); - - if (layer.wqkv == nullptr) { - layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); - layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, 0); - layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, 0); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, TENSOR_NOT_REQUIRED); + int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + // skip all tensors in the NextN layers + flags |= TENSOR_SKIP; } - layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + auto & layer = layers[i]; - layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, 0); + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, n_embd + 2*n_embd_gqa}, flags | TENSOR_NOT_REQUIRED); + layer.bqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "bias", i), {n_embd + 2*n_embd_gqa}, flags | TENSOR_NOT_REQUIRED); - layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, 0); + if (layer.wqkv == nullptr) { + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, flags); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_k_gqa}, flags); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_v_gqa}, flags); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), {n_embd}, flags | TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), {n_embd_gqa}, flags | TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), {n_embd_gqa}, flags | TENSOR_NOT_REQUIRED); + } - layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, flags); + + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), {n_embd}, flags); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff * 2}, flags); + + layer.ffn_post_norm = create_tensor(tn(LLM_TENSOR_FFN_POST_NORM, "weight", i), {n_embd}, flags); + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + + // Optional tensors + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); + } } } break; case LLM_ARCH_GLM4_MOE: @@ -5430,6 +5619,108 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; + case LLM_ARCH_GLM_DSA: + { + const bool is_mla = hparams.is_mla(); + if (!is_mla) { + throw std::runtime_error("GLM_DSA architecture requires MLA"); + } + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); + + const int64_t n_embd_head_qk_rope = hparams.n_rot; + const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; + + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert_shared = hparams.n_expert_shared; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + // skip all tensors in the NextN layers + // TODO @ngxson : TENSOR_NOT_REQUIRED was a hack, need to remove it later + flags |= TENSOR_SKIP | TENSOR_NOT_REQUIRED; + } + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, flags); + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, flags); + + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, flags); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, flags); + + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, flags); + + // note: only old legacy GGUF files will have the unsplit wkv_b tensor in + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, flags); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, flags); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, flags); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); + + // DSA indexer + layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {hparams.indexer_head_size}, flags); + layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {hparams.indexer_head_size}, flags); + layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, hparams.indexer_n_head}, flags); + layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, hparams.indexer_head_size}, flags); + layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.indexer_n_head * hparams.indexer_head_size}, flags); + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); + } + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + + // Optional tensors + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); + } + } + } break; case LLM_ARCH_NEMOTRON: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -5985,9 +6276,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_WAVTOKENIZER_DEC: { - tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd_features, n_vocab}, 0); + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {hparams.n_embd, n_vocab}, 0); - conv1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight"), {7, hparams.n_embd_features, hparams.posnet.n_embd}, 0); + conv1d = create_tensor(tn(LLM_TENSOR_CONV1D, "weight"), {7, hparams.n_embd, hparams.posnet.n_embd}, 0); conv1d_b = create_tensor(tn(LLM_TENSOR_CONV1D, "bias"), {1, hparams.posnet.n_embd}, 0); // posnet @@ -6083,8 +6374,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output_norm_b = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "bias"), {n_embd}, 0); } - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, n_embd}, 0); - output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {hparams.convnext.n_embd, hparams.n_embd_out()}, 0); + output_b = create_tensor(tn(LLM_TENSOR_OUTPUT, "bias"), {hparams.n_embd_out()}, 0); } break; case LLM_ARCH_BAILINGMOE: { @@ -6660,7 +6951,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } // for LFM2-ColBert-350M - dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.n_embd_out()}, TENSOR_NOT_REQUIRED); + dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.n_embd_out()}, TENSOR_NOT_REQUIRED); + dense_2_out_layers_b = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "bias"), {hparams.n_embd_out() }, TENSOR_NOT_REQUIRED); } break; case LLM_ARCH_SMALLTHINKER: { @@ -7101,6 +7393,131 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { hparams.n_ff_shexp, n_embd }, 0); } } break; + case LLM_ARCH_QWEN35MOE: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used; + + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + if (!hparams.is_recurrent(i)) { + // Attention layers + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // Q/K normalization for attention layers + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + } else { + // Linear attention (gated delta net) specific tensors + // Create tensors with calculated dimensions + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); + layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + } + + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + + // Shared experts + const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; + + layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, n_ff_shexp }, 0); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_shexp, n_embd }, 0); + } + } break; + case LLM_ARCH_QWEN35: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED); + + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED); + } + + // Calculate dimensions from hyperparameters + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t head_v_dim = hparams.ssm_d_state; + const int64_t n_k_heads = hparams.ssm_n_group; + const int64_t n_v_heads = hparams.ssm_dt_rank; + const int64_t key_dim = head_k_dim * n_k_heads; + const int64_t value_dim = head_v_dim * n_v_heads; + const int64_t conv_dim = key_dim * 2 + value_dim; + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0); + + if (!hparams.is_recurrent(i)) { + // Attention layers + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // Q/K normalization for attention layers + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + } else { + // Linear attention (gated delta net) specific tensors + // Create tensors with calculated dimensions + layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), { n_embd, key_dim * 2 + value_dim }, TENSOR_NOT_REQUIRED); + layer.wqkv_gate = create_tensor(tn(LLM_TENSOR_ATTN_GATE, "weight", i), { n_embd, value_dim }, TENSOR_NOT_REQUIRED); + layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0); + layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0); + layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A_NOSCAN, i), { hparams.ssm_dt_rank }, 0); + layer.ssm_beta = create_tensor(tn(LLM_TENSOR_SSM_BETA, "weight", i), { n_embd, n_v_heads }, 0); + layer.ssm_alpha = create_tensor(tn(LLM_TENSOR_SSM_ALPHA, "weight", i), { n_embd, n_v_heads }, 0); + layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0); + layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0); + } + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; case LLM_ARCH_MIMO2: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -7545,6 +7962,8 @@ void llama_model::print_info() const { arch == LLM_ARCH_PLAMO2 || arch == LLM_ARCH_GRANITE_HYBRID || arch == LLM_ARCH_QWEN3NEXT || + arch == LLM_ARCH_QWEN35 || + arch == LLM_ARCH_QWEN35MOE || arch == LLM_ARCH_NEMOTRON_H || arch == LLM_ARCH_NEMOTRON_H_MOE) { LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv); @@ -7576,7 +7995,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); } - if (arch == LLM_ARCH_DEEPSEEK2) { + if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_GLM_DSA) { LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); @@ -7776,7 +8195,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.n_seq_max, nullptr); } else if (llm_arch_is_hybrid(arch)) { - // The main difference between hybrid architectures is the // layer filters, so pick the right one here llama_memory_hybrid::layer_filter_cb filter_attn = nullptr; @@ -7801,7 +8219,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* attn_type_v */ params.type_v, /* attn_v_trans */ !cparams.flash_attn, /* attn_swa_full */ params.swa_full, - /* attn_kv_size */ cparams.n_ctx, + /* attn_kv_size */ cparams.n_ctx_seq, /* attn_n_ubatch */ cparams.n_ubatch, /* attn_n_pad */ 1, /* recurrent_type_r */ GGML_TYPE_F32, @@ -7818,7 +8236,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* attn_type_k */ params.type_k, /* attn_type_v */ params.type_v, /* attn_v_trans */ !cparams.flash_attn, - /* attn_kv_size */ cparams.n_ctx, + /* attn_kv_size */ cparams.n_ctx_seq, /* attn_n_pad */ 1, /* attn_n_swa */ hparams.n_swa, /* attn_swa_type */ hparams.swa_type, @@ -8149,6 +8567,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique(*this, params); } break; case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_GLM_DSA: { llm = std::make_unique(*this, params); } break; @@ -8195,6 +8614,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_JAIS2: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_NEMOTRON: { llm = std::make_unique(*this, params); @@ -8313,7 +8736,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { case LLM_ARCH_LFM2: case LLM_ARCH_LFM2MOE: { - llm = std::make_unique(*this, params); + if (hparams.swa_type == LLAMA_SWA_TYPE_STANDARD) { + llm = std::make_unique>(*this, params); + } else { + llm = std::make_unique>(*this, params); + } } break; case LLM_ARCH_SMALLTHINKER: { @@ -8347,6 +8774,14 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_QWEN35: + { + llm = std::make_unique(*this, params); + } break; + case LLM_ARCH_QWEN35MOE: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_MISTRAL3: { llm = std::make_unique(*this, params); @@ -8368,7 +8803,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } // add on pooling layer - llm->build_pooling(cls, cls_b, cls_out, cls_out_b); + llm->build_pooling(cls, cls_b, cls_out, cls_out_b, cls_norm); // add backend sampling layers (if any) llm->build_sampling(); @@ -8377,7 +8812,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { // there will be two additional dense projection layers // dense linear projections are applied after pooling // TODO: move reranking logic here and generalize - llm->build_dense_out(dense_2_out_layers, dense_3_out_layers); + llm->build_dense_out(dense_2_out_layers, dense_2_out_layers_b, dense_3_out_layers); llm->res->set_outputs(); @@ -8546,6 +8981,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_MISTRAL3: case LLM_ARCH_LLAMA_EMBED: case LLM_ARCH_MAINCODER: + case LLM_ARCH_GLM_DSA: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 @@ -8594,6 +9030,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_BAILINGMOE2: case LLM_ARCH_DOTS1: case LLM_ARCH_HUNYUAN_MOE: + case LLM_ARCH_JAIS2: case LLM_ARCH_OPENAI_MOE: case LLM_ARCH_HUNYUAN_DENSE: case LLM_ARCH_LFM2: @@ -8615,6 +9052,8 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { return LLAMA_ROPE_TYPE_MROPE; case LLM_ARCH_QWEN3VL: case LLM_ARCH_QWEN3VLMOE: + case LLM_ARCH_QWEN35: + case LLM_ARCH_QWEN35MOE: return LLAMA_ROPE_TYPE_IMROPE; case LLM_ARCH_GLM4: diff --git a/src/llama-model.h b/src/llama-model.h index 7b580043b3..422ed45699 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -118,6 +118,7 @@ enum llm_type { LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_30B_A3B, LLM_TYPE_31B_A3_5B, + LLM_TYPE_35B_A3B, // Qwen3.5 LLM_TYPE_48B_A3B, // Kimi Linear LLM_TYPE_80B_A3B, // Qwen3 Next LLM_TYPE_100B_A6B, @@ -129,6 +130,7 @@ enum llm_type { LLM_TYPE_300B_A47B, // Ernie MoE big LLM_TYPE_310B_A15B, // /MiMo-V2-Flash LLM_TYPE_355B_A32B, // GLM-4.5 + LLM_TYPE_744B_A40B, // GLM-5 LLM_TYPE_E2B, LLM_TYPE_E4B, }; @@ -322,6 +324,9 @@ struct llama_layer { // qwen3next struct ggml_tensor * ssm_beta_alpha = nullptr; + // qwen3.5 + struct ggml_tensor * ssm_alpha = nullptr; + // rwkv struct ggml_tensor * time_mix_w1 = nullptr; struct ggml_tensor * time_mix_w2 = nullptr; @@ -425,6 +430,13 @@ struct llama_layer { struct ggml_tensor * ssm_g_b = nullptr; struct ggml_tensor * ssm_o_norm = nullptr; + // DSA (deepseek sparse attention) + struct ggml_tensor * indexer_k_norm = nullptr; + struct ggml_tensor * indexer_k_norm_b = nullptr; + struct ggml_tensor * indexer_proj = nullptr; + struct ggml_tensor * indexer_attn_k = nullptr; + struct ggml_tensor * indexer_attn_q_b = nullptr; // note: for lora a/b, not bias + struct llama_layer_posnet posnet; struct llama_layer_convnext convnext; @@ -463,6 +475,7 @@ struct llama_model { struct ggml_tensor * cls_b = nullptr; struct ggml_tensor * cls_out = nullptr; struct ggml_tensor * cls_out_b = nullptr; + struct ggml_tensor * cls_norm = nullptr; struct ggml_tensor * conv1d = nullptr; struct ggml_tensor * conv1d_b = nullptr; @@ -479,8 +492,9 @@ struct llama_model { //Dense linear projections for SentenceTransformers models like embeddinggemma // For Sentence Transformers models structure see // https://sbert.net/docs/sentence_transformer/usage/custom_models.html#structure-of-sentence-transformer-models - struct ggml_tensor * dense_2_out_layers = nullptr; - struct ggml_tensor * dense_3_out_layers = nullptr; + struct ggml_tensor * dense_2_out_layers = nullptr; + struct ggml_tensor * dense_2_out_layers_b = nullptr; + struct ggml_tensor * dense_3_out_layers = nullptr; // gguf metadata std::unordered_map gguf_kv; diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 6d6bdfa090..657df711ef 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -289,6 +289,15 @@ struct llm_tokenizer_bpe : llm_tokenizer { "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_JAIS2: + regex_exprs = { + // original regex from tokenizer.json + //"(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s{512}(?!\\S)|\\s{256}(?!\\S)|\\s{128}(?!\\S)|\\s{64}(?!\\S)|\\s{32}(?!\\S)|\\s{16}(?!\\S)|\\s{8}(?!\\S)|\\s{4}(?!\\S)|\\s{1,2}(?!\\S)|\\s{1}", + + // adapted: same as llama3 but with cascading whitespace pattern + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s{512}(?!\\S)|\\s{256}(?!\\S)|\\s{128}(?!\\S)|\\s{64}(?!\\S)|\\s{32}(?!\\S)|\\s{16}(?!\\S)|\\s{8}(?!\\S)|\\s{4}(?!\\S)|\\s{1,2}(?!\\S)|\\s{1}", + }; + break; case LLAMA_VOCAB_PRE_TYPE_DBRX: case LLAMA_VOCAB_PRE_TYPE_SMAUG: regex_exprs = { @@ -308,6 +317,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { break; case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM: case LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE: + case LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM: regex_exprs = { "\\p{N}{1,3}", "[一-龥぀-ゟ゠-ヿ]+", @@ -368,6 +378,13 @@ struct llm_tokenizer_bpe : llm_tokenizer { "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_QWEN35: + regex_exprs = { + // original regex from tokenizer.json + // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + "(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])|[^\\r\\n\\p{L}\\p{N}]?[\\p{L}\\p{M}]+|\\p{N}| ?[^\\s\\p{L}\\p{M}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_PORO: case LLAMA_VOCAB_PRE_TYPE_BLOOM: case LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH: @@ -415,6 +432,14 @@ struct llm_tokenizer_bpe : llm_tokenizer { "[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))*((?=[\\p{L}])([^A-Z]))+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?((?=[\\p{L}])([^a-z]))+((?=[\\p{L}])([^A-Z]))*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_TINY_AYA: + regex_exprs = { + // original regex from tokenizer.json: "\\d{1,3}(?=(?:\\d{3})*\\b)" + "\\d{1,3}(?=(?:\\d{3})*\\b)", + // original regex from tokenizer.json: "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" + "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}{1,3}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_KIMI_K2: regex_exprs = { // K2 trigger pattern - this will activate the custom K2 handler in unicode.cpp @@ -1905,8 +1930,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "jina-v2-de" || tokenizer_pre == "a.x-4.0" || tokenizer_pre == "mellum" || - tokenizer_pre == "modern-bert" ) { + tokenizer_pre == "modern-bert") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT2; + } else if ( + tokenizer_pre == "jais-2") { + pre_type = LLAMA_VOCAB_PRE_TYPE_JAIS2; } else if ( tokenizer_pre == "jina-v1-en" || tokenizer_pre == "jina-v2-code" || @@ -1926,6 +1954,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "kormo") { pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; clean_spaces = false; + } else if ( + tokenizer_pre == "qwen35") { + pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN35; + clean_spaces = false; } else if ( tokenizer_pre == "stablelm2") { pre_type = LLAMA_VOCAB_PRE_TYPE_STABLELM2; @@ -1994,10 +2026,14 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "megrez") { pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; } else if ( - tokenizer_pre == "gpt-4o" || - tokenizer_pre == "llama4") { + tokenizer_pre == "gpt-4o" || + tokenizer_pre == "llama4") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O; clean_spaces = false; + } else if ( + tokenizer_pre == "tiny_aya") { + pre_type = LLAMA_VOCAB_PRE_TYPE_TINY_AYA; + clean_spaces = false; } else if ( tokenizer_pre == "superbpe") { pre_type = LLAMA_VOCAB_PRE_TYPE_SUPERBPE; @@ -2028,6 +2064,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "hunyuan-dense") { pre_type = LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE; clean_spaces = false; + } else if ( + tokenizer_pre == "joyai-llm") { + pre_type = LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM; + clean_spaces = false; } else if ( tokenizer_pre == "kimi-k2") { pre_type = LLAMA_VOCAB_PRE_TYPE_KIMI_K2; diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 28c3a82b91..be5b08012d 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -54,6 +54,10 @@ enum llama_vocab_pre_type { LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN = 43, LLAMA_VOCAB_PRE_TYPE_YOUTU = 44, LLAMA_VOCAB_PRE_TYPE_EXAONE_MOE = 45, + LLAMA_VOCAB_PRE_TYPE_QWEN35 = 46, + LLAMA_VOCAB_PRE_TYPE_TINY_AYA = 47, + LLAMA_VOCAB_PRE_TYPE_JOYAI_LLM = 48, + LLAMA_VOCAB_PRE_TYPE_JAIS2 = 49, }; struct LLM_KV; diff --git a/src/models/deepseek2.cpp b/src/models/deepseek2.cpp index 987f449934..b2c1f16060 100644 --- a/src/models/deepseek2.cpp +++ b/src/models/deepseek2.cpp @@ -45,7 +45,8 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { + int effective_n_layers = hparams.n_layer - hparams.nextn_predict_layers; + for (int il = 0; il < effective_n_layers; ++il) { ggml_tensor * inpSA = inpL; // norm @@ -188,7 +189,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); } } - if (il == n_layer - 1 && inp_out_ids) { + if (il == effective_n_layers - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } diff --git a/src/models/delta-net-base.cpp b/src/models/delta-net-base.cpp new file mode 100644 index 0000000000..99f1fdd953 --- /dev/null +++ b/src/models/delta-net-base.cpp @@ -0,0 +1,376 @@ +#include "models.h" + +#define CHUNK_SIZE 64 + +// utility to get one slice from the third dimension +// input dim: [x, y, c, b] +// output dim: [x, y, 1, b] +static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) { + return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3], + t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c); +} + +llm_build_delta_net_base::llm_build_delta_net_base(const llm_graph_params & params) : llm_graph_context(params) {} + +std::pair llm_build_delta_net_base::build_delta_net_chunking( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il) { + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + const bool kda = (g->ne[0] == S_k && g->ne[1] == H_k); + + GGML_ASSERT(S_k == S_v); + GGML_ASSERT(H_v % H_k == 0); + + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs); + + GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); + GGML_ASSERT( g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs); + GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); + GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); + + const float scale = 1.0f / sqrtf(S_k); + + q = ggml_scale(ctx0, q, scale); + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(b, "b_in", il); + cb(g, "g_in", il); + + q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] + k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] + v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs] + g = ggml_permute(ctx0, g, 0, 2, 1, 3); // [g_0, n_tokens, H_v, n_seqs] + b = ggml_permute(ctx0, b, 0, 2, 1, 3); // [ 1, n_tokens, H_v, n_seqs] + + const int CS = CHUNK_SIZE; + + const int pad = (CS - n_tokens % CS) % CS; + const int n_chunks = (n_tokens + pad) / CS; + + q = ggml_pad(ctx0, q, 0, pad, 0, 0); + k = ggml_pad(ctx0, k, 0, pad, 0, 0); + v = ggml_pad(ctx0, v, 0, pad, 0, 0); + g = ggml_pad(ctx0, g, 0, pad, 0, 0); + b = ggml_pad(ctx0, b, 0, pad, 0, 0); + + ggml_tensor * v_b = ggml_mul(ctx0, v, b); + ggml_tensor * k_b = ggml_mul(ctx0, k, b); + + cb(v_b, "v_b", il); + cb(k_b, "k_b", il); + + q = ggml_reshape_4d(ctx0, q, S_k, CS, n_chunks, H_k * n_seqs); + k = ggml_reshape_4d(ctx0, k, S_k, CS, n_chunks, H_k * n_seqs); + k_b = ggml_reshape_4d(ctx0, k_b, S_k, CS, n_chunks, H_v * n_seqs); + v = ggml_reshape_4d(ctx0, v, S_v, CS, n_chunks, H_v * n_seqs); + v_b = ggml_reshape_4d(ctx0, v_b, S_v, CS, n_chunks, H_v * n_seqs); + + g = ggml_reshape_4d(ctx0, g, g->ne[0], CS, n_chunks, H_v * n_seqs); + b = ggml_reshape_4d(ctx0, b, 1, CS, n_chunks, H_v * n_seqs); + + // [CS, g_0, n_chunks, H_v * n_seqs] + // TODO: extend ggml_cumsum with axis parameter to avoid transpose + ggml_tensor * g_cs = ggml_cumsum(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, g))); + cb(g_cs, "g_cs", il); + + ggml_tensor * kb = nullptr; + ggml_tensor * kq = nullptr; + if (kda) { + const int64_t CHB = n_chunks * H_k * n_seqs; + + ggml_tensor * g_cs_i = ggml_reshape_4d(ctx0, g_cs, CS, 1, S_k, CHB); // [chunk_size, 1, S_k, CHB] + ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, S_k, CHB); // [1, chunk_size, S_k, CHB] + + g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, S_k, CHB); // [1, chunk_size, S_k, CHB] -> [chunk_size, chunk_size, S_k, CHB] + + // decay_mask [chunk_size,chunk_size,S_k,CHB] + ggml_tensor * decay_mask; + decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i); + decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG); + decay_mask = ggml_exp(ctx0, decay_mask); + cb(decay_mask, "decay_mask", il); + + // decay_mask [S_k,BT_j,BT_i,CHB] *Note* second and third chunk_sizes are switched + decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, CS, CS, CHB); + + ggml_tensor * k_b_i = ggml_reshape_4d(ctx0, k_b, S_k, CS, 1, CHB); + ggml_tensor * k_j = ggml_reshape_4d(ctx0, k, S_k, 1, CS, CHB); + ggml_tensor * q_i = ggml_reshape_4d(ctx0, q, S_k, CS, 1, CHB); + + ggml_tensor * decay_k_b_i = ggml_mul(ctx0, decay_mask, k_b_i); + ggml_tensor * decay_q_i = ggml_mul(ctx0, decay_mask, q_i); + + // decay_k_b_i [S,BT,BT,CHB] @ k_j [S,1,BT,CHB] = Akk [BT,1,BT,CHB] + kb = ggml_mul_mat(ctx0, decay_k_b_i, k_j); + kq = ggml_mul_mat(ctx0, decay_q_i, k_j); + + kb = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, kb, CS, CS, n_chunks, H_v * n_seqs))); + kq = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, kq, CS, CS, n_chunks, H_v * n_seqs))); + } else { + ggml_tensor * g_cs_i = g_cs; + ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, n_chunks, H_v * n_seqs); + + g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, n_chunks, H_v * n_seqs); + + // [CS, CS, n_chunks, H_v * n_seqs] + ggml_tensor * decay_mask; + decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i); + decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG); + decay_mask = ggml_exp(ctx0, decay_mask); + cb(decay_mask, "decay_mask", il); + + // [CS, CS, n_chunks, H_k * n_seqs] + kb = ggml_mul_mat(ctx0, k, k_b); + kb = ggml_mul (ctx0, kb, decay_mask); + + // [CS, CS, n_chunks, H_k * n_seqs] + kq = ggml_mul_mat(ctx0, k, q); + kq = ggml_mul(ctx0, kq, decay_mask); + } + + kq = ggml_tri(ctx0, kq, GGML_TRI_TYPE_LOWER_DIAG); + cb(kq, "kq", il); + + // [CS, CS, n_chunks, H_k * n_seqs] + ggml_tensor * attn; + attn = ggml_tri(ctx0, kb, GGML_TRI_TYPE_LOWER); + cb(attn, "attn", il); + + ggml_tensor * identity; + identity = ggml_view_1d(ctx0, attn, CS, 0); + identity = ggml_fill (ctx0, identity, 1.0f); + identity = ggml_diag (ctx0, identity); + + ggml_tensor * lhs = ggml_add(ctx0, attn, identity); + cb(lhs, "dnet_add_ch_lhs", il); + + attn = ggml_neg(ctx0, attn); + cb(attn, "attn_pre_solve", il); + + ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); + attn = ggml_add(ctx0, lin_solve, identity); + cb(attn, "dnet_add_ch_attn_solved", il); // [CS, CS, n_chunks, H_k * n_seqs] + + // [S_v, CS, n_chunks, H_v * n_seqs] + v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_b)), attn); + + // [CS, 1, n_chunks, H_v * n_seqs] KDA: [CS, S_k, n_chunks, H_v * n_seqs] + ggml_tensor * g_exp = ggml_exp(ctx0, g_cs); + + k_b = ggml_cont(ctx0, ggml_transpose(ctx0, k_b)); + + // [CS, S_k, n_chunks, H_k * n_seqs] + ggml_tensor * kbg = ggml_mul(ctx0, k_b, g_exp); + cb(kbg, "k_beta_g_exp", il); + + // [S_k, CS, n_chunks, H_k * n_seqs] + ggml_tensor * k_cd = ggml_mul_mat(ctx0, kbg, attn); + cb(k_cd, "k_cumdecay", il); + + // [1, CS, n_chunks, H_k * n_seqs] KDA: [S_k, CS, n_chunks, H_k * n_seqs] + ggml_tensor * g_exp_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_exp)); + ggml_tensor * q_g_exp = ggml_mul(ctx0, q, g_exp_t); + + // vectorized calculation of key_gdiff + // improved from the chunked version: + // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) + // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() + // key_gdiff = key * g_diff.unsqueeze(-1) + // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new + // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew + + // get last element in g_cumsum along CS dimension (ne0) + // example: [[x, y, z, ..., last], ...] -> [[last], ...] + // [1, 1, n_chunks, H_v * n_seqs] KDA: [1, S_k, n_chunks, H_v * n_seqs] + ggml_tensor * g_last = ggml_view_4d(ctx0, g_cs, 1, g_cs->ne[1], g_cs->ne[2], g_cs->ne[3], + g_cs->nb[1], + g_cs->nb[2], + g_cs->nb[3], + ggml_row_size(g_cs->type, g_cs->ne[0] - 1)); + cb(g_last, "g_last", il); + + // TODO: remove this cont when CUDA supports non-cont unary ops + g_last = ggml_cont(ctx0, g_last); + + // [1, 1, n_chunks, H_v * n_seqs] KDA: [S_k, 1, n_chunks, H_v * n_seqs] + ggml_tensor * g_last_exp_t = ggml_transpose(ctx0, ggml_exp(ctx0, g_last)); + cb(g_last_exp_t, "g_last_exp_t", il); + + // [CS, 1, n_chunks, H_v * n_seqs] KDA: [CS, S_k, n_chunks, H_v * n_seqs] + ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cs, g_last)); + cb(g_diff, "g_diff", il); + + ggml_tensor * g_diff_exp_t = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_exp(ctx0, g_diff))); + + // [S_k, CS, n_chunks, H_v * n_seqs] + ggml_tensor * kg = ggml_mul(ctx0, k, g_diff_exp_t); + cb(kg, "key_gdiff", il); + + // [CS, S_k, n_chunks, H_v * n_seqs] + ggml_tensor * kg_t = ggml_cont(ctx0, ggml_transpose(ctx0, kg)); + cb(kg_t, "key_gdiff_t", il); + + ggml_tensor * s_t = ggml_transpose(ctx0, s); + s_t = ggml_cont_4d(ctx0, s_t, S_v, S_v, 1, H_v * n_seqs); + cb(s_t, "dnet_add_ch_state", il); + + // [CS, S_v, n_chunks, H_v * n_seqs] + ggml_tensor * v_t = ggml_cont(ctx0, ggml_transpose(ctx0, v)); + + for (int64_t chunk = 0; chunk < n_chunks; chunk++) { + ggml_tensor * ch_k_cd = get_slice_2d(ctx0, k_cd, chunk); // [S_k, CS, 1, H_k * n_seqs] + ggml_tensor * ch_v_t = get_slice_2d(ctx0, v_t, chunk); // [ CS, S_v, 1, H_v * n_seqs] + ggml_tensor * ch_kq = get_slice_2d(ctx0, kq, chunk); // [ CS, CS, 1, H_k * n_seqs] + ggml_tensor * ch_q_g_exp = get_slice_2d(ctx0, q_g_exp, chunk); // [S_k, CS, 1, H_k * n_seqs] + ggml_tensor * ch_kg_t = get_slice_2d(ctx0, kg_t, chunk); // [ CS, S_k, 1, H_v * n_seqs] + + // [CS, S_v, 1, H_v * n_seqs] + ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s_t); + cb(v_t_p, "v_prime", il); + + // [CS, S_v, 1, H_v * n_seqs] + ggml_tensor * v_t_new = ggml_sub(ctx0, ch_v_t, v_t_p); + cb(v_t_new, "v_t_new", il); + + // [S_v, CS, 1, H_v * n_seqs] + ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_t_new, ch_kq); + cb(v_attn, "v_attn", il); + + // [S_v, CS, 1, H_v * n_seqs] + ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s_t, ch_q_g_exp); + cb(attn_inter, "attn_inter", il); + + // [S_v, CS, 1, H_v * n_seqs] + ggml_tensor * o_ch = ggml_add(ctx0, attn_inter, v_attn); + cb(o_ch, "dnet_add_ch_attn_out", il); + + v = ggml_set_inplace(ctx0, v, o_ch, v->nb[1], v->nb[2], v->nb[3], chunk * v->nb[2]); + + // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new + // TODO: head broadcast might not work here - probably will need a transpose + ggml_tensor * kgv = ggml_mul_mat(ctx0, ch_kg_t, v_t_new); // [S_k, S_v, 1, H_k * n_seqs] + + // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew + ggml_tensor * ch_g_last_exp_t = get_slice_2d(ctx0, g_last_exp_t, chunk); + + s_t = ggml_mul(ctx0, s_t, ch_g_last_exp_t); + s_t = ggml_add(ctx0, s_t, kgv); + cb(s_t, "dnet_add_ch_state", il); + } + + s_t = ggml_reshape_4d(ctx0, s_t, S_v, S_v, H_v, n_seqs); + + // truncate padded tokens + ggml_tensor * o = ggml_view_4d(ctx0, v, + S_v, n_tokens, H_v, n_seqs, + ggml_row_size(v->type, S_v), + ggml_row_size(v->type, S_v * CS * n_chunks), + ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0); + o = ggml_permute (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs] + s = ggml_transpose(ctx0, s_t); + cb(s, "output_state", il); + + return {o, s}; +} + +std::pair llm_build_delta_net_base::build_delta_net_autoregressive( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, // beta + ggml_tensor * s, // state + int il) { + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(n_tokens == 1); + + GGML_ASSERT(S_k == S_v); + GGML_ASSERT(H_v % H_k == 0); + + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs); + + GGML_ASSERT(g->ne[0] == 1 || g->ne[0] == S_v); + GGML_ASSERT( g->ne[1] == H_v && g->ne[2] == n_tokens && g->ne[3] == n_seqs); + GGML_ASSERT(b->ne[0] == 1 && b->ne[1] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); + GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); + + const float scale = 1.0f / sqrtf(S_k); + + q = ggml_scale(ctx0, q, scale); + + q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] + k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] + v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs] + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(b, "b_in", il); + cb(g, "g_in", il); + + // GDA: [1, 1, H_v, n_seqs] + // KDA: [1, S_k, H_v, n_seqs] + g = ggml_reshape_4d(ctx0, g, 1, g->ne[0], H_v, n_seqs); + b = ggml_reshape_4d(ctx0, b, 1, 1, H_v, n_seqs); + + // [S_v, S_v, H_v, n_seqs] + g = ggml_exp(ctx0, g); + s = ggml_mul(ctx0, s, g); + + ggml_tensor * s_t = ggml_cont(ctx0, ggml_transpose(ctx0, s)); + + // [1, S_v, H_v, n_seqs] + ggml_tensor * sk; + sk = ggml_mul (ctx0, s_t, k); + sk = ggml_sum_rows(ctx0, sk); + + // [S_v, 1, H_v, n_seqs] + ggml_tensor * d; + d = ggml_sub(ctx0, v, ggml_transpose(ctx0, sk)); + d = ggml_mul(ctx0, d, b); + + // [1, S_v, H_v, n_seqs] + ggml_tensor * d_t; + d_t = ggml_transpose(ctx0, d); + + // [S_v, S_v, H_v, n_seqs] + ggml_tensor * kd; + k = ggml_repeat(ctx0, k, s); + kd = ggml_mul (ctx0, k, d_t); + + s_t = ggml_add(ctx0, s_t, kd); + + cb(s_t, "dnet_add_ar_state", il); + + ggml_tensor * s_q = ggml_mul (ctx0, s_t, q); + ggml_tensor * o = ggml_sum_rows(ctx0, s_q); + + o = ggml_permute (ctx0, o, 2, 0, 1, 3); // [S_v, H_v, n_tokens, n_seqs] + s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs] + + return {o, s}; +} diff --git a/src/models/falcon-h1.cpp b/src/models/falcon-h1.cpp index b641a09407..785a7e5e66 100644 --- a/src/models/falcon-h1.cpp +++ b/src/models/falcon-h1.cpp @@ -1,9 +1,7 @@ #include "models.h" - - llm_build_falcon_h1::llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params) { + llm_build_mamba_base(params) { const int64_t n_embd_head = hparams.n_embd_head_v; ggml_tensor * cur; diff --git a/src/models/glm4.cpp b/src/models/glm4.cpp index 204aa3932a..bcd837b30d 100644 --- a/src/models/glm4.cpp +++ b/src/models/glm4.cpp @@ -29,7 +29,10 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { + // Only process up to last layer (skip final NextN layer) + // Final layer tensors are loaded but not processed in forward pass + const int n_transformer_layers = n_layer - hparams.nextn_predict_layers; + for (int il = 0; il < n_transformer_layers; ++il) { ggml_tensor * inpSA = inpL; // Pre-attention norm @@ -100,7 +103,7 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params model.layers[il].wo, NULL, Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); } - if (il == n_layer - 1 && inp_out_ids) { + if (il == n_transformer_layers - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } @@ -130,9 +133,13 @@ llm_build_glm4::llm_build_glm4(const llama_model & model, const llm_graph_params cur = build_norm(cur, model.layers[il].ffn_post_norm, NULL, LLM_NORM_RMS, il); cb(cur, "post_mlp_norm", il); } - // Add residual connection after post-MLP norm - inpL = ggml_add(ctx0, cur, ffn_inp); - cb(inpL, "l_out", il); + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; } // Final norm cur = build_norm(inpL, model.output_norm, NULL, LLM_NORM_RMS, -1); diff --git a/src/models/granite-hybrid.cpp b/src/models/granite-hybrid.cpp index f6ca4c17a2..726ecdcca7 100644 --- a/src/models/granite-hybrid.cpp +++ b/src/models/granite-hybrid.cpp @@ -2,7 +2,7 @@ llm_build_granite_hybrid::llm_build_granite_hybrid(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params) { + llm_build_mamba_base(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); diff --git a/src/models/jais2.cpp b/src/models/jais2.cpp new file mode 100644 index 0000000000..a69fcaa3bb --- /dev/null +++ b/src/models/jais2.cpp @@ -0,0 +1,123 @@ +#include "models.h" + +// JAIS-2 model graph builder +// Uses: LayerNorm (not RMSNorm), relu2 activation, separate Q/K/V, RoPE embeddings +llm_build_jais2::llm_build_jais2(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + // KV input for attention + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + // Pre-attention LayerNorm + cur = build_norm(inpL, + model.layers[il].attn_norm, + model.layers[il].attn_norm_b, + LLM_NORM, il); + cb(cur, "attn_norm", il); + + // Self-attention with separate Q, K, V projections + { + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + Qcur = ggml_add(ctx0, Qcur, model.layers[il].bq); + cb(Qcur, "Qcur_bias", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + Kcur = ggml_add(ctx0, Kcur, model.layers[il].bk); + cb(Kcur, "Kcur_bias", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + Vcur = ggml_add(ctx0, Vcur, model.layers[il].bv); + cb(Vcur, "Vcur_bias", il); + + // Reshape for attention + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // Apply RoPE + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur_rope", il); + cb(Kcur, "Kcur_rope", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + // Residual connection + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpL); + cb(ffn_inp, "ffn_inp", il); + + // Pre-FFN LayerNorm + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, + model.layers[il].ffn_norm_b, + LLM_NORM, il); + cb(cur, "ffn_norm", il); + + // FFN with relu2 activation (ReLU squared) - no gate projection + // up -> relu2 -> down + cur = build_ffn(cur, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, + NULL, NULL, NULL, // no gate + model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, + NULL, + LLM_FFN_RELU_SQR, LLM_FFN_SEQ, il); + cb(cur, "ffn_out", il); + + // Residual connection + inpL = ggml_add(ctx0, cur, ffn_inp); + inpL = build_cvec(inpL, il); + cb(inpL, "l_out", il); + } + + // Final LayerNorm + cur = build_norm(inpL, + model.output_norm, + model.output_norm_b, + LLM_NORM, -1); + cb(cur, "result_norm", -1); + + res->t_embd = cur; + + // Output projection + cur = build_lora_mm(model.output, cur); + cb(cur, "result_output", -1); + + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/src/models/jamba.cpp b/src/models/jamba.cpp index a0187772cc..ceab581740 100644 --- a/src/models/jamba.cpp +++ b/src/models/jamba.cpp @@ -1,6 +1,6 @@ #include "models.h" -llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) { +llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) { const int64_t n_embd_head = hparams.n_embd_head_v; ggml_tensor * cur; diff --git a/src/models/kimi-linear.cpp b/src/models/kimi-linear.cpp index 0f037d1a39..4d6bb83c14 100644 --- a/src/models/kimi-linear.cpp +++ b/src/models/kimi-linear.cpp @@ -1,7 +1,7 @@ #include "models.h" #include "ggml.h" -#define CHUNK_SIZE 64 +#include "llama-memory-recurrent.h" // Causal Conv1d function for Q,K,V // When qkv is 0, it is Q, 1 is K, 2 is V @@ -41,8 +41,11 @@ static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_t conv_x->nb[1], conv_x->nb[2], n_seq_tokens * conv_x->nb[0]); ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_x, - ggml_view_1d(ctx0, conv_states_all, conv_state_size * n_seqs, - (kv_head * n_embd_r_total + qkv * conv_state_size) * ggml_element_size(conv_states_all)))); + ggml_view_3d(ctx0, conv_states_all, + d_conv - 1, d_inner, n_seqs, + (d_conv - 1) * ggml_element_size(conv_states_all), // nb1: contiguous within one channel's conv taps + n_embd_r_total * ggml_element_size(conv_states_all), // nb2: stride between sequences (skip over K,V states) + (kv_head * n_embd_r_total + qkv * conv_state_size) * ggml_element_size(conv_states_all)))); // offset to first seq's Q/K/V state // Reshape conv weight: GGUF [d_conv, 1, d_inner, 1] -> ggml_ssm_conv expects [d_conv, d_inner] // GGUF stores as [d_conv, 1, d_inner, 1] with memory layout w[conv_step + channel * d_conv] // vLLM stores as [d_inner, d_conv] with memory layout w[channel * d_conv + conv_step] @@ -62,7 +65,7 @@ static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_t } llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params), model(model) { + llm_build_delta_net_base(params), model(model) { ggml_tensor * cur; ggml_tensor * inpL; @@ -81,17 +84,6 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll // Output ids for selecting which tokens to output ggml_tensor * inp_out_ids = build_inp_out_ids(); - ggml_tensor * chunked_causal_mask = - ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f), - GGML_TRI_TYPE_LOWER); - - ggml_tensor * chunked_identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f)); - ggml_tensor * chunked_diag_mask = ggml_add(ctx0, chunked_causal_mask, chunked_identity); - - ggml_build_forward_expand(gf, chunked_causal_mask); - ggml_build_forward_expand(gf, chunked_identity); - ggml_build_forward_expand(gf, chunked_diag_mask); - // Kimi dimension constants const int64_t n_head = hparams.n_head(); const int64_t head_dim = hparams.n_embd_head_kda; @@ -157,27 +149,35 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll g1 = ggml_mul(ctx0, g1, A); cb(g1, "kda_g1", il); + g1 = ggml_reshape_4d(ctx0, g1, head_dim, n_head, n_seq_tokens, n_seqs); + // Compute beta (mixing coefficient) ggml_tensor * beta = ggml_mul_mat(ctx0, layer.ssm_beta, cur); - beta = ggml_reshape_4d(ctx0, beta, n_head, 1, n_seq_tokens, n_seqs); + beta = ggml_reshape_4d(ctx0, beta, 1, n_head, n_seq_tokens, n_seqs); cb(beta, "kda_beta", il); + beta = ggml_sigmoid(ctx0, beta); + // Reshape for KDA recurrence // {n_embd, n_tokens} -> {n_embd, n_seq_tokens, n_seqs} cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); - g1 = ggml_reshape_4d(ctx0, g1, head_dim, n_head, n_seq_tokens, n_seqs); - // Get SSM state and compute KDA recurrence using ggml_kda_scan ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); ggml_tensor * state = build_rs(inp_rs, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_dim, head_dim, n_head, n_seqs); - // Choose between build_kda_chunking and build_kda_recurrent based on n_tokens - std::pair attn_out = n_seq_tokens == 1 ? - build_kda_autoregressive(Qcur, Kcur, Vcur, g1, beta, state, il) : - build_kda_chunking(Qcur, Kcur, Vcur, g1, beta, state, chunked_causal_mask, chunked_identity, chunked_diag_mask, il); - ggml_tensor * output = attn_out.first; + const float eps_norm = hparams.f_norm_rms_eps; + + Qcur = ggml_l2_norm(ctx0, Qcur, eps_norm); + Kcur = ggml_l2_norm(ctx0, Kcur, eps_norm); + + // Choose between build_delta_net_chunking and build_delta_net_recurrent based on n_tokens + std::pair attn_out = n_seq_tokens == 1 ? + build_delta_net_autoregressive(Qcur, Kcur, Vcur, g1, beta, state, il) : + build_delta_net_chunking(Qcur, Kcur, Vcur, g1, beta, state, il); + + ggml_tensor * output = ggml_cont(ctx0, attn_out.first); ggml_tensor * new_state = attn_out.second; cb(output, "attn_output", il); cb(new_state, "new_state", il); @@ -388,385 +388,3 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll ggml_build_forward_expand(gf, cur); } - -/* - This is a ggml implementation of the naive_chunk_kda function of - https://github.com/fla-org/flash-linear-attention/blob/main/fla/ops/kda/naive.py -*/ -std::pair llm_build_kimi_linear::build_kda_chunking( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * gk, - ggml_tensor * beta, - ggml_tensor * state, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, - int il) { - GGML_ASSERT(ggml_is_contiguous(state)); - - const int64_t S_k = q->ne[0]; - const int64_t H_k = q->ne[1]; - const int64_t n_tokens = q->ne[2]; - const int64_t n_seqs = q->ne[3]; - - const int64_t S_v = v->ne[0]; - const int64_t H_v = v->ne[1]; - - GGML_ASSERT(v->ne[2] == n_tokens); - GGML_ASSERT(k->ne[2] == n_tokens); - GGML_ASSERT(gk->ne[0] == S_v && gk->ne[1] == H_v && gk->ne[2] == n_tokens && gk->ne[3] == n_seqs); - GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); - GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v && state->ne[2] == H_v && state->ne[3] == n_seqs); - - GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); - - GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case - - // TODO: can this ever be false? - const bool use_qk_l2norm = true; - - if (use_qk_l2norm) { - const float eps_norm = hparams.f_norm_rms_eps; - - q = ggml_l2_norm(ctx0, q, eps_norm); - k = ggml_l2_norm(ctx0, k, eps_norm); - } - - const float scale = 1.0f / sqrtf(S_v); - - beta = ggml_sigmoid(ctx0, beta); - - cb(q, "q_in", il); - cb(k, "k_in", il); - cb(v, "v_in", il); - cb(beta, "beta_in", il); - cb(gk, "gk_in", il); - - q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs); - k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_k, n_tokens, H_k, n_seqs); - v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - gk = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - - beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3)); - state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); - - cb(q, "q_perm", il); - cb(k, "k_perm", il); - cb(v, "v_perm", il); - cb(beta, "beta_perm", il); - cb(gk, "gk_perm", il); - cb(state, "state_in", il); - - GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs); - GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs); - GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs); - - // Do padding - const int64_t chunk_size = CHUNK_SIZE; - - const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size; - const int64_t n_chunks = (n_tokens + pad) / chunk_size; - - q = ggml_pad(ctx0, q, 0, pad, 0, 0); - k = ggml_pad(ctx0, k, 0, pad, 0, 0); - v = ggml_pad(ctx0, v, 0, pad, 0, 0); - gk = ggml_pad(ctx0, gk, 0, pad, 0, 0); - beta = ggml_pad(ctx0, beta, 0, pad, 0, 0); - - cb(q, "q_pad", il); - cb(k, "k_pad", il); - cb(v, "v_pad", il); - cb(beta, "beta_pad", il); - cb(gk, "gk_pad", il); - - ggml_tensor * v_beta = ggml_mul(ctx0, v, beta); - ggml_tensor * k_beta = ggml_mul(ctx0, k, beta); - - cb(v_beta, "v_beta", il); - cb(k_beta, "k_beta", il); - - const int64_t HB = H_k * n_seqs; - - q = ggml_cont_4d(ctx0, q, S_k, chunk_size, n_chunks, HB); - k = ggml_cont_4d(ctx0, k, S_k, chunk_size, n_chunks, HB); - k_beta = ggml_cont_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, HB); - v = ggml_cont_4d(ctx0, v, S_v, chunk_size, n_chunks, HB); - v_beta = ggml_cont_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, HB); - - gk = ggml_cont_4d(ctx0, gk, S_k, chunk_size, n_chunks, HB); - beta = ggml_cont_4d(ctx0, beta, 1, chunk_size, n_chunks, HB); - - // switch for cumsum - gk = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk, 1, 0, 2, 3), chunk_size, S_k, n_chunks, HB); - cb(gk, "gk", il); - ggml_tensor * gk_cumsum = ggml_cumsum(ctx0, gk); - cb(gk_cumsum, "gk_cumsum", il); - -/* - Compute Akk and Aqk loop together - Akk loop: - for i in range(BT): - k_i = k[..., i, :] # k_i [B,H,NT,S] - g_i = g[..., i:i+1, :] # g_i [B,H,NT,1,S] - A[..., i] = torch.einsum('... c d, ... d -> ... c', k * (g - g_i).exp(), k_i) - Aqk loop: - for j in range(BT): - k_j = k[:, :, i, j] - g_j = g[:, :, i, j:j+1, :] - A[..., j] = torch.einsum('... c d, ... d -> ... c', q_i * (g_i - g_j).exp(), k_j) -*/ - const int64_t CHB = n_chunks * H_k * n_seqs; - ggml_tensor * gkcs_i = ggml_reshape_4d(ctx0, gk_cumsum, chunk_size, 1, S_k, CHB); // [chunk_size, 1, S_k, CHB] - ggml_tensor * gkcs_j = ggml_reshape_4d(ctx0, gkcs_i, 1, chunk_size, S_k, CHB); // [1, chunk_size, S_k, CHB] - - ggml_tensor * gkcs_j_bc = ggml_repeat_4d(ctx0, gkcs_j, chunk_size, chunk_size, S_k, CHB); // [1, chunk_size, S_k, CHB] -> [chunk_size, chunk_size, S_k, CHB] - // decay_mask [chunk_size,chunk_size,S_k,CHB] - ggml_tensor * decay_mask = ggml_sub(ctx0, gkcs_j_bc, gkcs_i); - cb(decay_mask, "decay_mask", il); - - decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); - cb(decay_mask, "decay_masked", il); - decay_mask = ggml_exp(ctx0, decay_mask); - decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); - - // decay_mask [S_k,BT_j,BT_i,CHB] *Note* second and third chunk_sizes are switched - decay_mask = ggml_cont_4d(ctx0, ggml_permute(ctx0, decay_mask, 2, 1, 0, 3), S_k, chunk_size, chunk_size, CHB); - - ggml_tensor * k_i = ggml_reshape_4d(ctx0, k, S_k, chunk_size, 1, CHB); - ggml_tensor * k_j = ggml_reshape_4d(ctx0, k, S_k, 1, chunk_size, CHB); - ggml_tensor * q_i = ggml_reshape_4d(ctx0, q, S_k, chunk_size, 1, CHB); - - ggml_tensor * decay_k_i = ggml_mul(ctx0, decay_mask, k_i); - ggml_tensor * decay_q_i = ggml_mul(ctx0, decay_mask, q_i); - - // decay_k_i [S.BT,BT,CHB] @ k_j [S,1,BT,CHB] = Akk [BT,1,BT,CHB] - ggml_tensor * Akk = ggml_mul_mat(ctx0, decay_k_i, k_j); - ggml_tensor * Aqk = ggml_mul_mat(ctx0, decay_q_i, k_j); - Akk = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Akk, chunk_size, chunk_size, n_chunks, HB))); - Aqk = ggml_cont(ctx0, ggml_transpose(ctx0, ggml_reshape_4d(ctx0, Aqk, chunk_size, chunk_size, n_chunks, HB))); - cb(Akk, "Akk", il); - cb(Aqk, "Aqk", il); - - Akk = ggml_mul(ctx0, Akk, beta); - Akk = ggml_neg(ctx0, ggml_mul(ctx0, Akk, causal_mask)); - cb(Akk, "attn_pre_solve", il); - - Aqk = ggml_mul(ctx0, Aqk, diag_mask); - Aqk = ggml_scale(ctx0, Aqk, scale); // scale q - cb(Aqk, "Aqk_masked", il); - - // for i in range(1, chunk_size): - // row = attn[..., i, :i].clone() - // sub = attn[..., :i, :i].clone() - // attn[..., i, :i] = row + (row.unsqueeze(-1) * sub).sum(-2) - // attn = attn + torch.eye(chunk_size, dtype=attn.dtype, device=attn.device) - // - // We reduce this to a linear triangular solve: AX = B, where B = attn, A = I - tril(A) - ggml_tensor * attn_lower = ggml_mul(ctx0, Akk, causal_mask); - ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower); - - ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, Akk, true, true, false); - Akk = ggml_mul(ctx0, lin_solve, causal_mask); - Akk = ggml_add(ctx0, Akk, identity); - - cb(Akk, "attn_solved", il); - - // switch back for downstream - gk_cumsum = ggml_cont_4d(ctx0, ggml_permute(ctx0, gk_cumsum, 1, 0, 2, 3), S_k, chunk_size, n_chunks, HB); - ggml_tensor * gkexp = ggml_exp(ctx0, gk_cumsum); - cb(gk_cumsum, "gk_cumsum", il); - - // u = (A*beta[..., None, :]) @ v aka U_[t] - ggml_tensor * vb = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), Akk); - - ggml_tensor * kbeta_gkexp = ggml_mul(ctx0, k_beta, gkexp); - cb(kbeta_gkexp, "kbeta_gkexp", il); - - ggml_tensor * k_cumdecay = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gkexp)), Akk); - cb(k_cumdecay, "k_cumdecay", il); - - ggml_tensor * core_attn_out = nullptr; - ggml_tensor * new_state = ggml_dup(ctx0, state); - - cb(new_state, "new_state", il); - - for (int64_t chunk = 0; chunk < n_chunks; chunk++) { -// extract one chunk worth of data - auto chunkify = [=](ggml_tensor * t) { - return ggml_cont(ctx0, ggml_view_4d(ctx0, t, t->ne[0], chunk_size, 1, t->ne[3], - t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk)); - }; - auto chunkify_A = [=](ggml_tensor * t) { - return ggml_cont(ctx0, ggml_view_4d(ctx0, t, chunk_size, chunk_size, 1, t->ne[3], - t->nb[1], t->nb[2], t->nb[3], t->nb[2] * chunk)); - }; - - -// k [S,BT,NT,H*B] => k_chunk [S,BT,1,H*B] - ggml_tensor * k_chunk = chunkify(k); - ggml_tensor * q_chunk = chunkify(q); - ggml_tensor * vb_chunk = chunkify(vb); - -// gk_cumsum [S,BT,NT,H*B] => gk_cs_chunk [S,BT,1,H*B] - ggml_tensor * gk_cs_chunk = chunkify(gk_cumsum); - ggml_tensor * k_cumdecay_chunk = chunkify(k_cumdecay); - ggml_tensor * gkexp_chunk = ggml_exp(ctx0, gk_cs_chunk); - ggml_tensor * Aqk_chunk = chunkify_A(Aqk); - - ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs); - - // new_state [S,S,1,H*B] k_cumdecay_chunk [S,BT,1,H*B] - // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state or W_[t] @ S_[t] - ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk); - - // v_new = v_i - v_prime or U_[t] - W_[t]*S_[t] - ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, vb_chunk, v_prime), v_prime); - ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new)); - - // q_chunk [S,BT,1,H*B] gkexp_chunk [S,BT,1,H*B] - // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state - // or Gamma_[t]*Q_]t] @ S - ggml_tensor * q_gk_exp = ggml_mul(ctx0, q_chunk, gkexp_chunk); - ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_gk_exp); - attn_inter = ggml_scale(ctx0, attn_inter, scale); // scale q - - // v_new_t [S,BT,1,H*B] Aqk [BT,BT,1,H*B] - // core_attn_out[:, :, i] = attn_inter + attn @ v_new or A' @ (U_[t] - W_[t]*S_[t]) - ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, Aqk_chunk); - - // o[:, :, i] = (q_i * g_i.exp()) @ S + A @ v_i - ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn); - - core_attn_out = core_attn_out == nullptr ? core_attn_out_chunk : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 1); - - ggml_tensor * gk_cum_last = - ggml_cont(ctx0, ggml_view_4d(ctx0, gk_cs_chunk, gk_cs_chunk->ne[0], 1, gk_cs_chunk->ne[2], gk_cs_chunk->ne[3], - gk_cs_chunk->nb[1], gk_cs_chunk->nb[2], gk_cs_chunk->nb[3], - gk_cs_chunk->nb[1] * (gk_cs_chunk->ne[1] - 1))); - - ggml_tensor * gkexp_last = ggml_exp(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, gk_cum_last))); - - ggml_tensor * gk_diff = ggml_neg(ctx0, ggml_sub(ctx0, gk_cs_chunk, gk_cum_last)); - - ggml_tensor * gk_diff_exp = ggml_exp(ctx0, gk_diff); - - ggml_tensor * key_gkdiff = ggml_mul(ctx0, k_chunk, gk_diff_exp); - - // rearrange((g_i[:,:,-1:] - g_i).exp()*k_i, 'b h c k -> b h k c') @ (U_[t] - W_[t] @ S) - ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, ggml_cont(ctx0, ggml_transpose(ctx0, key_gkdiff))); - - new_state = ggml_add(ctx0, - ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gkexp_last, gkexp_last->ne[0], gkexp_last->ne[1], H_v, n_seqs)), - ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs)); - } - - core_attn_out = ggml_cont_4d(ctx0, core_attn_out, S_v, chunk_size * n_chunks, H_v, n_seqs); - - // truncate padded tokens - ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, - S_v, n_tokens, H_v, n_seqs, - ggml_row_size(core_attn_out->type, S_v), - ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks), - ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0); - output_tokens = ggml_cont(ctx0, output_tokens); - // permute back to (S_v, H_v, n_tokens, n_seqs) - output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3); - output_tokens = ggml_cont(ctx0, output_tokens); - - cb(new_state, "output_state", il); - - return {output_tokens, new_state}; -} - -std::pair llm_build_kimi_linear::build_kda_autoregressive( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * gk, - ggml_tensor * beta, - ggml_tensor * state, - int il) { - GGML_ASSERT(ggml_is_contiguous(v)); - GGML_ASSERT(ggml_is_contiguous(gk)); - - const int64_t S_k = q->ne[0]; - const int64_t H_k = q->ne[1]; - const int64_t n_tokens = q->ne[2]; - const int64_t n_seqs = q->ne[3]; - - const int64_t S_v = v->ne[0]; - const int64_t H_v = v->ne[1]; - - GGML_ASSERT(n_tokens == 1); - GGML_ASSERT(v->ne[2] == n_tokens); - GGML_ASSERT(k->ne[2] == n_tokens); - GGML_ASSERT(gk->ne[0] == S_k && gk->ne[1] == H_k && gk->ne[2] == n_tokens && gk->ne[3] == n_seqs); - GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); - GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_k && state->ne[2] == H_v && state->ne[3] == n_seqs); - - GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); - - GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case - - const float eps_norm = hparams.f_norm_rms_eps; - - q = ggml_l2_norm(ctx0, q, eps_norm); - k = ggml_l2_norm(ctx0, k, eps_norm); - - const float scale = 1.0f / sqrtf(S_v); - - q = ggml_scale(ctx0, q, scale); - beta = ggml_sigmoid(ctx0, beta); - - cb(q, "q_in", il); - cb(k, "k_in", il); - cb(v, "v_in", il); - cb(beta, "beta_in", il); - cb(gk, "gk_in", il); - -// g [H,1,B,1] g_t [1,H,B,1] => [1,1,H,B] -// gk [S,H,1,B] => [S,1,H,B] gk_t [1,S,H,B] -// beta [H,1,1,B] beta_t [1,H,1,B] => [1,1,H,B] - gk = ggml_reshape_4d(ctx0, gk, S_k, 1, H_k, n_seqs); - ggml_tensor * gk_t = ggml_cont(ctx0, ggml_transpose(ctx0, gk)); - ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs); - - // Apply exponential to gk_t - gk_t = ggml_exp(ctx0, gk_t); - // Apply the gated delta rule for the single timestep - // last_recurrent_state = last_recurrent_state * gk_t - // S = S * g_i[..., None].exp() - state = ggml_mul(ctx0, state, gk_t); - - ggml_tensor * state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state)); - -// state [S,S,H,B] k [S,1,H,B] k_state [S_v,1,H,B] - k = ggml_reshape_4d(ctx0, k, S_k, 1, H_k, n_seqs); - ggml_tensor * k_state = ggml_mul_mat(ctx0, state_t, k); - - // v_i - (k_i[..., None] * S).sum(-2) - v = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs); - ggml_tensor * v_diff = ggml_sub(ctx0, v, k_state); - - // b_i[..., None] * k_i - ggml_tensor * k_beta = ggml_mul(ctx0, k, beta_t); - - // S = S + torch.einsum('b h k, b h v -> b h k v', b_i[..., None] * k_i, v_i - (k_i[..., None] * S).sum(-2)) - // v_diff_t [1,S_v,H,B] k_beta_t [1,S_k,H,B] state [S_v,S_k,H,B] - state = ggml_add(ctx0, state, ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_diff)), ggml_cont(ctx0, ggml_transpose(ctx0, k_beta)))); - - q = ggml_reshape_4d(ctx0, q, S_k, 1, H_k, n_seqs); - state_t = ggml_cont(ctx0, ggml_transpose(ctx0, state)); - ggml_tensor * core_attn_out = ggml_mul_mat(ctx0, state_t, q); - // core_attn_out should be [S_v, 1, H_v, n_seqs] after this - cb(core_attn_out, "output_tokens", il); - cb(state, "new_state", il); - - return {core_attn_out, state}; -} - diff --git a/src/models/lfm2.cpp b/src/models/lfm2.cpp index 7f805d7879..cf01ad6255 100644 --- a/src/models/lfm2.cpp +++ b/src/models/lfm2.cpp @@ -1,18 +1,149 @@ #include "models.h" +#include "../llama-memory-hybrid-iswa.h" #include "../llama-memory-hybrid.h" +template +llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params & params) : + llm_graph_context(params) { + using inp_hybrid_type = std::conditional_t; + using inp_attn_type = std::conditional_t; + using mem_hybrid_ctx = std::conditional_t; -llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params & params) : - llm_graph_context(params), - model(model) { + // lambda helpers for readability + auto build_dense_feed_forward = [&model, this](ggml_tensor * cur, int il) -> ggml_tensor * { + GGML_ASSERT(!model.layers[il].ffn_up_b); + GGML_ASSERT(!model.layers[il].ffn_gate_b); + GGML_ASSERT(!model.layers[il].ffn_down_b); + return build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + }; + auto build_moe_feed_forward = [&model, this](ggml_tensor * cur, int il) -> ggml_tensor * { + return build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, true, false, 0.0, + static_cast(hparams.expert_gating_func), il); + }; + auto build_attn_block = [&model, this](ggml_tensor * cur, + ggml_tensor * inp_pos, + inp_attn_type * inp_attn, + int il) -> ggml_tensor * { + GGML_ASSERT(hparams.n_embd_v_gqa(il) == hparams.n_embd_k_gqa(il)); + const auto n_embd_head = hparams.n_embd_head_v; + const auto n_head_kv = hparams.n_head_kv(il); + + auto * q = build_lora_mm(model.layers[il].wq, cur); + cb(q, "model.layers.{}.self_attn.q_proj", il); + auto * k = build_lora_mm(model.layers[il].wk, cur); + cb(k, "model.layers.{}.self_attn.k_proj", il); + auto * v = build_lora_mm(model.layers[il].wv, cur); + cb(v, "model.layers.{}.self_attn.v_proj", il); + + q = ggml_reshape_3d(ctx0, q, n_embd_head, n_head, n_tokens); + k = ggml_reshape_3d(ctx0, k, n_embd_head, n_head_kv, n_tokens); + v = ggml_reshape_3d(ctx0, v, n_embd_head, n_head_kv, n_tokens); + + // qk norm + q = build_norm(q, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(q, "model.layers.{}.self_attn.q_layernorm", il); + k = build_norm(k, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(k, "model.layers.{}.self_attn.k_layernorm", il); + + // RoPE + q = ggml_rope_ext(ctx0, q, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, + attn_factor, beta_fast, beta_slow); + k = ggml_rope_ext(ctx0, k, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, + attn_factor, beta_fast, beta_slow); + + cur = build_attn(inp_attn, + model.layers[il].wo, NULL, + q, k, v, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); + + cb(cur, "model.layers.{}.self_attn.out_proj", il); + + return cur; + }; + auto build_shortconv_block = [&model, this](ggml_tensor * cur, + llm_graph_input_rs * inp_recr, + int il) -> ggml_tensor * { + const auto * mctx_cur = static_cast(mctx)->get_recr(); + const uint32_t kv_head = mctx_cur->get_head(); + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + const int64_t n_seqs = ubatch.n_seqs; + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs()); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + GGML_ASSERT(hparams.n_shortconv_l_cache > 1); + const uint32_t d_conv = hparams.n_shortconv_l_cache - 1; + + // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} + cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); + + auto * bcx = build_lora_mm(model.layers[il].shortconv.in_proj, cur); + cb(bcx, "model.layers.{}.conv.in_proj", il); + + constexpr auto n_chunks = 3; + GGML_ASSERT(bcx->ne[0] % n_chunks == 0); + const auto chunk_size = bcx->ne[0] / n_chunks; + auto * b = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], + 0 * chunk_size * ggml_element_size(bcx)); + auto * c = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], + 1 * chunk_size * ggml_element_size(bcx)); + auto * x = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], + 2 * chunk_size * ggml_element_size(bcx)); + + auto * bx = ggml_transpose(ctx0, ggml_mul(ctx0, b, x)); + + // read conv state + auto * conv_state = mctx_cur->get_r_l(il); + auto * conv_rs = build_rs(inp_recr, conv_state, hparams.n_embd_r(), n_seqs); + auto * conv = ggml_reshape_3d(ctx0, conv_rs, d_conv, hparams.n_embd, n_seqs); + + bx = ggml_concat(ctx0, conv, bx, 0); + GGML_ASSERT(bx->ne[0] > conv->ne[0]); + + // last d_conv columns is a new conv state + auto * new_conv = ggml_view_3d(ctx0, bx, conv->ne[0], bx->ne[1], bx->ne[2], bx->nb[1], bx->nb[2], + (bx->ne[0] - conv->ne[0]) * ggml_element_size(bx)); + GGML_ASSERT(ggml_are_same_shape(conv, new_conv)); + + // write new conv conv state + ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_conv, + ggml_view_1d(ctx0, conv_state, ggml_nelements(new_conv), + kv_head * d_conv * n_embd * ggml_element_size(new_conv)))); + + auto * conv_kernel = model.layers[il].shortconv.conv; + auto * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel); + cb(conv_out, "model.layers.{}.conv.conv", il); + + auto * y = ggml_mul(ctx0, c, conv_out); + y = build_lora_mm(model.layers[il].shortconv.out_proj, y); + cb(y, "model.layers.{}.conv.out_proj", il); + // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} + y = ggml_reshape_2d(ctx0, y, y->ne[0], n_seq_tokens * n_seqs); + + return y; + }; + + // actual graph construction starts here ggml_tensor * cur = build_inp_embd(model.tok_embd); cb(cur, "model.embed_tokens", -1); ggml_build_forward_expand(gf, cur); + inp_hybrid_type * inp_hybrid = nullptr; + if constexpr (iswa) { + inp_hybrid = build_inp_mem_hybrid_iswa(); + } else { + inp_hybrid = build_inp_mem_hybrid(); + } + ggml_tensor * inp_pos = build_inp_pos(); - auto * inp_hybrid = build_inp_mem_hybrid(); ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { @@ -54,122 +185,6 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params ggml_build_forward_expand(gf, cur); } -ggml_tensor * llm_build_lfm2::build_moe_feed_forward(ggml_tensor * cur, int il) const { - return build_moe_ffn(cur, - model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, - model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, - model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, true, false, 0.0, - static_cast(hparams.expert_gating_func), il); -} - -ggml_tensor * llm_build_lfm2::build_dense_feed_forward(ggml_tensor * cur, int il) const { - GGML_ASSERT(!model.layers[il].ffn_up_b); - GGML_ASSERT(!model.layers[il].ffn_gate_b); - GGML_ASSERT(!model.layers[il].ffn_down_b); - return build_ffn(cur, - model.layers[il].ffn_up, NULL, NULL, - model.layers[il].ffn_gate, NULL, NULL, - model.layers[il].ffn_down, NULL, NULL, - NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); -} - -ggml_tensor * llm_build_lfm2::build_attn_block(ggml_tensor * cur, - ggml_tensor * inp_pos, - llm_graph_input_attn_kv * inp_attn, - int il) const { - GGML_ASSERT(hparams.n_embd_v_gqa(il) == hparams.n_embd_k_gqa(il)); - const auto n_embd_head = hparams.n_embd_head_v; - const auto n_head_kv = hparams.n_head_kv(il); - - auto * q = build_lora_mm(model.layers[il].wq, cur); - cb(q, "model.layers.{}.self_attn.q_proj", il); - auto * k = build_lora_mm(model.layers[il].wk, cur); - cb(k, "model.layers.{}.self_attn.k_proj", il); - auto * v = build_lora_mm(model.layers[il].wv, cur); - cb(v, "model.layers.{}.self_attn.v_proj", il); - - q = ggml_reshape_3d(ctx0, q, n_embd_head, n_head, n_tokens); - k = ggml_reshape_3d(ctx0, k, n_embd_head, n_head_kv, n_tokens); - v = ggml_reshape_3d(ctx0, v, n_embd_head, n_head_kv, n_tokens); - - // qk norm - q = build_norm(q, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); - cb(q, "model.layers.{}.self_attn.q_layernorm", il); - k = build_norm(k, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); - cb(k, "model.layers.{}.self_attn.k_layernorm", il); - - // RoPE - q = ggml_rope_ext(ctx0, q, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, - attn_factor, beta_fast, beta_slow); - k = ggml_rope_ext(ctx0, k, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, ext_factor, - attn_factor, beta_fast, beta_slow); - - cur = build_attn(inp_attn, - model.layers[il].wo, NULL, - q, k, v, nullptr, nullptr, nullptr, 1.0f / sqrtf(float(n_embd_head)), il); - - cb(cur, "model.layers.{}.self_attn.out_proj", il); - - return cur; -} - -ggml_tensor * llm_build_lfm2::build_shortconv_block(ggml_tensor * cur, llm_graph_input_rs * inp_recr, int il) { - const auto * mctx_cur = static_cast(mctx)->get_recr(); - const uint32_t kv_head = mctx_cur->get_head(); - const int64_t n_seq_tokens = ubatch.n_seq_tokens; - const int64_t n_seqs = ubatch.n_seqs; - GGML_ASSERT(n_seqs != 0); - GGML_ASSERT(ubatch.equal_seqs()); - GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); - - GGML_ASSERT(hparams.n_shortconv_l_cache > 1); - const uint32_t d_conv = hparams.n_shortconv_l_cache - 1; - - // {n_embd, n_tokens} => {n_embd, n_seq_tokens, n_seqs} - cur = ggml_reshape_3d(ctx0, cur, cur->ne[0], n_seq_tokens, n_seqs); - - auto * bcx = build_lora_mm(model.layers[il].shortconv.in_proj, cur); - cb(bcx, "model.layers.{}.conv.in_proj", il); - - constexpr auto n_chunks = 3; - GGML_ASSERT(bcx->ne[0] % n_chunks == 0); - const auto chunk_size = bcx->ne[0] / n_chunks; - auto * b = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], - 0 * chunk_size * ggml_element_size(bcx)); - auto * c = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], - 1 * chunk_size * ggml_element_size(bcx)); - auto * x = ggml_view_3d(ctx0, bcx, chunk_size, bcx->ne[1], bcx->ne[2], bcx->nb[1], bcx->nb[2], - 2 * chunk_size * ggml_element_size(bcx)); - - auto * bx = ggml_transpose(ctx0, ggml_mul(ctx0, b, x)); - - // read conv state - auto * conv_state = mctx_cur->get_r_l(il); - auto * conv_rs = build_rs(inp_recr, conv_state, hparams.n_embd_r(), n_seqs); - auto * conv = ggml_reshape_3d(ctx0, conv_rs, d_conv, hparams.n_embd, n_seqs); - - bx = ggml_concat(ctx0, conv, bx, 0); - GGML_ASSERT(bx->ne[0] > conv->ne[0]); - - // last d_conv columns is a new conv state - auto * new_conv = ggml_view_3d(ctx0, bx, conv->ne[0], bx->ne[1], bx->ne[2], bx->nb[1], bx->nb[2], - (bx->ne[0] - conv->ne[0]) * ggml_element_size(bx)); - GGML_ASSERT(ggml_are_same_shape(conv, new_conv)); - - // write new conv conv state - ggml_build_forward_expand(gf, ggml_cpy(ctx0, new_conv, - ggml_view_1d(ctx0, conv_state, ggml_nelements(new_conv), - kv_head * d_conv * n_embd * ggml_element_size(new_conv)))); - - auto * conv_kernel = model.layers[il].shortconv.conv; - auto * conv_out = ggml_ssm_conv(ctx0, bx, conv_kernel); - cb(conv_out, "model.layers.{}.conv.conv", il); - - auto * y = ggml_mul(ctx0, c, conv_out); - y = build_lora_mm(model.layers[il].shortconv.out_proj, y); - cb(y, "model.layers.{}.conv.out_proj", il); - // {n_embd, n_seq_tokens, n_seqs} => {n_embd, n_tokens} - y = ggml_reshape_2d(ctx0, y, y->ne[0], n_seq_tokens * n_seqs); - - return y; -} +// Explicit template instantiations +template struct llm_build_lfm2; +template struct llm_build_lfm2; diff --git a/src/models/graph-context-mamba.cpp b/src/models/mamba-base.cpp similarity index 97% rename from src/models/graph-context-mamba.cpp rename to src/models/mamba-base.cpp index b9a363b32b..aaac9487df 100644 --- a/src/models/graph-context-mamba.cpp +++ b/src/models/mamba-base.cpp @@ -1,8 +1,10 @@ #include "models.h" -llm_graph_context_mamba::llm_graph_context_mamba(const llm_graph_params & params) : llm_graph_context(params) {} +#include "llama-memory-recurrent.h" -ggml_tensor * llm_graph_context_mamba::build_mamba_layer(llm_graph_input_rs * inp, +llm_build_mamba_base::llm_build_mamba_base(const llm_graph_params & params) : llm_graph_context(params) {} + +ggml_tensor * llm_build_mamba_base::build_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, @@ -143,7 +145,7 @@ ggml_tensor * llm_graph_context_mamba::build_mamba_layer(llm_graph_input_rs * in return cur; } -ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * inp, +ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, diff --git a/src/models/mamba.cpp b/src/models/mamba.cpp index 46819613c2..55fd2e055c 100644 --- a/src/models/mamba.cpp +++ b/src/models/mamba.cpp @@ -1,7 +1,6 @@ #include "models.h" - -llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) { +llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) { ggml_tensor * cur; ggml_tensor * inpL; diff --git a/src/models/models.h b/src/models/models.h index b01971faf5..b6bed17e93 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -1,23 +1,51 @@ #pragma once -#include "../llama-model.h" -#include "../llama-graph.h" +#include "llama-model.h" +#include "llama-graph.h" -// TODO: remove in follow-up PR - move to .cpp files -#include "../llama-memory-recurrent.h" +// note: almost all graphs require atleast sqrtf, so include cmath globally #include -struct llm_graph_context_mamba : public llm_graph_context { - llm_graph_context_mamba(const llm_graph_params & params); +// +// base classes +// - virtual ~llm_graph_context_mamba() = default; +struct llm_build_mamba_base : public llm_graph_context { + llm_build_mamba_base(const llm_graph_params & params); + + virtual ~llm_build_mamba_base() = default; ggml_tensor * build_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il); ggml_tensor * build_mamba2_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il) const; }; -// Base class for RWKV-related models +struct llm_build_delta_net_base : public llm_graph_context { + llm_build_delta_net_base(const llm_graph_params & params); + + virtual ~llm_build_delta_net_base() = default; + + // returns pair of output and new state + std::pair build_delta_net_chunking( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il); + + // returns pair of output and new state + std::pair build_delta_net_autoregressive( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il); +}; + struct llm_build_rwkv6_base : public llm_graph_context { const llama_model & model; @@ -58,6 +86,10 @@ struct llm_build_rwkv7_base : public llm_graph_context { int il) const; }; +// +// models +// + struct llm_build_afmoe : public llm_graph_context { llm_build_afmoe(const llama_model & model, const llm_graph_params & params); }; @@ -175,7 +207,7 @@ struct llm_build_falcon : public llm_graph_context { llm_build_falcon(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_falcon_h1 : public llm_graph_context_mamba { +struct llm_build_falcon_h1 : public llm_build_mamba_base { llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params); }; @@ -262,7 +294,7 @@ private: const int il); }; -struct llm_build_granite_hybrid : public llm_graph_context_mamba { +struct llm_build_granite_hybrid : public llm_build_mamba_base { llm_build_granite_hybrid(const llama_model & model, const llm_graph_params & params); ggml_tensor * build_layer_ffn(ggml_tensor * cur, ggml_tensor * inpSA, const llama_model & model, const int il); ggml_tensor * build_attention_layer(ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv * inp_attn, @@ -293,11 +325,15 @@ struct llm_build_jais : public llm_graph_context { llm_build_jais(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_jamba : public llm_graph_context_mamba { +struct llm_build_jais2 : public llm_graph_context { + llm_build_jais2(const llama_model & model, const llm_graph_params & params); +}; + +struct llm_build_jamba : public llm_build_mamba_base { llm_build_jamba(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_kimi_linear : public llm_graph_context_mamba { +struct llm_build_kimi_linear : public llm_build_delta_net_base { llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params); std::pair build_kda_autoregressive( @@ -324,15 +360,9 @@ struct llm_build_kimi_linear : public llm_graph_context_mamba { const llama_model & model; }; +template struct llm_build_lfm2 : public llm_graph_context { - const llama_model & model; - llm_build_lfm2(const llama_model & model, const llm_graph_params & params); - ggml_tensor * build_moe_feed_forward(ggml_tensor * cur, int il) const; - ggml_tensor * build_dense_feed_forward(ggml_tensor * cur, int il) const; - ggml_tensor * build_attn_block(ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv * inp_attn, int il) const; - ggml_tensor * build_shortconv_block(ggml_tensor * cur, llm_graph_input_rs * inp_recr, int il); - }; struct llm_build_llada : public llm_graph_context { @@ -356,7 +386,7 @@ struct llm_build_maincoder : public llm_graph_context { llm_build_maincoder(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_mamba : public llm_graph_context_mamba { +struct llm_build_mamba : public llm_build_mamba_base { llm_build_mamba(const llama_model & model, const llm_graph_params & params); }; @@ -388,11 +418,11 @@ struct llm_build_nemotron : public llm_graph_context { llm_build_nemotron(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_nemotron_h : public llm_graph_context_mamba { +struct llm_build_nemotron_h : public llm_build_mamba_base { llm_build_nemotron_h(const llama_model & model, const llm_graph_params & params); - ggml_tensor * build_ffn_layer(ggml_tensor * cur, const llama_model & model, const int il); + ggml_tensor * build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il); ggml_tensor * build_attention_layer(ggml_tensor * cur, llm_graph_input_attn_kv * inp_attn, - const llama_model & model, const int64_t n_embd_head, const int il); + const llama_model & model, int64_t n_embd_head, int il); }; struct llm_build_neo_bert : public llm_graph_context { @@ -437,7 +467,7 @@ struct llm_build_phi3 : public llm_graph_context { llm_build_phi3(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_plamo2 : public llm_graph_context_mamba { +struct llm_build_plamo2 : public llm_build_mamba_base { llm_build_plamo2(const llama_model & model, const llm_graph_params & params); private: ggml_tensor * build_plamo2_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il); @@ -485,7 +515,8 @@ struct llm_build_qwen3vl : public llm_graph_context { struct llm_build_qwen3vlmoe : public llm_graph_context { llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_qwen3next : public llm_graph_context_mamba { + +struct llm_build_qwen3next : public llm_build_delta_net_base { llm_build_qwen3next(const llama_model & model, const llm_graph_params & params); private: ggml_tensor * build_layer_attn( @@ -497,37 +528,78 @@ private: ggml_tensor * build_layer_attn_linear( llm_graph_input_rs * inp, ggml_tensor * cur, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, int il); ggml_tensor * build_layer_ffn( ggml_tensor * cur, int il); - // returns pair of output and new state - std::pair build_delta_net_chunking( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, + ggml_tensor * build_norm_gated( + ggml_tensor * input, + ggml_tensor * weights, + ggml_tensor * gate, + int layer); + + // returns pair of qkv, z + std::pair build_qkvz( + ggml_tensor * input, int il); - // returns pair of output and new state - std::pair build_delta_net_autoregressive( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - int il); + const llama_model & model; +}; + +struct llm_build_qwen35 : public llm_build_delta_net_base { + llm_build_qwen35(const llama_model & model, const llm_graph_params & params); +private: + ggml_tensor * build_layer_attn( + llm_graph_input_attn_kv * inp_attn, + ggml_tensor * cur, + ggml_tensor * inp_pos, + int * sections, + int il); + + ggml_tensor * build_layer_attn_linear( + llm_graph_input_rs * inp, + ggml_tensor * cur, + int il); + + ggml_tensor * build_layer_ffn( + ggml_tensor * cur, + int il); + + ggml_tensor * build_norm_gated( + ggml_tensor * input, + ggml_tensor * weights, + ggml_tensor * gate, + int layer); + + // returns pair of qkv, z + std::pair build_qkvz( + ggml_tensor * input, + int il); + + const llama_model & model; +}; + +// TODO: derive llm_build_delta_net_base instead +struct llm_build_qwen35moe : public llm_build_delta_net_base { + llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params); +private: + ggml_tensor * build_layer_attn( + llm_graph_input_attn_kv * inp_attn, + ggml_tensor * cur, + ggml_tensor * inp_pos, + int * sections, + int il); + + ggml_tensor * build_layer_attn_linear( + llm_graph_input_rs * inp, + ggml_tensor * cur, + int il); + + ggml_tensor * build_layer_ffn( + ggml_tensor * cur, + int il); ggml_tensor * build_norm_gated( ggml_tensor * input, diff --git a/src/models/modern-bert.cpp b/src/models/modern-bert.cpp index bb12ed819f..32066c712b 100644 --- a/src/models/modern-bert.cpp +++ b/src/models/modern-bert.cpp @@ -104,13 +104,6 @@ llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const ll LLM_NORM, -1); cb(cur, "final_norm_out", -1); - if (hparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { - // extracting cls token - cur = ggml_view_1d(ctx0, cur, hparams.n_embd, 0); - cb(cur, "cls_pooled_embd", -1); - } - - cb(cur, "res_embd", -1); res->t_embd = cur; ggml_build_forward_expand(gf, cur); } diff --git a/src/models/nemotron-h.cpp b/src/models/nemotron-h.cpp index 079c730ac2..d61d62a8c9 100644 --- a/src/models/nemotron-h.cpp +++ b/src/models/nemotron-h.cpp @@ -1,9 +1,7 @@ #include "models.h" - - llm_build_nemotron_h::llm_build_nemotron_h(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params) { + llm_build_mamba_base(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -65,8 +63,8 @@ llm_build_nemotron_h::llm_build_nemotron_h(const llama_model & model, const llm_ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor * cur, llm_graph_input_attn_kv * inp_attn, const llama_model & model, - const int64_t n_embd_head, - const int il) { + int64_t n_embd_head, + int il) { // compute Q and K ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); cb(Qcur, "Qcur", il); @@ -106,7 +104,7 @@ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor * return cur; } -ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, const int il) { +ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il) { if (model.layers[il].ffn_gate_inp == nullptr) { cur = build_ffn(cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, diff --git a/src/models/plamo2.cpp b/src/models/plamo2.cpp index 31115a08f9..3af236843b 100644 --- a/src/models/plamo2.cpp +++ b/src/models/plamo2.cpp @@ -1,7 +1,9 @@ #include "models.h" +#include "llama-memory-recurrent.h" + llm_build_plamo2::llm_build_plamo2(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params) { + llm_build_mamba_base(params) { ggml_tensor * cur; ggml_tensor * inpL; diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp new file mode 100644 index 0000000000..56eefd7de2 --- /dev/null +++ b/src/models/qwen35.cpp @@ -0,0 +1,385 @@ +#include "models.h" + +#include "llama-memory-recurrent.h" + +llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_params & params) : + llm_build_delta_net_base(params), model(model) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + cb(inpL, "model.input_embed", -1); + + auto * inp = build_inp_mem_hybrid(); + + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // Determine layer type and build appropriate attention mechanism + if (hparams.is_recurrent(il)) { + // Linear attention layer (gated delta net) + cur = build_layer_attn_linear(inp->get_recr(), cur, il); + } else { + // Full attention layer + cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // Residual connection + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "attn_residual", il); + + // Save the tensor before post-attention norm for residual connection + ggml_tensor * ffn_residual = cur; + + // Post-attention norm + ggml_tensor * attn_post_norm = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(attn_post_norm, "attn_post_norm", il); + + // Dense FFN layer - without residual connection + cur = build_layer_ffn(attn_post_norm, il); + cb(cur, "ffn_out", il); + + // Residual connection for FFN - add to the tensor from before post_attention_layernorm + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "post_ffn", il); + + // Input for next layer + inpL = cur; + } + cur = inpL; + + // Final norm + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // LM head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} + +std::pair llm_build_qwen35::build_qkvz( + ggml_tensor * input, + int il) { + const int64_t n_seqs = ubatch.n_seqs; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input); + qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs); + cb(qkv_mixed, "linear_attn_qkv_mixed", il); + + ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input); + cb(z, "z", il); + + return { qkv_mixed, z }; +} + +ggml_tensor * llm_build_qwen35::build_norm_gated( + ggml_tensor * input, + ggml_tensor * weights, + ggml_tensor * gate, + int layer) { + ggml_tensor * normalized = build_norm(input, weights, nullptr, LLM_NORM_RMS, layer); + ggml_tensor * gated_silu = ggml_silu(ctx0, gate); + + return ggml_mul(ctx0, normalized, gated_silu); +} + +ggml_tensor * llm_build_qwen35::build_layer_attn( + llm_graph_input_attn_kv * inp, + ggml_tensor * cur, + ggml_tensor * inp_pos, + int * sections, + int il) { + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention + + // Qwen3Next uses a single Q projection that outputs query + gate + ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur); // [ (n_embd_head * 2) * n_head, n_tokens ] + cb(Qcur_full, "Qcur_full", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, 0); + cb(Qcur, "Qcur_reshaped", il); + + // Apply Q normalization + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + // Apply K normalization + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + ggml_element_size(Qcur_full) * n_embd_head); + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "gate_reshaped", il); + + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // Apply MRoPE + Qcur = ggml_rope_multi( + ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_multi( + ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // Attention computation + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp, + nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_pregate", il); + + ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate); + cb(gate_sigmoid, "gate_sigmoid", il); + + cur = ggml_mul(ctx0, cur, gate_sigmoid); + cb(cur, "attn_gated", il); + + cur = build_lora_mm(model.layers[il].wo, cur); + cb(cur, "attn_output", il); + + return cur; +} + +ggml_tensor * llm_build_qwen35::build_layer_attn_linear( + llm_graph_input_rs * inp, + ggml_tensor * cur, + int il) { + const auto * mctx_cur = inp->mctx; + + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t n_seqs = ubatch.n_seqs; + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t num_k_heads = hparams.ssm_n_group; + const int64_t num_v_heads = hparams.ssm_dt_rank; + const int64_t head_v_dim = d_inner / num_v_heads; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + const auto kv_head = mctx_cur->get_head(); + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs()); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + // Input projections + auto qkvz = build_qkvz(cur, il); + ggml_tensor * qkv_mixed = qkvz.first; + ggml_tensor * z = qkvz.second; + + ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur); + beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs); + cb(beta, "beta", il); + + beta = ggml_sigmoid(ctx0, beta); + + ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur); + alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); + cb(alpha, "alpha", il); + + ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); + ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased); + cb(alpha_softplus, "a_softplus", il); + + ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus + cb(gate, "gate", il); + + gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); + + // Get convolution states from cache + ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); + ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); + + // Build the convolution states tensor + ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); + cb(conv_states, "conv_states", il); + + // Calculate convolution kernel size + ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; + const int64_t conv_kernel_size = conv_kernel->ne[0]; + const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; + + conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); + cb(conv_states, "conv_states_reshaped", il); + + qkv_mixed = ggml_transpose(ctx0, qkv_mixed); + cb(qkv_mixed, "qkv_mixed_transposed", il); + + ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); + cb(conv_input, "conv_input", il); + + // Update convolution state cache + // Extract the last (conv_kernel_size - 1) states from conv_input + ggml_tensor * last_conv_states = + ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], + conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); + cb(last_conv_states, "last_conv_states", il); + + ggml_tensor * state_update_target = + ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs, + kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); + cb(state_update_target, "state_update_target", il); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + cb(conv_states_all, "conv_states_updated", il); + + ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); + state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); + cb(state, "state_predelta", il); + + ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel); + cb(conv_output_proper, "conv_output_raw", il); + + ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper); + cb(conv_output_silu, "conv_output_silu", il); + + ggml_tensor * conv_qkv_mix = conv_output_silu; + + // Calculate the total conv dimension + int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads; + int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim); + + // Extract the convolved Q, K, V from conv_output + ggml_tensor * q_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_k_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + 0); + + ggml_tensor * k_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_k_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); + + ggml_tensor * v_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_v_dim, num_v_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_v_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + ggml_row_size(conv_qkv_mix->type, 2 * head_k_dim * num_k_heads)); + + cb(q_conv, "q_conv", il); + cb(k_conv, "k_conv", il); + cb(v_conv, "v_conv", il); + + const float eps_norm = hparams.f_norm_rms_eps; + + q_conv = ggml_l2_norm(ctx0, q_conv, eps_norm); + k_conv = ggml_l2_norm(ctx0, k_conv, eps_norm); + + //q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + //k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); + + // if head keys and value keys are different, repeat to force tensors into matching shapes + if (num_k_heads != num_v_heads) { + GGML_ASSERT(num_v_heads % num_k_heads == 0); + // TODO: try to avoid these explicit repeats by utilizing op broadcast + q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); + k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); + } + + cb(q_conv, "q_conv_predelta", il); + cb(k_conv, "k_conv_predelta", il); + cb(v_conv, "v_conv_predelta", il); + + // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens + std::pair attn_out; // pair of (output, new_state) + if (n_seq_tokens == 1) { + attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); + } else { + attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, il); + } + ggml_tensor * output = attn_out.first; + ggml_tensor * new_state = attn_out.second; + cb(output, "attn_output", il); + cb(new_state, "new_state", il); + + // Update the recurrent states + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, new_state, + ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, + kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + + // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] + ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); + + // Apply gated normalization: self.norm(core_attn_out, z) + ggml_tensor * attn_out_norm = build_norm_gated(output, model.layers[il].ssm_norm, z_2d, il); + + // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim] + ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs); + cb(final_output, "final_output", il); + + // Output projection + cur = build_lora_mm(model.layers[il].ssm_out, final_output); + cb(cur, "linear_attn_out", il); + + // Reshape back to original dimensions + cur = ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs); + + return cur; +} + +ggml_tensor * llm_build_qwen35::build_layer_ffn(ggml_tensor * cur, const int il) { + // Qwen3.5 does not use MoE FFN + GGML_ASSERT(model.layers[il].ffn_gate_inp == nullptr); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + return cur; +} diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp new file mode 100644 index 0000000000..c7295e3364 --- /dev/null +++ b/src/models/qwen35moe.cpp @@ -0,0 +1,418 @@ +#include "models.h" + +#include "llama-memory-recurrent.h" + +llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params) : + llm_build_delta_net_base(params), model(model) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + int sections[4]; + std::copy(std::begin(hparams.rope_sections), std::begin(hparams.rope_sections) + 4, sections); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + cb(inpL, "model.input_embed", -1); + + auto * inp = build_inp_mem_hybrid(); + + ggml_tensor * inp_pos = build_inp_pos(); + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // Determine layer type and build appropriate attention mechanism + if (hparams.is_recurrent(il)) { + // Linear attention layer (gated delta net) + cur = build_layer_attn_linear(inp->get_recr(), cur, il); + } else { + // Full attention layer + cur = build_layer_attn(inp->get_attn(), cur, inp_pos, sections, il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + // Residual connection + cur = ggml_add(ctx0, cur, inpSA); + cb(cur, "attn_residual", il); + + // Save the tensor before post-attention norm for residual connection + ggml_tensor * ffn_residual = cur; + + // Post-attention norm + ggml_tensor * attn_post_norm = build_norm(cur, model.layers[il].attn_post_norm, nullptr, LLM_NORM_RMS, il); + cb(attn_post_norm, "attn_post_norm", il); + + // MOE FFN layer + cur = build_layer_ffn(attn_post_norm, il); + cb(cur, "ffn_out", il); + + // Residual connection for FFN - add to the tensor from before post_attention_layernorm + cur = ggml_add(ctx0, cur, ffn_residual); + cb(cur, "post_moe", il); + + // Input for next layer + inpL = cur; + } + cur = inpL; + + // Final norm + cur = build_norm(cur, model.output_norm, nullptr, LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // LM head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} + +std::pair llm_build_qwen35moe::build_qkvz( + ggml_tensor * input, + int il) { + const int64_t n_seqs = ubatch.n_seqs; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + ggml_tensor * qkv_mixed = build_lora_mm(model.layers[il].wqkv, input); + qkv_mixed = ggml_reshape_3d(ctx0, qkv_mixed, qkv_mixed->ne[0], n_seq_tokens, n_seqs); + cb(qkv_mixed, "linear_attn_qkv_mixed", il); + + ggml_tensor * z = build_lora_mm(model.layers[il].wqkv_gate, input); + cb(z, "z", il); + + return { qkv_mixed, z }; +} + +ggml_tensor * llm_build_qwen35moe::build_norm_gated( + ggml_tensor * input, + ggml_tensor * weights, + ggml_tensor * gate, + int layer) { + ggml_tensor * normalized = build_norm(input, weights, nullptr, LLM_NORM_RMS, layer); + ggml_tensor * gated_silu = ggml_silu(ctx0, gate); + + return ggml_mul(ctx0, normalized, gated_silu); +} + +ggml_tensor * llm_build_qwen35moe ::build_layer_attn( + llm_graph_input_attn_kv * inp, + ggml_tensor * cur, + ggml_tensor * inp_pos, + int * sections, + int il) { + const int64_t n_embd_head = hparams.n_embd_head_v; + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + // Order: joint QG projection, QG split, Q norm, KV projection, K norm, RoPE, attention + + // Qwen3Next uses a single Q projection that outputs query + gate + ggml_tensor * Qcur_full = build_lora_mm(model.layers[il].wq, cur); // [ (n_embd_head * 2) * n_head, n_tokens ] + cb(Qcur_full, "Qcur_full", il); + + ggml_tensor * Qcur = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, 0); + cb(Qcur, "Qcur_reshaped", il); + + // Apply Q normalization + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + // Apply K normalization + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + ggml_tensor * gate = ggml_view_3d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, + ggml_element_size(Qcur_full) * n_embd_head * 2, + ggml_element_size(Qcur_full) * n_embd_head * 2 * n_head, + ggml_element_size(Qcur_full) * n_embd_head); + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + cb(gate, "gate_reshaped", il); + + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + // Apply IMRoPE + Qcur = ggml_rope_multi( + ctx0, Qcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_multi( + ctx0, Kcur, inp_pos, nullptr, + n_rot, sections, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + // Attention computation + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + cur = build_attn(inp, + nullptr, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_pregate", il); + + ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate); + cb(gate_sigmoid, "gate_sigmoid", il); + + cur = ggml_mul(ctx0, cur, gate_sigmoid); + cb(cur, "attn_gated", il); + + cur = build_lora_mm(model.layers[il].wo, cur); + cb(cur, "attn_output", il); + + return cur; +} + +ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( + llm_graph_input_rs * inp, + ggml_tensor * cur, + int il) { + const auto * mctx_cur = inp->mctx; + + const int64_t d_inner = hparams.ssm_d_inner; + const int64_t n_seqs = ubatch.n_seqs; + const int64_t head_k_dim = hparams.ssm_d_state; + const int64_t num_k_heads = hparams.ssm_n_group; + const int64_t num_v_heads = hparams.ssm_dt_rank; + const int64_t head_v_dim = d_inner / num_v_heads; + const int64_t n_seq_tokens = ubatch.n_seq_tokens; + + const auto kv_head = mctx_cur->get_head(); + + GGML_ASSERT(n_seqs != 0); + GGML_ASSERT(ubatch.equal_seqs()); + GGML_ASSERT(ubatch.n_tokens == n_seq_tokens * n_seqs); + + // Input projections + auto qkvz = build_qkvz(cur, il); + ggml_tensor * qkv_mixed = qkvz.first; + ggml_tensor * z = qkvz.second; + + ggml_tensor * beta = build_lora_mm(model.layers[il].ssm_beta, cur); + beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs); + cb(beta, "beta", il); + + beta = ggml_sigmoid(ctx0, beta); + + ggml_tensor * alpha = build_lora_mm(model.layers[il].ssm_alpha, cur); + alpha = ggml_cont_3d(ctx0, alpha, num_v_heads, n_seq_tokens, n_seqs); + cb(alpha, "alpha", il); + + ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); + ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased); + cb(alpha_softplus, "a_softplus", il); + + ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus + cb(gate, "gate", il); + + gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); + + // Get convolution states from cache + ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); + ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); + + // Build the convolution states tensor + ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); + cb(conv_states, "conv_states", il); + + // Calculate convolution kernel size + ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; + const int64_t conv_kernel_size = conv_kernel->ne[0]; + const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; + + conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); + cb(conv_states, "conv_states_reshaped", il); + + qkv_mixed = ggml_transpose(ctx0, qkv_mixed); + cb(qkv_mixed, "qkv_mixed_transposed", il); + + ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); + cb(conv_input, "conv_input", il); + + // Update convolution state cache + // Extract the last (conv_kernel_size - 1) states from conv_input + ggml_tensor * last_conv_states = + ggml_view_3d(ctx0, conv_input, conv_kernel_size - 1, conv_channels, n_seqs, conv_input->nb[1], + conv_input->nb[2], (conv_input->ne[0] - conv_states->ne[0]) * ggml_element_size(conv_input)); + cb(last_conv_states, "last_conv_states", il); + + ggml_tensor * state_update_target = + ggml_view_1d(ctx0, conv_states_all, (conv_kernel_size - 1) * conv_channels * n_seqs, + kv_head * (conv_kernel_size - 1) * conv_channels * ggml_element_size(conv_states_all)); + cb(state_update_target, "state_update_target", il); + + ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); + cb(conv_states_all, "conv_states_updated", il); + + ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); + state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); + cb(state, "state_predelta", il); + + ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel); + cb(conv_output_proper, "conv_output_raw", il); + + ggml_tensor * conv_output_silu = ggml_silu(ctx0, conv_output_proper); + cb(conv_output_silu, "conv_output_silu", il); + + ggml_tensor * conv_qkv_mix = conv_output_silu; + + // Calculate the total conv dimension + int64_t qkv_dim = head_k_dim * num_k_heads * 2 + head_v_dim * num_v_heads; + int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim); + + // Extract the convolved Q, K, V from conv_output + ggml_tensor * q_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_k_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + 0); + + ggml_tensor * k_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_k_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); + + ggml_tensor * v_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_v_dim, num_v_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_v_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + ggml_row_size(conv_qkv_mix->type, 2 * head_k_dim * num_k_heads)); + + cb(q_conv, "q_conv", il); + cb(k_conv, "k_conv", il); + cb(v_conv, "v_conv", il); + + const float eps_norm = hparams.f_norm_rms_eps; + + q_conv = ggml_l2_norm(ctx0, q_conv, eps_norm); + k_conv = ggml_l2_norm(ctx0, k_conv, eps_norm); + + //q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + //k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); + + // if head keys and value keys are different, repeat to force tensors into matching shapes + if (num_k_heads != num_v_heads) { + GGML_ASSERT(num_v_heads % num_k_heads == 0); + // TODO: try to avoid these explicit repeats by utilizing op broadcast + q_conv = ggml_repeat_4d(ctx0, q_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); + k_conv = ggml_repeat_4d(ctx0, k_conv, head_k_dim, num_v_heads, n_seq_tokens, n_seqs); + } + + cb(q_conv, "q_conv_predelta", il); + cb(k_conv, "k_conv_predelta", il); + cb(v_conv, "v_conv_predelta", il); + + // Choose between build_delta_net_chunking, build_delta_net_recurrent, and build_delta_net_autoregressive based on n_tokens + std::pair attn_out; // pair of (output, new_state) + if (n_seq_tokens == 1) { + attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); + } else { + attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, il); + } + ggml_tensor * output = attn_out.first; + ggml_tensor * new_state = attn_out.second; + cb(output, "attn_output", il); + cb(new_state, "new_state", il); + + // Update the recurrent states + ggml_build_forward_expand(gf, + ggml_cpy(ctx0, new_state, + ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, + kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); + + // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] + ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); + + // Apply gated normalization: self.norm(core_attn_out, z) + ggml_tensor * attn_out_norm = build_norm_gated(output, model.layers[il].ssm_norm, z_2d, il); + + // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim] + ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs); + cb(final_output, "final_output", il); + + // Output projection + cur = build_lora_mm(model.layers[il].ssm_out, final_output); + cb(cur, "linear_attn_out", il); + + // Reshape back to original dimensions + cur = ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs); + + return cur; +} + +ggml_tensor * llm_build_qwen35moe ::build_layer_ffn(ggml_tensor * cur, const int il) { + // Check if this is an MoE layer + GGML_ASSERT(model.layers[il].ffn_gate_inp != nullptr); + + ggml_tensor * moe_out = + build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, + nullptr, + n_expert, n_expert_used, LLM_FFN_SILU, + true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); + cb(moe_out, "ffn_moe_out", il); + + // Add shared experts if present - following Qwen3Next reference implementation + if (model.layers[il].ffn_up_shexp != nullptr) { + ggml_tensor * ffn_shexp = + build_ffn(cur, + model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_gate_shexp, NULL, NULL, + model.layers[il].ffn_down_shexp, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(ffn_shexp, "ffn_shexp", il); + + // Apply shared expert gating as in the reference implementation + // The shared expert has its own gate that is sigmoided + // Note: ffn_gate_inp_shexp is the shared expert gate (outputs 1 value per token) + ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur); + cb(shared_gate, "shared_expert_gate", il); + + // Apply sigmoid to the gate + shared_gate = ggml_sigmoid(ctx0, shared_gate); + cb(shared_gate, "shared_expert_gate_sigmoid", il); + + + // Apply the gate to the shared expert output + ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate); + cb(ffn_shexp, "ffn_shexp_gated", il); + + cur = ggml_add(ctx0, moe_out, ffn_shexp); + cb(cur, "ffn_out", il); + } else { + cur = moe_out; + } + + return cur; +} diff --git a/src/models/qwen3next.cpp b/src/models/qwen3next.cpp index 99b1a76a48..974120ea6f 100644 --- a/src/models/qwen3next.cpp +++ b/src/models/qwen3next.cpp @@ -1,10 +1,9 @@ -#include "ggml.h" #include "models.h" -#define CHUNK_SIZE 64 +#include "llama-memory-recurrent.h" llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params), model(model) { + llm_build_delta_net_base(params), model(model) { ggml_tensor * cur; ggml_tensor * inpL; @@ -16,17 +15,6 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - ggml_tensor * causal_mask = - ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f), - GGML_TRI_TYPE_LOWER); - - ggml_tensor * identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f)); - ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity); - - ggml_build_forward_expand(gf, causal_mask); - ggml_build_forward_expand(gf, identity); - ggml_build_forward_expand(gf, diag_mask); - for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -36,7 +24,7 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr // Determine layer type and build appropriate attention mechanism if (hparams.is_recurrent(il)) { // Linear attention layer (gated delta net) - cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il); + cur = build_layer_attn_linear(inp->get_recr(), cur, il); } else { // Full attention layer cur = build_layer_attn(inp->get_attn(), cur, inp_pos, il); @@ -94,354 +82,6 @@ static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c); } -std::pair llm_build_qwen3next::build_delta_net_chunking( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, - int il) { - const int64_t S_k = q->ne[0]; - const int64_t H_k = q->ne[1]; - const int64_t n_tokens = q->ne[2]; - const int64_t n_seqs = q->ne[3]; - - const int64_t S_v = v->ne[0]; - const int64_t H_v = v->ne[1]; - - GGML_ASSERT(v->ne[2] == n_tokens); - GGML_ASSERT(k->ne[2] == n_tokens); - GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); - GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); - GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); - - GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); - - GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case - - const float eps_norm = hparams.f_norm_rms_eps; - - q = ggml_l2_norm(ctx0, q, eps_norm); - k = ggml_l2_norm(ctx0, k, eps_norm); - - const float scale = 1.0f / sqrtf(S_v); - - q = ggml_scale(ctx0, q, scale); - - beta = ggml_sigmoid(ctx0, beta); - - cb(q, "q_in", il); - cb(k, "k_in", il); - cb(v, "v_in", il); - cb(beta, "beta_in", il); - cb(g, "g_in", il); - - q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs); - - beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3)); - state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); - - cb(q, "q_perm", il); - cb(k, "k_perm", il); - cb(v, "v_perm", il); - cb(beta, "beta_perm", il); - cb(g, "g_perm", il); - cb(state, "state_in", il); - - GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs); - GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs); - GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs); - - // Do padding - const int64_t chunk_size = CHUNK_SIZE; - - const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size; - const int64_t n_chunks = (n_tokens + pad) / chunk_size; - - q = ggml_pad(ctx0, q, 0, pad, 0, 0); - k = ggml_pad(ctx0, k, 0, pad, 0, 0); - v = ggml_pad(ctx0, v, 0, pad, 0, 0); - g = ggml_pad(ctx0, g, pad, 0, 0, 0); - beta = ggml_pad(ctx0, beta, 0, pad, 0, 0); - - cb(q, "q_pad", il); - cb(k, "k_pad", il); - cb(v, "v_pad", il); - cb(beta, "beta_pad", il); - cb(g, "g_pad", il); - - ggml_tensor * v_beta = ggml_mul(ctx0, v, beta); - ggml_tensor * k_beta = ggml_mul(ctx0, k, beta); - - cb(v_beta, "v_beta", il); - cb(k_beta, "k_beta", il); - - q = ggml_reshape_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs); - k = ggml_reshape_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs); - k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs); - v = ggml_reshape_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs); - v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs); - - g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs); - beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs); - - ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g); - cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * gcs_i = g_cumsum; // ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs); - ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs); - - ggml_tensor * gcs_j_broadcast = - ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs); - - ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i); - cb(decay_mask, "decay_mask", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); - decay_mask = ggml_exp(ctx0, decay_mask); - decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); - - ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta); - - ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask); - ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask)); - cb(attn, "attn_pre_solve", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask); - ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower); - - ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); - attn = ggml_mul(ctx0, lin_solve, causal_mask); - attn = ggml_add(ctx0, attn, identity); - cb(attn, "attn_solved", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn); - - ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum)); - ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t); - - ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp); - cb(kbeta_gexp, "kbeta_gexp", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * k_cumdecay = - ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp))))); - cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q); - attn_kq = ggml_mul(ctx0, attn_kq, decay_mask); - attn_kq = ggml_mul(ctx0, attn_kq, diag_mask); - cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - - // vectorized calculation of key_gdiff - // improved from the chunked version: - // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) - // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() - // key_gdiff = key * g_diff.unsqueeze(-1) - // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new - // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew - - // get last element in g_cumsum along chunk_size dimension (ne0) - // example: [[x, y, z, ..., last], ...] -> [[last], ...] - ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3], - g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3], - (g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum)); - g_last = ggml_cont(ctx0, g_last); - cb(g_last, "g_last", il); // shape: (1, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last); - cb(g_last_exp, "g_last_exp", il); // shape: (1, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last)); - cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); - ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp, - 1, chunk_size, n_chunks, g_diff_exp->ne[3]); - - ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t); - cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff)); - cb(key_gdiff_t, "key_gdiff_t", il); // shape: (chunk_size, S_k, n_chunks, H_v * n_seqs) - - - // state to be updated per chunk - ggml_tensor * new_state = state; // ggml_dup(ctx0, state); - cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs) - - // shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs) - ggml_tensor * core_attn_out = nullptr; - - for (int64_t chunk = 0; chunk < n_chunks; chunk++) { - // shape: (S_k, chunk_size, 1, H_k * n_seqs) - ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); // (no cont), next op: ggml_mul - - // shape: (S_v, chunk_size, 1, H_v * n_seqs) - ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk); // (no cont), next op: ggml_repeat - - // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) - ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk); // (no cont), next op: ggml_mul - - // shape: (chunk_size, 1, H_v * n_seqs) - ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); // (no cont), next op: ggml_mul_mat - - // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) - // replaced by precomputed attn_kq - ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk); - cb(attn_chunk, "attn_chunk", il); - - ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs); - - // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state - ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk); - cb(v_prime, "v_prime_chunk", il); // shape: (S_v, 1, H_v * n_seqs) - - // v_new = v_i - v_prime - ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime); - ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new)); - cb(v_new, "v_new_chunk", il); - - // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state - ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk); - ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp); - cb(attn_inter, "attn_inter_chunk", il); - - // core_attn_out[:, :, i] = attn_inter + attn @ v_new - ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk); - cb(v_attn, "v_attn_chunk", il); - - ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn); - cb(core_attn_out_chunk, "core_attn_out_chunk", il); // shape: (S_v, chunk_size, 1, H_v * n_seqs) - - core_attn_out = core_attn_out == nullptr - ? core_attn_out_chunk - : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2); - - // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new - ggml_tensor * k_gdiff_t = get_slice_2d(ctx0, key_gdiff_t, chunk); - //ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why? - ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t); - - // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew - ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk)); - new_state = ggml_add(ctx0, - ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)), - ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs)); - } - - // truncate padded tokens - ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, - S_v, n_tokens, H_v, n_seqs, - ggml_row_size(core_attn_out->type, S_v), - ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks), - ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0); - output_tokens = ggml_cont(ctx0, output_tokens); - cb(output_tokens, "output_tokens", il); - - // permute back to (S_v, H_v, n_tokens, n_seqs) - output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3); - output_tokens = ggml_cont(ctx0, output_tokens); - - return {output_tokens, new_state}; -} - -std::pair llm_build_qwen3next::build_delta_net_autoregressive( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - int il) { - const int64_t S_k = q->ne[0]; - const int64_t H_k = q->ne[1]; - const int64_t n_tokens = q->ne[2]; - const int64_t n_seqs = q->ne[3]; - - const int64_t S_v = v->ne[0]; - const int64_t H_v = v->ne[1]; - - GGML_ASSERT(n_tokens == 1); // This function is optimized for single token processing - GGML_ASSERT(v->ne[2] == n_tokens); - GGML_ASSERT(k->ne[2] == n_tokens); - GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); - GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); - GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); - - GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); - - GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case - - const float eps_norm = hparams.f_norm_rms_eps; - - q = ggml_l2_norm(ctx0, q, eps_norm); - k = ggml_l2_norm(ctx0, k, eps_norm); - - const float scale = 1.0f / sqrtf(S_v); - - q = ggml_scale(ctx0, q, scale); - beta = ggml_sigmoid(ctx0, beta); - - cb(q, "q_in", il); - cb(k, "k_in", il); - cb(v, "v_in", il); - cb(beta, "beta_in", il); - cb(g, "g_in", il); - - state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); - - ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs); - ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs); - - // Apply exponential to g_t - g_t = ggml_exp(ctx0, g_t); - - // Apply the gated delta rule for the single timestep - // last_recurrent_state = last_recurrent_state * g_t - state = ggml_mul(ctx0, state, g_t); - - // kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) - ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs); - ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed); - // we need to sum over dim=-2, so we transpose, sum, then transpose again - kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem)))); - - // v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v) - ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs); - // delta = (v_t - kv_mem) * beta_t - ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem); // both should be [S_v, 1, H_v, n_seqs] - ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t); - - // last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta - ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta); - state = ggml_add(ctx0, state, k_t_delta); - - // Compute the attention output - // core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) - ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs); // unsqueeze q_t - ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed); - // again, since it's over dim = -2, transpose, sum, transpose back - ggml_tensor * core_attn_out = - ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q)))); - - // core_attn_out should be [S_v, 1, H_v, n_seqs] after this - cb(core_attn_out, "output_tokens", il); - cb(state, "new_state", il); - - return {core_attn_out, state}; -} - ggml_tensor * llm_build_qwen3next::build_norm_gated( ggml_tensor * input, ggml_tensor * weights, @@ -472,39 +112,29 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn( // Split Q projection into query and gate // The split should be along dimension 0 (the feature dimension) ggml_tensor * Qcur = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1, - Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0); + Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0); + cb(Qcur, "Qcur_view", il); + ggml_tensor * gate = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1, Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], n_embd_head * ggml_element_size(Qcur_full)); - cb(Qcur, "Qcur", il); cb(gate, "gate", il); - // Now reshape Qcur to [n_embd_head, n_head, n_tokens] for multi-head attention - Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cb(Qcur, "Qcur_reshaped", il); - - // Apply Q normalization - Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); - cb(Qcur, "Qcur_normed", il); - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); cb(Kcur, "Kcur", il); ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); cb(Vcur, "Vcur", il); - // Apply K normalization Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); cb(Kcur, "Kcur_normed", il); - // Reshape gate to [n_embd, n_tokens] for the sigmoid gating (flatten the heads) - gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); - cb(gate, "gate_reshaped", il); - - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - - // Apply RoPE Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -519,7 +149,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn( cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - // Attention computation const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp, @@ -527,10 +156,15 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn( Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_pregate", il); - ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate); - cb(gate_sigmoid, "gate_sigmoid", il); + // TODO: CUDA is missing non-contiguous unary ops. when implemented: remove this cont + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); - cur = ggml_mul(ctx0, cur, gate_sigmoid); + gate = ggml_sigmoid(ctx0, gate); + cb(gate, "gate_sigmoid", il); + + gate = ggml_reshape_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + + cur = ggml_mul(ctx0, cur, gate); cb(cur, "attn_gated", il); cur = build_lora_mm(model.layers[il].wo, cur); @@ -560,7 +194,6 @@ std::pair llm_build_qwen3next::build_qkvz( cb(z, "z", il); return { qkv_mixed, z }; - } else { // legacy (slower) path ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, input); @@ -624,9 +257,6 @@ std::pair llm_build_qwen3next::build_qkvz( ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( llm_graph_input_rs * inp, ggml_tensor * cur, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, int il) { const auto * mctx_cur = inp->mctx; @@ -671,7 +301,10 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped)); cb(a, "a", il); - ggml_tensor * beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs); + // TODO: CUDA is missing non-contiguous unary ops. when implemented: remove this cont + b = ggml_cont(ctx0, b); + + ggml_tensor * beta = ggml_sigmoid(ctx0, b); // Reshape a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads] ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs); @@ -679,15 +312,17 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased); cb(alpha_softplus, "a_softplus", il); + ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus cb(gate, "gate", il); + beta = ggml_reshape_4d(ctx0, beta, 1, num_v_heads, n_seq_tokens, n_seqs); + gate = ggml_reshape_4d(ctx0, gate, 1, num_v_heads, n_seq_tokens, n_seqs); + // Get convolution states from cache ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // bool use_precomputed_states = n_seq_tokens == 1 && mctx_cur->has_previous_state(); - // Build the convolution states tensor ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); cb(conv_states, "conv_states", il); @@ -696,11 +331,12 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; - conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); + + conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); cb(conv_states, "conv_states_reshaped", il); - qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3); - cb(qkv_mixed, "qkv_mixed_permuted", il); + qkv_mixed = ggml_transpose(ctx0, qkv_mixed); + cb(qkv_mixed, "qkv_mixed_transposed", il); ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); cb(conv_input, "conv_input", il); @@ -720,7 +356,10 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); cb(conv_states_all, "conv_states_updated", il); - // Apply SSM convolution + ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); + state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); + cb(state, "state_predelta", il); + ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel); cb(conv_output_proper, "conv_output_raw", il); @@ -734,26 +373,36 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim); // Extract the convolved Q, K, V from conv_output - ggml_tensor * q_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0); + ggml_tensor * q_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_k_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + 0); + + ggml_tensor * k_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_k_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); + + ggml_tensor * v_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_v_dim, num_v_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_v_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + ggml_row_size(conv_qkv_mix->type, 2 * head_k_dim * num_k_heads)); + cb(q_conv, "q_conv", il); - ggml_tensor * k_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, - head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); cb(k_conv, "k_conv", il); - ggml_tensor * v_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, nb1_qkv, - 2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); cb(v_conv, "v_conv", il); - // Unsqueeze them - q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); - k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); - v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); + const float eps_norm = hparams.f_norm_rms_eps; - ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); - state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs); - cb(state, "state_predelta", il); + q_conv = ggml_l2_norm(ctx0, q_conv, eps_norm); + k_conv = ggml_l2_norm(ctx0, k_conv, eps_norm); + + //q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + //k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // if head keys and value keys are different, repeat to force tensors into matching shapes if (num_k_heads != num_v_heads) { @@ -786,7 +435,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( if (n_seq_tokens == 1) { attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); } else { - attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il); + attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, il); } ggml_tensor * output = attn_out.first; ggml_tensor * new_state = attn_out.second; @@ -795,19 +444,15 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( // Update the recurrent states ggml_build_forward_expand(gf, - ggml_cpy(ctx0, new_state, - ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, - kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); - - // Reshape both attn_out_final and z to 2D tensors for normalization - // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] - ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + ggml_cpy(ctx0, new_state, + ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, + kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] - ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // Apply gated normalization: self.norm(core_attn_out, z) - ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il); + ggml_tensor * attn_out_norm = build_norm_gated(output, model.layers[il].ssm_norm, z_2d, il); // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim] ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs); @@ -818,7 +463,8 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( cb(cur, "linear_attn_out", il); // Reshape back to original dimensions - cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs); + cur = ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs); + return cur; } @@ -839,7 +485,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int if (model.layers[il].ffn_up_shexp != nullptr) { ggml_tensor * ffn_shexp = build_ffn(cur, - model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL, NULL, model.layers[il].ffn_down_shexp, NULL, NULL, NULL, @@ -852,11 +498,9 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur); cb(shared_gate, "shared_expert_gate", il); - // Apply sigmoid to the gate shared_gate = ggml_sigmoid(ctx0, shared_gate); cb(shared_gate, "shared_expert_gate_sigmoid", il); - // Apply the gate to the shared expert output ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate); cb(ffn_shexp, "ffn_shexp_gated", il); diff --git a/src/models/rwkv6-base.cpp b/src/models/rwkv6-base.cpp index 7beed2daff..83aeab7280 100644 --- a/src/models/rwkv6-base.cpp +++ b/src/models/rwkv6-base.cpp @@ -1,5 +1,7 @@ #include "models.h" +#include "llama-memory-recurrent.h" + llm_build_rwkv6_base::llm_build_rwkv6_base(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model) {} diff --git a/src/models/rwkv7-base.cpp b/src/models/rwkv7-base.cpp index cda4465384..7fcab77745 100644 --- a/src/models/rwkv7-base.cpp +++ b/src/models/rwkv7-base.cpp @@ -1,5 +1,7 @@ #include "models.h" +#include "llama-memory-recurrent.h" + llm_build_rwkv7_base::llm_build_rwkv7_base(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model) {} diff --git a/src/unicode.cpp b/src/unicode.cpp index adfc489d1f..1475b53b65 100644 --- a/src/unicode.cpp +++ b/src/unicode.cpp @@ -1,16 +1,10 @@ -#if defined(_MSC_VER) -#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING -#endif - #include "unicode.h" #include "unicode-data.h" #include #include -#include #include #include -#include #include #include #include @@ -199,27 +193,6 @@ static std::unordered_map unicode_utf8_to_byte_map() { return map; } -static inline std::wstring unicode_wstring_from_utf8(const std::string & s) { -#if defined(__clang__) - // disable C++17 deprecation warning for std::codecvt_utf8 -# pragma clang diagnostic push -# pragma clang diagnostic ignored "-Wdeprecated-declarations" -#elif defined(__GNUC__) -# pragma GCC diagnostic push -# pragma GCC diagnostic ignored "-Wdeprecated-declarations" -#endif - - std::wstring_convert> conv; - -#if defined(__clang__) -# pragma clang diagnostic pop -#elif defined(__GNUC__) -# pragma GCC diagnostic pop -#endif - - return conv.from_bytes(s); -} - static std::vector unicode_byte_encoding_process(const std::vector & bpe_words) { std::vector bpe_encoded_words; for (const auto & word : bpe_words) { @@ -796,6 +769,12 @@ static std::vector unicode_regex_split_custom(const std::string & text, } else if (regex_expr == "\\p{AFMoE_digits}") { // AFMOE digit pattern - use custom implementation for proper splitting bpe_offsets = unicode_regex_split_custom_afmoe(text, offsets); + } else if (regex_expr == "\\d{1,3}(?=(?:\\d{3})*\\b)") { + // tiny_aya digit grouping pattern from tokenizer.json: + // {"type": "Split", "pattern": {"Regex": "\\d{1,3}(?=(?:\\d{3})*\\b)"}, "behavior": "Isolated"} + // Splits digits into groups of 3 from the right (e.g., 1234567 -> 1, 234, 567) + // TODO: Revisit this regex, incase there are any subtle tokenization differences with the original regex. + bpe_offsets = unicode_regex_split_custom_afmoe(text, offsets); } return bpe_offsets; @@ -1028,10 +1007,10 @@ std::vector unicode_regex_split(const std::string & text, const std break; } } + const auto cpts_regex = unicode_cpts_from_utf8(regex_expr); if (use_collapsed) { // sanity-check that the original regex does not contain any non-ASCII characters - const auto cpts_regex = unicode_cpts_from_utf8(regex_expr); for (size_t i = 0; i < cpts_regex.size(); ++i) { if (cpts_regex[i] >= 128) { throw std::runtime_error("Regex includes both unicode categories and non-ASCII characters - not supported"); @@ -1087,7 +1066,7 @@ std::vector unicode_regex_split(const std::string & text, const std bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets); } else { // no unicode category used, we can use std::wregex directly - const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr); + std::wstring wregex_expr(cpts_regex.begin(), cpts_regex.end()); // std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback std::wstring wtext(cpts.begin(), cpts.end()); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index c9436c5995..350bffc315 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -11,7 +11,9 @@ function(llama_build source) add_executable(${TEST_TARGET} ${TEST_SOURCES}) target_link_libraries(${TEST_TARGET} PRIVATE common) - install(TARGETS ${TEST_TARGET} RUNTIME) + if (LLAMA_TESTS_INSTALL) + install(TARGETS ${TEST_TARGET} RUNTIME) + endif() endfunction() function(llama_test target) @@ -100,7 +102,9 @@ function(llama_build_and_test source) endif() add_executable(${TEST_TARGET} ${TEST_SOURCES}) - install(TARGETS ${TEST_TARGET} RUNTIME) + if (LLAMA_TESTS_INSTALL) + install(TARGETS ${TEST_TARGET} RUNTIME) + endif() target_link_libraries(${TEST_TARGET} PRIVATE common) add_test( diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 6fe1780f3b..746648a064 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1943,7 +1943,11 @@ struct test_unary : public test_case { ggml_tensor * a; if (v & 1) { - auto ne = ne_a; ne[0] *= 3; + auto ne = ne_a; + ne[0] *= 3; + ne[1] *= 2; + ne[2] *= 5; + ne[3] *= 4; a = ggml_new_tensor(ctx, type, 4, ne.data()); if (grad_supported) { ggml_set_param(a); @@ -2782,9 +2786,10 @@ struct test_set : public test_case { const ggml_type type_dst; const std::array ne; const int dim; + const bool inplace; std::string vars() override { - return VARS_TO_STR4(type_src, type_dst, ne, dim); + return VARS_TO_STR5(type_src, type_dst, ne, dim, inplace); } size_t op_size(ggml_tensor * t) override { @@ -2792,8 +2797,8 @@ struct test_set : public test_case { } test_set(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32, - std::array ne = {6, 5, 4, 3}, int dim = 1) - : type_src(type_src), type_dst(type_dst), ne(ne), dim(dim) {} + std::array ne = {6, 5, 4, 3}, int dim = 1, bool inplace = false) + : type_src(type_src), type_dst(type_dst), ne(ne), dim(dim), inplace(inplace) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data()); @@ -2804,7 +2809,7 @@ struct test_set : public test_case { for (int i = 0; i < dim; ++i) { ne_dst[i] *= 2; } - ggml_tensor* dst = ggml_new_tensor(ctx, type_dst, 4, ne_dst.data()); + ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, ne_dst.data()); ggml_set_param(dst); ggml_set_name(dst, "dst"); @@ -2812,9 +2817,16 @@ struct test_set : public test_case { for (int i = 0; i < dim; ++i) { offset += ((ne_dst[i] - ne[i])/2)*dst->nb[i]; } - ggml_tensor * out = ggml_set(ctx, dst, src, - // The backward pass requires setting a contiguous region: - src->nb[1], src->nb[2], src->nb[3], offset); + ggml_tensor * out; + if (inplace) { + out = ggml_set_inplace(ctx, dst, src, + // The backward pass requires setting a contiguous region: + src->nb[1], src->nb[2], src->nb[3], offset); + } else { + out = ggml_set(ctx, dst, src, + // The backward pass requires setting a contiguous region: + src->nb[1], src->nb[2], src->nb[3], offset); + } ggml_set_name(out, "out"); return out; @@ -2964,11 +2976,12 @@ struct test_bin_bcast : public test_case { const std::array ne; const std::array nr; int nf; // number of fused ops, nf == 1 -> single op (no fusion) + bool perm1; // permute src1? bool run_whole_graph() override { return nf > 1; } std::string vars() override { - return VARS_TO_STR4(type, ne, nr, nf); + return VARS_TO_STR5(type, ne, nr, nf, perm1); } size_t op_size(ggml_tensor * t) override { @@ -2978,8 +2991,9 @@ struct test_bin_bcast : public test_case { test_bin_bcast(op_t op, ggml_type type = GGML_TYPE_F32, std::array ne = {10, 10, 1, 1}, std::array nr = {1, 2, 1, 1}, - int nf = 1) - : op(op), type(type), ne(ne), nr(nr), nf(nf) {} + int nf = 1, + bool perm1 = false) + : op(op), type(type), ne(ne), nr(nr), nf(nf), perm1(perm1) {} ggml_tensor * build_graph(ggml_context * ctx) override { GGML_ASSERT(nf <= 16); @@ -2989,12 +3003,19 @@ struct test_bin_bcast : public test_case { ggml_tensor * b[16]; for (int i = 0; i < nf; ++i) { - b[i] = ggml_new_tensor(ctx, type, 4, ne.data()); + if (perm1) { + const int p[4] = { 1, 2, 0, 3 }; // hardcoded for now + + b[i] = ggml_new_tensor_4d(ctx, type, ne[p[0]], ne[p[1]], ne[p[2]], ne[p[3]]); + b[i] = ggml_permute(ctx, b[i], p[0], p[1], p[2], p[3]); + } else { + b[i] = ggml_new_tensor(ctx, type, 4, ne.data()); + } ggml_set_name(b[i], (std::string("b") + std::to_string(i)).c_str()); } // The backward pass supports broadcasting only for GGML_ADD: - const bool grad_supported = op == ggml_add && ggml_are_same_shape(a, b[0]) && nf == 1; + const bool grad_supported = op == ggml_add && ggml_are_same_shape(a, b[0]) && nf == 1 && !perm1; if (grad_supported) { ggml_set_param(a); ggml_set_param(b[0]); @@ -5800,20 +5821,27 @@ struct test_l2_norm : public test_case { const ggml_type type; const std::array ne; const float eps; + bool v; std::string vars() override { - return VARS_TO_STR2(type, ne); + return VARS_TO_STR4(type, ne, eps, v); } test_l2_norm(ggml_type type = GGML_TYPE_F32, std::array ne = {64, 64, 320, 1}, - float eps = 1e-12f) - : type(type), ne(ne), eps(eps) {} + float eps = 1e-12f, + bool v = false) + : type(type), ne(ne), eps(eps), v(v) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_set_name(a, "a"); + if (v) { + a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0); + ggml_set_name(a, "view of a"); + } + ggml_tensor * out = ggml_l2_norm(ctx, a, eps); ggml_set_name(out, "out"); @@ -5826,26 +5854,46 @@ struct test_acc : public test_case { const ggml_type type; const std::array ne_a; const std::array ne_b; + const int64_t stride_dim; std::string vars() override { - return VARS_TO_STR3(type, ne_a, ne_b); + return VARS_TO_STR4(type, ne_a, ne_b, stride_dim); } test_acc(ggml_type type = GGML_TYPE_F32, - std::array ne_a = {256, 17, 1, 1}, - std::array ne_b = {256, 16, 1, 1}) - : type(type), ne_a(ne_a), ne_b(ne_b) {} + std::array ne_a = {256, 17, 2, 3}, + std::array ne_b = {256, 16, 2, 3}, + uint64_t stride_dim = -1) + : type(type), ne_a(ne_a), ne_b(ne_b), stride_dim(stride_dim) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); ggml_set_param(a); ggml_set_name(a, "a"); - ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data()); - ggml_set_param(b); + ggml_tensor * b; + if (stride_dim == 1 || stride_dim == 2 || stride_dim == 3) { + // Create a larger tensor and take a view at a non-zero offset. + // This tests that the backend correctly handles b's data offset + std::array ne_b_pad = {ne_b[0], ne_b[1], ne_b[2], ne_b[3]}; + ne_b_pad[stride_dim] += 1; + ggml_tensor * b_pad = ggml_new_tensor(ctx, type, 4, ne_b_pad.data()); + ggml_set_param(b_pad); + ggml_set_name(b_pad, "b_pad"); + // View that skips the first row, so b has a non-zero byte offset + b = ggml_view_4d(ctx, b_pad, + ne_b[0], ne_b[1], ne_b[2], ne_b[3], + b_pad->nb[1], b_pad->nb[2], b_pad->nb[3], + b_pad->nb[1]); + } else { + b = ggml_new_tensor(ctx, type, 4, ne_b.data()); + ggml_set_param(b); + } ggml_set_name(b, "b"); - ggml_tensor * out = ggml_acc(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], b->nb[1]); + // When ne_b[0] < ne_a[0], a->nb[1] != b->nb[1], so the stride + // parameters to ggml_acc don't match b's natural stride. + ggml_tensor * out = ggml_acc(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], 0); ggml_set_name(out, "out"); return out; @@ -5894,33 +5942,36 @@ struct test_pad_ext : public test_case { const int rp2; const int lp3; const int rp3; - const bool v; + const int tfrm; // 0 - none, 1 - non-cont, 2 - perm const bool circular; std::string vars() override { - return VARS_TO_STR12(type, ne_a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, v, circular); + return VARS_TO_STR12(type, ne_a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, tfrm, circular); } test_pad_ext(ggml_type type = GGML_TYPE_F32, std::array ne_a = {512, 512, 3, 1}, int lp0 = 1, int rp0 = 1, int lp1 = 1, int rp1 = 1, int lp2 = 1, int rp2 = 1, int lp3 = 1, int rp3 = 1, - bool v = false, bool circular = false) + int tfrm = 0, bool circular = false) : type(type), ne_a(ne_a), lp0(lp0), rp0(rp0), lp1(lp1), rp1(rp1), lp2(lp2), rp2(rp2), lp3(lp3), rp3(rp3), - v(v), circular(circular) {} + tfrm(tfrm), circular(circular) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); ggml_set_name(a, "a"); - if (v) { + if (tfrm == 1) { a = ggml_view_4d(ctx, a, (a->ne[0] + 1) / 2, (a->ne[1] + 1) / 2, (a->ne[2] + 1) / 2, (a->ne[3] + 1) / 2, a->nb[1], a->nb[2], a->nb[3], 0); ggml_set_name(a, "view of a"); + } else if (tfrm == 2) { + a = ggml_permute(ctx, a, 2, 1, 0, 3); + ggml_set_name(a, "permuted a"); } ggml_tensor * out = circular ? ggml_pad_ext_circular(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3) - : ggml_pad_ext(ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3); + : ggml_pad_ext (ctx, a, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3); ggml_set_name(out, "out"); return out; @@ -7412,11 +7463,13 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10, 8, 3, 1}, {1, 2, 0, 3})); for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) { - test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim)); + test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim, false)); + test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim, true)); } for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) { - test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim)); + test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim, false)); + test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim, true)); } // same-type copy @@ -7474,25 +7527,27 @@ static std::vector> make_test_cases_eval() { } } - auto add_test_bin_bcast = [&](ggml_type type, std::array ne, std::array nr) { + auto add_test_bin_bcast = [&](ggml_type type, std::array ne, std::array nr, bool perm1 = false) { for (auto op : {ggml_add, ggml_sub, ggml_mul, ggml_div}) { - test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr)); + test_cases.emplace_back(new test_bin_bcast(op, type, ne, nr, 1, perm1)); } }; for (ggml_type type : {GGML_TYPE_F16, GGML_TYPE_F32}) { - add_test_bin_bcast(type, {1, 1, 8, 1}, {1, 1, 1, 1}); - add_test_bin_bcast(type, {1, 1, 1, 1}, {32, 1, 1, 1}); - add_test_bin_bcast(type, {1, 1, 320, 320}, {1, 1, 1, 1}); - add_test_bin_bcast(type, {10, 5, 1, 1}, {1, 1, 1, 1}); - add_test_bin_bcast(type, {10, 5, 4, 1}, {1, 1, 1, 1}); - add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 1}); - add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 1, 1, 1}); - add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 1, 1}); - add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 1}); - add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 2}); - add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 2}); - add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 2, 2}); - add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 2, 2, 2}); + for (bool perm1 : {false, true}) { + add_test_bin_bcast(type, {1, 1, 8, 1}, {1, 1, 1, 1}, perm1); + add_test_bin_bcast(type, {1, 1, 1, 1}, {32, 1, 1, 1}, perm1); + add_test_bin_bcast(type, {1, 1, 320, 320}, {1, 1, 1, 1}, perm1); + add_test_bin_bcast(type, {10, 5, 1, 1}, {1, 1, 1, 1}, perm1); + add_test_bin_bcast(type, {10, 5, 4, 1}, {1, 1, 1, 1}, perm1); + add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 1}, perm1); + add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 1, 1, 1}, perm1); + add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 1, 1}, perm1); + add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 1}, perm1); + add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 1, 2}, perm1); + add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 1, 2, 2}, perm1); + add_test_bin_bcast(type, {10, 5, 4, 3}, {1, 2, 2, 2}, perm1); + add_test_bin_bcast(type, {10, 5, 4, 3}, {2, 2, 2, 2}, perm1); + } // test case for k_bin_bcast_unravel in CUDA backend add_test_bin_bcast(type, {1, 1, 65536, 1}, {256, 1, 1, 1}); @@ -7548,7 +7603,8 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps)); } test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, { n, 5, 4, 3 }, eps)); - test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps)); + test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false)); + test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true)); } } @@ -7879,20 +7935,27 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_round (type)); test_cases.emplace_back(new test_trunc (type)); test_cases.emplace_back(new test_sqr (type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_sqr (type, {1024, 1024, 1, 1})); test_cases.emplace_back(new test_sqrt (type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_sqrt (type, {1024, 1024, 1, 1})); test_cases.emplace_back(new test_log (type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_log (type, {1024, 1024, 1, 1})); test_cases.emplace_back(new test_sin (type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_sin (type, {1024, 1024, 1, 1})); test_cases.emplace_back(new test_cos (type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_cos (type, {1024, 1024, 1, 1})); test_cases.emplace_back(new test_clamp (type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_clamp (type, {1024, 1024, 1, 1})); test_cases.emplace_back(new test_leaky_relu(type, {7, 1, 5, 3})); + test_cases.emplace_back(new test_leaky_relu(type, {1024, 1024, 1, 1})); test_cases.emplace_back(new test_floor (type, {7, 1, 5, 3})); - test_cases.emplace_back(new test_floor (type, { 1024, 1024, 1, 1 })); + test_cases.emplace_back(new test_floor (type, {1024, 1024, 1, 1})); test_cases.emplace_back(new test_ceil (type, {7, 1, 5, 3})); - test_cases.emplace_back(new test_ceil (type, { 1024, 1024, 1, 1 })); + test_cases.emplace_back(new test_ceil (type, {1024, 1024, 1, 1})); test_cases.emplace_back(new test_round (type, {7, 1, 5, 3})); - test_cases.emplace_back(new test_round (type, { 1024, 1024, 1, 1 })); + test_cases.emplace_back(new test_round (type, {1024, 1024, 1, 1})); test_cases.emplace_back(new test_trunc (type, {7, 1, 5, 3})); - test_cases.emplace_back(new test_trunc (type, { 1024, 1024, 1, 1 })); + test_cases.emplace_back(new test_trunc (type, {1024, 1024, 1, 1})); } test_cases.emplace_back(new test_diag_mask_inf(GGML_TYPE_F32, {10, 10, 1, 1}, 5)); @@ -8107,29 +8170,40 @@ static std::vector> make_test_cases_eval() { } test_cases.emplace_back(new test_sum()); - test_cases.emplace_back(new test_sum_rows()); test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 2, 1, 3})); // row-contiguous but non-contiguous test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 3, 2, 1})); test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 1, 3, 2})); + test_cases.emplace_back(new test_mean()); + test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 1, 1, 1 })); + test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 })); + test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 })); + test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32, 1, 1, 1 })); + test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32, 256, 1, 1 })); + test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32768, 1, 1, 1 })); + test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1, 1, 1 })); + test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 })); + test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 })); + test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }, { 1, 0, 2, 3 })); // sum dst not-contiguous + test_cases.emplace_back(new test_sum_rows()); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, false)); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, false, true)); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, true)); - test_cases.emplace_back(new test_mean()); - test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1, 1, 1 })); + test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 16, 5, 6, 3 }, true, false)); + test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 16, 5, 6, 3 }, false, true)); + test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 16, 5, 6, 3 }, true, true)); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1, 1, 1 })); - test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 1, 1, 1 })); - test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 })); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1024, 1, 1 })); - test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 })); - test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }, { 1, 0, 2, 3 })); // sum dst not-contiguous test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 256, 1, 1 })); - test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 })); - test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 })); test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1})); test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1})); test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {64, 64, 320, 1})); test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {9, 9, 1280, 1})); - test_cases.emplace_back(new test_acc()); + test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 1, 1}, {256, 16, 1, 1}, -1)); + test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {256, 16, 2, 3}, -1)); + test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {128, 16, 2, 3}, -1)); + test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {256, 16, 2, 3}, 1)); + test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {128, 16, 2, 3}, 2)); + test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {64, 16, 2, 3}, 3)); test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_pad(GGML_TYPE_F32, {33, 17, 2, 1}, 4, 3, true)); // circular test_cases.emplace_back(new test_pad_ext()); @@ -8198,10 +8272,10 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 200, 64, 4, 4 })); test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 4 }, { 384, 64, 4, 4 })); - for (bool v : {false, true}) { + for (int tfrm : {0, 1, 2}) { for (bool circular : {false, true}) { - test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, v, circular)); - test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, v, circular)); + test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {512, 512, 1, 1}, 0, 1, 0, 1, 0, 0, 0, 0, tfrm, circular)); + test_cases.emplace_back(new test_pad_ext(GGML_TYPE_F32, {11, 22, 33, 44}, 1, 2, 3, 4, 5, 6, 7, 8, tfrm, circular)); } } @@ -8227,7 +8301,7 @@ static std::vector> make_test_cases_eval() { //for (int kv : { 1, 17, 31, 33, 61, 113, 65, 127, 129, 130, 255, 260, 371, 380, 407, 512, 1024, }) { for (int kv : { 113, 512, 1024, }) { if (nr2 != 1 && kv != 512) continue; - for (int nb : { 1, 3, 32, 35, }) { + for (int nb : { 1, 3, 32, 75, }) { for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) { if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue; for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { @@ -8520,7 +8594,7 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_rope(type, { 80, 32, 512, 1}, 20, GGML_ROPE_TYPE_NEOX, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // neox (stablelm) test_cases.emplace_back(new test_rope(type, { 64, 8, 512, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // neox (falcon 40B) test_cases.emplace_back(new test_rope(type, {128, 12, 512, 1}, 128, GGML_ROPE_TYPE_MROPE, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // rope_multi,m-rope (qwen2vl 2B) - test_cases.emplace_back(new test_rope(type, {128, 12, 2, 1}, 128, GGML_ROPE_TYPE_IMROPE, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // rope_multi,imrope (qwen3vl 2B) + test_cases.emplace_back(new test_rope(type, {128, 12, 512, 1}, 128, GGML_ROPE_TYPE_IMROPE, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // rope_multi,imrope (qwen3vl 2B) test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, 1.0f, 0.0f, 1.0f, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT) } } @@ -8564,6 +8638,14 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 512, 1)); // prefill test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 1, 1)); // generate + // acc + test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 1, 1}, {256, 16, 1, 1}, -1)); + test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {256, 16, 2, 3}, -1)); + test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {128, 16, 2, 3}, -1)); + test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {256, 16, 2, 3}, 1)); + test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {128, 16, 2, 3}, 2)); + test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {64, 16, 2, 3}, 3)); + return test_cases; } diff --git a/tests/test-jinja.cpp b/tests/test-jinja.cpp index 1f25c6ae71..f5197bd33f 100644 --- a/tests/test-jinja.cpp +++ b/tests/test-jinja.cpp @@ -691,6 +691,48 @@ static void test_filters(testing & t) { "{\n \"a\": 1,\n \"b\": [\n 1,\n 2\n ]\n}" ); + test_template(t, "indent", + "{{ data|indent(2) }}", + {{ "data", "foo\nbar" }}, + "foo\n bar" + ); + + test_template(t, "indent first only", + "{{ data|indent(width=3,first=true) }}", + {{ "data", "foo\nbar" }}, + " foo\n bar" + ); + + test_template(t, "indent blank lines and first line", + "{{ data|indent(width=5,blank=true,first=true) }}", + {{ "data", "foo\n\nbar" }}, + " foo\n \n bar" + ); + + test_template(t, "indent with default width", + "{{ data|indent() }}", + {{ "data", "foo\nbar" }}, + "foo\n bar" + ); + + test_template(t, "indent with no newline", + "{{ data|indent }}", + {{ "data", "foo" }}, + "foo" + ); + + test_template(t, "indent with trailing newline", + "{{ data|indent(blank=true) }}", + {{ "data", "foo\n" }}, + "foo\n " + ); + + test_template(t, "indent with string", + "{{ data|indent(width='>>>>') }}", + {{ "data", "foo\nbar" }}, + "foo\n>>>>bar" + ); + test_template(t, "chained filters", "{{ ' HELLO '|trim|lower }}", json::object(), diff --git a/tools/cli/cli.cpp b/tools/cli/cli.cpp index 02ccb72598..ad421e6326 100644 --- a/tools/cli/cli.cpp +++ b/tools/cli/cli.cpp @@ -52,6 +52,7 @@ struct cli_context { json messages = json::array(); std::vector input_files; task_params defaults; + bool verbose_prompt; // thread for showing "loading" animation std::atomic loading_show; @@ -66,6 +67,8 @@ struct cli_context { defaults.stream = true; // make sure we always use streaming mode defaults.timings_per_token = true; // in order to get timings even when we cancel mid-way // defaults.return_progress = true; // TODO: show progress + + verbose_prompt = params.verbose_prompt; } std::string generate_completion(result_timings & out_timings) { @@ -91,6 +94,12 @@ struct cli_context { rd.post_task({std::move(task)}); } + if (verbose_prompt) { + console::set_display(DISPLAY_TYPE_PROMPT); + console::log("%s\n\n", chat_params.prompt.c_str()); + console::set_display(DISPLAY_TYPE_RESET); + } + // wait for first result console::spinner::start(); server_task_result_ptr result = rd.next(should_stop); diff --git a/tools/mtmd/CMakeLists.txt b/tools/mtmd/CMakeLists.txt index 751440af32..755a3d4b00 100644 --- a/tools/mtmd/CMakeLists.txt +++ b/tools/mtmd/CMakeLists.txt @@ -19,6 +19,8 @@ add_library(mtmd models/glm4v.cpp models/internvl.cpp models/kimivl.cpp + models/kimik25.cpp + models/nemotron-v2-vl.cpp models/llama4.cpp models/llava.cpp models/minicpmv.cpp diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index ad232178bf..03bedf9d3f 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -235,6 +235,8 @@ enum projector_type { PROJECTOR_TYPE_LFM2A, PROJECTOR_TYPE_GLM4V, PROJECTOR_TYPE_YOUTUVL, + PROJECTOR_TYPE_KIMIK25, + PROJECTOR_TYPE_NEMOTRON_V2_VL, PROJECTOR_TYPE_UNKNOWN, }; @@ -268,6 +270,8 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_LFM2A, "lfm2a"}, { PROJECTOR_TYPE_GLM4V, "glm4v"}, { PROJECTOR_TYPE_YOUTUVL, "youtuvl"}, + { PROJECTOR_TYPE_KIMIK25, "kimik25"}, + { PROJECTOR_TYPE_NEMOTRON_V2_VL, "nemotron_v2_vl"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index d4ff9151bb..e0eb9b32c8 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -15,6 +15,7 @@ enum ffn_op_type { FFN_GELU_ERF, FFN_SILU, FFN_GELU_QUICK, + FFN_RELU_SQR, }; enum norm_type { diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index 9fa5afc390..57f6dd00a3 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -10,6 +10,7 @@ #include "ggml-backend.h" #include "gguf.h" +#include #include #include #include @@ -341,9 +342,17 @@ ggml_tensor * clip_graph::build_vit( /* nb2 */ cur->nb[1], /* offset */ ggml_row_size(cur->type, 2 * n_embd)); - // TODO: q/k norm requires row size == n_embd, while here it's d_head - // we can add support in the future if needed - GGML_ASSERT(layer.q_norm == nullptr && layer.k_norm == nullptr); + if (layer.q_norm) { + GGML_ASSERT(layer.q_norm->ne[0] == Qcur->ne[0]); + Qcur = build_norm(Qcur, layer.q_norm, NULL, norm_t, eps, il); + cb(Qcur, "Qcur_norm", il); + } + + if (layer.k_norm) { + GGML_ASSERT(layer.k_norm->ne[0] == Kcur->ne[0]); + Kcur = build_norm(Kcur, layer.k_norm, NULL, norm_t, eps, il); + cb(Kcur, "Kcur_norm", il); + } } else { // separate q, k, v @@ -558,6 +567,12 @@ ggml_tensor * clip_graph::build_ffn( cur = ggml_gelu_quick(ctx0, cur); cb(cur, "ffn_gelu_quick", il); } break; + case FFN_RELU_SQR: + { + cur = ggml_relu(ctx0, cur); + cur = ggml_sqr(ctx0, cur); + cb(cur, "ffn_relu_sqr", il); + } break; } if (down) { @@ -613,9 +628,6 @@ ggml_tensor * clip_graph::build_attn( ggml_tensor * v = ggml_permute(ctx0, v_cur, 1, 2, 0, 3); v = ggml_cont(ctx0, v); - const auto n_tokens = q->ne[1]; - const auto n_head = q->ne[2]; - ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); // F32 may not needed for vision encoders? // ggml_mul_mat_set_prec(kq, GGML_PREC_F32); @@ -624,7 +636,7 @@ ggml_tensor * clip_graph::build_attn( ggml_tensor * kqv = ggml_mul_mat(ctx0, v, kq); cur = ggml_permute(ctx0, kqv, 0, 2, 1, 3); - cur = ggml_cont_2d(ctx0, cur, cur->ne[0]*n_head, n_tokens); + cur = ggml_cont_2d(ctx0, cur, cur->ne[0] * cur->ne[1], cur->ne[2] * cur->ne[3]); } cb(cur, "kqv_out", il); @@ -672,8 +684,8 @@ ggml_tensor * clip_graph::build_rope_2d( { first = ggml_view_3d(ctx0, cur, n_dim/2, n_head, n_pos, - ggml_row_size(cur->type, n_dim), - ggml_row_size(cur->type, n_dim*n_head), + cur->nb[1], + cur->nb[2], 0); first = ggml_rope_ext( ctx0, @@ -691,8 +703,8 @@ ggml_tensor * clip_graph::build_rope_2d( { second = ggml_view_3d(ctx0, cur, n_dim/2, n_head, n_pos, - ggml_row_size(cur->type, n_dim), - ggml_row_size(cur->type, n_dim*n_head), + cur->nb[1], + cur->nb[2], n_dim/2 * ggml_element_size(cur)); second = ggml_rope_ext( ctx0, @@ -809,6 +821,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { builder = std::make_unique(ctx, img); } break; + case PROJECTOR_TYPE_NEMOTRON_V2_VL: + { + builder = std::make_unique(ctx, img); + } break; case PROJECTOR_TYPE_LLAMA4: { builder = std::make_unique(ctx, img); @@ -825,6 +841,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { builder = std::make_unique(ctx, img); } break; + case PROJECTOR_TYPE_KIMIK25: + { + builder = std::make_unique(ctx, img); + } break; case PROJECTOR_TYPE_COGVLM: { builder = std::make_unique(ctx, img); @@ -1105,6 +1125,7 @@ struct clip_model_loader { } } break; case PROJECTOR_TYPE_INTERNVL: + case PROJECTOR_TYPE_NEMOTRON_V2_VL: { get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false); } break; @@ -1116,9 +1137,8 @@ struct clip_model_loader { case PROJECTOR_TYPE_LFM2: { get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false); - // ref: https://huggingface.co/LiquidAI/LFM2-VL-3B/blob/main/preprocessor_config.json - // config above specifies number of tokens after downsampling, while here it is before, relax lowerbound to 64 - hparams.set_limit_image_tokens(64, 1024); + // ref: https://huggingface.co/LiquidAI/LFM2.5-VL-1.6B/blob/main/processor_config.json + hparams.set_limit_image_tokens(64, 256); } break; case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_LIGHTONOCR: @@ -1139,6 +1159,22 @@ struct clip_model_loader { hparams.set_limit_image_tokens(8, 1024); hparams.set_warmup_n_tokens(256); // avoid OOM on warmup } break; + case PROJECTOR_TYPE_KIMIK25: + { + hparams.rope_theta = 10000.0f; + get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false); + + int min_pixels = 0, max_pixels = 0; + get_u32(KEY_IMAGE_MIN_PIXELS, min_pixels, false); + get_u32(KEY_IMAGE_MAX_PIXELS, max_pixels, false); + if (min_pixels > 0 && max_pixels > 0) { + hparams.image_min_pixels = min_pixels; + hparams.image_max_pixels = max_pixels; + hparams.warmup_image_size = static_cast(std::sqrt(max_pixels)); + } else { + hparams.set_limit_image_tokens(2, 4096); + } + } break; case PROJECTOR_TYPE_GEMMA3: { // default value (used by all model sizes in gemma 3 family) @@ -1668,6 +1704,7 @@ struct clip_model_loader { model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); } break; case PROJECTOR_TYPE_KIMIVL: + case PROJECTOR_TYPE_KIMIK25: { model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM); model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B); @@ -1746,6 +1783,12 @@ struct clip_model_loader { model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight")); model.mm_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias")); } break; + case PROJECTOR_TYPE_NEMOTRON_V2_VL: + { + model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight")); + model.mm_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight")); + model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight")); + } break; case PROJECTOR_TYPE_GLMA: { model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight")); @@ -2807,6 +2850,119 @@ private: } }; +// ref: https://github.com/huggingface/transformers/blob/v5.1.0/src/transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +// some of the logic is similar to llava_uhd, but with different hyperparameters and some logic is unique (e.g. grid layout) +struct lfm2_vl_image_processor { + // ref: https://huggingface.co/LiquidAI/LFM2.5-VL-1.6B/blob/main/processor_config.json + static constexpr int min_tiles = 2; + static constexpr int max_tiles = 10; + static constexpr float max_pixels_tolerance = 2.0f; + static constexpr int tile_size = 512; + + static llava_uhd::slice_instructions get_slice_instructions(struct clip_ctx * ctx, const clip_image_size & original_size) { + llava_uhd::slice_instructions inst; + const auto & params = ctx->model.hparams; + const int align_size = params.patch_size * params.n_merge; + + inst.interpolation_overview = img_tool::RESIZE_ALGO_BILINEAR; + inst.interpolation_refined = img_tool::RESIZE_ALGO_BILINEAR; + inst.overview_size = img_tool::calc_size_preserved_ratio(original_size, align_size, params.image_min_pixels, params.image_max_pixels); + + // tile if either dimension exceeds tile_size with tolerance + const bool needs_tiling = original_size.width > tile_size * max_pixels_tolerance || original_size.height > tile_size * max_pixels_tolerance; + + if (!needs_tiling) { + inst.refined_size = clip_image_size{0, 0}; + inst.grid_size = clip_image_size{0, 0}; + return inst; + } + + const clip_image_size grid = get_grid_layout(original_size.height, original_size.width); + + inst.grid_size = grid; + inst.refined_size = clip_image_size{tile_size * grid.width, tile_size * grid.height}; + + LOG_DBG("%s: original size: %d x %d, overview size: %d x %d, refined size: %d x %d, grid size: %d x %d\n", + __func__, + original_size.width, original_size.height, + inst.overview_size.width, inst.overview_size.height, + inst.refined_size.width, inst.refined_size.height, + grid.width, grid.height); + + for (int row = 0; row < grid.height; row++) { + for (int col = 0; col < grid.width; col++) { + llava_uhd::slice_coordinates slice; + slice.x = col * tile_size; + slice.y = row * tile_size; + slice.size = clip_image_size{tile_size, tile_size}; + inst.slices.push_back(slice); + LOG_DBG("%s: slice %d: x=%d, y=%d, size=%d x %d\n", + __func__, (int)inst.slices.size() - 1, + slice.x, slice.y, slice.size.width, slice.size.height); + } + } + + return inst; + } + +private: + static clip_image_size find_closest_aspect_ratio( + float aspect_ratio, + const std::vector & target_ratios, + int width, int height) { + float best_ratio_diff = std::numeric_limits::max(); + clip_image_size best_ratio = {1, 1}; + const float area = static_cast(width * height); + + for (const auto & ratio : target_ratios) { + const float target_aspect_ratio = static_cast(ratio.width) / ratio.height; + const float ratio_diff = std::abs(aspect_ratio - target_aspect_ratio); + if (ratio_diff < best_ratio_diff) { + best_ratio_diff = ratio_diff; + best_ratio = ratio; + } else if (ratio_diff == best_ratio_diff) { + const float target_area = static_cast(tile_size * tile_size * ratio.width * ratio.height); + if (area > 0.5f * target_area) { + best_ratio = ratio; + } + } + } + return best_ratio; + } + + static std::vector get_target_ratios() { + std::vector ratios; + for (int n = min_tiles; n <= max_tiles; n++) { + for (int w = 1; w <= n; w++) { + for (int h = 1; h <= n; h++) { + if (w * h >= min_tiles && w * h <= max_tiles) { + bool found = false; + for (const auto & r : ratios) { + if (r.width == w && r.height == h) { + found = true; + break; + } + } + if (!found) { + ratios.push_back({w, h}); + } + } + } + } + } + std::sort(ratios.begin(), ratios.end(), [](const clip_image_size & a, const clip_image_size & b) { + return a.width * a.height < b.width * b.height; + }); + return ratios; + } + + static clip_image_size get_grid_layout(int height, int width) { + const float aspect_ratio = static_cast(width) / height; + const auto ratios = get_target_ratios(); + return find_closest_aspect_ratio(aspect_ratio, ratios, width, height); + } +}; + // returns the normalized float tensor for llava-1.5, for spatial_unpad with anyres processing for llava-1.6 it returns the normalized image patch tensors as a vector // res_imgs memory is being allocated here, previous allocations will be freed if found bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, struct clip_image_f32_batch * res_imgs) { @@ -2954,6 +3110,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str case PROJECTOR_TYPE_GLM_EDGE: case PROJECTOR_TYPE_GEMMA3: case PROJECTOR_TYPE_INTERNVL: // TODO @ngxson : support dynamic resolution + case PROJECTOR_TYPE_NEMOTRON_V2_VL: { clip_image_u8 resized_image; int sz = params.image_size; @@ -3021,6 +3178,20 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str } break; case PROJECTOR_TYPE_LFM2: + { + auto const inst = lfm2_vl_image_processor::get_slice_instructions(ctx, original_size); + std::vector imgs = llava_uhd::slice_image(img, inst); + + for (size_t i = 0; i < imgs.size(); ++i) { + clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std); + res_imgs->entries.push_back(std::move(res)); + } + + res_imgs->grid_x = inst.grid_size.width; + res_imgs->grid_y = inst.grid_size.height; + } break; + case PROJECTOR_TYPE_KIMIVL: { GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0); @@ -3032,8 +3203,24 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str const std::array pad_color = {122, 116, 104}; clip_image_u8 resized_img; - const bool pad = (ctx->proj_type() != PROJECTOR_TYPE_LFM2); - img_tool::resize(*img, resized_img, target_size, img_tool::RESIZE_ALGO_BILINEAR, pad, pad_color); + img_tool::resize(*img, resized_img, target_size, img_tool::RESIZE_ALGO_BILINEAR, true, pad_color); + clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(resized_img, *res, params.image_mean, params.image_std); + res_imgs->entries.push_back(std::move(res)); + } break; + + case PROJECTOR_TYPE_KIMIK25: + { + GGML_ASSERT(params.image_min_pixels > 0 && params.image_max_pixels > 0); + const clip_image_size target_size = img_tool::calc_size_preserved_ratio( + original_size, + params.patch_size * params.n_merge, + params.image_min_pixels, + params.image_max_pixels); + const std::array pad_color = {0, 0, 0}; + + clip_image_u8 resized_img; + img_tool::resize(*img, resized_img, target_size, img_tool::RESIZE_ALGO_BICUBIC, true, pad_color); clip_image_f32_ptr res(clip_image_f32_init()); normalize_image_u8_to_f32(resized_img, *res, params.image_mean, params.image_std); res_imgs->entries.push_back(std::move(res)); @@ -3233,6 +3420,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im case PROJECTOR_TYPE_GEMMA3: case PROJECTOR_TYPE_IDEFICS3: case PROJECTOR_TYPE_INTERNVL: + case PROJECTOR_TYPE_NEMOTRON_V2_VL: case PROJECTOR_TYPE_LLAMA4: { // both X and Y are downscaled by the scale factor @@ -3247,6 +3435,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im } break; case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_KIMIVL: + case PROJECTOR_TYPE_KIMIK25: { // dynamic size int out_patch_size = params.patch_size * ctx->model.hparams.n_merge; @@ -3588,6 +3777,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima } break; case PROJECTOR_TYPE_PIXTRAL: case PROJECTOR_TYPE_KIMIVL: + case PROJECTOR_TYPE_KIMIK25: case PROJECTOR_TYPE_LIGHTONOCR: { // set the 2D positions @@ -3639,6 +3829,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima case PROJECTOR_TYPE_GEMMA3NV: case PROJECTOR_TYPE_IDEFICS3: case PROJECTOR_TYPE_INTERNVL: + case PROJECTOR_TYPE_NEMOTRON_V2_VL: case PROJECTOR_TYPE_QWEN2A: case PROJECTOR_TYPE_GLMA: case PROJECTOR_TYPE_ULTRAVOX: @@ -3724,6 +3915,47 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima ggml_backend_tensor_get(embeddings, vec, 0, ggml_nbytes(embeddings)); } + // Debug: dump final embeddings if MTMD_DEBUG_EMBEDDINGS is set + if (std::getenv("MTMD_DEBUG_EMBEDDINGS") != nullptr) { + const int64_t n_embd = embeddings->ne[0]; + const int64_t n_tokens = embeddings->ne[1]; + std::vector emb_data(n_embd * n_tokens); + ggml_backend_tensor_get(embeddings, emb_data.data(), 0, ggml_nbytes(embeddings)); + + LOG_INF("\n=== MTMD_DEBUG_EMBEDDINGS ===\n"); + LOG_INF("Shape: [%lld, %lld]\n", (long long)n_embd, (long long)n_tokens); + + // Print first few values of first token + LOG_INF("Token 0 (first 16 values): "); + for (int i = 0; i < std::min((int64_t)16, n_embd); i++) { + LOG_INF("%.6f ", emb_data[i]); + } + LOG_INF("\n"); + + // Print last few values of first token + if (n_embd > 16) { + LOG_INF("Token 0 (last 16 values): "); + for (int64_t i = n_embd - 16; i < n_embd; i++) { + LOG_INF("%.6f ", emb_data[i]); + } + LOG_INF("\n"); + } + + // Compute and print statistics + float sum = 0.0f, sum_sq = 0.0f, min_val = emb_data[0], max_val = emb_data[0]; + for (size_t i = 0; i < emb_data.size(); i++) { + sum += emb_data[i]; + sum_sq += emb_data[i] * emb_data[i]; + min_val = std::min(min_val, emb_data[i]); + max_val = std::max(max_val, emb_data[i]); + } + float mean = sum / emb_data.size(); + float variance = (sum_sq / emb_data.size()) - (mean * mean); + LOG_INF("Stats: mean=%.6f, std=%.6f, min=%.6f, max=%.6f, sum=%.6f\n", + mean, sqrtf(variance), min_val, max_val, sum); + LOG_INF("=== END MTMD_DEBUG_EMBEDDINGS ===\n\n"); + } + return true; } @@ -3761,6 +3993,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { case PROJECTOR_TYPE_MUSIC_FLAMINGO: return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_INTERNVL: + case PROJECTOR_TYPE_NEMOTRON_V2_VL: return ctx->model.mm_3_w->ne[1]; case PROJECTOR_TYPE_LLAMA4: return ctx->model.mm_model_proj->ne[1]; @@ -3770,6 +4003,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_LFM2: case PROJECTOR_TYPE_KIMIVL: + case PROJECTOR_TYPE_KIMIK25: return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_COGVLM: return ctx->model.mm_4h_to_h_w->ne[1]; diff --git a/tools/mtmd/models/glm4v.cpp b/tools/mtmd/models/glm4v.cpp index f39b6922eb..6f52df41ab 100644 --- a/tools/mtmd/models/glm4v.cpp +++ b/tools/mtmd/models/glm4v.cpp @@ -2,7 +2,6 @@ ggml_cgraph * clip_graph_glm4v::build() { GGML_ASSERT(model.patch_bias != nullptr); - GGML_ASSERT(model.position_embeddings != nullptr); GGML_ASSERT(model.class_embedding == nullptr); const int batch_size = 1; @@ -45,19 +44,22 @@ ggml_cgraph * clip_graph_glm4v::build() { // pos-conv norm inp = build_norm(inp, model.norm_embd_w, model.norm_embd_b, norm_t, eps, -1); - // calculate absolute position embedding and apply - ggml_tensor * learned_pos_embd = resize_position_embeddings(GGML_SCALE_MODE_BICUBIC); - learned_pos_embd = ggml_cont_4d( - ctx0, learned_pos_embd, - n_embd * 2, n_patches_x / 2, n_patches_y, batch_size); - learned_pos_embd = ggml_reshape_4d( - ctx0, learned_pos_embd, - n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2)); - learned_pos_embd = ggml_permute(ctx0, learned_pos_embd, 0, 2, 1, 3); - learned_pos_embd = ggml_cont_3d( - ctx0, learned_pos_embd, - n_embd, n_patches_x * n_patches_y, batch_size); - cb(learned_pos_embd, "learned_pos_embd", -1); + ggml_tensor * learned_pos_embd = nullptr; + // Note: GLM-OCR does not have learned position embeddings + if (model.position_embeddings != nullptr) { + learned_pos_embd = resize_position_embeddings(GGML_SCALE_MODE_BICUBIC); + learned_pos_embd = ggml_cont_4d( + ctx0, learned_pos_embd, + n_embd * 2, n_patches_x / 2, n_patches_y, batch_size); + learned_pos_embd = ggml_reshape_4d( + ctx0, learned_pos_embd, + n_embd * 2, n_patches_x / 2, 2, batch_size * (n_patches_y / 2)); + learned_pos_embd = ggml_permute(ctx0, learned_pos_embd, 0, 2, 1, 3); + learned_pos_embd = ggml_cont_3d( + ctx0, learned_pos_embd, + n_embd, n_patches_x * n_patches_y, batch_size); + cb(learned_pos_embd, "learned_pos_embd", -1); + } auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { return ggml_rope_multi( diff --git a/tools/mtmd/models/kimik25.cpp b/tools/mtmd/models/kimik25.cpp new file mode 100644 index 0000000000..cf9f27f63a --- /dev/null +++ b/tools/mtmd/models/kimik25.cpp @@ -0,0 +1,101 @@ +#include "models.h" +#include +#include + +// note: this is similar to clip_graph::resize_position_embeddings, major difference is having +// the w/h in ne[1] and ne[2] instead of assuming with sqrt. Could try storing the tensor in 2D instead +// with a w*h? Also the permute is a bit different at (2, 1, 0, 3) instead of (2, 0, 1, 3). +ggml_tensor * clip_graph_kimik25::resize_position_embeddings_3d(uint32_t interpolation_mode) { + ggml_tensor * pos_embd = model.position_embeddings; + const int height = img.ny / patch_size; + const int width = img.nx / patch_size; + const uint32_t mode = interpolation_mode; + + GGML_ASSERT(pos_embd); + + const int64_t stored_c = pos_embd->ne[0]; // C = 1152 + const int64_t orig_w = pos_embd->ne[1]; // W = 64 + const int64_t orig_h = pos_embd->ne[2]; // H = 64 + + GGML_ASSERT(stored_c == n_embd); + + if (height == (int)orig_h && width == (int)orig_w) { + // No interpolation needed, just flatten to [C, H*W] + return ggml_cont_2d(ctx0, pos_embd, n_embd, width * height); + } + + pos_embd = ggml_permute(ctx0, pos_embd, 2, 1, 0, 3); + pos_embd = ggml_interpolate(ctx0, pos_embd, height, width, n_embd, 1, mode); + pos_embd = ggml_permute(ctx0, pos_embd, 2, 1, 0, 3); + pos_embd = ggml_cont_2d(ctx0, pos_embd, n_embd, width * height); + return pos_embd; +} + +ggml_cgraph * clip_graph_kimik25::build() { + ggml_tensor * pos_h = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches); + ggml_set_name(pos_h, "pos_h"); + ggml_set_input(pos_h); + + ggml_tensor * pos_w = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_patches); + ggml_set_name(pos_w, "pos_w"); + ggml_set_input(pos_w); + + ggml_tensor * learned_pos_embd = resize_position_embeddings_3d(GGML_SCALE_MODE_BICUBIC); + + // Kimi-K2.5 uses interleaved 2D RoPE pattern natively, but + // Q / K are permuted during conversion to use split format. + auto add_pos = [&](ggml_tensor * cur, const clip_layer &) { + cur = build_rope_2d(ctx0, cur, pos_w, pos_h, hparams.rope_theta, false); + return cur; + }; + + ggml_tensor * inp = build_inp(); + + // I don't know why, but doing this in the build_vit lead to the ggml_add not occurring? + // Doing it manually here does work. + inp = ggml_add(ctx0, inp, learned_pos_embd); + + ggml_tensor * cur = build_vit( + inp, n_patches, + NORM_TYPE_NORMAL, + hparams.ffn_op, + nullptr, + add_pos); + + cb(cur, "vit_out", -1); + + { + // patch_merger + const int scale_factor = model.hparams.n_merge; + cur = build_patch_merge_permute(cur, scale_factor); + + // projection norm + int proj_inp_dim = cur->ne[0]; + int n_merged_patches = cur->ne[1]; + cur = ggml_view_2d(ctx0, cur, + n_embd, n_merged_patches * scale_factor * scale_factor, + ggml_row_size(cur->type, n_embd), 0); + cur = ggml_norm(ctx0, cur, hparams.eps); + cur = ggml_mul(ctx0, cur, model.mm_input_norm_w); + cur = ggml_add(ctx0, cur, model.mm_input_norm_b); + cur = ggml_view_2d(ctx0, cur, + proj_inp_dim, n_merged_patches, + ggml_row_size(cur->type, proj_inp_dim), 0); + cb(cur, "proj_inp_normed", -1); + + // projection mlp + cur = build_ffn(cur, + model.mm_1_w, model.mm_1_b, + nullptr, nullptr, + model.mm_2_w, model.mm_2_b, + FFN_GELU, + -1); + + cb(cur, "proj_out", -1); + } + + // build the graph + ggml_build_forward_expand(gf, cur); + + return gf; +} diff --git a/tools/mtmd/models/models.h b/tools/mtmd/models/models.h index 9970980c7b..0beff16c5e 100644 --- a/tools/mtmd/models/models.h +++ b/tools/mtmd/models/models.h @@ -42,6 +42,11 @@ struct clip_graph_internvl : clip_graph { ggml_cgraph * build() override; }; +struct clip_graph_nemotron_v2_vl : clip_graph { + clip_graph_nemotron_v2_vl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; +}; + struct clip_graph_llama4 : clip_graph { clip_graph_llama4(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} ggml_cgraph * build() override; @@ -109,3 +114,10 @@ struct clip_graph_mobilenetv5 : clip_graph { ggml_tensor * inp, const mobilenetv5_block & block); }; + +struct clip_graph_kimik25 : clip_graph { + clip_graph_kimik25(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; + + ggml_tensor * resize_position_embeddings_3d(uint32_t interpolation_mode); +}; diff --git a/tools/mtmd/models/nemotron-v2-vl.cpp b/tools/mtmd/models/nemotron-v2-vl.cpp new file mode 100644 index 0000000000..03094be1b2 --- /dev/null +++ b/tools/mtmd/models/nemotron-v2-vl.cpp @@ -0,0 +1,35 @@ +#include "models.h" + +ggml_cgraph * clip_graph_nemotron_v2_vl::build() { + GGML_ASSERT(model.class_embedding != nullptr); + GGML_ASSERT(model.position_embeddings != nullptr); + + const int n_registers = model.class_embedding->ne[1]; + const int n_pos = n_patches + n_registers; + + ggml_tensor * inp = build_inp(); + + // add position embeddings (pre-downsampled during GGUF conversion for fixed 512x512 input) + inp = ggml_add(ctx0, inp, model.position_embeddings); + cb(inp, "inp_pos", -1); + + inp = ggml_concat(ctx0, model.class_embedding, inp, 1); + + ggml_tensor * cur = build_vit(inp, n_pos, NORM_TYPE_NORMAL, hparams.ffn_op, nullptr, nullptr); + + cur = ggml_view_2d(ctx0, cur, + n_embd, n_patches, + ggml_row_size(cur->type, n_embd), + n_registers * ggml_row_size(cur->type, n_embd)); + + cur = build_patch_merge_permute(cur, model.hparams.n_merge); + + { + cur = build_norm(cur, model.mm_0_w, nullptr, NORM_TYPE_RMS, 1e-6, -1); + cur = build_ffn(cur, model.mm_1_w, nullptr, nullptr, nullptr, model.mm_3_w, nullptr, FFN_RELU_SQR, -1); + } + + ggml_build_forward_expand(gf, cur); + + return gf; +} diff --git a/tools/mtmd/models/qwen3vl.cpp b/tools/mtmd/models/qwen3vl.cpp index 35a42cb84d..5ecb10fe43 100644 --- a/tools/mtmd/models/qwen3vl.cpp +++ b/tools/mtmd/models/qwen3vl.cpp @@ -182,7 +182,9 @@ ggml_cgraph * clip_graph_qwen3vl::build() { model.mm_1_w, model.mm_1_b, ffn_op_type::FFN_GELU, -1); - embeddings = ggml_concat(ctx0, embeddings, deepstack_features, 0); // concat along the feature dimension + if (deepstack_features) { + embeddings = ggml_concat(ctx0, embeddings, deepstack_features, 0); + } // concat along the feature dimension // build the graph ggml_build_forward_expand(gf, embeddings); diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index d037e834f3..af733d97d5 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -85,6 +85,7 @@ enum mtmd_slice_tmpl { MTMD_SLICE_TMPL_MINICPMV_2_6, MTMD_SLICE_TMPL_LLAMA4, MTMD_SLICE_TMPL_IDEFICS3, + MTMD_SLICE_TMPL_LFM2, }; const char * mtmd_default_marker() { @@ -174,7 +175,7 @@ struct mtmd_context { clip_context_params ctx_clip_params { /* use_gpu */ ctx_params.use_gpu, - /* flash_attn_type */ CLIP_FLASH_ATTN_TYPE_AUTO, + /* flash_attn_type */ mtmd_get_clip_flash_attn_type(ctx_params.flash_attn_type), /* image_min_tokens */ ctx_params.image_min_tokens, /* image_max_tokens */ ctx_params.image_max_tokens, /* warmup */ ctx_params.warmup, @@ -307,9 +308,19 @@ struct mtmd_context { img_end = "<|im_end|>"; } else if (proj == PROJECTOR_TYPE_LFM2) { - img_beg = "<|image_start|>"; - img_end = "<|image_end|>"; - + // multi-tile: + // <|image_start|> + // <|img_row_1_col_1|> (tile) <|img_row_1_col_2|> (tile) ... + // <|img_thumbnail|> (thumbnail) + // <|image_end|> + // single-tile: + // <|image_start|> (image) <|image_end|> + img_beg = "<|image_start|>"; + img_end = "<|image_end|>"; + slice_tmpl = MTMD_SLICE_TMPL_LFM2; + sli_img_start_tmpl = "<|img_row_%d_col_%d|>"; + tok_ov_img_start = {lookup_token("<|img_thumbnail|>")}; + ov_img_first = false; } else if (proj == PROJECTOR_TYPE_GLM4V) { img_beg = "<|begin_of_image|>"; img_end = "<|end_of_image|>"; @@ -562,11 +573,13 @@ struct mtmd_tokenizer { } // handle llava-uhd style preprocessing + const bool has_tiling_grid = batch_f32.grid_x > 0 && batch_f32.grid_y > 0; if ( ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_5 || ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_6 || ctx->slice_tmpl == MTMD_SLICE_TMPL_LLAMA4 || ctx->slice_tmpl == MTMD_SLICE_TMPL_IDEFICS3 + || (ctx->slice_tmpl == MTMD_SLICE_TMPL_LFM2 && has_tiling_grid) ) { const int n_col = batch_f32.grid_x; const int n_row = batch_f32.grid_y; diff --git a/tools/mtmd/tests.sh b/tools/mtmd/tests.sh index 012958e0e0..d2b7e684af 100755 --- a/tools/mtmd/tests.sh +++ b/tools/mtmd/tests.sh @@ -28,6 +28,14 @@ if [ "${1:-}" = "huge" ]; then echo "Include BIG and HUGE models..." fi +# Check if the second argument is "flash", then enable flash attention +# This is useful to test if flash attention off works correctly +FLASH_ATTN="on" +if [ "${2:-}" = "flash_off" ] || [ "${1:-}" = "flash_off" ]; then + FLASH_ATTN="off" + echo "Flash attention disabled..." +fi + ############### arr_prefix=() @@ -143,6 +151,7 @@ for i in "${!arr_hf[@]}"; do -hf $(printf %q "$hf") \ --image $(printf %q "$SCRIPT_DIR/$inp_file") \ --temp 0 -n 128 \ + --flash-attn $(printf %q "$FLASH_ATTN") \ ${extra_args}" # if extra_args does not contain -p, we add a default prompt diff --git a/tools/perplexity/perplexity.cpp b/tools/perplexity/perplexity.cpp index 1ead9c871e..433b747f0d 100644 --- a/tools/perplexity/perplexity.cpp +++ b/tools/perplexity/perplexity.cpp @@ -347,7 +347,8 @@ static results_perplexity perplexity_v2(llama_context * ctx, const common_params int count = 0; double nll = 0.0; - LOG_INF("%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch); + const int n_seq = std::max(1, n_batch / n_ctx); + LOG_INF("%s: computing over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq); for (int i = 0; i < n_chunk; ++i) { const int start = i * params.ppl_stride; @@ -1737,11 +1738,21 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { } const int n_batch = params.n_batch; - const int num_batches = (n_ctx + n_batch - 1)/n_batch; + const int num_batches = (static_cast(n_ctx) + n_batch - 1) / n_batch; + // Calculate n_seq based on the logits file's n_ctx, but cap it at what the context supports + const int n_seq_max = llama_n_seq_max(ctx); + int n_seq = std::max(1, n_batch / static_cast(n_ctx)); + if (n_seq > n_seq_max) { + LOG_WRN("%s: calculated n_seq=%d exceeds context's n_seq_max=%d, capping at %d\n", + __func__, n_seq, n_seq_max, n_seq_max); + n_seq = n_seq_max; + } const int nv = 2*((n_vocab + 1)/2) + 4; const bool add_bos = llama_vocab_get_add_bos(vocab); GGML_ASSERT(!llama_vocab_get_add_eos(vocab)); + llama_batch batch = llama_batch_init(std::min(n_batch, static_cast(n_ctx)*n_seq), 0, 1); + std::vector log_probs_uint16(size_t(n_ctx - 1 - n_ctx/2) * nv); std::vector kld_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk); std::vector p_diff_values(size_t(n_ctx - 1 - n_ctx/2)*n_chunk); @@ -1750,6 +1761,8 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { logits.reserve(size_t(n_ctx) * n_vocab); } + LOG_INF("%s: computing over %d chunks, n_ctx=%u, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq); + std::vector workers(std::thread::hardware_concurrency() - 1); auto mean_and_uncertainty = [] (double sum, double sum2, size_t count) { @@ -1774,107 +1787,122 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { auto kld_ptr = kld_values.data(); auto p_diff_ptr = p_diff_values.data(); - for (int i = 0; i < n_chunk; ++i) { + const int first = n_ctx/2; + + for (int i = 0; i < n_chunk; i += n_seq) { const int start = i * n_ctx; const int end = start + n_ctx; - const auto t_start = std::chrono::high_resolution_clock::now(); + const int n_seq_batch = std::min(n_seq, n_chunk - i); - if (in.read((char *)log_probs_uint16.data(), log_probs_uint16.size()*sizeof(uint16_t)).fail()) { - LOG_ERR("%s: failed reading log-probs for chunk %d\n", __func__, i); - return; - } + const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache llama_memory_clear(llama_get_memory(ctx), true); - llama_batch batch = llama_batch_init(n_batch, 0, 1); - for (int j = 0; j < num_batches; ++j) { const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); - // save original token and restore it after eval - const auto token_org = tokens[batch_start]; - - // add BOS token for the first batch of each chunk - if (add_bos && j == 0) { - tokens[batch_start] = llama_vocab_bos(vocab); - } + int n_outputs = 0; common_batch_clear(batch); - for (int i = 0; i < batch_size; i++) { - common_batch_add(batch, tokens[batch_start + i], j*n_batch + i, {0}, true); + for (int seq = 0; seq < n_seq_batch; seq++) { + int seq_start = batch_start + seq*n_ctx; + + // save original token and restore it after eval + const auto token_org = tokens[seq_start]; + + // add BOS token for the first batch of each chunk + if (add_bos && j == 0) { + tokens[seq_start] = llama_vocab_bos(vocab); + } + + for (int k = 0; k < batch_size; ++k) { + const int pos = j*n_batch + k; + const bool need_logits = pos >= first; + common_batch_add(batch, tokens[seq_start + k], pos, { seq }, need_logits); + n_outputs += need_logits; + } + + // restore the original token in case it was set to BOS + tokens[seq_start] = token_org; } if (llama_decode(ctx, batch)) { - LOG_ERR("%s : failed to eval\n", __func__); + LOG_ERR("%s : failed to decode\n", __func__); llama_batch_free(batch); return; } - // restore the original token in case it was set to BOS - tokens[batch_start] = token_org; - - if (num_batches > 1) { + if (num_batches > 1 && n_outputs > 0) { const auto * batch_logits = llama_get_logits(ctx); - logits.insert(logits.end(), batch_logits, batch_logits + size_t(batch_size) * n_vocab); + logits.insert(logits.end(), batch_logits, batch_logits + size_t(n_outputs) * n_vocab); } } - llama_batch_free(batch); - - const auto t_end = std::chrono::high_resolution_clock::now(); - if (i == 0) { + llama_synchronize(ctx); + const auto t_end = std::chrono::high_resolution_clock::now(); const float t_total = std::chrono::duration(t_end - t_start).count(); LOG_INF("%s: %.2f seconds per pass - ETA ", __func__, t_total); - int total_seconds = (int)(t_total * n_chunk); + int total_seconds = (int)(t_total * n_chunk / n_seq); if (total_seconds >= 60*60) { LOG("%d hours ", total_seconds / (60*60)); total_seconds = total_seconds % (60*60); } LOG("%.2f minutes\n", total_seconds / 60.0); + LOG("\n"); + LOG("chunk PPL ln(PPL(Q)/PPL(base)) KL Divergence Δp RMS Same top p\n"); } - LOG("\n"); - LOG("chunk PPL ln(PPL(Q)/PPL(base)) KL Divergence Δp RMS Same top p\n"); - const int first = n_ctx/2; - const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx); - process_logits(n_vocab, all_logits + size_t(first)*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, - workers, log_probs_uint16, kld, kld_ptr, p_diff_ptr); - p_diff_ptr += n_ctx - 1 - first; - kld_ptr += n_ctx - 1 - first; + // Read log probs for each sequence in the batch + for (int seq = 0; seq < n_seq_batch; seq++) { + if (in.read((char *)log_probs_uint16.data(), log_probs_uint16.size()*sizeof(uint16_t)).fail()) { + LOG_ERR("%s: failed reading log-probs for chunk %d\n", __func__, i + seq); + llama_batch_free(batch); + return; + } - LOG("%4d", i+1); + const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx + first); - auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count); - const double ppl_val = exp(log_ppl.first); - const double ppl_unc = ppl_val * log_ppl.second; // ppl_unc = sqrt( (dexp(x) / dx) ** 2 * log_ppl.second ** 2 ) - LOG(" %9.4lf ± %9.4lf", ppl_val, ppl_unc); + process_logits(n_vocab, all_logits, tokens.data() + start + seq*n_ctx + first, n_ctx - 1 - first, + workers, log_probs_uint16, kld, kld_ptr, p_diff_ptr); + p_diff_ptr += n_ctx - 1 - first; + kld_ptr += n_ctx - 1 - first; - auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count); - const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count); - const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first; - const double log_ppl_ratio_unc = sqrt(log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov); - LOG(" %10.5lf ± %10.5lf", log_ppl_ratio_val, log_ppl_ratio_unc); + LOG("%4d", i + seq + 1); - auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count); - LOG(" %10.5lf ± %10.5lf", kl_div.first, kl_div.second); + auto log_ppl = mean_and_uncertainty(kld.sum_nll, kld.sum_nll2, kld.count); + const double ppl_val = exp(log_ppl.first); + const double ppl_unc = ppl_val * log_ppl.second; + LOG(" %9.4lf ± %9.4lf", ppl_val, ppl_unc); - auto p_diff_mse = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count); - const double p_diff_rms_val = sqrt(p_diff_mse.first); - const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second; - LOG(" %6.3lf ± %6.3lf %%", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc); + auto log_ppl_base = mean_and_uncertainty(kld.sum_nll_base, kld.sum_nll_base2, kld.count); + const double log_ppl_cov = covariance(kld.sum_nll, kld.sum_nll_base, kld.sum_nll_nll_base, kld.count); + const double log_ppl_ratio_val = log_ppl.first - log_ppl_base.first; + const double log_ppl_ratio_unc = sqrt(log_ppl.second*log_ppl.second + log_ppl_base.second*log_ppl_base.second - 2.0*log_ppl_cov); + LOG(" %10.5lf ± %10.5lf", log_ppl_ratio_val, log_ppl_ratio_unc); - double p_top_val = 1.*kld.n_same_top/kld.count; - double p_top_unc = sqrt(p_top_val*(1 - p_top_val)/(kld.count - 1)); - LOG(" %6.3lf ± %6.3lf %%", 100.0*p_top_val, 100.0*p_top_unc); + auto kl_div = mean_and_uncertainty(kld.sum_kld, kld.sum_kld2, kld.count); + LOG(" %10.5lf ± %10.5lf", kl_div.first, kl_div.second); - LOG("\n"); + auto p_diff_mse = mean_and_uncertainty(kld.sum_p_diff2, kld.sum_p_diff4, kld.count); + const double p_diff_rms_val = sqrt(p_diff_mse.first); + const double p_diff_rms_unc = 0.5/p_diff_rms_val * p_diff_mse.second; + LOG(" %6.3lf ± %6.3lf %%", 100.0*p_diff_rms_val, 100.0*p_diff_rms_unc); + + double p_top_val = 1.*kld.n_same_top/kld.count; + double p_top_unc = sqrt(p_top_val*(1 - p_top_val)/(kld.count - 1)); + LOG(" %6.3lf ± %6.3lf %%", 100.0*p_top_val, 100.0*p_top_unc); + + LOG("\n"); + } logits.clear(); } + + llama_batch_free(batch); LOG("\n"); if (kld.count < 100) return; // we do not wish to do statistics on so few values @@ -1996,7 +2024,7 @@ int main(int argc, char ** argv) { const bool ppl = !params.hellaswag && !params.winogrande && !params.multiple_choice && !params.kl_divergence; - if (ppl) { + if (ppl || params.kl_divergence) { const int32_t n_seq = std::max(1, params.n_batch / n_ctx); const int32_t n_kv = n_seq * n_ctx; @@ -2006,12 +2034,8 @@ int main(int argc, char ** argv) { params.n_batch = std::min(params.n_batch, n_kv); } else { params.n_batch = std::min(params.n_batch, params.n_ctx); - if (params.kl_divergence) { - params.n_parallel = 1; - } else { - // ensure there's at least enough seq_ids for HellaSwag - params.n_parallel = std::max(4, params.n_parallel); - } + // ensure there's at least enough seq_ids for HellaSwag + params.n_parallel = std::max(4, params.n_parallel); } if (params.ppl_stride > 0) { diff --git a/tools/quantize/quantize.cpp b/tools/quantize/quantize.cpp index 0709e0bda0..c0f49279ee 100644 --- a/tools/quantize/quantize.cpp +++ b/tools/quantize/quantize.cpp @@ -119,27 +119,48 @@ static bool try_parse_ftype(const std::string & ftype_str_in, llama_ftype & ftyp [[noreturn]] static void usage(const char * executable) { printf("usage: %s [--help] [--allow-requantize] [--leave-output-tensor] [--pure] [--imatrix] [--include-weights]\n", executable); - printf(" [--exclude-weights] [--output-tensor-type] [--token-embedding-type] [--tensor-type] [--tensor-type-file] [--prune-layers] [--keep-split] [--override-kv]\n"); + printf(" [--exclude-weights] [--output-tensor-type] [--token-embedding-type] [--tensor-type] [--tensor-type-file]\n"); + printf(" [--prune-layers] [--keep-split] [--override-kv]\n"); printf(" model-f32.gguf [model-quant.gguf] type [nthreads]\n\n"); - printf(" --allow-requantize: Allows requantizing tensors that have already been quantized. Warning: This can severely reduce quality compared to quantizing from 16bit or 32bit\n"); - printf(" --leave-output-tensor: Will leave output.weight un(re)quantized. Increases model size but may also increase quality, especially when requantizing\n"); - printf(" --pure: Disable k-quant mixtures and quantize all tensors to the same type\n"); - printf(" --imatrix file_name: use data in file_name as importance matrix for quant optimizations\n"); - printf(" --include-weights tensor_name: use importance matrix for this/these tensor(s)\n"); - printf(" --exclude-weights tensor_name: use importance matrix for this/these tensor(s)\n"); - printf(" --output-tensor-type ggml_type: use this ggml_type for the output.weight tensor\n"); - printf(" --token-embedding-type ggml_type: use this ggml_type for the token embeddings tensor\n"); - printf(" --tensor-type TENSOR=TYPE: quantize this tensor to this ggml_type. example: --tensor-type attn_q=q8_0\n"); - printf(" Advanced option to selectively quantize tensors. May be specified multiple times.\n"); - printf(" --tensor-type-file tensor_type.txt: list of tensors to quantize to specific ggml_type. example: --tensor-type-file tensor_type_list.txt\n"); - printf(" Advanced option to selectively quantize a long list of tensors. Format to be tensor_name=ggml_type, separated by spaces/newline.\n"); - printf(" --prune-layers L0,L1,L2...comma-separated list of layer numbers to prune from the model\n"); - printf(" Advanced option to remove all tensors from the given layers\n"); - printf(" --keep-split: will generate quantized model in the same shards as input\n"); + printf(" --allow-requantize\n"); + printf(" allow requantizing tensors that have already been quantized\n"); + printf(" WARNING: this can severely reduce quality compared to quantizing\n"); + printf(" from 16bit or 32bit!\n"); + printf(" --leave-output-tensor\n"); + printf(" leave output.weight un(re)quantized\n"); + printf(" increases model size but may also increase quality, especially when requantizing\n"); + printf(" --pure\n"); + printf(" disable k-quant mixtures and quantize all tensors to the same type\n"); + printf(" --imatrix file_name\n"); + printf(" use data in file_name as importance matrix for quant optimizations\n"); + printf(" --include-weights tensor_name\n"); + printf(" use importance matrix for this/these tensor(s)\n"); + printf(" --exclude-weights tensor_name\n"); + printf(" do not use importance matrix for this/these tensor(s)\n"); + printf(" --output-tensor-type ggml_type\n"); + printf(" use this ggml_type for the output.weight tensor\n"); + printf(" --token-embedding-type ggml_type\n"); + printf(" use this ggml_type for the token embeddings tensor\n"); + printf(" --tensor-type tensor_name=ggml_type\n"); + printf(" quantize this tensor to this ggml_type\n"); + printf(" this is an advanced option to selectively quantize tensors. may be specified multiple times.\n"); + printf(" example: --tensor-type attn_q=q8_0\n"); + printf(" --tensor-type-file tensor_types.txt\n"); + printf(" list of tensors to quantize to a specific ggml_type\n"); + printf(" this is an advanced option to selectively quantize a long list of tensors.\n"); + printf(" the file should use the same format as above, separated by spaces or newlines.\n"); + printf(" --prune-layers L0,L1,L2...\n"); + printf(" comma-separated list of layer numbers to prune from the model\n"); + printf(" WARNING: this is an advanced option, use with care.\n"); + printf(" --keep-split\n"); + printf(" generate quantized model in the same shards as input\n"); printf(" --override-kv KEY=TYPE:VALUE\n"); - printf(" Advanced option to override model metadata by key in the quantized model. May be specified multiple times.\n"); - printf("Note: --include-weights and --exclude-weights cannot be used together\n"); - printf("\nAllowed quantization types:\n"); + printf(" override model metadata by key in the quantized model. may be specified multiple times.\n"); + printf(" WARNING: this is an advanced option, use with care.\n\n"); + printf("note: --include-weights and --exclude-weights cannot be used together\n\n"); + printf("-----------------------------------------------------------------------------\n"); + printf(" allowed quantization types\n"); + printf("-----------------------------------------------------------------------------\n\n"); for (const auto & it : QUANT_OPTIONS) { if (it.name != "COPY") { printf(" %2d or ", it.ftype); diff --git a/tools/rpc/rpc-server.cpp b/tools/rpc/rpc-server.cpp index 58b93c7468..6feb0e91f3 100644 --- a/tools/rpc/rpc-server.cpp +++ b/tools/rpc/rpc-server.cpp @@ -1,12 +1,7 @@ -#if defined(_MSC_VER) -#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING -#endif - #include "ggml-rpc.h" #ifdef _WIN32 # define NOMINMAX # define DIRECTORY_SEPARATOR '\\' -# include # include # include # include @@ -15,23 +10,43 @@ # include # include #endif -#include #include #include #include -#include #include #include #include -namespace fs = std::filesystem; +#if defined(__linux__) +#include +#include +#endif + +// NOTE: this is copied from common.cpp to avoid linking with libcommon +#ifdef _WIN32 +static std::wstring utf8_to_wstring(const std::string & str) { + if (str.empty()) { + return std::wstring(); + } + + int size = MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), NULL, 0); + + if (size <= 0) { + return std::wstring(); + } + + std::wstring wstr(size, 0); + MultiByteToWideChar(CP_UTF8, 0, str.c_str(), (int)str.size(), &wstr[0], size); + + return wstr; +} +#endif // NOTE: this is copied from common.cpp to avoid linking with libcommon // returns true if successful, false otherwise static bool fs_create_directory_with_parents(const std::string & path) { #ifdef _WIN32 - std::wstring_convert> converter; - std::wstring wpath = converter.from_bytes(path); + std::wstring wpath = utf8_to_wstring(path); // if the path already exists, check whether it's a directory const DWORD attributes = GetFileAttributesW(wpath.c_str()); @@ -44,9 +59,16 @@ static bool fs_create_directory_with_parents(const std::string & path) { // process path from front to back, procedurally creating directories while ((pos_slash = path.find('\\', pos_slash)) != std::string::npos) { const std::wstring subpath = wpath.substr(0, pos_slash); - const wchar_t * test = subpath.c_str(); - const bool success = CreateDirectoryW(test, NULL); + pos_slash += 1; + + // skip the drive letter, in some systems it can return an access denied error + if (subpath.length() == 2 && subpath[1] == ':') { + continue; + } + + const bool success = CreateDirectoryW(subpath.c_str(), NULL); + if (!success) { const DWORD error = GetLastError(); @@ -60,8 +82,6 @@ static bool fs_create_directory_with_parents(const std::string & path) { return false; } } - - pos_slash += 1; } return true; @@ -112,16 +132,31 @@ static std::string fs_get_cache_directory() { if (getenv("LLAMA_CACHE")) { cache_directory = std::getenv("LLAMA_CACHE"); } else { -#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || defined(__OpenBSD__) +#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || \ + defined(__OpenBSD__) || defined(__NetBSD__) if (std::getenv("XDG_CACHE_HOME")) { cache_directory = std::getenv("XDG_CACHE_HOME"); - } else { + } else if (std::getenv("HOME")) { cache_directory = std::getenv("HOME") + std::string("/.cache/"); + } else { +#if defined(__linux__) + /* no $HOME is defined, fallback to getpwuid */ + struct passwd *pw = getpwuid(getuid()); + if ((!pw) || (!pw->pw_dir)) { + throw std::runtime_error("Failed to find $HOME directory"); + } + + cache_directory = std::string(pw->pw_dir) + std::string("/.cache/"); +#else /* defined(__linux__) */ + throw std::runtime_error("Failed to find $HOME directory"); +#endif /* defined(__linux__) */ } #elif defined(__APPLE__) cache_directory = std::getenv("HOME") + std::string("/Library/Caches/"); #elif defined(_WIN32) cache_directory = std::getenv("LOCALAPPDATA"); +#elif defined(__EMSCRIPTEN__) + GGML_ABORT("not implemented on this platform"); #else # error Unknown architecture #endif diff --git a/tools/server/CMakeLists.txt b/tools/server/CMakeLists.txt index a39b4c5b35..5621a51b22 100644 --- a/tools/server/CMakeLists.txt +++ b/tools/server/CMakeLists.txt @@ -28,10 +28,6 @@ target_link_libraries(${TARGET} PUBLIC common mtmd ${CMAKE_THREAD_LIBS_INIT}) set(TARGET llama-server) -if (NOT LLAMA_HTTPLIB) - message(FATAL_ERROR "LLAMA_HTTPLIB is OFF, cannot build llama-server. Hint: to skip building server, set -DLLAMA_BUILD_SERVER=OFF") -endif() - set(TARGET_SRCS server.cpp server-http.cpp @@ -63,8 +59,4 @@ target_include_directories(${TARGET} PRIVATE ../mtmd) target_include_directories(${TARGET} PRIVATE ${CMAKE_SOURCE_DIR}) target_link_libraries(${TARGET} PRIVATE server-context PUBLIC common cpp-httplib ${CMAKE_THREAD_LIBS_INIT}) -if (WIN32) - TARGET_LINK_LIBRARIES(${TARGET} PRIVATE ws2_32) -endif() - target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/tools/server/README.md b/tools/server/README.md index d132830171..0b56ca1e27 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -19,7 +19,7 @@ Set of LLM REST APIs and a web UI to interact with llama.cpp. * Speculative decoding * Easy-to-use web UI -For the ful list of features, please refer to [server's changelog](https://github.com/ggml-org/llama.cpp/issues/9291) +For the full list of features, please refer to [server's changelog](https://github.com/ggml-org/llama.cpp/issues/9291) ## Usage diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index e3b06f4901..cec38413a5 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index a853f65c8d..d717fb6698 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -916,8 +916,7 @@ json oaicompat_chat_params_parse( json image_url = json_value(p, "image_url", json::object()); handle_media(out_files, image_url, opt.media_path); - // replace this chunk with a marker - p["type"] = "text"; + p["type"] = "media_marker"; p["text"] = mtmd_default_marker(); p.erase("image_url"); @@ -938,8 +937,7 @@ json oaicompat_chat_params_parse( // TODO: add audio_url support by reusing handle_media() - // replace this chunk with a marker - p["type"] = "text"; + p["type"] = "media_marker"; p["text"] = mtmd_default_marker(); p.erase("input_audio"); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index b71d496eeb..8aab0d4c1b 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -77,6 +77,7 @@ struct server_slot { size_t last_nl_pos = 0; std::string generated_text; + std::string debug_generated_text; llama_tokens generated_tokens; // idx of draft tokens in the main batch @@ -425,7 +426,7 @@ struct server_slot { if (!only_metrics) { res["prompt"] = ptask->tokens.detokenize(ctx, true); - res["generated"] = generated_text; + res["generated"] = generated_text.empty() ? debug_generated_text : generated_text; } } @@ -1442,6 +1443,12 @@ private: res->id_slot = slot.id; res->index = slot.task->index; + + // keep copy of last generated text for debugging purposes + if (slots_debug) { + slot.debug_generated_text = slot.generated_text; + } + // in stream mode, content and tokens are already in last partial chunk if (slot.task->params.stream) { res->content = ""; @@ -2507,7 +2514,8 @@ private: slot.n_prompt_tokens_processed++; // process the last few tokens of the prompt separately in order to allow for a checkpoint to be created. - if (do_checkpoint && slot.task->n_tokens() - slot.prompt.n_tokens() == 64) { + const int n_last = std::min(n_batch, 512); + if (do_checkpoint && slot.task->n_tokens() == slot.prompt.n_tokens() + n_last) { break; } } @@ -3583,6 +3591,8 @@ void server_routes::init_routes() { auto res = create_response(); std::vector files; json body = convert_responses_to_chatcmpl(json::parse(req.body)); + SRV_DBG("%s\n", "Request converted: OpenAI Responses -> OpenAI Chat Completions"); + SRV_DBG("converted request: %s\n", body.dump().c_str()); json body_parsed = oaicompat_chat_params_parse( body, meta->chat_params, @@ -3599,6 +3609,8 @@ void server_routes::init_routes() { auto res = create_response(); std::vector files; json body = convert_anthropic_to_oai(json::parse(req.body)); + SRV_DBG("%s\n", "Request converted: Anthropic -> OpenAI Chat Completions"); + SRV_DBG("converted request: %s\n", body.dump().c_str()); json body_parsed = oaicompat_chat_params_parse( body, meta->chat_params, @@ -3615,6 +3627,8 @@ void server_routes::init_routes() { auto res = create_response(); std::vector files; json body = convert_anthropic_to_oai(json::parse(req.body)); + SRV_DBG("%s\n", "Request converted: Anthropic -> OpenAI Chat Completions"); + SRV_DBG("converted request: %s\n", body.dump().c_str()); json body_parsed = oaicompat_chat_params_parse( body, meta->chat_params, diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 2d25db63b7..a137427c69 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -80,7 +80,6 @@ json task_params::to_json(bool only_metrics) const { {"speculative.type", common_speculative_type_to_str(speculative.type)}, {"speculative.ngram_size_n", speculative.ngram_size_n}, {"speculative.ngram_size_m", speculative.ngram_size_m}, - {"speculative.ngram_c_rate", speculative.ngram_check_rate}, {"speculative.ngram_m_hits", speculative.ngram_min_hits}, {"timings_per_token", timings_per_token}, {"post_sampling_probs", post_sampling_probs}, @@ -144,7 +143,6 @@ json task_params::to_json(bool only_metrics) const { {"speculative.type", common_speculative_type_to_str(speculative.type)}, {"speculative.ngram_size_n", speculative.ngram_size_n}, {"speculative.ngram_size_m", speculative.ngram_size_m}, - {"speculative.ngram_c_rate", speculative.ngram_check_rate}, {"speculative.ngram_m_hits", speculative.ngram_min_hits}, {"timings_per_token", timings_per_token}, {"post_sampling_probs", post_sampling_probs}, @@ -257,12 +255,10 @@ task_params server_task::params_from_json_cmpl( params.speculative.ngram_size_n = json_value(data, "speculative.ngram_size_n", defaults.speculative.ngram_size_n); params.speculative.ngram_size_m = json_value(data, "speculative.ngram_size_m", defaults.speculative.ngram_size_m); - params.speculative.ngram_check_rate = json_value(data, "speculative.ngram_c_rate", defaults.speculative.ngram_check_rate); params.speculative.ngram_min_hits = json_value(data, "speculative.ngram_m_hits", defaults.speculative.ngram_min_hits); params.speculative.ngram_size_n = std::max(std::min(1, (int) params.speculative.ngram_size_n), 1024); params.speculative.ngram_size_m = std::max(std::min(1, (int) params.speculative.ngram_size_m), 1024); - params.speculative.ngram_check_rate = std::max(std::min(1, (int) params.speculative.ngram_check_rate), 1024); params.speculative.ngram_min_hits = std::max(std::min(1, (int) params.speculative.ngram_min_hits), 1024); // Use OpenAI API logprobs only if n_probs wasn't provided diff --git a/tools/server/webui/.storybook/main.ts b/tools/server/webui/.storybook/main.ts index bfd16fa224..4f6945f210 100644 --- a/tools/server/webui/.storybook/main.ts +++ b/tools/server/webui/.storybook/main.ts @@ -1,17 +1,24 @@ import type { StorybookConfig } from '@storybook/sveltekit'; +import { dirname, resolve } from 'path'; +import { fileURLToPath } from 'url'; + +const __dirname = dirname(fileURLToPath(import.meta.url)); const config: StorybookConfig = { stories: ['../tests/stories/**/*.mdx', '../tests/stories/**/*.stories.@(js|ts|svelte)'], addons: [ '@storybook/addon-svelte-csf', '@chromatic-com/storybook', - '@storybook/addon-docs', + '@storybook/addon-vitest', '@storybook/addon-a11y', - '@storybook/addon-vitest' + '@storybook/addon-docs' ], - framework: { - name: '@storybook/sveltekit', - options: {} + framework: '@storybook/sveltekit', + viteFinal: async (config) => { + config.server = config.server || {}; + config.server.fs = config.server.fs || {}; + config.server.fs.allow = [...(config.server.fs.allow || []), resolve(__dirname, '../tests')]; + return config; } }; export default config; diff --git a/tools/server/webui/.storybook/preview.ts b/tools/server/webui/.storybook/preview.ts index 8d530e43e3..566dbfd289 100644 --- a/tools/server/webui/.storybook/preview.ts +++ b/tools/server/webui/.storybook/preview.ts @@ -13,7 +13,7 @@ const preview: Preview = { }, backgrounds: { - disable: true + disabled: true }, a11y: { diff --git a/tools/server/webui/docs/flows/settings-flow.md b/tools/server/webui/docs/flows/settings-flow.md index 578e01e6e1..40ad3bd94d 100644 --- a/tools/server/webui/docs/flows/settings-flow.md +++ b/tools/server/webui/docs/flows/settings-flow.md @@ -49,14 +49,20 @@ sequenceDiagram settingsStore->>serverStore: defaultParams serverStore-->>settingsStore: {temperature, top_p, top_k, ...} - settingsStore->>ParamSvc: extractServerDefaults(defaultParams) - ParamSvc-->>settingsStore: Record + loop each SYNCABLE_PARAMETER + alt key NOT in userOverrides + settingsStore->>settingsStore: config[key] = serverDefault[key] + Note right of settingsStore: Non-overridden params adopt server default + else key in userOverrides + Note right of settingsStore: Keep user value, skip server default + end + end - settingsStore->>ParamSvc: mergeWithServerDefaults(config, serverDefaults) - Note right of ParamSvc: For each syncable parameter:
- If NOT in userOverrides → use server default
- If in userOverrides → keep user value - ParamSvc-->>settingsStore: mergedConfig + alt serverStore.props has webuiSettings + settingsStore->>settingsStore: Apply webuiSettings from server + Note right of settingsStore: Server-provided UI settings
(e.g. showRawOutputSwitch) + end - settingsStore->>settingsStore: config = mergedConfig settingsStore->>settingsStore: saveConfig() deactivate settingsStore @@ -67,11 +73,18 @@ sequenceDiagram UI->>settingsStore: updateConfig(key, value) activate settingsStore settingsStore->>settingsStore: config[key] = value - settingsStore->>settingsStore: userOverrides.add(key) - Note right of settingsStore: Mark as user-modified (won't be overwritten by server) + + alt value matches server default for key + settingsStore->>settingsStore: userOverrides.delete(key) + Note right of settingsStore: Matches server default, remove override + else value differs from server default + settingsStore->>settingsStore: userOverrides.add(key) + Note right of settingsStore: Mark as user-modified (won't be overwritten) + end + settingsStore->>settingsStore: saveConfig() - settingsStore->>LS: set("llama-config", config) - settingsStore->>LS: set("llama-userOverrides", [...userOverrides]) + settingsStore->>LS: set(CONFIG_LOCALSTORAGE_KEY, config) + settingsStore->>LS: set(USER_OVERRIDES_LOCALSTORAGE_KEY, [...userOverrides]) deactivate settingsStore UI->>settingsStore: updateMultipleConfig({key1: val1, key2: val2}) @@ -88,10 +101,9 @@ sequenceDiagram UI->>settingsStore: resetConfig() activate settingsStore - settingsStore->>settingsStore: config = SETTING_CONFIG_DEFAULT + settingsStore->>settingsStore: config = {...SETTING_CONFIG_DEFAULT} settingsStore->>settingsStore: userOverrides.clear() - settingsStore->>settingsStore: syncWithServerDefaults() - Note right of settingsStore: Apply server defaults for syncable params + Note right of settingsStore: All params reset to defaults
Next syncWithServerDefaults will adopt server values settingsStore->>settingsStore: saveConfig() deactivate settingsStore @@ -139,6 +151,6 @@ sequenceDiagram Note over settingsStore: UI-only (not synced): rect rgb(255, 240, 240) - Note over settingsStore: systemMessage, custom (JSON)
showStatistics, enableContinueGeneration
autoMicOnEmpty, disableAutoScroll
apiKey, pdfAsImage, disableReasoningFormat + Note over settingsStore: systemMessage, custom (JSON)
showStatistics, enableContinueGeneration
autoMicOnEmpty, disableAutoScroll
apiKey, pdfAsImage, disableReasoningParsing, showRawOutputSwitch end ``` diff --git a/tools/server/webui/eslint.config.js b/tools/server/webui/eslint.config.js index 5baea57f33..cd20fb383a 100644 --- a/tools/server/webui/eslint.config.js +++ b/tools/server/webui/eslint.config.js @@ -27,7 +27,9 @@ export default ts.config( // typescript-eslint strongly recommend that you do not use the no-undef lint rule on TypeScript projects. // see: https://typescript-eslint.io/troubleshooting/faqs/eslint/#i-get-errors-from-the-no-undef-rule-about-global-variables-not-being-defined-even-though-there-are-no-typescript-errors 'no-undef': 'off', - 'svelte/no-at-html-tags': 'off' + 'svelte/no-at-html-tags': 'off', + // This app uses hash-based routing (#/) where resolve() from $app/paths does not apply + 'svelte/no-navigation-without-resolve': 'off' } }, { diff --git a/tools/server/webui/package-lock.json b/tools/server/webui/package-lock.json index 6834416824..8d13e5a535 100644 --- a/tools/server/webui/package-lock.json +++ b/tools/server/webui/package-lock.json @@ -22,31 +22,32 @@ "unist-util-visit": "^5.0.0" }, "devDependencies": { - "@chromatic-com/storybook": "^4.1.2", + "@chromatic-com/storybook": "^5.0.0", "@eslint/compat": "^1.2.5", "@eslint/js": "^9.18.0", "@internationalized/date": "^3.10.1", "@lucide/svelte": "^0.515.0", "@playwright/test": "^1.49.1", - "@storybook/addon-a11y": "^10.0.7", - "@storybook/addon-docs": "^10.0.7", + "@storybook/addon-a11y": "^10.2.4", + "@storybook/addon-docs": "^10.2.4", "@storybook/addon-svelte-csf": "^5.0.10", - "@storybook/addon-vitest": "^10.0.7", - "@storybook/sveltekit": "^10.0.7", + "@storybook/addon-vitest": "^10.2.4", + "@storybook/sveltekit": "^10.2.4", "@sveltejs/adapter-static": "^3.0.10", "@sveltejs/kit": "^2.48.4", "@sveltejs/vite-plugin-svelte": "^6.2.1", "@tailwindcss/forms": "^0.5.9", "@tailwindcss/typography": "^0.5.15", "@tailwindcss/vite": "^4.0.0", - "@types/node": "^22", + "@types/node": "^24", "@vitest/browser": "^3.2.3", + "@vitest/coverage-v8": "^3.2.3", "bits-ui": "^2.14.4", "clsx": "^2.1.1", "dexie": "^4.0.11", "eslint": "^9.18.0", "eslint-config-prettier": "^10.0.1", - "eslint-plugin-storybook": "^10.0.7", + "eslint-plugin-storybook": "^10.2.4", "eslint-plugin-svelte": "^3.0.0", "fflate": "^0.8.2", "globals": "^16.0.0", @@ -60,7 +61,7 @@ "rehype-katex": "^7.0.1", "remark-math": "^6.0.0", "sass": "^1.93.3", - "storybook": "^10.0.7", + "storybook": "^10.2.4", "svelte": "^5.38.2", "svelte-check": "^4.0.0", "tailwind-merge": "^3.3.1", @@ -113,16 +114,42 @@ "node": ">=6.9.0" } }, - "node_modules/@babel/helper-validator-identifier": { + "node_modules/@babel/helper-string-parser": { "version": "7.27.1", - "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.27.1.tgz", - "integrity": "sha512-D2hP9eA+Sqx1kBZgzxZh0y1trbuU+JoDkiEwqhQ36nodYqJwyEIhPSdMNd7lOm/4io72luTPWH20Yda0xOuUow==", + "resolved": "https://registry.npmjs.org/@babel/helper-string-parser/-/helper-string-parser-7.27.1.tgz", + "integrity": "sha512-qMlSxKbpRlAridDExk92nSobyDdpPijUq2DW6oDnUqd0iOGxmQjyqhMIihI9+zv4LPyZdRje2cavWPbCbWm3eA==", "dev": true, "license": "MIT", "engines": { "node": ">=6.9.0" } }, + "node_modules/@babel/helper-validator-identifier": { + "version": "7.28.5", + "resolved": "https://registry.npmjs.org/@babel/helper-validator-identifier/-/helper-validator-identifier-7.28.5.tgz", + "integrity": "sha512-qSs4ifwzKJSV39ucNjsvc6WVHs6b7S03sOh2OcHF9UHfVPqWWALUsNUVzhSBiItjRZoLHx7nIarVjqKVusUZ1Q==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@babel/parser": { + "version": "7.29.0", + "resolved": "https://registry.npmjs.org/@babel/parser/-/parser-7.29.0.tgz", + "integrity": "sha512-IyDgFV5GeDUVX4YdF/3CPULtVGSXXMLh1xVIgdCgxApktqnQV0r7/8Nqthg+8YLGaAtdyIlo2qIdZrbCv4+7ww==", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/types": "^7.29.0" + }, + "bin": { + "parser": "bin/babel-parser.js" + }, + "engines": { + "node": ">=6.0.0" + } + }, "node_modules/@babel/runtime": { "version": "7.27.6", "resolved": "https://registry.npmjs.org/@babel/runtime/-/runtime-7.27.6.tgz", @@ -133,15 +160,39 @@ "node": ">=6.9.0" } }, + "node_modules/@babel/types": { + "version": "7.29.0", + "resolved": "https://registry.npmjs.org/@babel/types/-/types-7.29.0.tgz", + "integrity": "sha512-LwdZHpScM4Qz8Xw2iKSzS+cfglZzJGvofQICy7W7v4caru4EaAmyUuO6BGrbyQ2mYV11W0U8j5mBhd14dd3B0A==", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/helper-string-parser": "^7.27.1", + "@babel/helper-validator-identifier": "^7.28.5" + }, + "engines": { + "node": ">=6.9.0" + } + }, + "node_modules/@bcoe/v8-coverage": { + "version": "1.0.2", + "resolved": "https://registry.npmjs.org/@bcoe/v8-coverage/-/v8-coverage-1.0.2.tgz", + "integrity": "sha512-6zABk/ECA/QYSCQ1NGiVwwbQerUCZ+TQbp64Q3AgmfNvurHH0j8TtXa1qbShXA6qqkpAj4V5W8pP6mLe1mcMqA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18" + } + }, "node_modules/@chromatic-com/storybook": { - "version": "4.1.2", - "resolved": "https://registry.npmjs.org/@chromatic-com/storybook/-/storybook-4.1.2.tgz", - "integrity": "sha512-QAWGtHwib0qsP5CcO64aJCF75zpFgpKK3jNpxILzQiPK3sVo4EmnVGJVdwcZWpWrGdH8E4YkncGoitw4EXzKMg==", + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/@chromatic-com/storybook/-/storybook-5.0.0.tgz", + "integrity": "sha512-8wUsqL8kg6R5ue8XNE7Jv/iD1SuE4+6EXMIGIuE+T2loBITEACLfC3V8W44NJviCLusZRMWbzICddz0nU0bFaw==", "dev": true, "license": "MIT", "dependencies": { "@neoconfetti/react": "^1.0.0", - "chromatic": "^12.0.0", + "chromatic": "^13.3.4", "filesize": "^10.0.12", "jsonfile": "^6.1.0", "strip-ansi": "^7.1.0" @@ -151,7 +202,7 @@ "yarn": ">=1.22.18" }, "peerDependencies": { - "storybook": "^0.0.0-0 || ^9.0.0 || ^9.1.0-0 || ^9.2.0-0 || ^10.0.0-0 || ^10.1.0-0 || ^10.2.0-0 || ^10.3.0-0" + "storybook": "^0.0.0-0 || ^10.1.0 || ^10.1.0-0 || ^10.2.0-0 || ^10.3.0-0" } }, "node_modules/@esbuild/aix-ppc64": { @@ -597,9 +648,9 @@ } }, "node_modules/@eslint-community/eslint-utils": { - "version": "4.7.0", - "resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.7.0.tgz", - "integrity": "sha512-dyybb3AcajC7uha6CvhdVRJqaKyn7w2YKqKyAN37NKYgZT36w+iRb0Dymmc5qEJ549c/S31cMMSFd75bteCpCw==", + "version": "4.9.1", + "resolved": "https://registry.npmjs.org/@eslint-community/eslint-utils/-/eslint-utils-4.9.1.tgz", + "integrity": "sha512-phrYmNiYppR7znFEdqgfWHXR6NCkZEK7hwWDHZUjit/2/U0r6XvkDl0SYnoM51Hq7FhCGdLDT6zxCCOY1hexsQ==", "dev": true, "license": "MIT", "dependencies": { @@ -629,9 +680,9 @@ } }, "node_modules/@eslint-community/regexpp": { - "version": "4.12.1", - "resolved": "https://registry.npmjs.org/@eslint-community/regexpp/-/regexpp-4.12.1.tgz", - "integrity": "sha512-CCZCDJuduB9OUkFkY2IgppNZMi2lBQgD2qzwXkEia16cge2pijY/aXi96CJMquDMn3nJdlPV1A5KrJEXwfLNzQ==", + "version": "4.12.2", + "resolved": "https://registry.npmjs.org/@eslint-community/regexpp/-/regexpp-4.12.2.tgz", + "integrity": "sha512-EriSTlt5OC9/7SXkRSCAhfSxxoSUgBm33OH+IkwbdpgoqsSsUg7y3uh+IICI/Qg4BBWr3U2i39RpmycbxMq4ew==", "dev": true, "license": "MIT", "engines": { @@ -639,11 +690,14 @@ } }, "node_modules/@eslint/compat": { - "version": "1.3.1", - "resolved": "https://registry.npmjs.org/@eslint/compat/-/compat-1.3.1.tgz", - "integrity": "sha512-k8MHony59I5EPic6EQTCNOuPoVBnoYXkP+20xvwFjN7t0qI3ImyvyBgg+hIVPwC8JaxVjjUZld+cLfBLFDLucg==", + "version": "1.4.1", + "resolved": "https://registry.npmjs.org/@eslint/compat/-/compat-1.4.1.tgz", + "integrity": "sha512-cfO82V9zxxGBxcQDr1lfaYB7wykTa0b00mGa36FrJl7iTFd0Z2cHfEYuxcBRP/iNijCsWsEkA+jzT8hGYmv33w==", "dev": true, "license": "Apache-2.0", + "dependencies": { + "@eslint/core": "^0.17.0" + }, "engines": { "node": "^18.18.0 || ^20.9.0 || >=21.1.0" }, @@ -657,13 +711,13 @@ } }, "node_modules/@eslint/config-array": { - "version": "0.21.0", - "resolved": "https://registry.npmjs.org/@eslint/config-array/-/config-array-0.21.0.tgz", - "integrity": "sha512-ENIdc4iLu0d93HeYirvKmrzshzofPw6VkZRKQGe9Nv46ZnWUzcF1xV01dcvEg/1wXUR61OmmlSfyeyO7EvjLxQ==", + "version": "0.21.1", + "resolved": "https://registry.npmjs.org/@eslint/config-array/-/config-array-0.21.1.tgz", + "integrity": "sha512-aw1gNayWpdI/jSYVgzN5pL0cfzU02GT3NBpeT/DXbx1/1x7ZKxFPd9bwrzygx/qiwIQiJ1sw/zD8qY/kRvlGHA==", "dev": true, "license": "Apache-2.0", "dependencies": { - "@eslint/object-schema": "^2.1.6", + "@eslint/object-schema": "^2.1.7", "debug": "^4.3.1", "minimatch": "^3.1.2" }, @@ -672,19 +726,22 @@ } }, "node_modules/@eslint/config-helpers": { - "version": "0.3.0", - "resolved": "https://registry.npmjs.org/@eslint/config-helpers/-/config-helpers-0.3.0.tgz", - "integrity": "sha512-ViuymvFmcJi04qdZeDc2whTHryouGcDlaxPqarTD0ZE10ISpxGUVZGZDx4w01upyIynL3iu6IXH2bS1NhclQMw==", + "version": "0.4.2", + "resolved": "https://registry.npmjs.org/@eslint/config-helpers/-/config-helpers-0.4.2.tgz", + "integrity": "sha512-gBrxN88gOIf3R7ja5K9slwNayVcZgK6SOUORm2uBzTeIEfeVaIhOpCtTox3P6R7o2jLFwLFTLnC7kU/RGcYEgw==", "dev": true, "license": "Apache-2.0", + "dependencies": { + "@eslint/core": "^0.17.0" + }, "engines": { "node": "^18.18.0 || ^20.9.0 || >=21.1.0" } }, "node_modules/@eslint/core": { - "version": "0.15.2", - "resolved": "https://registry.npmjs.org/@eslint/core/-/core-0.15.2.tgz", - "integrity": "sha512-78Md3/Rrxh83gCxoUc0EiciuOHsIITzLy53m3d9UyiW8y9Dj2D29FeETqyKA+BRK76tnTp6RXWb3pCay8Oyomg==", + "version": "0.17.0", + "resolved": "https://registry.npmjs.org/@eslint/core/-/core-0.17.0.tgz", + "integrity": "sha512-yL/sLrpmtDaFEiUj1osRP4TI2MDz1AddJL+jZ7KSqvBuliN4xqYY54IfdN8qD8Toa6g1iloph1fxQNkjOxrrpQ==", "dev": true, "license": "Apache-2.0", "dependencies": { @@ -695,9 +752,9 @@ } }, "node_modules/@eslint/eslintrc": { - "version": "3.3.1", - "resolved": "https://registry.npmjs.org/@eslint/eslintrc/-/eslintrc-3.3.1.tgz", - "integrity": "sha512-gtF186CXhIl1p4pJNGZw8Yc6RlshoePRvE0X91oPGb3vZ8pM3qOS9W9NGPat9LziaBV7XrJWGylNQXkGcnM3IQ==", + "version": "3.3.3", + "resolved": "https://registry.npmjs.org/@eslint/eslintrc/-/eslintrc-3.3.3.tgz", + "integrity": "sha512-Kr+LPIUVKz2qkx1HAMH8q1q6azbqBAsXJUxBl/ODDuVPX45Z9DfwB8tPjTi6nNZ8BuM3nbJxC5zCAg5elnBUTQ==", "dev": true, "license": "MIT", "dependencies": { @@ -707,7 +764,7 @@ "globals": "^14.0.0", "ignore": "^5.2.0", "import-fresh": "^3.2.1", - "js-yaml": "^4.1.0", + "js-yaml": "^4.1.1", "minimatch": "^3.1.2", "strip-json-comments": "^3.1.1" }, @@ -732,9 +789,9 @@ } }, "node_modules/@eslint/js": { - "version": "9.31.0", - "resolved": "https://registry.npmjs.org/@eslint/js/-/js-9.31.0.tgz", - "integrity": "sha512-LOm5OVt7D4qiKCqoiPbA7LWmI+tbw1VbTUowBcUMgQSuM6poJufkFkYDcQpo5KfgD39TnNySV26QjOh7VFpSyw==", + "version": "9.39.2", + "resolved": "https://registry.npmjs.org/@eslint/js/-/js-9.39.2.tgz", + "integrity": "sha512-q1mjIoW1VX4IvSocvM/vbTiveKC4k9eLrajNEuSsmjymSDEbpGddtpfOoN7YGAqBK3NG+uqo8ia4PDTt8buCYA==", "dev": true, "license": "MIT", "engines": { @@ -745,9 +802,9 @@ } }, "node_modules/@eslint/object-schema": { - "version": "2.1.6", - "resolved": "https://registry.npmjs.org/@eslint/object-schema/-/object-schema-2.1.6.tgz", - "integrity": "sha512-RBMg5FRL0I0gs51M/guSAj5/e14VQ4tpZnQNWwuDT66P14I43ItmPfIZRhO9fUVIPOAQXU47atlywZ/czoqFPA==", + "version": "2.1.7", + "resolved": "https://registry.npmjs.org/@eslint/object-schema/-/object-schema-2.1.7.tgz", + "integrity": "sha512-VtAOaymWVfZcmZbp6E2mympDIHvyjXs/12LqWYjVw6qjrfF+VK+fyG33kChz3nnK+SU5/NeHOqrTEHS8sXO3OA==", "dev": true, "license": "Apache-2.0", "engines": { @@ -755,13 +812,13 @@ } }, "node_modules/@eslint/plugin-kit": { - "version": "0.3.5", - "resolved": "https://registry.npmjs.org/@eslint/plugin-kit/-/plugin-kit-0.3.5.tgz", - "integrity": "sha512-Z5kJ+wU3oA7MMIqVR9tyZRtjYPr4OC004Q4Rw7pgOKUOKkJfZ3O24nz3WYfGRpMDNmcOi3TwQOmgm7B7Tpii0w==", + "version": "0.4.1", + "resolved": "https://registry.npmjs.org/@eslint/plugin-kit/-/plugin-kit-0.4.1.tgz", + "integrity": "sha512-43/qtrDUokr7LJqoF2c3+RInu/t4zfrpYdoSDfYyhg52rwLV6TnOvdG4fXm7IkSB3wErkcmJS9iEhjVtOSEjjA==", "dev": true, "license": "Apache-2.0", "dependencies": { - "@eslint/core": "^0.15.2", + "@eslint/core": "^0.17.0", "levn": "^0.4.1" }, "engines": { @@ -873,6 +930,24 @@ "@swc/helpers": "^0.5.0" } }, + "node_modules/@isaacs/cliui": { + "version": "8.0.2", + "resolved": "https://registry.npmjs.org/@isaacs/cliui/-/cliui-8.0.2.tgz", + "integrity": "sha512-O8jcjabXaleOG9DQ0+ARXWZBTfnP4WNAqzuiJK7ll44AmxGKv/J2M4TPjxjY3znBCfvBXFzucm1twdyFybFqEA==", + "dev": true, + "license": "ISC", + "dependencies": { + "string-width": "^5.1.2", + "string-width-cjs": "npm:string-width@^4.2.0", + "strip-ansi": "^7.0.1", + "strip-ansi-cjs": "npm:strip-ansi@^6.0.1", + "wrap-ansi": "^8.1.0", + "wrap-ansi-cjs": "npm:wrap-ansi@^7.0.0" + }, + "engines": { + "node": ">=12" + } + }, "node_modules/@isaacs/fs-minipass": { "version": "4.0.1", "resolved": "https://registry.npmjs.org/@isaacs/fs-minipass/-/fs-minipass-4.0.1.tgz", @@ -886,6 +961,16 @@ "node": ">=18.0.0" } }, + "node_modules/@istanbuljs/schema": { + "version": "0.1.3", + "resolved": "https://registry.npmjs.org/@istanbuljs/schema/-/schema-0.1.3.tgz", + "integrity": "sha512-ZXRY4jNvVgSVQ8DL3LTcakaAtXwTVUxE81hslsyD2AtoXW/wVob10HkOJ1X/pAlcI7D+2YoZKg5do8G/w6RYgA==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, "node_modules/@jridgewell/gen-mapping": { "version": "0.3.12", "resolved": "https://registry.npmjs.org/@jridgewell/gen-mapping/-/gen-mapping-0.3.12.tgz", @@ -922,9 +1007,9 @@ "license": "MIT" }, "node_modules/@jridgewell/trace-mapping": { - "version": "0.3.29", - "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.29.tgz", - "integrity": "sha512-uw6guiW/gcAGPDhLmd77/6lW8QLeiV5RUTsAX46Db6oLhGaVj4lhnPwb184s1bkc8kdVg/+h988dro8GRDpmYQ==", + "version": "0.3.31", + "resolved": "https://registry.npmjs.org/@jridgewell/trace-mapping/-/trace-mapping-0.3.31.tgz", + "integrity": "sha512-zzNR+SdQSDJzc8joaeP8QQoCQr8NuYx2dIIytl1QeBEZHJ9uW6hebsrYgbz8hJwUQao3TWCMtmfV8Nu1twOLAw==", "license": "MIT", "dependencies": { "@jridgewell/resolve-uri": "^3.1.0", @@ -1151,44 +1236,6 @@ "dev": true, "license": "MIT" }, - "node_modules/@nodelib/fs.scandir": { - "version": "2.1.5", - "resolved": "https://registry.npmjs.org/@nodelib/fs.scandir/-/fs.scandir-2.1.5.tgz", - "integrity": "sha512-vq24Bq3ym5HEQm2NKCr3yXDwjc7vTsEThRDnkp2DK9p1uqLR+DHurm/NOTo0KG7HYHU7eppKZj3MyqYuMBf62g==", - "dev": true, - "license": "MIT", - "dependencies": { - "@nodelib/fs.stat": "2.0.5", - "run-parallel": "^1.1.9" - }, - "engines": { - "node": ">= 8" - } - }, - "node_modules/@nodelib/fs.stat": { - "version": "2.0.5", - "resolved": "https://registry.npmjs.org/@nodelib/fs.stat/-/fs.stat-2.0.5.tgz", - "integrity": "sha512-RkhPPp2zrqDAQA/2jNhnztcPAlv64XdhIp7a7454A5ovI7Bukxgt7MX7udwAu3zg1DcpPU0rz3VV1SeaqvY4+A==", - "dev": true, - "license": "MIT", - "engines": { - "node": ">= 8" - } - }, - "node_modules/@nodelib/fs.walk": { - "version": "1.2.8", - "resolved": "https://registry.npmjs.org/@nodelib/fs.walk/-/fs.walk-1.2.8.tgz", - "integrity": "sha512-oGB+UxlgWcgQkgwo8GcEGwemoTFt3FIO9ababBmaGwXIoBKZ+GTy0pP185beGg7Llih/NSHSV2XAs1lnznocSg==", - "dev": true, - "license": "MIT", - "dependencies": { - "@nodelib/fs.scandir": "2.1.5", - "fastq": "^1.6.0" - }, - "engines": { - "node": ">= 8" - } - }, "node_modules/@parcel/watcher": { "version": "2.5.1", "resolved": "https://registry.npmjs.org/@parcel/watcher/-/watcher-2.5.1.tgz", @@ -1513,6 +1560,17 @@ "node": ">=0.10" } }, + "node_modules/@pkgjs/parseargs": { + "version": "0.11.0", + "resolved": "https://registry.npmjs.org/@pkgjs/parseargs/-/parseargs-0.11.0.tgz", + "integrity": "sha512-+1VkjdD0QBLPodGrJUeqarH8VAIvQODIbwh9XpP5Syisf7YoQgsJKPNFoqqLQlu+VQ/tVSshMR6loPMn8U+dPg==", + "dev": true, + "license": "MIT", + "optional": true, + "engines": { + "node": ">=14" + } + }, "node_modules/@playwright/test": { "version": "1.56.1", "resolved": "https://registry.npmjs.org/@playwright/test/-/test-1.56.1.tgz", @@ -1824,9 +1882,9 @@ "license": "MIT" }, "node_modules/@storybook/addon-a11y": { - "version": "10.0.7", - "resolved": "https://registry.npmjs.org/@storybook/addon-a11y/-/addon-a11y-10.0.7.tgz", - "integrity": "sha512-JsYPpZ/n67/2bI1XJeyrAWHHQkHemPkPHjCA0tAUnMz1Shlo/LV2q1Ahgpxoihx4strbHwZz71bcS4MqkHBduA==", + "version": "10.2.4", + "resolved": "https://registry.npmjs.org/@storybook/addon-a11y/-/addon-a11y-10.2.4.tgz", + "integrity": "sha512-VGhdZ+iP2l/CSulIKV2kt3SMWVHntOigqWqGkNYf6YNYofynUYEKdsNqBvHx4ySuNEl/eXJ8LRO8FKYnU7LxZQ==", "dev": true, "license": "MIT", "dependencies": { @@ -1838,20 +1896,20 @@ "url": "https://opencollective.com/storybook" }, "peerDependencies": { - "storybook": "^10.0.7" + "storybook": "^10.2.4" } }, "node_modules/@storybook/addon-docs": { - "version": "10.0.7", - "resolved": "https://registry.npmjs.org/@storybook/addon-docs/-/addon-docs-10.0.7.tgz", - "integrity": "sha512-qQQMoeYZC4W+/8ubfOZiTrE8nYC/f4wWP1uq4peRyDy1N2nIN9SwhyxwMn0m3VpeGmRBga5dLvJY9ko6SnJekg==", + "version": "10.2.4", + "resolved": "https://registry.npmjs.org/@storybook/addon-docs/-/addon-docs-10.2.4.tgz", + "integrity": "sha512-FzscAmdBiOGnGrxiEM+8eTg43kjqgjLfObg+lbJVRR/a0DmZ3xfAPNB0+VKYQbN0FacNcWLM9LZ/7U0hRBPBnQ==", "dev": true, "license": "MIT", "dependencies": { "@mdx-js/react": "^3.0.0", - "@storybook/csf-plugin": "10.0.7", - "@storybook/icons": "^1.6.0", - "@storybook/react-dom-shim": "10.0.7", + "@storybook/csf-plugin": "10.2.4", + "@storybook/icons": "^2.0.1", + "@storybook/react-dom-shim": "10.2.4", "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0", "react-dom": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0", "ts-dedent": "^2.0.0" @@ -1861,7 +1919,7 @@ "url": "https://opencollective.com/storybook" }, "peerDependencies": { - "storybook": "^10.0.7" + "storybook": "^10.2.4" } }, "node_modules/@storybook/addon-svelte-csf": { @@ -1888,16 +1946,14 @@ } }, "node_modules/@storybook/addon-vitest": { - "version": "10.0.7", - "resolved": "https://registry.npmjs.org/@storybook/addon-vitest/-/addon-vitest-10.0.7.tgz", - "integrity": "sha512-i6v/mAl+elrUxb+1f4NdnM17t/fg+KGJWL1U9quflXTd3KiLY0xJB4LwNP6yYo7Imc5NIO2fRkJbGvNqLBRe2Q==", + "version": "10.2.4", + "resolved": "https://registry.npmjs.org/@storybook/addon-vitest/-/addon-vitest-10.2.4.tgz", + "integrity": "sha512-BT1iP89U4wcbpzTURU8WYTAeUcdNh4WIt0BqsnATmMwR/jKNJW6QgXCVqGQTSpRjWj40hX5e2JkQYCNXdjKsPw==", "dev": true, "license": "MIT", "dependencies": { "@storybook/global": "^5.0.0", - "@storybook/icons": "^1.6.0", - "prompts": "^2.4.0", - "ts-dedent": "^2.2.0" + "@storybook/icons": "^2.0.1" }, "funding": { "type": "opencollective", @@ -1907,7 +1963,7 @@ "@vitest/browser": "^3.0.0 || ^4.0.0", "@vitest/browser-playwright": "^4.0.0", "@vitest/runner": "^3.0.0 || ^4.0.0", - "storybook": "^10.0.7", + "storybook": "^10.2.4", "vitest": "^3.0.0 || ^4.0.0" }, "peerDependenciesMeta": { @@ -1926,13 +1982,13 @@ } }, "node_modules/@storybook/builder-vite": { - "version": "10.0.7", - "resolved": "https://registry.npmjs.org/@storybook/builder-vite/-/builder-vite-10.0.7.tgz", - "integrity": "sha512-wk2TAoUY5+9t78GWVBndu9rEo9lo6Ec3SRrLT4VpIlcS2GPK+5f26UC2uvIBwOF/N7JrUUKq/zWDZ3m+do9QDg==", + "version": "10.2.4", + "resolved": "https://registry.npmjs.org/@storybook/builder-vite/-/builder-vite-10.2.4.tgz", + "integrity": "sha512-/hcT1xj3CL5GkJ5v5/EguZdttDwNE6weNXK7vKzp034tnGcLycOossDsTiUQkBowSL+Ylc8aKj+ZgvddPNfOig==", "dev": true, "license": "MIT", "dependencies": { - "@storybook/csf-plugin": "10.0.7", + "@storybook/csf-plugin": "10.2.4", "ts-dedent": "^2.0.0" }, "funding": { @@ -1940,7 +1996,7 @@ "url": "https://opencollective.com/storybook" }, "peerDependencies": { - "storybook": "^10.0.7", + "storybook": "^10.2.4", "vite": "^5.0.0 || ^6.0.0 || ^7.0.0" } }, @@ -1955,9 +2011,9 @@ } }, "node_modules/@storybook/csf-plugin": { - "version": "10.0.7", - "resolved": "https://registry.npmjs.org/@storybook/csf-plugin/-/csf-plugin-10.0.7.tgz", - "integrity": "sha512-YaYYlCyJBwxaMk7yREOdz+9MDSgxIYGdeJ9EIq/bUndmkoj9SRo1P9/0lC5dseWQoiGy4T3PbZiWruD8uM5m3g==", + "version": "10.2.4", + "resolved": "https://registry.npmjs.org/@storybook/csf-plugin/-/csf-plugin-10.2.4.tgz", + "integrity": "sha512-kupPQEV+4N9mzsZHYaokvhO/KHBjYdWda9PNmPQwy0TR7r2mzthgaNH72TjmgN1L6DIbsuyOG1wtczcPJn4+Jg==", "dev": true, "license": "MIT", "dependencies": { @@ -1970,7 +2026,7 @@ "peerDependencies": { "esbuild": "*", "rollup": "*", - "storybook": "^10.0.7", + "storybook": "^10.2.4", "vite": "*", "webpack": "*" }, @@ -1997,23 +2053,20 @@ "license": "MIT" }, "node_modules/@storybook/icons": { - "version": "1.6.0", - "resolved": "https://registry.npmjs.org/@storybook/icons/-/icons-1.6.0.tgz", - "integrity": "sha512-hcFZIjW8yQz8O8//2WTIXylm5Xsgc+lW9ISLgUk1xGmptIJQRdlhVIXCpSyLrQaaRiyhQRaVg7l3BD9S216BHw==", + "version": "2.0.1", + "resolved": "https://registry.npmjs.org/@storybook/icons/-/icons-2.0.1.tgz", + "integrity": "sha512-/smVjw88yK3CKsiuR71vNgWQ9+NuY2L+e8X7IMrFjexjm6ZR8ULrV2DRkTA61aV6ryefslzHEGDInGpnNeIocg==", "dev": true, "license": "MIT", - "engines": { - "node": ">=14.0.0" - }, "peerDependencies": { - "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0-beta", - "react-dom": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0-beta" + "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0", + "react-dom": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0" } }, "node_modules/@storybook/react-dom-shim": { - "version": "10.0.7", - "resolved": "https://registry.npmjs.org/@storybook/react-dom-shim/-/react-dom-shim-10.0.7.tgz", - "integrity": "sha512-bp4OnMtZGwPJQDqNRi4K5iibLbZ2TZZMkWW7oSw5jjPFpGSreSjCe8LH9yj/lDnK8Ox9bGMCBFE5RV5XuML29w==", + "version": "10.2.4", + "resolved": "https://registry.npmjs.org/@storybook/react-dom-shim/-/react-dom-shim-10.2.4.tgz", + "integrity": "sha512-i22OtrZ7GeZPt/odLf0vqyDhRSKyaLsHkkKSBcANQfzRRnBZmiz2FchOtWm9uvoDWybQsTruZq7kTdtpEhwyGw==", "dev": true, "license": "MIT", "funding": { @@ -2023,13 +2076,13 @@ "peerDependencies": { "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0", "react-dom": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0", - "storybook": "^10.0.7" + "storybook": "^10.2.4" } }, "node_modules/@storybook/svelte": { - "version": "10.0.7", - "resolved": "https://registry.npmjs.org/@storybook/svelte/-/svelte-10.0.7.tgz", - "integrity": "sha512-rO+YQhHucy47Vh67z318pALmd6x+K1Kj30Fb4a6oOEw4xn4zCo9KTmkMWs24c4oduEXD/eJu3badlRmsVXzyfA==", + "version": "10.2.4", + "resolved": "https://registry.npmjs.org/@storybook/svelte/-/svelte-10.2.4.tgz", + "integrity": "sha512-W9R51zUCd2iHOQBg/D93+bdpYv6kbtFx+kft5X8lPKQl6yEu0aKs9i5N5GyCASOhIApgx/tkqZIJ7vgM4cqrHA==", "dev": true, "license": "MIT", "peer": true, @@ -2042,19 +2095,19 @@ "url": "https://opencollective.com/storybook" }, "peerDependencies": { - "storybook": "^10.0.7", + "storybook": "^10.2.4", "svelte": "^5.0.0" } }, "node_modules/@storybook/svelte-vite": { - "version": "10.0.7", - "resolved": "https://registry.npmjs.org/@storybook/svelte-vite/-/svelte-vite-10.0.7.tgz", - "integrity": "sha512-q9/RtrhX1CnznO6AO9MDEy1bsccbGeRxW28FLpgUrztV4IGZ/dFUrFIFurKRyuA3/nFsbtzp1F5jFt3RExmmTw==", + "version": "10.2.4", + "resolved": "https://registry.npmjs.org/@storybook/svelte-vite/-/svelte-vite-10.2.4.tgz", + "integrity": "sha512-FMgKMRdoZFDwPD6eIDMldcgp6d6NtIGuXyUJjb29qLias/gE5TI6hg+cWmmWXQRTrXwdyepeMBmIfRcZbB6REQ==", "dev": true, "license": "MIT", "dependencies": { - "@storybook/builder-vite": "10.0.7", - "@storybook/svelte": "10.0.7", + "@storybook/builder-vite": "10.2.4", + "@storybook/svelte": "10.2.4", "magic-string": "^0.30.0", "svelte2tsx": "^0.7.44", "typescript": "^4.9.4 || ^5.0.0" @@ -2065,28 +2118,28 @@ }, "peerDependencies": { "@sveltejs/vite-plugin-svelte": "^2.0.0 || ^3.0.0 || ^4.0.0 || ^5.0.0 || ^6.0.0", - "storybook": "^10.0.7", + "storybook": "^10.2.4", "svelte": "^5.0.0", "vite": "^5.0.0 || ^6.0.0 || ^7.0.0" } }, "node_modules/@storybook/sveltekit": { - "version": "10.0.7", - "resolved": "https://registry.npmjs.org/@storybook/sveltekit/-/sveltekit-10.0.7.tgz", - "integrity": "sha512-ujTW7PfWvgBrzd7jzaZe9JgjUeM5YvBKm+xru6t7Dr4bdfmkKqlZHPRdXn/sy+fQNyfg6JL2WKy2KIIeA+RvSg==", + "version": "10.2.4", + "resolved": "https://registry.npmjs.org/@storybook/sveltekit/-/sveltekit-10.2.4.tgz", + "integrity": "sha512-1qDX35iSJHWo1AOd7HMzJtCHBfgahXqTWNiyZa/JMEKJ3qC1otaU8XMmTjsZ6fCRF99piNdgqtWM8+s1TJOldg==", "dev": true, "license": "MIT", "dependencies": { - "@storybook/builder-vite": "10.0.7", - "@storybook/svelte": "10.0.7", - "@storybook/svelte-vite": "10.0.7" + "@storybook/builder-vite": "10.2.4", + "@storybook/svelte": "10.2.4", + "@storybook/svelte-vite": "10.2.4" }, "funding": { "type": "opencollective", "url": "https://opencollective.com/storybook" }, "peerDependencies": { - "storybook": "^10.0.7", + "storybook": "^10.2.4", "svelte": "^5.0.0", "vite": "^5.0.0 || ^6.0.0 || ^7.0.0" } @@ -2111,9 +2164,9 @@ } }, "node_modules/@sveltejs/kit": { - "version": "2.49.2", - "resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.49.2.tgz", - "integrity": "sha512-Vp3zX/qlwerQmHMP6x0Ry1oY7eKKRcOWGc2P59srOp4zcqyn+etJyQpELgOi4+ZSUgteX8Y387NuwruLgGXLUQ==", + "version": "2.52.0", + "resolved": "https://registry.npmjs.org/@sveltejs/kit/-/kit-2.52.0.tgz", + "integrity": "sha512-zG+HmJuSF7eC0e7xt2htlOcEMAdEtlVdb7+gAr+ef08EhtwUsjLxcAwBgUCJY3/5p08OVOxVZti91WfXeuLvsg==", "dev": true, "license": "MIT", "peer": true, @@ -2123,13 +2176,13 @@ "@types/cookie": "^0.6.0", "acorn": "^8.14.1", "cookie": "^0.6.0", - "devalue": "^5.3.2", + "devalue": "^5.6.2", "esm-env": "^1.2.2", "kleur": "^4.1.5", "magic-string": "^0.30.5", "mrmime": "^2.0.0", "sade": "^1.8.1", - "set-cookie-parser": "^2.6.0", + "set-cookie-parser": "^3.0.0", "sirv": "^3.0.0" }, "bin": { @@ -2142,11 +2195,15 @@ "@opentelemetry/api": "^1.0.0", "@sveltejs/vite-plugin-svelte": "^3.0.0 || ^4.0.0-next.1 || ^5.0.0 || ^6.0.0-next.0", "svelte": "^4.0.0 || ^5.0.0-next.0", + "typescript": "^5.3.3", "vite": "^5.0.3 || ^6.0.0 || ^7.0.0-beta.0" }, "peerDependenciesMeta": { "@opentelemetry/api": { "optional": true + }, + "typescript": { + "optional": true } } }, @@ -2735,14 +2792,14 @@ "license": "MIT" }, "node_modules/@types/node": { - "version": "22.16.5", - "resolved": "https://registry.npmjs.org/@types/node/-/node-22.16.5.tgz", - "integrity": "sha512-bJFoMATwIGaxxx8VJPeM8TonI8t579oRvgAuT8zFugJsJZgzqv0Fu8Mhp68iecjzG7cnN3mO2dJQ5uUM2EFrgQ==", + "version": "24.10.10", + "resolved": "https://registry.npmjs.org/@types/node/-/node-24.10.10.tgz", + "integrity": "sha512-+0/4J266CBGPUq/ELg7QUHhN25WYjE0wYTPSQJn1xeu8DOlIOPxXxrNGiLmfAWl7HMMgWFWXpt9IDjMWrF5Iow==", "dev": true, "license": "MIT", "peer": true, "dependencies": { - "undici-types": "~6.21.0" + "undici-types": "~7.16.0" } }, "node_modules/@types/react": { @@ -2763,21 +2820,20 @@ "license": "MIT" }, "node_modules/@typescript-eslint/eslint-plugin": { - "version": "8.37.0", - "resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-8.37.0.tgz", - "integrity": "sha512-jsuVWeIkb6ggzB+wPCsR4e6loj+rM72ohW6IBn2C+5NCvfUVY8s33iFPySSVXqtm5Hu29Ne/9bnA0JmyLmgenA==", + "version": "8.56.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-8.56.0.tgz", + "integrity": "sha512-lRyPDLzNCuae71A3t9NEINBiTn7swyOhvUj3MyUOxb8x6g6vPEFoOU+ZRmGMusNC3X3YMhqMIX7i8ShqhT74Pw==", "dev": true, "license": "MIT", "dependencies": { - "@eslint-community/regexpp": "^4.10.0", - "@typescript-eslint/scope-manager": "8.37.0", - "@typescript-eslint/type-utils": "8.37.0", - "@typescript-eslint/utils": "8.37.0", - "@typescript-eslint/visitor-keys": "8.37.0", - "graphemer": "^1.4.0", - "ignore": "^7.0.0", + "@eslint-community/regexpp": "^4.12.2", + "@typescript-eslint/scope-manager": "8.56.0", + "@typescript-eslint/type-utils": "8.56.0", + "@typescript-eslint/utils": "8.56.0", + "@typescript-eslint/visitor-keys": "8.56.0", + "ignore": "^7.0.5", "natural-compare": "^1.4.0", - "ts-api-utils": "^2.1.0" + "ts-api-utils": "^2.4.0" }, "engines": { "node": "^18.18.0 || ^20.9.0 || >=21.1.0" @@ -2787,9 +2843,9 @@ "url": "https://opencollective.com/typescript-eslint" }, "peerDependencies": { - "@typescript-eslint/parser": "^8.37.0", - "eslint": "^8.57.0 || ^9.0.0", - "typescript": ">=4.8.4 <5.9.0" + "@typescript-eslint/parser": "^8.56.0", + "eslint": "^8.57.0 || ^9.0.0 || ^10.0.0", + "typescript": ">=4.8.4 <6.0.0" } }, "node_modules/@typescript-eslint/eslint-plugin/node_modules/ignore": { @@ -2803,18 +2859,18 @@ } }, "node_modules/@typescript-eslint/parser": { - "version": "8.37.0", - "resolved": "https://registry.npmjs.org/@typescript-eslint/parser/-/parser-8.37.0.tgz", - "integrity": "sha512-kVIaQE9vrN9RLCQMQ3iyRlVJpTiDUY6woHGb30JDkfJErqrQEmtdWH3gV0PBAfGZgQXoqzXOO0T3K6ioApbbAA==", + "version": "8.56.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/parser/-/parser-8.56.0.tgz", + "integrity": "sha512-IgSWvLobTDOjnaxAfDTIHaECbkNlAlKv2j5SjpB2v7QHKv1FIfjwMy8FsDbVfDX/KjmCmYICcw7uGaXLhtsLNg==", "dev": true, "license": "MIT", "peer": true, "dependencies": { - "@typescript-eslint/scope-manager": "8.37.0", - "@typescript-eslint/types": "8.37.0", - "@typescript-eslint/typescript-estree": "8.37.0", - "@typescript-eslint/visitor-keys": "8.37.0", - "debug": "^4.3.4" + "@typescript-eslint/scope-manager": "8.56.0", + "@typescript-eslint/types": "8.56.0", + "@typescript-eslint/typescript-estree": "8.56.0", + "@typescript-eslint/visitor-keys": "8.56.0", + "debug": "^4.4.3" }, "engines": { "node": "^18.18.0 || ^20.9.0 || >=21.1.0" @@ -2824,20 +2880,20 @@ "url": "https://opencollective.com/typescript-eslint" }, "peerDependencies": { - "eslint": "^8.57.0 || ^9.0.0", - "typescript": ">=4.8.4 <5.9.0" + "eslint": "^8.57.0 || ^9.0.0 || ^10.0.0", + "typescript": ">=4.8.4 <6.0.0" } }, "node_modules/@typescript-eslint/project-service": { - "version": "8.37.0", - "resolved": "https://registry.npmjs.org/@typescript-eslint/project-service/-/project-service-8.37.0.tgz", - "integrity": "sha512-BIUXYsbkl5A1aJDdYJCBAo8rCEbAvdquQ8AnLb6z5Lp1u3x5PNgSSx9A/zqYc++Xnr/0DVpls8iQ2cJs/izTXA==", + "version": "8.56.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/project-service/-/project-service-8.56.0.tgz", + "integrity": "sha512-M3rnyL1vIQOMeWxTWIW096/TtVP+8W3p/XnaFflhmcFp+U4zlxUxWj4XwNs6HbDeTtN4yun0GNTTDBw/SvufKg==", "dev": true, "license": "MIT", "dependencies": { - "@typescript-eslint/tsconfig-utils": "^8.37.0", - "@typescript-eslint/types": "^8.37.0", - "debug": "^4.3.4" + "@typescript-eslint/tsconfig-utils": "^8.56.0", + "@typescript-eslint/types": "^8.56.0", + "debug": "^4.4.3" }, "engines": { "node": "^18.18.0 || ^20.9.0 || >=21.1.0" @@ -2847,18 +2903,18 @@ "url": "https://opencollective.com/typescript-eslint" }, "peerDependencies": { - "typescript": ">=4.8.4 <5.9.0" + "typescript": ">=4.8.4 <6.0.0" } }, "node_modules/@typescript-eslint/scope-manager": { - "version": "8.37.0", - "resolved": "https://registry.npmjs.org/@typescript-eslint/scope-manager/-/scope-manager-8.37.0.tgz", - "integrity": "sha512-0vGq0yiU1gbjKob2q691ybTg9JX6ShiVXAAfm2jGf3q0hdP6/BruaFjL/ManAR/lj05AvYCH+5bbVo0VtzmjOA==", + "version": "8.56.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/scope-manager/-/scope-manager-8.56.0.tgz", + "integrity": "sha512-7UiO/XwMHquH+ZzfVCfUNkIXlp/yQjjnlYUyYz7pfvlK3/EyyN6BK+emDmGNyQLBtLGaYrTAI6KOw8tFucWL2w==", "dev": true, "license": "MIT", "dependencies": { - "@typescript-eslint/types": "8.37.0", - "@typescript-eslint/visitor-keys": "8.37.0" + "@typescript-eslint/types": "8.56.0", + "@typescript-eslint/visitor-keys": "8.56.0" }, "engines": { "node": "^18.18.0 || ^20.9.0 || >=21.1.0" @@ -2869,9 +2925,9 @@ } }, "node_modules/@typescript-eslint/tsconfig-utils": { - "version": "8.37.0", - "resolved": "https://registry.npmjs.org/@typescript-eslint/tsconfig-utils/-/tsconfig-utils-8.37.0.tgz", - "integrity": "sha512-1/YHvAVTimMM9mmlPvTec9NP4bobA1RkDbMydxG8omqwJJLEW/Iy2C4adsAESIXU3WGLXFHSZUU+C9EoFWl4Zg==", + "version": "8.56.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/tsconfig-utils/-/tsconfig-utils-8.56.0.tgz", + "integrity": "sha512-bSJoIIt4o3lKXD3xmDh9chZcjCz5Lk8xS7Rxn+6l5/pKrDpkCwtQNQQwZ2qRPk7TkUYhrq3WPIHXOXlbXP0itg==", "dev": true, "license": "MIT", "engines": { @@ -2882,21 +2938,21 @@ "url": "https://opencollective.com/typescript-eslint" }, "peerDependencies": { - "typescript": ">=4.8.4 <5.9.0" + "typescript": ">=4.8.4 <6.0.0" } }, "node_modules/@typescript-eslint/type-utils": { - "version": "8.37.0", - "resolved": "https://registry.npmjs.org/@typescript-eslint/type-utils/-/type-utils-8.37.0.tgz", - "integrity": "sha512-SPkXWIkVZxhgwSwVq9rqj/4VFo7MnWwVaRNznfQDc/xPYHjXnPfLWn+4L6FF1cAz6e7dsqBeMawgl7QjUMj4Ow==", + "version": "8.56.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/type-utils/-/type-utils-8.56.0.tgz", + "integrity": "sha512-qX2L3HWOU2nuDs6GzglBeuFXviDODreS58tLY/BALPC7iu3Fa+J7EOTwnX9PdNBxUI7Uh0ntP0YWGnxCkXzmfA==", "dev": true, "license": "MIT", "dependencies": { - "@typescript-eslint/types": "8.37.0", - "@typescript-eslint/typescript-estree": "8.37.0", - "@typescript-eslint/utils": "8.37.0", - "debug": "^4.3.4", - "ts-api-utils": "^2.1.0" + "@typescript-eslint/types": "8.56.0", + "@typescript-eslint/typescript-estree": "8.56.0", + "@typescript-eslint/utils": "8.56.0", + "debug": "^4.4.3", + "ts-api-utils": "^2.4.0" }, "engines": { "node": "^18.18.0 || ^20.9.0 || >=21.1.0" @@ -2906,14 +2962,14 @@ "url": "https://opencollective.com/typescript-eslint" }, "peerDependencies": { - "eslint": "^8.57.0 || ^9.0.0", - "typescript": ">=4.8.4 <5.9.0" + "eslint": "^8.57.0 || ^9.0.0 || ^10.0.0", + "typescript": ">=4.8.4 <6.0.0" } }, "node_modules/@typescript-eslint/types": { - "version": "8.37.0", - "resolved": "https://registry.npmjs.org/@typescript-eslint/types/-/types-8.37.0.tgz", - "integrity": "sha512-ax0nv7PUF9NOVPs+lmQ7yIE7IQmAf8LGcXbMvHX5Gm+YJUYNAl340XkGnrimxZ0elXyoQJuN5sbg6C4evKA4SQ==", + "version": "8.56.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/types/-/types-8.56.0.tgz", + "integrity": "sha512-DBsLPs3GsWhX5HylbP9HNG15U0bnwut55Lx12bHB9MpXxQ+R5GC8MwQe+N1UFXxAeQDvEsEDY6ZYwX03K7Z6HQ==", "dev": true, "license": "MIT", "engines": { @@ -2925,22 +2981,21 @@ } }, "node_modules/@typescript-eslint/typescript-estree": { - "version": "8.37.0", - "resolved": "https://registry.npmjs.org/@typescript-eslint/typescript-estree/-/typescript-estree-8.37.0.tgz", - "integrity": "sha512-zuWDMDuzMRbQOM+bHyU4/slw27bAUEcKSKKs3hcv2aNnc/tvE/h7w60dwVw8vnal2Pub6RT1T7BI8tFZ1fE+yg==", + "version": "8.56.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/typescript-estree/-/typescript-estree-8.56.0.tgz", + "integrity": "sha512-ex1nTUMWrseMltXUHmR2GAQ4d+WjkZCT4f+4bVsps8QEdh0vlBsaCokKTPlnqBFqqGaxilDNJG7b8dolW2m43Q==", "dev": true, "license": "MIT", "dependencies": { - "@typescript-eslint/project-service": "8.37.0", - "@typescript-eslint/tsconfig-utils": "8.37.0", - "@typescript-eslint/types": "8.37.0", - "@typescript-eslint/visitor-keys": "8.37.0", - "debug": "^4.3.4", - "fast-glob": "^3.3.2", - "is-glob": "^4.0.3", - "minimatch": "^9.0.4", - "semver": "^7.6.0", - "ts-api-utils": "^2.1.0" + "@typescript-eslint/project-service": "8.56.0", + "@typescript-eslint/tsconfig-utils": "8.56.0", + "@typescript-eslint/types": "8.56.0", + "@typescript-eslint/visitor-keys": "8.56.0", + "debug": "^4.4.3", + "minimatch": "^9.0.5", + "semver": "^7.7.3", + "tinyglobby": "^0.2.15", + "ts-api-utils": "^2.4.0" }, "engines": { "node": "^18.18.0 || ^20.9.0 || >=21.1.0" @@ -2950,7 +3005,7 @@ "url": "https://opencollective.com/typescript-eslint" }, "peerDependencies": { - "typescript": ">=4.8.4 <5.9.0" + "typescript": ">=4.8.4 <6.0.0" } }, "node_modules/@typescript-eslint/typescript-estree/node_modules/brace-expansion": { @@ -2980,16 +3035,16 @@ } }, "node_modules/@typescript-eslint/utils": { - "version": "8.37.0", - "resolved": "https://registry.npmjs.org/@typescript-eslint/utils/-/utils-8.37.0.tgz", - "integrity": "sha512-TSFvkIW6gGjN2p6zbXo20FzCABbyUAuq6tBvNRGsKdsSQ6a7rnV6ADfZ7f4iI3lIiXc4F4WWvtUfDw9CJ9pO5A==", + "version": "8.56.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/utils/-/utils-8.56.0.tgz", + "integrity": "sha512-RZ3Qsmi2nFGsS+n+kjLAYDPVlrzf7UhTffrDIKr+h2yzAlYP/y5ZulU0yeDEPItos2Ph46JAL5P/On3pe7kDIQ==", "dev": true, "license": "MIT", "dependencies": { - "@eslint-community/eslint-utils": "^4.7.0", - "@typescript-eslint/scope-manager": "8.37.0", - "@typescript-eslint/types": "8.37.0", - "@typescript-eslint/typescript-estree": "8.37.0" + "@eslint-community/eslint-utils": "^4.9.1", + "@typescript-eslint/scope-manager": "8.56.0", + "@typescript-eslint/types": "8.56.0", + "@typescript-eslint/typescript-estree": "8.56.0" }, "engines": { "node": "^18.18.0 || ^20.9.0 || >=21.1.0" @@ -2999,19 +3054,19 @@ "url": "https://opencollective.com/typescript-eslint" }, "peerDependencies": { - "eslint": "^8.57.0 || ^9.0.0", - "typescript": ">=4.8.4 <5.9.0" + "eslint": "^8.57.0 || ^9.0.0 || ^10.0.0", + "typescript": ">=4.8.4 <6.0.0" } }, "node_modules/@typescript-eslint/visitor-keys": { - "version": "8.37.0", - "resolved": "https://registry.npmjs.org/@typescript-eslint/visitor-keys/-/visitor-keys-8.37.0.tgz", - "integrity": "sha512-YzfhzcTnZVPiLfP/oeKtDp2evwvHLMe0LOy7oe+hb9KKIumLNohYS9Hgp1ifwpu42YWxhZE8yieggz6JpqO/1w==", + "version": "8.56.0", + "resolved": "https://registry.npmjs.org/@typescript-eslint/visitor-keys/-/visitor-keys-8.56.0.tgz", + "integrity": "sha512-q+SL+b+05Ud6LbEE35qe4A99P+htKTKVbyiNEe45eCbJFyh/HVK9QXwlrbz+Q4L8SOW4roxSVwXYj4DMBT7Ieg==", "dev": true, "license": "MIT", "dependencies": { - "@typescript-eslint/types": "8.37.0", - "eslint-visitor-keys": "^4.2.1" + "@typescript-eslint/types": "8.56.0", + "eslint-visitor-keys": "^5.0.0" }, "engines": { "node": "^18.18.0 || ^20.9.0 || >=21.1.0" @@ -3021,6 +3076,19 @@ "url": "https://opencollective.com/typescript-eslint" } }, + "node_modules/@typescript-eslint/visitor-keys/node_modules/eslint-visitor-keys": { + "version": "5.0.0", + "resolved": "https://registry.npmjs.org/eslint-visitor-keys/-/eslint-visitor-keys-5.0.0.tgz", + "integrity": "sha512-A0XeIi7CXU7nPlfHS9loMYEKxUaONu/hTEzHTGba9Huu94Cq1hPivf+DE5erJozZOky0LfvXAyrV/tcswpLI0Q==", + "dev": true, + "license": "Apache-2.0", + "engines": { + "node": "^20.19.0 || ^22.13.0 || >=24" + }, + "funding": { + "url": "https://opencollective.com/eslint" + } + }, "node_modules/@ungap/structured-clone": { "version": "1.3.0", "resolved": "https://registry.npmjs.org/@ungap/structured-clone/-/structured-clone-1.3.0.tgz", @@ -3064,6 +3132,40 @@ } } }, + "node_modules/@vitest/coverage-v8": { + "version": "3.2.4", + "resolved": "https://registry.npmjs.org/@vitest/coverage-v8/-/coverage-v8-3.2.4.tgz", + "integrity": "sha512-EyF9SXU6kS5Ku/U82E259WSnvg6c8KTjppUncuNdm5QHpe17mwREHnjDzozC8x9MZ0xfBUFSaLkRv4TMA75ALQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@ampproject/remapping": "^2.3.0", + "@bcoe/v8-coverage": "^1.0.2", + "ast-v8-to-istanbul": "^0.3.3", + "debug": "^4.4.1", + "istanbul-lib-coverage": "^3.2.2", + "istanbul-lib-report": "^3.0.1", + "istanbul-lib-source-maps": "^5.0.6", + "istanbul-reports": "^3.1.7", + "magic-string": "^0.30.17", + "magicast": "^0.3.5", + "std-env": "^3.9.0", + "test-exclude": "^7.0.1", + "tinyrainbow": "^2.0.0" + }, + "funding": { + "url": "https://opencollective.com/vitest" + }, + "peerDependencies": { + "@vitest/browser": "3.2.4", + "vitest": "3.2.4" + }, + "peerDependenciesMeta": { + "@vitest/browser": { + "optional": true + } + } + }, "node_modules/@vitest/expect": { "version": "3.2.4", "resolved": "https://registry.npmjs.org/@vitest/expect/-/expect-3.2.4.tgz", @@ -3108,16 +3210,6 @@ } } }, - "node_modules/@vitest/mocker/node_modules/estree-walker": { - "version": "3.0.3", - "resolved": "https://registry.npmjs.org/estree-walker/-/estree-walker-3.0.3.tgz", - "integrity": "sha512-7RUKfXgSMMkzt6ZuXmqapOurLGPPfgj6l9uRZ7lRGolvk0y2yocc35LdcxKC5PQZdn2DMqioAQ2NoWcrTKmm6g==", - "dev": true, - "license": "MIT", - "dependencies": { - "@types/estree": "^1.0.0" - } - }, "node_modules/@vitest/pretty-format": { "version": "3.2.4", "resolved": "https://registry.npmjs.org/@vitest/pretty-format/-/pretty-format-3.2.4.tgz", @@ -3296,6 +3388,25 @@ "node": ">=4" } }, + "node_modules/ast-v8-to-istanbul": { + "version": "0.3.11", + "resolved": "https://registry.npmjs.org/ast-v8-to-istanbul/-/ast-v8-to-istanbul-0.3.11.tgz", + "integrity": "sha512-Qya9fkoofMjCBNVdWINMjB5KZvkYfaO9/anwkWnjxibpWUxo5iHl2sOdP7/uAqaRuUYuoo8rDwnbaaKVFxoUvw==", + "dev": true, + "license": "MIT", + "dependencies": { + "@jridgewell/trace-mapping": "^0.3.31", + "estree-walker": "^3.0.3", + "js-tokens": "^10.0.0" + } + }, + "node_modules/ast-v8-to-istanbul/node_modules/js-tokens": { + "version": "10.0.0", + "resolved": "https://registry.npmjs.org/js-tokens/-/js-tokens-10.0.0.tgz", + "integrity": "sha512-lM/UBzQmfJRo9ABXbPWemivdCW8V2G8FHaHdypQaIy523snUjog0W71ayWXTjiR+ixeMyVHN2XcpnTd/liPg/Q==", + "dev": true, + "license": "MIT" + }, "node_modules/async": { "version": "3.2.6", "resolved": "https://registry.npmjs.org/async/-/async-3.2.6.tgz", @@ -3353,9 +3464,9 @@ } }, "node_modules/bits-ui": { - "version": "2.14.4", - "resolved": "https://registry.npmjs.org/bits-ui/-/bits-ui-2.14.4.tgz", - "integrity": "sha512-W6kenhnbd/YVvur+DKkaVJ6GldE53eLewur5AhUCqslYQ0vjZr8eWlOfwZnMiPB+PF5HMVqf61vXBvmyrAmPWg==", + "version": "2.15.7", + "resolved": "https://registry.npmjs.org/bits-ui/-/bits-ui-2.15.7.tgz", + "integrity": "sha512-M9VrQAJXnT3xfhN/joEtVXhO794yBPmadZfNtDT4t4QwI8wgCBmDuv8FlH6K4v0q0Ugw07tumAPfym9MU2BGpg==", "dev": true, "license": "MIT", "dependencies": { @@ -3440,6 +3551,7 @@ "integrity": "sha512-yQbXgO/OSZVD2IsiLlro+7Hf6Q18EJrKSEsdoMzKePKXct3gvD8oLcOQdIzGupr5Fj+EDe8gO/lxc1BzfMpxvA==", "dev": true, "license": "MIT", + "optional": true, "dependencies": { "fill-range": "^7.1.1" }, @@ -3447,6 +3559,22 @@ "node": ">=8" } }, + "node_modules/bundle-name": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/bundle-name/-/bundle-name-4.1.0.tgz", + "integrity": "sha512-tjwM5exMg6BGRI+kNmTntNsvdZS1X8BFYS6tnJ2hdH0kVxM6/eVZ2xy+FqStSWvYmtfFMDLIxurorHwDKfDz5Q==", + "dev": true, + "license": "MIT", + "dependencies": { + "run-applescript": "^7.0.0" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, "node_modules/cac": { "version": "6.7.14", "resolved": "https://registry.npmjs.org/cac/-/cac-6.7.14.tgz", @@ -3609,9 +3737,9 @@ } }, "node_modules/chromatic": { - "version": "12.2.0", - "resolved": "https://registry.npmjs.org/chromatic/-/chromatic-12.2.0.tgz", - "integrity": "sha512-GswmBW9ZptAoTns1BMyjbm55Z7EsIJnUvYKdQqXIBZIKbGErmpA+p4c0BYA+nzw5B0M+rb3Iqp1IaH8TFwIQew==", + "version": "13.3.5", + "resolved": "https://registry.npmjs.org/chromatic/-/chromatic-13.3.5.tgz", + "integrity": "sha512-MzPhxpl838qJUo0A55osCF2ifwPbjcIPeElr1d4SHcjnHoIcg7l1syJDrAYK/a+PcCBrOGi06jPNpQAln5hWgw==", "dev": true, "license": "MIT", "bin": { @@ -3751,9 +3879,9 @@ "license": "MIT" }, "node_modules/debug": { - "version": "4.4.1", - "resolved": "https://registry.npmjs.org/debug/-/debug-4.4.1.tgz", - "integrity": "sha512-KcKCqiftBJcZr++7ykoDIEwSa3XWowTfNPo92BYxjXiyYEVrUQh2aLyhxBCwww+heortUFxEJYcRzosstTEBYQ==", + "version": "4.4.3", + "resolved": "https://registry.npmjs.org/debug/-/debug-4.4.3.tgz", + "integrity": "sha512-RGwwWnwQvkVfavKVt22FGLw+xYSdzARwm0ru6DhTVA3umU5hZc28V3kO4stgYryrTlLpuvgI9GiijltAjNbcqA==", "license": "MIT", "dependencies": { "ms": "^2.1.3" @@ -3829,6 +3957,49 @@ "node": ">=0.10.0" } }, + "node_modules/default-browser": { + "version": "5.5.0", + "resolved": "https://registry.npmjs.org/default-browser/-/default-browser-5.5.0.tgz", + "integrity": "sha512-H9LMLr5zwIbSxrmvikGuI/5KGhZ8E2zH3stkMgM5LpOWDutGM2JZaj460Udnf1a+946zc7YBgrqEWwbk7zHvGw==", + "dev": true, + "license": "MIT", + "dependencies": { + "bundle-name": "^4.1.0", + "default-browser-id": "^5.0.0" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/default-browser-id": { + "version": "5.0.1", + "resolved": "https://registry.npmjs.org/default-browser-id/-/default-browser-id-5.0.1.tgz", + "integrity": "sha512-x1VCxdX4t+8wVfd1so/9w+vQ4vx7lKd2Qp5tDRutErwmR85OgmfX7RlLRMWafRMY7hbEiXIbudNrjOAPa/hL8Q==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/define-lazy-prop": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/define-lazy-prop/-/define-lazy-prop-3.0.0.tgz", + "integrity": "sha512-N+MeXYoqr3pOgn8xfyRPREN7gHakLYjhsHhWGT3fWAiL4IkAt0iDw14QiiEm2bE30c5XX5q0FtAA3CK5f9/BUg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, "node_modules/dequal": { "version": "2.0.3", "resolved": "https://registry.npmjs.org/dequal/-/dequal-2.0.3.tgz", @@ -3896,6 +4067,20 @@ "node": ">= 0.4" } }, + "node_modules/eastasianwidth": { + "version": "0.2.0", + "resolved": "https://registry.npmjs.org/eastasianwidth/-/eastasianwidth-0.2.0.tgz", + "integrity": "sha512-I88TYZWc9XiYHRQ4/3c5rjjfgkjhLyW2luGIheGERbNQ6OY7yTybanSpDXZa8y7VUP9YmDcYa+eyq4ca7iLqWA==", + "dev": true, + "license": "MIT" + }, + "node_modules/emoji-regex": { + "version": "9.2.2", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-9.2.2.tgz", + "integrity": "sha512-L18DaJsXSUk2+42pv8mLs5jJT2hqFkFE4j21wOmgbUqsZ2hL72NsUU785g9RXgo3s0ZNgVl42TiHp3ZtOv/Vyg==", + "dev": true, + "license": "MIT" + }, "node_modules/enhanced-resolve": { "version": "5.18.2", "resolved": "https://registry.npmjs.org/enhanced-resolve/-/enhanced-resolve-5.18.2.tgz", @@ -4031,26 +4216,25 @@ } }, "node_modules/eslint": { - "version": "9.31.0", - "resolved": "https://registry.npmjs.org/eslint/-/eslint-9.31.0.tgz", - "integrity": "sha512-QldCVh/ztyKJJZLr4jXNUByx3gR+TDYZCRXEktiZoUR3PGy4qCmSbkxcIle8GEwGpb5JBZazlaJ/CxLidXdEbQ==", + "version": "9.39.2", + "resolved": "https://registry.npmjs.org/eslint/-/eslint-9.39.2.tgz", + "integrity": "sha512-LEyamqS7W5HB3ujJyvi0HQK/dtVINZvd5mAAp9eT5S/ujByGjiZLCzPcHVzuXbpJDJF/cxwHlfceVUDZ2lnSTw==", "dev": true, "license": "MIT", "peer": true, "dependencies": { - "@eslint-community/eslint-utils": "^4.2.0", + "@eslint-community/eslint-utils": "^4.8.0", "@eslint-community/regexpp": "^4.12.1", - "@eslint/config-array": "^0.21.0", - "@eslint/config-helpers": "^0.3.0", - "@eslint/core": "^0.15.0", + "@eslint/config-array": "^0.21.1", + "@eslint/config-helpers": "^0.4.2", + "@eslint/core": "^0.17.0", "@eslint/eslintrc": "^3.3.1", - "@eslint/js": "9.31.0", - "@eslint/plugin-kit": "^0.3.1", + "@eslint/js": "9.39.2", + "@eslint/plugin-kit": "^0.4.1", "@humanfs/node": "^0.16.6", "@humanwhocodes/module-importer": "^1.0.1", "@humanwhocodes/retry": "^0.4.2", "@types/estree": "^1.0.6", - "@types/json-schema": "^7.0.15", "ajv": "^6.12.4", "chalk": "^4.0.0", "cross-spawn": "^7.0.6", @@ -4109,23 +4293,23 @@ } }, "node_modules/eslint-plugin-storybook": { - "version": "10.0.7", - "resolved": "https://registry.npmjs.org/eslint-plugin-storybook/-/eslint-plugin-storybook-10.0.7.tgz", - "integrity": "sha512-qOQq9KdT1jsBgT3qsxUH2n67aj1WR8D1XCoER8Q6yuVlS5TimNwk1mZeWkXVf/o4RQQT6flT2y5cG2gPLZPvJA==", + "version": "10.2.9", + "resolved": "https://registry.npmjs.org/eslint-plugin-storybook/-/eslint-plugin-storybook-10.2.9.tgz", + "integrity": "sha512-nmPxjPw2KfmosqAUb/W0jmEfAZzK97kyJ8W5KMuweCblwjIL0hI/GMsWSP8CCBPnhQ9LnuxtT8JtQUOsslcbwA==", "dev": true, "license": "MIT", "dependencies": { - "@typescript-eslint/utils": "^8.8.1" + "@typescript-eslint/utils": "^8.48.0" }, "peerDependencies": { "eslint": ">=8", - "storybook": "^10.0.7" + "storybook": "^10.2.9" } }, "node_modules/eslint-plugin-svelte": { - "version": "3.11.0", - "resolved": "https://registry.npmjs.org/eslint-plugin-svelte/-/eslint-plugin-svelte-3.11.0.tgz", - "integrity": "sha512-KliWlkieHyEa65aQIkRwUFfHzT5Cn4u3BQQsu3KlkJOs7c1u7ryn84EWaOjEzilbKgttT4OfBURA8Uc4JBSQIw==", + "version": "3.15.0", + "resolved": "https://registry.npmjs.org/eslint-plugin-svelte/-/eslint-plugin-svelte-3.15.0.tgz", + "integrity": "sha512-QKB7zqfuB8aChOfBTComgDptMf2yxiJx7FE04nneCmtQzgTHvY8UJkuh8J2Rz7KB9FFV9aTHX6r7rdYGvG8T9Q==", "dev": true, "license": "MIT", "dependencies": { @@ -4138,7 +4322,7 @@ "postcss-load-config": "^3.1.4", "postcss-safe-parser": "^7.0.0", "semver": "^7.6.3", - "svelte-eslint-parser": "^1.3.0" + "svelte-eslint-parser": "^1.4.0" }, "engines": { "node": "^18.18.0 || ^20.9.0 || >=21.1.0" @@ -4147,7 +4331,7 @@ "url": "https://github.com/sponsors/ota-meshi" }, "peerDependencies": { - "eslint": "^8.57.1 || ^9.0.0", + "eslint": "^8.57.1 || ^9.0.0 || ^10.0.0", "svelte": "^3.37.0 || ^4.0.0 || ^5.0.0" }, "peerDependenciesMeta": { @@ -4270,6 +4454,16 @@ "node": ">=4.0" } }, + "node_modules/estree-walker": { + "version": "3.0.3", + "resolved": "https://registry.npmjs.org/estree-walker/-/estree-walker-3.0.3.tgz", + "integrity": "sha512-7RUKfXgSMMkzt6ZuXmqapOurLGPPfgj6l9uRZ7lRGolvk0y2yocc35LdcxKC5PQZdn2DMqioAQ2NoWcrTKmm6g==", + "dev": true, + "license": "MIT", + "dependencies": { + "@types/estree": "^1.0.0" + } + }, "node_modules/esutils": { "version": "2.0.3", "resolved": "https://registry.npmjs.org/esutils/-/esutils-2.0.3.tgz", @@ -4310,36 +4504,6 @@ "dev": true, "license": "MIT" }, - "node_modules/fast-glob": { - "version": "3.3.3", - "resolved": "https://registry.npmjs.org/fast-glob/-/fast-glob-3.3.3.tgz", - "integrity": "sha512-7MptL8U0cqcFdzIzwOTHoilX9x5BrNqye7Z/LuC7kCMRio1EMSyqRK3BEAUD7sXRq4iT4AzTVuZdhgQ2TCvYLg==", - "dev": true, - "license": "MIT", - "dependencies": { - "@nodelib/fs.stat": "^2.0.2", - "@nodelib/fs.walk": "^1.2.3", - "glob-parent": "^5.1.2", - "merge2": "^1.3.0", - "micromatch": "^4.0.8" - }, - "engines": { - "node": ">=8.6.0" - } - }, - "node_modules/fast-glob/node_modules/glob-parent": { - "version": "5.1.2", - "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-5.1.2.tgz", - "integrity": "sha512-AOIgSQCepiJYwP3ARnGx+5VnTu2HBYdzbGP45eLw1vr3zB3vZLeyed1sC9hnbcOc9/SrMyM5RPQrkGz4aS9Zow==", - "dev": true, - "license": "ISC", - "dependencies": { - "is-glob": "^4.0.1" - }, - "engines": { - "node": ">= 6" - } - }, "node_modules/fast-json-stable-stringify": { "version": "2.1.0", "resolved": "https://registry.npmjs.org/fast-json-stable-stringify/-/fast-json-stable-stringify-2.1.0.tgz", @@ -4354,16 +4518,6 @@ "dev": true, "license": "MIT" }, - "node_modules/fastq": { - "version": "1.19.1", - "resolved": "https://registry.npmjs.org/fastq/-/fastq-1.19.1.tgz", - "integrity": "sha512-GwLTyxkCXjXbxqIhTsMI2Nui8huMPtnxg7krajPJAjnEG/iiOS7i+zCtWGZR9G0NBKbXKh6X9m9UIsYX/N6vvQ==", - "dev": true, - "license": "ISC", - "dependencies": { - "reusify": "^1.0.4" - } - }, "node_modules/fdir": { "version": "6.5.0", "resolved": "https://registry.npmjs.org/fdir/-/fdir-6.5.0.tgz", @@ -4418,6 +4572,7 @@ "integrity": "sha512-YsGpe3WHLK8ZYi4tWDg2Jy3ebRz2rXowDxnld4bkQB00cc/1Zw9AWnC0i9ztDJitivtQvaI9KaLyKrc+hBW0yg==", "dev": true, "license": "MIT", + "optional": true, "dependencies": { "to-regex-range": "^5.0.1" }, @@ -4484,6 +4639,23 @@ } } }, + "node_modules/foreground-child": { + "version": "3.3.1", + "resolved": "https://registry.npmjs.org/foreground-child/-/foreground-child-3.3.1.tgz", + "integrity": "sha512-gIXjKqtFuWEgzFRJA9WCQeSJLZDjgJUOMCMzxtvFq/37KojM1BFGufqsCy0r4qSQmYLsZYMeyRqzIWOMup03sw==", + "dev": true, + "license": "ISC", + "dependencies": { + "cross-spawn": "^7.0.6", + "signal-exit": "^4.0.1" + }, + "engines": { + "node": ">=14" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, "node_modules/fsevents": { "version": "2.3.2", "resolved": "https://registry.npmjs.org/fsevents/-/fsevents-2.3.2.tgz", @@ -4548,6 +4720,27 @@ "node": ">= 0.4" } }, + "node_modules/glob": { + "version": "10.5.0", + "resolved": "https://registry.npmjs.org/glob/-/glob-10.5.0.tgz", + "integrity": "sha512-DfXN8DfhJ7NH3Oe7cFmu3NCu1wKbkReJ8TorzSAFbSKrlNaQSKfIzqYqVY8zlbs2NLBbWpRiU52GX2PbaBVNkg==", + "dev": true, + "license": "ISC", + "dependencies": { + "foreground-child": "^3.1.0", + "jackspeak": "^3.1.2", + "minimatch": "^9.0.4", + "minipass": "^7.1.2", + "package-json-from-dist": "^1.0.0", + "path-scurry": "^1.11.1" + }, + "bin": { + "glob": "dist/esm/bin.mjs" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, "node_modules/glob-parent": { "version": "6.0.2", "resolved": "https://registry.npmjs.org/glob-parent/-/glob-parent-6.0.2.tgz", @@ -4561,6 +4754,32 @@ "node": ">=10.13.0" } }, + "node_modules/glob/node_modules/brace-expansion": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", + "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "balanced-match": "^1.0.0" + } + }, + "node_modules/glob/node_modules/minimatch": { + "version": "9.0.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", + "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", + "dev": true, + "license": "ISC", + "dependencies": { + "brace-expansion": "^2.0.1" + }, + "engines": { + "node": ">=16 || 14 >=14.17" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, "node_modules/globals": { "version": "16.3.0", "resolved": "https://registry.npmjs.org/globals/-/globals-16.3.0.tgz", @@ -4594,13 +4813,6 @@ "dev": true, "license": "ISC" }, - "node_modules/graphemer": { - "version": "1.4.0", - "resolved": "https://registry.npmjs.org/graphemer/-/graphemer-1.4.0.tgz", - "integrity": "sha512-EtKwoO6kxCL9WO5xipiHTZlSzBm7WLT627TqC/uVRd0HKmq8NXyebnNYxDoBi7wt8eTWrUrKXCOVaFq9x1kgag==", - "dev": true, - "license": "MIT" - }, "node_modules/has-flag": { "version": "4.0.0", "resolved": "https://registry.npmjs.org/has-flag/-/has-flag-4.0.0.tgz", @@ -4909,6 +5121,13 @@ "node": ">=12" } }, + "node_modules/html-escaper": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/html-escaper/-/html-escaper-2.0.2.tgz", + "integrity": "sha512-H2iMtd0I4Mt5eYiapRdIDjp+XzelXQ0tFE4JS7YFwFevXXMmOp9myNrUvCg0D6ws8iqkRPBfKHgbwig1SmlLfg==", + "dev": true, + "license": "MIT" + }, "node_modules/html-void-elements": { "version": "3.0.0", "resolved": "https://registry.npmjs.org/html-void-elements/-/html-void-elements-3.0.0.tgz", @@ -5035,6 +5254,22 @@ "integrity": "sha512-0aO8FkhNZlj/ZIbNi7Lxxr12obT7cL1moPfE4tg1LkX7LlLfC6DeX4l2ZEud1ukP9jNQyNnfzQVqwbwmAATY4Q==", "license": "MIT" }, + "node_modules/is-docker": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/is-docker/-/is-docker-3.0.0.tgz", + "integrity": "sha512-eljcgEDlEns/7AXFosB5K/2nCM4P7FQPkGc/DWLy5rmFEWvZayGrik1d9/QIY5nJ4f9YsVvBkA6kJpHn9rISdQ==", + "dev": true, + "license": "MIT", + "bin": { + "is-docker": "cli.js" + }, + "engines": { + "node": "^12.20.0 || ^14.13.1 || >=16.0.0" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, "node_modules/is-extglob": { "version": "2.1.1", "resolved": "https://registry.npmjs.org/is-extglob/-/is-extglob-2.1.1.tgz", @@ -5045,6 +5280,16 @@ "node": ">=0.10.0" } }, + "node_modules/is-fullwidth-code-point": { + "version": "3.0.0", + "resolved": "https://registry.npmjs.org/is-fullwidth-code-point/-/is-fullwidth-code-point-3.0.0.tgz", + "integrity": "sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=8" + } + }, "node_modules/is-glob": { "version": "4.0.3", "resolved": "https://registry.npmjs.org/is-glob/-/is-glob-4.0.3.tgz", @@ -5058,12 +5303,32 @@ "node": ">=0.10.0" } }, + "node_modules/is-inside-container": { + "version": "1.0.0", + "resolved": "https://registry.npmjs.org/is-inside-container/-/is-inside-container-1.0.0.tgz", + "integrity": "sha512-KIYLCCJghfHZxqjYBE7rEy0OBuTd5xCHS7tHVgvCLkx7StIoaxwNW3hCALgEUjFfeRk+MG/Qxmp/vtETEF3tRA==", + "dev": true, + "license": "MIT", + "dependencies": { + "is-docker": "^3.0.0" + }, + "bin": { + "is-inside-container": "cli.js" + }, + "engines": { + "node": ">=14.16" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, "node_modules/is-number": { "version": "7.0.0", "resolved": "https://registry.npmjs.org/is-number/-/is-number-7.0.0.tgz", "integrity": "sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng==", "dev": true, "license": "MIT", + "optional": true, "engines": { "node": ">=0.12.0" } @@ -5080,6 +5345,22 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/is-wsl": { + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/is-wsl/-/is-wsl-3.1.0.tgz", + "integrity": "sha512-UcVfVfaK4Sc4m7X3dUSoHoozQGBEFeDC+zVo06t98xe8CzHSZZBekNXH+tu0NalHolcJ/QAGqS46Hef7QXBIMw==", + "dev": true, + "license": "MIT", + "dependencies": { + "is-inside-container": "^1.0.0" + }, + "engines": { + "node": ">=16" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, "node_modules/isexe": { "version": "2.0.0", "resolved": "https://registry.npmjs.org/isexe/-/isexe-2.0.0.tgz", @@ -5087,6 +5368,76 @@ "dev": true, "license": "ISC" }, + "node_modules/istanbul-lib-coverage": { + "version": "3.2.2", + "resolved": "https://registry.npmjs.org/istanbul-lib-coverage/-/istanbul-lib-coverage-3.2.2.tgz", + "integrity": "sha512-O8dpsF+r0WV/8MNRKfnmrtCWhuKjxrq2w+jpzBL5UZKTi2LeVWnWOmWRxFlesJONmc+wLAGvKQZEOanko0LFTg==", + "dev": true, + "license": "BSD-3-Clause", + "engines": { + "node": ">=8" + } + }, + "node_modules/istanbul-lib-report": { + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/istanbul-lib-report/-/istanbul-lib-report-3.0.1.tgz", + "integrity": "sha512-GCfE1mtsHGOELCU8e/Z7YWzpmybrx/+dSTfLrvY8qRmaY6zXTKWn6WQIjaAFw069icm6GVMNkgu0NzI4iPZUNw==", + "dev": true, + "license": "BSD-3-Clause", + "dependencies": { + "istanbul-lib-coverage": "^3.0.0", + "make-dir": "^4.0.0", + "supports-color": "^7.1.0" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/istanbul-lib-source-maps": { + "version": "5.0.6", + "resolved": "https://registry.npmjs.org/istanbul-lib-source-maps/-/istanbul-lib-source-maps-5.0.6.tgz", + "integrity": "sha512-yg2d+Em4KizZC5niWhQaIomgf5WlL4vOOjZ5xGCmF8SnPE/mDWWXgvRExdcpCgh9lLRRa1/fSYp2ymmbJ1pI+A==", + "dev": true, + "license": "BSD-3-Clause", + "dependencies": { + "@jridgewell/trace-mapping": "^0.3.23", + "debug": "^4.1.1", + "istanbul-lib-coverage": "^3.0.0" + }, + "engines": { + "node": ">=10" + } + }, + "node_modules/istanbul-reports": { + "version": "3.2.0", + "resolved": "https://registry.npmjs.org/istanbul-reports/-/istanbul-reports-3.2.0.tgz", + "integrity": "sha512-HGYWWS/ehqTV3xN10i23tkPkpH46MLCIMFNCaaKNavAXTF1RkqxawEPtnjnGZ6XKSInBKkiOA5BKS+aZiY3AvA==", + "dev": true, + "license": "BSD-3-Clause", + "dependencies": { + "html-escaper": "^2.0.0", + "istanbul-lib-report": "^3.0.0" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/jackspeak": { + "version": "3.4.3", + "resolved": "https://registry.npmjs.org/jackspeak/-/jackspeak-3.4.3.tgz", + "integrity": "sha512-OGlZQpz2yfahA/Rd1Y8Cd9SIEsqvXkLVoSw/cgwhnhFMDbsQFeZYoJJ7bIZBS9BcamUW96asq/npPWugM+RQBw==", + "dev": true, + "license": "BlueOak-1.0.0", + "dependencies": { + "@isaacs/cliui": "^8.0.2" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + }, + "optionalDependencies": { + "@pkgjs/parseargs": "^0.11.0" + } + }, "node_modules/jiti": { "version": "2.4.2", "resolved": "https://registry.npmjs.org/jiti/-/jiti-2.4.2.tgz", @@ -5481,9 +5832,9 @@ } }, "node_modules/lodash": { - "version": "4.17.21", - "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.21.tgz", - "integrity": "sha512-v2kDEe57lecTulaDIuNTPy3Ry4gLGJ6Z1O3vE1krgXZNrsQ+LFTGHVxVjcXPs17LhbZVGedAJv8XZ1tvj5FvSg==", + "version": "4.17.23", + "resolved": "https://registry.npmjs.org/lodash/-/lodash-4.17.23.tgz", + "integrity": "sha512-LgVTMpQtIopCi79SJeDiP0TfWi5CNEc/L/aRdTh3yIvmZXTnheWpKjSZhnvMl8iXbC1tFg9gdHHDMLoV7CnG+w==", "dev": true, "license": "MIT" }, @@ -5540,6 +5891,13 @@ "url": "https://github.com/sponsors/wooorm" } }, + "node_modules/lru-cache": { + "version": "10.4.3", + "resolved": "https://registry.npmjs.org/lru-cache/-/lru-cache-10.4.3.tgz", + "integrity": "sha512-JNAzZcXrCt42VGLuYz0zfAzDfAvJWW6AfYlDBQyDV5DClI2m5sAmK+OIO7s59XfsRsWHp02jAJrRadPRGTt6SQ==", + "dev": true, + "license": "ISC" + }, "node_modules/lz-string": { "version": "1.5.0", "resolved": "https://registry.npmjs.org/lz-string/-/lz-string-1.5.0.tgz", @@ -5559,6 +5917,34 @@ "@jridgewell/sourcemap-codec": "^1.5.0" } }, + "node_modules/magicast": { + "version": "0.3.5", + "resolved": "https://registry.npmjs.org/magicast/-/magicast-0.3.5.tgz", + "integrity": "sha512-L0WhttDl+2BOsybvEOLK7fW3UA0OQ0IQ2d6Zl2x/a6vVRs3bAY0ECOSHHeL5jD+SbOpOCUEi0y1DgHEn9Qn1AQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "@babel/parser": "^7.25.4", + "@babel/types": "^7.25.4", + "source-map-js": "^1.2.0" + } + }, + "node_modules/make-dir": { + "version": "4.0.0", + "resolved": "https://registry.npmjs.org/make-dir/-/make-dir-4.0.0.tgz", + "integrity": "sha512-hXdUTZYIVOt1Ex//jAQi+wTZZpUpwBj/0QsOzqegb3rGMMeJiSEu5xLHnYfBrRV4RH2+OCSOO95Is/7x1WJ4bw==", + "dev": true, + "license": "MIT", + "dependencies": { + "semver": "^7.5.3" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, "node_modules/markdown-table": { "version": "3.0.4", "resolved": "https://registry.npmjs.org/markdown-table/-/markdown-table-3.0.4.tgz", @@ -5927,16 +6313,6 @@ "url": "https://opencollective.com/unified" } }, - "node_modules/merge2": { - "version": "1.4.1", - "resolved": "https://registry.npmjs.org/merge2/-/merge2-1.4.1.tgz", - "integrity": "sha512-8q7VEgMJW4J8tcfVPy8g09NcQwZdbwFEqhe/WZkoIzjn/3TGDwtOCYtXGxA3O8tPzpczCCDgv+P2P5y00ZJOOg==", - "dev": true, - "license": "MIT", - "engines": { - "node": ">= 8" - } - }, "node_modules/micromark": { "version": "4.0.2", "resolved": "https://registry.npmjs.org/micromark/-/micromark-4.0.2.tgz", @@ -6526,6 +6902,7 @@ "integrity": "sha512-PXwfBhYu0hBCPw8Dn0E+WDYb7af3dSLVWKi3HGv84IdF4TyFoC0ysxFd0Goxw7nSv4T/PzEJQxsYsEiFCKo2BA==", "dev": true, "license": "MIT", + "optional": true, "dependencies": { "braces": "^3.0.3", "picomatch": "^2.3.1" @@ -6540,6 +6917,7 @@ "integrity": "sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==", "dev": true, "license": "MIT", + "optional": true, "engines": { "node": ">=8.6" }, @@ -6614,9 +6992,9 @@ } }, "node_modules/minizlib": { - "version": "3.0.2", - "resolved": "https://registry.npmjs.org/minizlib/-/minizlib-3.0.2.tgz", - "integrity": "sha512-oG62iEk+CYt5Xj2YqI5Xi9xWUeZhDI8jjQmC5oThVH5JGCTgIjr7ciJDzC7MBzYd//WvR1OTmP5Q38Q8ShQtVA==", + "version": "3.1.0", + "resolved": "https://registry.npmjs.org/minizlib/-/minizlib-3.1.0.tgz", + "integrity": "sha512-KZxYo1BUkWD2TVFLr0MQoM8vUUigWD3LlD83a/75BqC+4qE0Hb1Vo5v1FgcfaNXvfXzr+5EhQ6ing/CaBijTlw==", "dev": true, "license": "MIT", "dependencies": { @@ -6626,22 +7004,6 @@ "node": ">= 18" } }, - "node_modules/mkdirp": { - "version": "3.0.1", - "resolved": "https://registry.npmjs.org/mkdirp/-/mkdirp-3.0.1.tgz", - "integrity": "sha512-+NsyUUAZDmo6YVHzL/stxSu3t9YS1iljliy3BSDrXJ/dkn1KYdmtZODGGjLcc9XLgVVpH4KshHB8XmZgMhaBXg==", - "dev": true, - "license": "MIT", - "bin": { - "mkdirp": "dist/cjs/src/bin.js" - }, - "engines": { - "node": ">=10" - }, - "funding": { - "url": "https://github.com/sponsors/isaacs" - } - }, "node_modules/mode-watcher": { "version": "1.1.0", "resolved": "https://registry.npmjs.org/mode-watcher/-/mode-watcher-1.1.0.tgz", @@ -6728,6 +7090,25 @@ "url": "https://github.com/sponsors/ljharb" } }, + "node_modules/open": { + "version": "10.2.0", + "resolved": "https://registry.npmjs.org/open/-/open-10.2.0.tgz", + "integrity": "sha512-YgBpdJHPyQ2UE5x+hlSXcnejzAvD0b22U2OuAP+8OnlJT+PjWPxtgmGqKKc+RgTM63U9gN0YzrYc71R2WT/hTA==", + "dev": true, + "license": "MIT", + "dependencies": { + "default-browser": "^5.2.1", + "define-lazy-prop": "^3.0.0", + "is-inside-container": "^1.0.0", + "wsl-utils": "^0.1.0" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, "node_modules/opener": { "version": "1.5.2", "resolved": "https://registry.npmjs.org/opener/-/opener-1.5.2.tgz", @@ -6788,6 +7169,13 @@ "url": "https://github.com/sponsors/sindresorhus" } }, + "node_modules/package-json-from-dist": { + "version": "1.0.1", + "resolved": "https://registry.npmjs.org/package-json-from-dist/-/package-json-from-dist-1.0.1.tgz", + "integrity": "sha512-UEZIS3/by4OC8vL3P2dTXRETpebLI2NiI5vIrjaD/5UtrkFX/tNbwjTSRAGC/+7CAo2pIcBaRgWmcBBHcsaCIw==", + "dev": true, + "license": "BlueOak-1.0.0" + }, "node_modules/parent-module": { "version": "1.0.1", "resolved": "https://registry.npmjs.org/parent-module/-/parent-module-1.0.1.tgz", @@ -6834,6 +7222,23 @@ "node": ">=8" } }, + "node_modules/path-scurry": { + "version": "1.11.1", + "resolved": "https://registry.npmjs.org/path-scurry/-/path-scurry-1.11.1.tgz", + "integrity": "sha512-Xa4Nw17FS9ApQFJ9umLiJS4orGjm7ZzwUrwamcGQuHSzDyth9boKDaycYdDcZDuqYATXw4HFXgaqWTctW/v1HA==", + "dev": true, + "license": "BlueOak-1.0.0", + "dependencies": { + "lru-cache": "^10.2.0", + "minipass": "^5.0.0 || ^6.0.2 || ^7.0.0" + }, + "engines": { + "node": ">=16 || 14 >=14.18" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, "node_modules/pathe": { "version": "2.0.3", "resolved": "https://registry.npmjs.org/pathe/-/pathe-2.0.3.tgz", @@ -7238,30 +7643,6 @@ "node": ">=6" } }, - "node_modules/prompts": { - "version": "2.4.2", - "resolved": "https://registry.npmjs.org/prompts/-/prompts-2.4.2.tgz", - "integrity": "sha512-NxNv/kLguCA7p3jE8oL2aEBsrJWgAakBpgmgK6lpPWV+WuOmY6r2/zbAVnP+T8bQlA0nzHXSJSJW0Hq7ylaD2Q==", - "dev": true, - "license": "MIT", - "dependencies": { - "kleur": "^3.0.3", - "sisteransi": "^1.0.5" - }, - "engines": { - "node": ">= 6" - } - }, - "node_modules/prompts/node_modules/kleur": { - "version": "3.0.3", - "resolved": "https://registry.npmjs.org/kleur/-/kleur-3.0.3.tgz", - "integrity": "sha512-eTIzlVOSUR+JxdDFepEYcBMtZ9Qqdef+rnzWdRZuMbOywu5tO2w2N7rqjoANZ5k9vywhL6Br1VRjUIgTQx4E8w==", - "dev": true, - "license": "MIT", - "engines": { - "node": ">=6" - } - }, "node_modules/property-information": { "version": "7.1.0", "resolved": "https://registry.npmjs.org/property-information/-/property-information-7.1.0.tgz", @@ -7283,9 +7664,9 @@ } }, "node_modules/qs": { - "version": "6.14.0", - "resolved": "https://registry.npmjs.org/qs/-/qs-6.14.0.tgz", - "integrity": "sha512-YWWTjgABSKcvs/nWBi9PycY/JiPJqOD4JA6o9Sej2AtvSGarXxKC3OQSk4pAarbdQlKAh5D4FCQkJNkW+GAn3w==", + "version": "6.15.0", + "resolved": "https://registry.npmjs.org/qs/-/qs-6.15.0.tgz", + "integrity": "sha512-mAZTtNCeetKMH+pSjrb76NAM8V9a05I9aBZOHztWy/UqcJdQYNsf59vrRKWnojAT9Y+GbIvoTBC++CPHqpDBhQ==", "dev": true, "license": "BSD-3-Clause", "dependencies": { @@ -7298,27 +7679,6 @@ "url": "https://github.com/sponsors/ljharb" } }, - "node_modules/queue-microtask": { - "version": "1.2.3", - "resolved": "https://registry.npmjs.org/queue-microtask/-/queue-microtask-1.2.3.tgz", - "integrity": "sha512-NuaNSa6flKT5JaSYQzJok04JzTL1CA6aGhv5rfLW3PgqA+M2ChpZQnAC8h8i4ZFkBS8X5RqkDBHA7r4hej3K9A==", - "dev": true, - "funding": [ - { - "type": "github", - "url": "https://github.com/sponsors/feross" - }, - { - "type": "patreon", - "url": "https://www.patreon.com/feross" - }, - { - "type": "consulting", - "url": "https://feross.org/support" - } - ], - "license": "MIT" - }, "node_modules/react": { "version": "19.1.0", "resolved": "https://registry.npmjs.org/react/-/react-19.1.0.tgz", @@ -7596,17 +7956,6 @@ "node": ">=4" } }, - "node_modules/reusify": { - "version": "1.1.0", - "resolved": "https://registry.npmjs.org/reusify/-/reusify-1.1.0.tgz", - "integrity": "sha512-g6QUff04oZpHs0eG5p83rFLhHeV00ug/Yf9nZM6fLeUrPguBTkTQOdpAWWspMh55TZfVQDPaN3NQJfbVRAxdIw==", - "dev": true, - "license": "MIT", - "engines": { - "iojs": ">=1.0.0", - "node": ">=0.10.0" - } - }, "node_modules/rollup": { "version": "4.45.1", "resolved": "https://registry.npmjs.org/rollup/-/rollup-4.45.1.tgz", @@ -7648,28 +7997,17 @@ "fsevents": "~2.3.2" } }, - "node_modules/run-parallel": { - "version": "1.2.0", - "resolved": "https://registry.npmjs.org/run-parallel/-/run-parallel-1.2.0.tgz", - "integrity": "sha512-5l4VyZR86LZ/lDxZTR6jqL8AFE2S0IFLMP26AbjsLVADxHdhB/c0GUsH+y39UfCi3dzz8OlQuPmnaJOMoDHQBA==", + "node_modules/run-applescript": { + "version": "7.1.0", + "resolved": "https://registry.npmjs.org/run-applescript/-/run-applescript-7.1.0.tgz", + "integrity": "sha512-DPe5pVFaAsinSaV6QjQ6gdiedWDcRCbUuiQfQa2wmWV7+xC9bGulGI8+TdRmoFkAPaBXk8CrAbnlY2ISniJ47Q==", "dev": true, - "funding": [ - { - "type": "github", - "url": "https://github.com/sponsors/feross" - }, - { - "type": "patreon", - "url": "https://www.patreon.com/feross" - }, - { - "type": "consulting", - "url": "https://feross.org/support" - } - ], "license": "MIT", - "dependencies": { - "queue-microtask": "^1.2.2" + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" } }, "node_modules/runed": { @@ -7758,9 +8096,9 @@ "license": "MIT" }, "node_modules/semver": { - "version": "7.7.2", - "resolved": "https://registry.npmjs.org/semver/-/semver-7.7.2.tgz", - "integrity": "sha512-RF0Fw+rO5AMf9MAyaRXI4AV0Ulj5lMHqVxxdSgiVbixSCXoEmmX/jk0CuJw4+3SqroYO9VoUh+HcuJivvtJemA==", + "version": "7.7.3", + "resolved": "https://registry.npmjs.org/semver/-/semver-7.7.3.tgz", + "integrity": "sha512-SdsKMrI9TdgjdweUSR9MweHA4EJ8YxHn8DFaDisvhVlUOe4BF1tLD7GAj0lIqWVl+dPb/rExr0Btby5loQm20Q==", "dev": true, "license": "ISC", "bin": { @@ -7771,9 +8109,9 @@ } }, "node_modules/set-cookie-parser": { - "version": "2.7.1", - "resolved": "https://registry.npmjs.org/set-cookie-parser/-/set-cookie-parser-2.7.1.tgz", - "integrity": "sha512-IOc8uWeOZgnb3ptbCURJWNjWUPcO3ZnTTdzsurqERrP6nPyv+paC55vJM0LpOlT2ne+Ix+9+CRG1MNLlyZ4GjQ==", + "version": "3.0.1", + "resolved": "https://registry.npmjs.org/set-cookie-parser/-/set-cookie-parser-3.0.1.tgz", + "integrity": "sha512-n7Z7dXZhJbwuAHhNzkTti6Aw9QDDjZtm3JTpTGATIdNzdQz5GuFs22w90BcvF4INfnrL5xrX3oGsuqO5Dx3A1Q==", "dev": true, "license": "MIT" }, @@ -7883,6 +8221,19 @@ "dev": true, "license": "ISC" }, + "node_modules/signal-exit": { + "version": "4.1.0", + "resolved": "https://registry.npmjs.org/signal-exit/-/signal-exit-4.1.0.tgz", + "integrity": "sha512-bzyZ1e88w9O1iNJbKnOlvYTrWPDl46O1bG0D3XInv+9tkPrxrN8jUUTiFlDkkmKWgn1M6CfIA13SuGqOa9Korw==", + "dev": true, + "license": "ISC", + "engines": { + "node": ">=14" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, "node_modules/sirv": { "version": "3.0.1", "resolved": "https://registry.npmjs.org/sirv/-/sirv-3.0.1.tgz", @@ -7898,13 +8249,6 @@ "node": ">=18" } }, - "node_modules/sisteransi": { - "version": "1.0.5", - "resolved": "https://registry.npmjs.org/sisteransi/-/sisteransi-1.0.5.tgz", - "integrity": "sha512-bLGGlR1QxBcynn2d5YmDX4MGjlZvy2MRBDRNHLJ8VI6l6+9FUiyTFNJ0IveOSP0bcXgVDPRcfGqA0pjaqUpfVg==", - "dev": true, - "license": "MIT" - }, "node_modules/source-map": { "version": "0.6.1", "resolved": "https://registry.npmjs.org/source-map/-/source-map-0.6.1.tgz", @@ -7950,23 +8294,24 @@ "license": "MIT" }, "node_modules/storybook": { - "version": "10.0.7", - "resolved": "https://registry.npmjs.org/storybook/-/storybook-10.0.7.tgz", - "integrity": "sha512-7smAu0o+kdm378Q2uIddk32pn0UdIbrtTVU+rXRVtTVTCrK/P2cCui2y4JH+Bl3NgEq1bbBQpCAF/HKrDjk2Qw==", + "version": "10.2.9", + "resolved": "https://registry.npmjs.org/storybook/-/storybook-10.2.9.tgz", + "integrity": "sha512-DGok7XwIwdPWF+a49Yw+4madER5DZWRo9CdyySBLT3zeuxiEPt0Ua7ouJHm/y6ojnb/FVKZcQe8YmrE71s0qPQ==", "dev": true, "license": "MIT", "peer": true, "dependencies": { "@storybook/global": "^5.0.0", - "@storybook/icons": "^1.6.0", + "@storybook/icons": "^2.0.1", "@testing-library/jest-dom": "^6.6.3", "@testing-library/user-event": "^14.6.1", "@vitest/expect": "3.2.4", - "@vitest/mocker": "3.2.4", "@vitest/spy": "3.2.4", - "esbuild": "^0.18.0 || ^0.19.0 || ^0.20.0 || ^0.21.0 || ^0.22.0 || ^0.23.0 || ^0.24.0 || ^0.25.0", + "esbuild": "^0.18.0 || ^0.19.0 || ^0.20.0 || ^0.21.0 || ^0.22.0 || ^0.23.0 || ^0.24.0 || ^0.25.0 || ^0.26.0 || ^0.27.0", + "open": "^10.2.0", "recast": "^0.23.5", - "semver": "^7.6.2", + "semver": "^7.7.3", + "use-sync-external-store": "^1.5.0", "ws": "^8.18.0" }, "bin": { @@ -7985,6 +8330,60 @@ } } }, + "node_modules/string-width": { + "version": "5.1.2", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-5.1.2.tgz", + "integrity": "sha512-HnLOCR3vjcY8beoNLtcjZ5/nxn2afmME6lhrDrebokqMap+XbeW8n9TXpPDOqdGK5qcI3oT0GKTW6wC7EMiVqA==", + "dev": true, + "license": "MIT", + "dependencies": { + "eastasianwidth": "^0.2.0", + "emoji-regex": "^9.2.2", + "strip-ansi": "^7.0.1" + }, + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, + "node_modules/string-width-cjs": { + "name": "string-width", + "version": "4.2.3", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz", + "integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==", + "dev": true, + "license": "MIT", + "dependencies": { + "emoji-regex": "^8.0.0", + "is-fullwidth-code-point": "^3.0.0", + "strip-ansi": "^6.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/string-width-cjs/node_modules/emoji-regex": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", + "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==", + "dev": true, + "license": "MIT" + }, + "node_modules/string-width-cjs/node_modules/strip-ansi": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", + "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", + "dev": true, + "license": "MIT", + "dependencies": { + "ansi-regex": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, "node_modules/stringify-entities": { "version": "4.0.4", "resolved": "https://registry.npmjs.org/stringify-entities/-/stringify-entities-4.0.4.tgz", @@ -8015,6 +8414,20 @@ "url": "https://github.com/chalk/strip-ansi?sponsor=1" } }, + "node_modules/strip-ansi-cjs": { + "name": "strip-ansi", + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", + "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", + "dev": true, + "license": "MIT", + "dependencies": { + "ansi-regex": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, "node_modules/strip-ansi/node_modules/ansi-regex": { "version": "6.1.0", "resolved": "https://registry.npmjs.org/ansi-regex/-/ansi-regex-6.1.0.tgz", @@ -8186,9 +8599,9 @@ } }, "node_modules/svelte-eslint-parser": { - "version": "1.3.0", - "resolved": "https://registry.npmjs.org/svelte-eslint-parser/-/svelte-eslint-parser-1.3.0.tgz", - "integrity": "sha512-VCgMHKV7UtOGcGLGNFSbmdm6kEKjtzo5nnpGU/mnx4OsFY6bZ7QwRF5DUx+Hokw5Lvdyo8dpk8B1m8mliomrNg==", + "version": "1.4.1", + "resolved": "https://registry.npmjs.org/svelte-eslint-parser/-/svelte-eslint-parser-1.4.1.tgz", + "integrity": "sha512-1eqkfQ93goAhjAXxZiu1SaKI9+0/sxp4JIWQwUpsz7ybehRE5L8dNuz7Iry7K22R47p5/+s9EM+38nHV2OlgXA==", "dev": true, "license": "MIT", "dependencies": { @@ -8200,7 +8613,8 @@ "postcss-selector-parser": "^7.0.0" }, "engines": { - "node": "^18.18.0 || ^20.9.0 || >=21.1.0" + "node": "^18.18.0 || ^20.9.0 || >=21.1.0", + "pnpm": "10.24.0" }, "funding": { "url": "https://github.com/sponsors/ota-meshi" @@ -8215,9 +8629,9 @@ } }, "node_modules/svelte-eslint-parser/node_modules/postcss-selector-parser": { - "version": "7.1.0", - "resolved": "https://registry.npmjs.org/postcss-selector-parser/-/postcss-selector-parser-7.1.0.tgz", - "integrity": "sha512-8sLjZwK0R+JlxlYcTuVnyT2v+htpdrjDOKuMcOVdYjt52Lh8hWRYpxBPoKx/Zg+bcjc3wx6fmQevMmUztS/ccA==", + "version": "7.1.1", + "resolved": "https://registry.npmjs.org/postcss-selector-parser/-/postcss-selector-parser-7.1.1.tgz", + "integrity": "sha512-orRsuYpJVw8LdAwqqLykBj9ecS5/cRHlI5+nvTo8LcCKmzDmqVORXtOIYEEQuL9D4BxtA1lm5isAqzQZCoQ6Eg==", "dev": true, "license": "MIT", "dependencies": { @@ -8319,9 +8733,9 @@ } }, "node_modules/svelte2tsx": { - "version": "0.7.45", - "resolved": "https://registry.npmjs.org/svelte2tsx/-/svelte2tsx-0.7.45.tgz", - "integrity": "sha512-cSci+mYGygYBHIZLHlm/jYlEc1acjAHqaQaDFHdEBpUueM9kSTnPpvPtSl5VkJOU1qSJ7h1K+6F/LIUYiqC8VA==", + "version": "0.7.47", + "resolved": "https://registry.npmjs.org/svelte2tsx/-/svelte2tsx-0.7.47.tgz", + "integrity": "sha512-1aw/MFKVPM96OBevJdC12do2an9t5Zwr3Va9amLgTLpJje36ibD1iIHpuqCYWUrdR9vw6g6btKGQPmsqE8ZYCw==", "dev": true, "license": "MIT", "dependencies": { @@ -8391,23 +8805,63 @@ } }, "node_modules/tar": { - "version": "7.4.3", - "resolved": "https://registry.npmjs.org/tar/-/tar-7.4.3.tgz", - "integrity": "sha512-5S7Va8hKfV7W5U6g3aYxXmlPoZVAwUMy9AOKyF2fVuZa2UD3qZjg578OrLRt8PcNN1PleVaL/5/yYATNL0ICUw==", + "version": "7.5.9", + "resolved": "https://registry.npmjs.org/tar/-/tar-7.5.9.tgz", + "integrity": "sha512-BTLcK0xsDh2+PUe9F6c2TlRp4zOOBMTkoQHQIWSIzI0R7KG46uEwq4OPk2W7bZcprBMsuaeFsqwYr7pjh6CuHg==", "dev": true, - "license": "ISC", + "license": "BlueOak-1.0.0", "dependencies": { "@isaacs/fs-minipass": "^4.0.0", "chownr": "^3.0.0", "minipass": "^7.1.2", - "minizlib": "^3.0.1", - "mkdirp": "^3.0.1", + "minizlib": "^3.1.0", "yallist": "^5.0.0" }, "engines": { "node": ">=18" } }, + "node_modules/test-exclude": { + "version": "7.0.1", + "resolved": "https://registry.npmjs.org/test-exclude/-/test-exclude-7.0.1.tgz", + "integrity": "sha512-pFYqmTw68LXVjeWJMST4+borgQP2AyMNbg1BpZh9LbyhUeNkeaPF9gzfPGUAnSMV3qPYdWUwDIjjCLiSDOl7vg==", + "dev": true, + "license": "ISC", + "dependencies": { + "@istanbuljs/schema": "^0.1.2", + "glob": "^10.4.1", + "minimatch": "^9.0.4" + }, + "engines": { + "node": ">=18" + } + }, + "node_modules/test-exclude/node_modules/brace-expansion": { + "version": "2.0.2", + "resolved": "https://registry.npmjs.org/brace-expansion/-/brace-expansion-2.0.2.tgz", + "integrity": "sha512-Jt0vHyM+jmUBqojB7E1NIYadt0vI0Qxjxd2TErW94wDz+E2LAm5vKMXXwg6ZZBTHPuUlDgQHKXvjGBdfcF1ZDQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "balanced-match": "^1.0.0" + } + }, + "node_modules/test-exclude/node_modules/minimatch": { + "version": "9.0.5", + "resolved": "https://registry.npmjs.org/minimatch/-/minimatch-9.0.5.tgz", + "integrity": "sha512-G6T0ZX48xgozx7587koeX9Ys2NYy6Gmv//P89sEte9V9whIapMNF4idKxnW2QtCcLiTWlb/wfCabAtAFWhhBow==", + "dev": true, + "license": "ISC", + "dependencies": { + "brace-expansion": "^2.0.1" + }, + "engines": { + "node": ">=16 || 14 >=14.17" + }, + "funding": { + "url": "https://github.com/sponsors/isaacs" + } + }, "node_modules/tiny-invariant": { "version": "1.3.3", "resolved": "https://registry.npmjs.org/tiny-invariant/-/tiny-invariant-1.3.3.tgz", @@ -8482,6 +8936,7 @@ "integrity": "sha512-65P7iz6X5yEr1cwcgvQxbbIw7Uk3gOy5dIdtZ4rDveLqhrdJP+Li/Hx6tyK0NEb+2GCyneCMJiGqrADCSNk8sQ==", "dev": true, "license": "MIT", + "optional": true, "dependencies": { "is-number": "^7.0.0" }, @@ -8520,9 +8975,9 @@ } }, "node_modules/ts-api-utils": { - "version": "2.1.0", - "resolved": "https://registry.npmjs.org/ts-api-utils/-/ts-api-utils-2.1.0.tgz", - "integrity": "sha512-CUgTZL1irw8u29bzrOD/nH85jqyc74D6SshFgujOIA7osm2Rz7dYH77agkx7H4FBNxDq7Cjf+IjaX/8zwFW+ZQ==", + "version": "2.4.0", + "resolved": "https://registry.npmjs.org/ts-api-utils/-/ts-api-utils-2.4.0.tgz", + "integrity": "sha512-3TaVTaAv2gTiMB35i3FiGJaRfwb3Pyn/j3m/bfAvGe8FB7CF6u+LMYqYlDh7reQf7UNvoTvdfAqHGmPGOSsPmA==", "dev": true, "license": "MIT", "engines": { @@ -8601,16 +9056,16 @@ } }, "node_modules/typescript-eslint": { - "version": "8.37.0", - "resolved": "https://registry.npmjs.org/typescript-eslint/-/typescript-eslint-8.37.0.tgz", - "integrity": "sha512-TnbEjzkE9EmcO0Q2zM+GE8NQLItNAJpMmED1BdgoBMYNdqMhzlbqfdSwiRlAzEK2pA9UzVW0gzaaIzXWg2BjfA==", + "version": "8.56.0", + "resolved": "https://registry.npmjs.org/typescript-eslint/-/typescript-eslint-8.56.0.tgz", + "integrity": "sha512-c7toRLrotJ9oixgdW7liukZpsnq5CZ7PuKztubGYlNppuTqhIoWfhgHo/7EU0v06gS2l/x0i2NEFK1qMIf0rIg==", "dev": true, "license": "MIT", "dependencies": { - "@typescript-eslint/eslint-plugin": "8.37.0", - "@typescript-eslint/parser": "8.37.0", - "@typescript-eslint/typescript-estree": "8.37.0", - "@typescript-eslint/utils": "8.37.0" + "@typescript-eslint/eslint-plugin": "8.56.0", + "@typescript-eslint/parser": "8.56.0", + "@typescript-eslint/typescript-estree": "8.56.0", + "@typescript-eslint/utils": "8.56.0" }, "engines": { "node": "^18.18.0 || ^20.9.0 || >=21.1.0" @@ -8620,14 +9075,14 @@ "url": "https://opencollective.com/typescript-eslint" }, "peerDependencies": { - "eslint": "^8.57.0 || ^9.0.0", - "typescript": ">=4.8.4 <5.9.0" + "eslint": "^8.57.0 || ^9.0.0 || ^10.0.0", + "typescript": ">=4.8.4 <6.0.0" } }, "node_modules/undici-types": { - "version": "6.21.0", - "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-6.21.0.tgz", - "integrity": "sha512-iwDZqg0QAGrg9Rav5H4n0M64c3mkR59cJ6wQp+7C4nI0gsmExaedaYLNO44eT4AtBBwjbTiGPMlt2Md0T9H9JQ==", + "version": "7.16.0", + "resolved": "https://registry.npmjs.org/undici-types/-/undici-types-7.16.0.tgz", + "integrity": "sha512-Zz+aZWSj8LE6zoxD+xrjh4VfkIG8Ya6LvYkZqtUQGJPZjYl53ypCaUwWqo7eI0x66KBGeRo+mlBEkMSeSZ38Nw==", "dev": true, "license": "MIT" }, @@ -8814,9 +9269,9 @@ } }, "node_modules/unplugin": { - "version": "2.3.10", - "resolved": "https://registry.npmjs.org/unplugin/-/unplugin-2.3.10.tgz", - "integrity": "sha512-6NCPkv1ClwH+/BGE9QeoTIl09nuiAt0gS28nn1PvYXsGKRwM2TCbFA2QiilmehPDTXIe684k4rZI1yl3A1PCUw==", + "version": "2.3.11", + "resolved": "https://registry.npmjs.org/unplugin/-/unplugin-2.3.11.tgz", + "integrity": "sha512-5uKD0nqiYVzlmCRs01Fhs2BdkEgBS3SAVP6ndrBsuK42iC2+JHyxM05Rm9G8+5mkmRtzMZGY8Ct5+mliZxU/Ww==", "dev": true, "license": "MIT", "dependencies": { @@ -8846,6 +9301,16 @@ "dev": true, "license": "MIT" }, + "node_modules/use-sync-external-store": { + "version": "1.6.0", + "resolved": "https://registry.npmjs.org/use-sync-external-store/-/use-sync-external-store-1.6.0.tgz", + "integrity": "sha512-Pp6GSwGP/NrPIrxVFAIkOQeyw8lFenOHijQWkUTrDvrF4ALqylP2C/KCkeS9dpUM3KvYRQhna5vt7IL95+ZQ9w==", + "dev": true, + "license": "MIT", + "peerDependencies": { + "react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0" + } + }, "node_modules/util-deprecate": { "version": "1.0.2", "resolved": "https://registry.npmjs.org/util-deprecate/-/util-deprecate-1.0.2.tgz", @@ -9278,6 +9743,91 @@ "node": ">=0.10.0" } }, + "node_modules/wrap-ansi": { + "version": "8.1.0", + "resolved": "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-8.1.0.tgz", + "integrity": "sha512-si7QWI6zUMq56bESFvagtmzMdGOtoxfR+Sez11Mobfc7tm+VkUckk9bW2UeffTGVUbOksxmSw0AA2gs8g71NCQ==", + "dev": true, + "license": "MIT", + "dependencies": { + "ansi-styles": "^6.1.0", + "string-width": "^5.0.1", + "strip-ansi": "^7.0.1" + }, + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/chalk/wrap-ansi?sponsor=1" + } + }, + "node_modules/wrap-ansi-cjs": { + "name": "wrap-ansi", + "version": "7.0.0", + "resolved": "https://registry.npmjs.org/wrap-ansi/-/wrap-ansi-7.0.0.tgz", + "integrity": "sha512-YVGIj2kamLSTxw6NsZjoBxfSwsn0ycdesmc4p+Q21c5zPuZ1pl+NfxVdxPtdHvmNVOQ6XSYG4AUtyt/Fi7D16Q==", + "dev": true, + "license": "MIT", + "dependencies": { + "ansi-styles": "^4.0.0", + "string-width": "^4.1.0", + "strip-ansi": "^6.0.0" + }, + "engines": { + "node": ">=10" + }, + "funding": { + "url": "https://github.com/chalk/wrap-ansi?sponsor=1" + } + }, + "node_modules/wrap-ansi-cjs/node_modules/emoji-regex": { + "version": "8.0.0", + "resolved": "https://registry.npmjs.org/emoji-regex/-/emoji-regex-8.0.0.tgz", + "integrity": "sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==", + "dev": true, + "license": "MIT" + }, + "node_modules/wrap-ansi-cjs/node_modules/string-width": { + "version": "4.2.3", + "resolved": "https://registry.npmjs.org/string-width/-/string-width-4.2.3.tgz", + "integrity": "sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==", + "dev": true, + "license": "MIT", + "dependencies": { + "emoji-regex": "^8.0.0", + "is-fullwidth-code-point": "^3.0.0", + "strip-ansi": "^6.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/wrap-ansi-cjs/node_modules/strip-ansi": { + "version": "6.0.1", + "resolved": "https://registry.npmjs.org/strip-ansi/-/strip-ansi-6.0.1.tgz", + "integrity": "sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==", + "dev": true, + "license": "MIT", + "dependencies": { + "ansi-regex": "^5.0.1" + }, + "engines": { + "node": ">=8" + } + }, + "node_modules/wrap-ansi/node_modules/ansi-styles": { + "version": "6.2.3", + "resolved": "https://registry.npmjs.org/ansi-styles/-/ansi-styles-6.2.3.tgz", + "integrity": "sha512-4Dj6M28JB+oAH8kFkTLUo+a2jwOFkuqb3yucU0CANcRRUbxS0cP0nZYCGjcc3BNXwRIsUVmDGgzawme7zvJHvg==", + "dev": true, + "license": "MIT", + "engines": { + "node": ">=12" + }, + "funding": { + "url": "https://github.com/chalk/ansi-styles?sponsor=1" + } + }, "node_modules/ws": { "version": "8.18.3", "resolved": "https://registry.npmjs.org/ws/-/ws-8.18.3.tgz", @@ -9300,6 +9850,22 @@ } } }, + "node_modules/wsl-utils": { + "version": "0.1.0", + "resolved": "https://registry.npmjs.org/wsl-utils/-/wsl-utils-0.1.0.tgz", + "integrity": "sha512-h3Fbisa2nKGPxCpm89Hk33lBLsnaGBvctQopaBSOW/uIs6FTe1ATyAnKFJrzVs9vpGdsTe73WF3V4lIsk4Gacw==", + "dev": true, + "license": "MIT", + "dependencies": { + "is-wsl": "^3.1.0" + }, + "engines": { + "node": ">=18" + }, + "funding": { + "url": "https://github.com/sponsors/sindresorhus" + } + }, "node_modules/yallist": { "version": "5.0.0", "resolved": "https://registry.npmjs.org/yallist/-/yallist-5.0.0.tgz", diff --git a/tools/server/webui/package.json b/tools/server/webui/package.json index a361ce76e3..0b74e301b1 100644 --- a/tools/server/webui/package.json +++ b/tools/server/webui/package.json @@ -23,31 +23,32 @@ "cleanup": "rm -rf .svelte-kit build node_modules test-results" }, "devDependencies": { - "@chromatic-com/storybook": "^4.1.2", + "@chromatic-com/storybook": "^5.0.0", "@eslint/compat": "^1.2.5", "@eslint/js": "^9.18.0", "@internationalized/date": "^3.10.1", "@lucide/svelte": "^0.515.0", "@playwright/test": "^1.49.1", - "@storybook/addon-a11y": "^10.0.7", - "@storybook/addon-docs": "^10.0.7", + "@storybook/addon-a11y": "^10.2.4", + "@storybook/addon-docs": "^10.2.4", "@storybook/addon-svelte-csf": "^5.0.10", - "@storybook/addon-vitest": "^10.0.7", - "@storybook/sveltekit": "^10.0.7", + "@storybook/addon-vitest": "^10.2.4", + "@storybook/sveltekit": "^10.2.4", "@sveltejs/adapter-static": "^3.0.10", "@sveltejs/kit": "^2.48.4", "@sveltejs/vite-plugin-svelte": "^6.2.1", "@tailwindcss/forms": "^0.5.9", "@tailwindcss/typography": "^0.5.15", "@tailwindcss/vite": "^4.0.0", - "@types/node": "^22", + "@types/node": "^24", "@vitest/browser": "^3.2.3", + "@vitest/coverage-v8": "^3.2.3", "bits-ui": "^2.14.4", "clsx": "^2.1.1", "dexie": "^4.0.11", "eslint": "^9.18.0", "eslint-config-prettier": "^10.0.1", - "eslint-plugin-storybook": "^10.0.7", + "eslint-plugin-storybook": "^10.2.4", "eslint-plugin-svelte": "^3.0.0", "fflate": "^0.8.2", "globals": "^16.0.0", @@ -61,7 +62,7 @@ "rehype-katex": "^7.0.1", "remark-math": "^6.0.0", "sass": "^1.93.3", - "storybook": "^10.0.7", + "storybook": "^10.2.4", "svelte": "^5.38.2", "svelte-check": "^4.0.0", "tailwind-merge": "^3.3.1", diff --git a/tools/server/webui/src/app.css b/tools/server/webui/src/app.css index 9705040a4d..3ab21f0cc7 100644 --- a/tools/server/webui/src/app.css +++ b/tools/server/webui/src/app.css @@ -14,11 +14,11 @@ --popover-foreground: oklch(0.145 0 0); --primary: oklch(0.205 0 0); --primary-foreground: oklch(0.985 0 0); - --secondary: oklch(0.97 0 0); + --secondary: oklch(0.95 0 0); --secondary-foreground: oklch(0.205 0 0); --muted: oklch(0.97 0 0); --muted-foreground: oklch(0.556 0 0); - --accent: oklch(0.97 0 0); + --accent: oklch(0.95 0 0); --accent-foreground: oklch(0.205 0 0); --destructive: oklch(0.577 0.245 27.325); --border: oklch(0.875 0 0); @@ -37,7 +37,7 @@ --sidebar-accent-foreground: oklch(0.205 0 0); --sidebar-border: oklch(0.922 0 0); --sidebar-ring: oklch(0.708 0 0); - --code-background: oklch(0.975 0 0); + --code-background: oklch(0.985 0 0); --code-foreground: oklch(0.145 0 0); --layer-popover: 1000000; } @@ -51,7 +51,7 @@ --popover-foreground: oklch(0.985 0 0); --primary: oklch(0.922 0 0); --primary-foreground: oklch(0.205 0 0); - --secondary: oklch(0.269 0 0); + --secondary: oklch(0.29 0 0); --secondary-foreground: oklch(0.985 0 0); --muted: oklch(0.269 0 0); --muted-foreground: oklch(0.708 0 0); @@ -116,12 +116,62 @@ --color-sidebar-ring: var(--sidebar-ring); } +:root { + --chat-form-area-height: 8rem; + --chat-form-area-offset: 2rem; + --max-message-height: max(24rem, min(80dvh, calc(100dvh - var(--chat-form-area-height) - 12rem))); +} + +@media (min-width: 640px) { + :root { + --chat-form-area-height: 24rem; + --chat-form-area-offset: 12rem; + } +} + @layer base { * { @apply border-border outline-ring/50; } + body { @apply bg-background text-foreground; + scrollbar-width: thin; + scrollbar-gutter: stable; + } + + /* Global scrollbar styling - visible only on hover */ + * { + scrollbar-width: thin; + scrollbar-color: transparent transparent; + transition: scrollbar-color 0.2s ease; + } + + *:hover { + scrollbar-color: hsl(var(--muted-foreground) / 0.3) transparent; + } + + *::-webkit-scrollbar { + width: 6px; + height: 6px; + } + + *::-webkit-scrollbar-track { + background: transparent; + } + + *::-webkit-scrollbar-thumb { + background: transparent; + border-radius: 3px; + transition: background 0.2s ease; + } + + *:hover::-webkit-scrollbar-thumb { + background: hsl(var(--muted-foreground) / 0.3); + } + + *::-webkit-scrollbar-thumb:hover { + background: hsl(var(--muted-foreground) / 0.5); } } diff --git a/tools/server/webui/src/lib/components/app/misc/ActionButton.svelte b/tools/server/webui/src/lib/components/app/actions/ActionIcon.svelte similarity index 99% rename from tools/server/webui/src/lib/components/app/misc/ActionButton.svelte rename to tools/server/webui/src/lib/components/app/actions/ActionIcon.svelte index 411a8b6094..4494ea880b 100644 --- a/tools/server/webui/src/lib/components/app/misc/ActionButton.svelte +++ b/tools/server/webui/src/lib/components/app/actions/ActionIcon.svelte @@ -37,6 +37,7 @@ aria-label={ariaLabel || tooltip} > {@const IconComponent = icon} + diff --git a/tools/server/webui/src/lib/components/app/misc/CopyToClipboardIcon.svelte b/tools/server/webui/src/lib/components/app/actions/ActionIconCopyToClipboard.svelte similarity index 100% rename from tools/server/webui/src/lib/components/app/misc/CopyToClipboardIcon.svelte rename to tools/server/webui/src/lib/components/app/actions/ActionIconCopyToClipboard.svelte diff --git a/tools/server/webui/src/lib/components/app/misc/RemoveButton.svelte b/tools/server/webui/src/lib/components/app/actions/ActionIconRemove.svelte similarity index 94% rename from tools/server/webui/src/lib/components/app/misc/RemoveButton.svelte rename to tools/server/webui/src/lib/components/app/actions/ActionIconRemove.svelte index 173685510f..1ae3d21774 100644 --- a/tools/server/webui/src/lib/components/app/misc/RemoveButton.svelte +++ b/tools/server/webui/src/lib/components/app/actions/ActionIconRemove.svelte @@ -16,7 +16,7 @@ variant="ghost" size="sm" class="h-6 w-6 bg-white/20 p-0 hover:bg-white/30 {className}" - onclick={(e) => { + onclick={(e: MouseEvent) => { e.stopPropagation(); onRemove?.(id); }} diff --git a/tools/server/webui/src/lib/components/app/actions/ActionIconsCodeBlock.svelte b/tools/server/webui/src/lib/components/app/actions/ActionIconsCodeBlock.svelte new file mode 100644 index 0000000000..b20e79b5e0 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/actions/ActionIconsCodeBlock.svelte @@ -0,0 +1,46 @@ + + +
+
+ +
+ + {#if showPreview} + + {/if} +
diff --git a/tools/server/webui/src/lib/components/app/actions/index.ts b/tools/server/webui/src/lib/components/app/actions/index.ts new file mode 100644 index 0000000000..43485c7b7e --- /dev/null +++ b/tools/server/webui/src/lib/components/app/actions/index.ts @@ -0,0 +1,19 @@ +/** + * + * ACTIONS + * + * Small interactive components for user actions. + * + */ + +/** Styled icon button for action triggers with tooltip. */ +export { default as ActionIcon } from './ActionIcon.svelte'; + +/** Code block actions component (copy, preview). */ +export { default as ActionIconsCodeBlock } from './ActionIconsCodeBlock.svelte'; + +/** Copy-to-clipboard icon button with click handler. */ +export { default as ActionIconCopyToClipboard } from './ActionIconCopyToClipboard.svelte'; + +/** Remove/delete icon button with X icon. */ +export { default as ActionIconRemove } from './ActionIconRemove.svelte'; diff --git a/tools/server/webui/src/lib/components/app/misc/BadgeChatStatistic.svelte b/tools/server/webui/src/lib/components/app/badges/BadgeChatStatistic.svelte similarity index 100% rename from tools/server/webui/src/lib/components/app/misc/BadgeChatStatistic.svelte rename to tools/server/webui/src/lib/components/app/badges/BadgeChatStatistic.svelte diff --git a/tools/server/webui/src/lib/components/app/misc/BadgeInfo.svelte b/tools/server/webui/src/lib/components/app/badges/BadgeInfo.svelte similarity index 100% rename from tools/server/webui/src/lib/components/app/misc/BadgeInfo.svelte rename to tools/server/webui/src/lib/components/app/badges/BadgeInfo.svelte diff --git a/tools/server/webui/src/lib/components/app/misc/BadgeModality.svelte b/tools/server/webui/src/lib/components/app/badges/BadgeModality.svelte similarity index 100% rename from tools/server/webui/src/lib/components/app/misc/BadgeModality.svelte rename to tools/server/webui/src/lib/components/app/badges/BadgeModality.svelte diff --git a/tools/server/webui/src/lib/components/app/badges/index.ts b/tools/server/webui/src/lib/components/app/badges/index.ts new file mode 100644 index 0000000000..860afe3084 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/badges/index.ts @@ -0,0 +1,16 @@ +/** + * + * BADGES & INDICATORS + * + * Small visual indicators for status and metadata. + * + */ + +/** Badge displaying chat statistics (tokens, timing). */ +export { default as BadgeChatStatistic } from './BadgeChatStatistic.svelte'; + +/** Generic info badge with optional tooltip and click handler. */ +export { default as BadgeInfo } from './BadgeInfo.svelte'; + +/** Badge indicating model modality (vision, audio, tools). */ +export { default as BadgeModality } from './BadgeModality.svelte'; diff --git a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentPreview.svelte b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentPreview.svelte index 0b0bf52ad9..f05bdd8a03 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentPreview.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentPreview.svelte @@ -8,7 +8,8 @@ isImageFile, isPdfFile, isAudioFile, - getLanguageFromFilename + getLanguageFromFilename, + createBase64DataUrl } from '$lib/utils'; import { convertPDFToImage } from '$lib/utils/browser-only'; import { modelsStore } from '$lib/stores/models.svelte'; @@ -255,7 +256,7 @@ diff --git a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentThumbnailFile.svelte b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentThumbnailFile.svelte index 908db5894b..9d32ea0721 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentThumbnailFile.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatAttachments/ChatAttachmentThumbnailFile.svelte @@ -1,5 +1,5 @@ @@ -107,67 +77,40 @@ {#if displayItems.length > 0}
{#if limitToSingleRow} -
- - -
- {#each displayItems as item (item.id)} - {#if item.isImage && item.preview} - openPreview(item, event)} - /> - {:else} - openPreview(item, event)} - /> - {/if} - {/each} -
- - -
+ (isScrollable = scrollable)} + > + {#each displayItems as item (item.id)} + {#if item.isImage && item.preview} + openPreview(item, event)} + /> + {:else} + openPreview(item, event)} + /> + {/if} + {/each} + {#if showViewAll}
diff --git a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatForm.svelte b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatForm.svelte index 27ab975cbd..3551b0b3d6 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatForm.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatForm.svelte @@ -1,20 +1,19 @@
{ + e.preventDefault(); + if (!canSubmit || disabled || isLoading || hasLoadingAttachments) return; + onSubmit?.(); + }} > - -
- - 0 || uploadedFiles.length > 0} - hasText={message.trim().length > 0} - {disabled} - {isLoading} - {isRecording} - {uploadedFiles} - onFileUpload={handleFileUpload} - onMicClick={handleMicClick} - onStop={handleStop} - /> +
+ { + onValueChange?.(value); + }} + {disabled} + {placeholder} + /> + + 0} + {disabled} + {isLoading} + {isRecording} + {uploadedFiles} + onFileUpload={handleFileUpload} + onMicClick={handleMicClick} + {onStop} + onSystemPromptClick={() => onSystemPromptClick?.({ message: value, files: uploadedFiles })} + /> +
- - diff --git a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionAttachmentsDropdown.svelte b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionAttachmentsDropdown.svelte new file mode 100644 index 0000000000..b1cff67dcb --- /dev/null +++ b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionAttachmentsDropdown.svelte @@ -0,0 +1,173 @@ + + +
+ + + + + + + + +

{fileUploadTooltipText}

+
+
+
+ + + {#if hasVisionModality} + onFileUpload?.()} + > + + + Images + + {:else} + + + + + + Images + + + + +

Images require vision models to be processed

+
+
+ {/if} + + {#if hasAudioModality} + onFileUpload?.()} + > + + + Audio Files + + {:else} + + + + + + Audio Files + + + + +

Audio files require audio models to be processed

+
+
+ {/if} + + onFileUpload?.()} + > + + + Text Files + + + {#if hasVisionModality} + onFileUpload?.()} + > + + + PDF Files + + {:else} + + + onFileUpload?.()} + > + + + PDF Files + + + + +

PDFs will be converted to text. Image-based PDFs may not work properly.

+
+
+ {/if} + + + + onSystemPromptClick?.()} + > + + + System Message + + + + +

{systemMessageTooltip}

+
+
+
+
+
diff --git a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionFileAttachments.svelte b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionFileAttachments.svelte deleted file mode 100644 index dd37268096..0000000000 --- a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionFileAttachments.svelte +++ /dev/null @@ -1,123 +0,0 @@ - - -
- - - - - - - - -

{fileUploadTooltipText}

-
-
-
- - - - - onFileUpload?.()} - > - - - Images - - - - {#if !hasVisionModality} - -

Images require vision models to be processed

-
- {/if} -
- - - - onFileUpload?.()} - > - - - Audio Files - - - - {#if !hasAudioModality} - -

Audio files require audio models to be processed

-
- {/if} -
- - onFileUpload?.()} - > - - - Text Files - - - - - onFileUpload?.()} - > - - - PDF Files - - - - {#if !hasVisionModality} - -

PDFs will be converted to text. Image-based PDFs may not work properly.

-
- {/if} -
-
-
-
diff --git a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActions.svelte b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActions.svelte index dde9bda2d8..54b11c8624 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActions.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActions.svelte @@ -2,7 +2,7 @@ import { Square } from '@lucide/svelte'; import { Button } from '$lib/components/ui/button'; import { - ChatFormActionFileAttachments, + ChatFormActionAttachmentsDropdown, ChatFormActionRecord, ChatFormActionSubmit, ModelsSelector @@ -13,8 +13,7 @@ import { modelsStore, modelOptions, selectedModelId } from '$lib/stores/models.svelte'; import { isRouterMode } from '$lib/stores/server.svelte'; import { chatStore } from '$lib/stores/chat.svelte'; - import { activeMessages, usedModalities } from '$lib/stores/conversations.svelte'; - import { useModelChangeValidation } from '$lib/hooks/use-model-change-validation.svelte'; + import { activeMessages } from '$lib/stores/conversations.svelte'; interface Props { canSend?: boolean; @@ -27,6 +26,7 @@ onFileUpload?: () => void; onMicClick?: () => void; onStop?: () => void; + onSystemPromptClick?: () => void; } let { @@ -39,7 +39,8 @@ uploadedFiles = [], onFileUpload, onMicClick, - onStop + onStop, + onSystemPromptClick }: Props = $props(); let currentConfig = $derived(config()); @@ -152,43 +153,41 @@ export function openModelSelector() { selectorModelRef?.open(); } - - const { handleModelChange } = useModelChangeValidation({ - getRequiredModalities: () => usedModalities(), - onValidationFailure: async (previousModelId) => { - if (previousModelId) { - await modelsStore.selectModelById(previousModelId); - } - } - });
- +
+ +
- +
+ +
{#if isLoading} {:else if shouldShowRecordButton} diff --git a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormTextarea.svelte b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormTextarea.svelte index 19b763f55e..f0855b9dbe 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormTextarea.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormTextarea.svelte @@ -5,6 +5,7 @@ interface Props { class?: string; disabled?: boolean; + onInput?: () => void; onKeydown?: (event: KeyboardEvent) => void; onPaste?: (event: ClipboardEvent) => void; placeholder?: string; @@ -14,6 +15,7 @@ let { class: className = '', disabled = false, + onInput, onKeydown, onPaste, placeholder = 'Ask anything...', @@ -52,7 +54,10 @@ class:cursor-not-allowed={disabled} {disabled} onkeydown={onKeydown} - oninput={(event) => autoResizeTextarea(event.currentTarget)} + oninput={(event) => { + autoResizeTextarea(event.currentTarget); + onInput?.(); + }} onpaste={onPaste} {placeholder} > diff --git a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessage.svelte b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessage.svelte index 220276fc9e..ebf7f433d1 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessage.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessage.svelte @@ -1,121 +1,131 @@ -{#if message.role === 'system'} +{#if message.role === MessageRole.SYSTEM} -{:else if message.role === 'user'} +{:else if message.role === MessageRole.USER} (shouldBranchAfterEdit = value)} {showDeleteDialog} {siblingInfo} - {thinkingContent} - {toolCallContent} /> {/if} diff --git a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageActions.svelte b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageActions.svelte index 3cb48157d8..97b34e92cc 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageActions.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageActions.svelte @@ -1,13 +1,15 @@ -
+
@@ -64,23 +72,33 @@
- + {#if onEdit} - + {/if} - {#if role === 'assistant' && onRegenerate} - onRegenerate()} /> + {#if role === MessageRole.ASSISTANT && onRegenerate} + onRegenerate()} /> {/if} - {#if role === 'assistant' && onContinue} - + {#if role === MessageRole.ASSISTANT && onContinue} + {/if} - +
+ + {#if showRawOutputSwitch} +
+ Show raw output + onRawOutputToggle?.(checked)} + /> +
+ {/if}
import { - ModelBadge, ChatMessageActions, ChatMessageStatistics, - ChatMessageThinkingBlock, - CopyToClipboardIcon, MarkdownContent, + ModelBadge, ModelsSelector } from '$lib/components/app'; + import ChatMessageThinkingBlock from './ChatMessageThinkingBlock.svelte'; + import { getMessageEditContext } from '$lib/contexts'; import { useProcessingState } from '$lib/hooks/use-processing-state.svelte'; - import { useModelChangeValidation } from '$lib/hooks/use-model-change-validation.svelte'; - import { isLoading } from '$lib/stores/chat.svelte'; - import { autoResizeTextarea, copyToClipboard } from '$lib/utils'; + import { isLoading, isChatStreaming } from '$lib/stores/chat.svelte'; + import { autoResizeTextarea, copyToClipboard, isIMEComposing } from '$lib/utils'; + import { tick } from 'svelte'; import { fade } from 'svelte/transition'; - import { Check, X, Wrench } from '@lucide/svelte'; + import { Check, X } from '@lucide/svelte'; import { Button } from '$lib/components/ui/button'; import { Checkbox } from '$lib/components/ui/checkbox'; - import { INPUT_CLASSES } from '$lib/constants/input-classes'; + import { INPUT_CLASSES } from '$lib/constants/css-classes'; + import { MessageRole, KeyboardKey } from '$lib/enums'; import Label from '$lib/components/ui/label/label.svelte'; import { config } from '$lib/stores/settings.svelte'; - import { conversationsStore } from '$lib/stores/conversations.svelte'; import { isRouterMode } from '$lib/stores/server.svelte'; + import { modelsStore } from '$lib/stores/models.svelte'; + import { ServerModelStatus } from '$lib/enums'; + import { REASONING_TAGS } from '$lib/constants/agentic'; interface Props { class?: string; @@ -30,150 +33,198 @@ assistantMessages: number; messageTypes: string[]; } | null; - editedContent?: string; - isEditing?: boolean; + isLastAssistantMessage?: boolean; message: DatabaseMessage; messageContent: string | undefined; - onCancelEdit?: () => void; onCopy: () => void; onConfirmDelete: () => void; onContinue?: () => void; onDelete: () => void; onEdit?: () => void; - onEditKeydown?: (event: KeyboardEvent) => void; - onEditedContentChange?: (content: string) => void; onNavigateToSibling?: (siblingId: string) => void; onRegenerate: (modelOverride?: string) => void; - onSaveEdit?: () => void; onShowDeleteDialogChange: (show: boolean) => void; - onShouldBranchAfterEditChange?: (value: boolean) => void; showDeleteDialog: boolean; - shouldBranchAfterEdit?: boolean; siblingInfo?: ChatMessageSiblingInfo | null; textareaElement?: HTMLTextAreaElement; - thinkingContent: string | null; - toolCallContent: ApiChatCompletionToolCall[] | string | null; + } + + interface ParsedReasoningContent { + content: string; + reasoningContent: string | null; + hasReasoningMarkers: boolean; + } + + function parseReasoningContent(content: string | undefined): ParsedReasoningContent { + if (!content) { + return { + content: '', + reasoningContent: null, + hasReasoningMarkers: false + }; + } + + const plainParts: string[] = []; + const reasoningParts: string[] = []; + const { START, END } = REASONING_TAGS; + let cursor = 0; + let hasReasoningMarkers = false; + + while (cursor < content.length) { + const startIndex = content.indexOf(START, cursor); + + if (startIndex === -1) { + plainParts.push(content.slice(cursor)); + break; + } + + hasReasoningMarkers = true; + plainParts.push(content.slice(cursor, startIndex)); + + const reasoningStart = startIndex + START.length; + const endIndex = content.indexOf(END, reasoningStart); + + if (endIndex === -1) { + reasoningParts.push(content.slice(reasoningStart)); + cursor = content.length; + break; + } + + reasoningParts.push(content.slice(reasoningStart, endIndex)); + cursor = endIndex + END.length; + } + + return { + content: plainParts.join(''), + reasoningContent: reasoningParts.length > 0 ? reasoningParts.join('\n\n') : null, + hasReasoningMarkers + }; } let { class: className = '', deletionInfo, - editedContent = '', - isEditing = false, + isLastAssistantMessage = false, message, messageContent, - onCancelEdit, onConfirmDelete, onContinue, onCopy, onDelete, onEdit, - onEditKeydown, - onEditedContentChange, onNavigateToSibling, onRegenerate, - onSaveEdit, onShowDeleteDialogChange, - onShouldBranchAfterEditChange, showDeleteDialog, - shouldBranchAfterEdit = false, siblingInfo = null, - textareaElement = $bindable(), - thinkingContent, - toolCallContent = null + textareaElement = $bindable() }: Props = $props(); - const toolCalls = $derived( - Array.isArray(toolCallContent) ? (toolCallContent as ApiChatCompletionToolCall[]) : null - ); - const fallbackToolCalls = $derived(typeof toolCallContent === 'string' ? toolCallContent : null); + // Get edit context + const editCtx = getMessageEditContext(); + // Local state for assistant-specific editing + let shouldBranchAfterEdit = $state(false); + + function handleEditKeydown(event: KeyboardEvent) { + if (event.key === KeyboardKey.ENTER && !event.shiftKey && !isIMEComposing(event)) { + event.preventDefault(); + editCtx.save(); + } else if (event.key === KeyboardKey.ESCAPE) { + event.preventDefault(); + editCtx.cancel(); + } + } + + const parsedMessageContent = $derived.by(() => parseReasoningContent(messageContent)); + const visibleMessageContent = $derived(parsedMessageContent.content); + const thinkingContent = $derived(parsedMessageContent.reasoningContent); + const hasReasoningMarkers = $derived(parsedMessageContent.hasReasoningMarkers); const processingState = useProcessingState(); let currentConfig = $derived(config()); let isRouter = $derived(isRouterMode()); - let displayedModel = $derived((): string | null => { - if (message.model) { - return message.model; + let showRawOutput = $state(false); + let statsContainerEl: HTMLDivElement | undefined = $state(); + + function getScrollParent(el: HTMLElement): HTMLElement | null { + let parent = el.parentElement; + while (parent) { + const style = getComputedStyle(parent); + if (/(auto|scroll)/.test(style.overflowY)) { + return parent; + } + parent = parent.parentElement; + } + return null; + } + + async function handleStatsViewChange() { + const el = statsContainerEl; + if (!el) { + return; } - return null; - }); + const scrollParent = getScrollParent(el); + if (!scrollParent) { + return; + } - const { handleModelChange } = useModelChangeValidation({ - getRequiredModalities: () => conversationsStore.getModalitiesUpToMessage(message.id), - onSuccess: (modelName) => onRegenerate(modelName) - }); + const yBefore = el.getBoundingClientRect().top; + + await tick(); + + const delta = el.getBoundingClientRect().top - yBefore; + if (delta !== 0) { + scrollParent.scrollTop += delta; + } + + // Correct any drift after browser paint + requestAnimationFrame(() => { + const drift = el.getBoundingClientRect().top - yBefore; + + if (Math.abs(drift) > 1) { + scrollParent.scrollTop += drift; + } + }); + } + + let displayedModel = $derived(message.model ?? null); + + let isCurrentlyLoading = $derived(isLoading()); + let isStreaming = $derived(isChatStreaming()); + let hasNoContent = $derived(!visibleMessageContent?.trim()); + let isActivelyProcessing = $derived(isCurrentlyLoading || isStreaming); + + let showProcessingInfoTop = $derived( + message?.role === MessageRole.ASSISTANT && + isActivelyProcessing && + hasNoContent && + isLastAssistantMessage + ); + + let showProcessingInfoBottom = $derived( + message?.role === MessageRole.ASSISTANT && + isActivelyProcessing && + !hasNoContent && + isLastAssistantMessage + ); function handleCopyModel() { - const model = displayedModel(); - - void copyToClipboard(model ?? ''); + void copyToClipboard(displayedModel ?? ''); } $effect(() => { - if (isEditing && textareaElement) { + if (editCtx.isEditing && textareaElement) { autoResizeTextarea(textareaElement); } }); $effect(() => { - if (isLoading() && !message?.content?.trim()) { + if (showProcessingInfoTop || showProcessingInfoBottom) { processingState.startMonitoring(); } }); - - function formatToolCallBadge(toolCall: ApiChatCompletionToolCall, index: number) { - const callNumber = index + 1; - const functionName = toolCall.function?.name?.trim(); - const label = functionName || `Call #${callNumber}`; - - const payload: Record = {}; - - const id = toolCall.id?.trim(); - if (id) { - payload.id = id; - } - - const type = toolCall.type?.trim(); - if (type) { - payload.type = type; - } - - if (toolCall.function) { - const fnPayload: Record = {}; - - const name = toolCall.function.name?.trim(); - if (name) { - fnPayload.name = name; - } - - const rawArguments = toolCall.function.arguments?.trim(); - if (rawArguments) { - try { - fnPayload.arguments = JSON.parse(rawArguments); - } catch { - fnPayload.arguments = rawArguments; - } - } - - if (Object.keys(fnPayload).length > 0) { - payload.function = fnPayload; - } - } - - const formattedPayload = JSON.stringify(payload, null, 2); - - return { - label, - tooltip: formattedPayload, - copyValue: formattedPayload - }; - } - - function handleCopyToolCall(payload: string) { - void copyToClipboard(payload, 'Tool call copied to clipboard'); - }
- {#if thinkingContent} + {#if !editCtx.isEditing && thinkingContent} {/if} - {#if message?.role === 'assistant' && isLoading() && !message?.content?.trim()} + {#if showProcessingInfoTop}
- {processingState.getPromptProgressText() ?? processingState.getProcessingMessage()} + {processingState.getPromptProgressText() ?? + processingState.getProcessingMessage() ?? + 'Processing...'}
{/if} - {#if isEditing} + {#if editCtx.isEditing}
@@ -218,30 +271,35 @@ onShouldBranchAfterEditChange?.(checked === true)} + onCheckedChange={(checked) => (shouldBranchAfterEdit = checked === true)} />
- -
- {:else if message.role === 'assistant'} - {#if config().disableReasoningFormat} + {:else if message.role === MessageRole.ASSISTANT} + {#if showRawOutput}
{messageContent || ''}
{:else} - + {/if} {:else}
@@ -249,18 +307,41 @@
{/if} + {#if showProcessingInfoBottom} +
+
+ + {processingState.getPromptProgressText() ?? + processingState.getProcessingMessage() ?? + 'Processing...'} + +
+
+ {/if} +
- {#if displayedModel()} -
+ {#if displayedModel} +
{#if isRouter} { + const status = modelsStore.getModelStatus(modelId); + + if (status !== ServerModelStatus.LOADED) { + await modelsStore.loadModel(modelId); + } + + onRegenerate(modelName); + return true; + }} /> {:else} - + {/if} {#if currentConfig.showMessageStats && message.timings && message.timings.predicted_n && message.timings.predicted_ms} @@ -269,6 +350,7 @@ promptMs={message.timings.prompt_ms} predictedTokens={message.timings.predicted_n} predictedMs={message.timings.predicted_ms} + onActiveViewChange={handleStatsViewChange} /> {:else if isLoading() && currentConfig.showMessageStats} {@const liveStats = processingState.getLiveProcessingStats()} @@ -290,53 +372,11 @@ {/if}
{/if} - - {#if config().showToolCalls} - {#if (toolCalls && toolCalls.length > 0) || fallbackToolCalls} - - - - - Tool calls: - - - {#if toolCalls && toolCalls.length > 0} - {#each toolCalls as toolCall, index (toolCall.id ?? `${index}`)} - {@const badge = formatToolCallBadge(toolCall, index)} - - {/each} - {:else if fallbackToolCalls} - - {/if} - - {/if} - {/if}
- {#if message.timestamp && !isEditing} + {#if message.timestamp && !editCtx.isEditing} (showRawOutput = enabled)} /> {/if}
@@ -402,17 +445,4 @@ white-space: pre-wrap; word-break: break-word; } - - .tool-call-badge { - max-width: 12rem; - white-space: nowrap; - overflow: hidden; - text-overflow: ellipsis; - } - - .tool-call-badge--fallback { - max-width: 20rem; - white-space: normal; - word-break: break-word; - } diff --git a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageEditForm.svelte b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageEditForm.svelte index f812ea2fd9..299bdc78fc 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageEditForm.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageEditForm.svelte @@ -1,79 +1,26 @@ - - -
- { - if (fileId.startsWith('attachment-')) { - const index = parseInt(fileId.replace('attachment-', ''), 10); - if (!isNaN(index) && index >= 0 && index < editedExtras.length) { - handleRemoveExistingAttachment(index); - } - } else { - handleRemoveUploadedFile(fileId); - } - }} - limitToSingleRow - class="py-5" - style="scroll-padding: 1rem;" +
+ - -
- - -
- - -
- - {#if isRouter} - - {/if} - - -
-
- {#if showSaveOnlyOption && onSaveEditOnly} + {#if editCtx.showSaveOnlyOption}
@@ -386,6 +131,6 @@ cancelText="Keep editing" variant="destructive" icon={AlertTriangle} - onConfirm={onCancelEdit} + onConfirm={editCtx.cancel} onCancel={() => (showDiscardDialog = false)} /> diff --git a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageStatistics.svelte b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageStatistics.svelte index 24fe5926ba..77951e9d2a 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageStatistics.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageStatistics.svelte @@ -3,18 +3,18 @@ import { BadgeChatStatistic } from '$lib/components/app'; import * as Tooltip from '$lib/components/ui/tooltip'; import { ChatMessageStatsView } from '$lib/enums'; + import { formatPerformanceTime } from '$lib/utils'; + import { MS_PER_SECOND, DEFAULT_PERFORMANCE_TIME } from '$lib/constants/formatters'; interface Props { predictedTokens?: number; predictedMs?: number; promptTokens?: number; promptMs?: number; - // Live mode: when true, shows stats during streaming isLive?: boolean; - // Whether prompt processing is still in progress isProcessingPrompt?: boolean; - // Initial view to show (defaults to READING in live mode) initialView?: ChatMessageStatsView; + onActiveViewChange?: (view: ChatMessageStatsView) => void; } let { @@ -24,12 +24,17 @@ promptMs, isLive = false, isProcessingPrompt = false, - initialView = ChatMessageStatsView.GENERATION + initialView = ChatMessageStatsView.GENERATION, + onActiveViewChange }: Props = $props(); - let activeView: ChatMessageStatsView = $state(initialView); + let activeView: ChatMessageStatsView = $derived(initialView); let hasAutoSwitchedToGeneration = $state(false); + $effect(() => { + onActiveViewChange?.(activeView); + }); + // In live mode: auto-switch to GENERATION tab when prompt processing completes $effect(() => { if (isLive) { @@ -56,26 +61,28 @@ predictedMs > 0 ); - let tokensPerSecond = $derived(hasGenerationStats ? (predictedTokens! / predictedMs!) * 1000 : 0); - let timeInSeconds = $derived( - predictedMs !== undefined ? (predictedMs / 1000).toFixed(2) : '0.00' + let tokensPerSecond = $derived( + hasGenerationStats ? (predictedTokens! / predictedMs!) * MS_PER_SECOND : 0 + ); + let formattedTime = $derived( + predictedMs !== undefined ? formatPerformanceTime(predictedMs) : DEFAULT_PERFORMANCE_TIME ); let promptTokensPerSecond = $derived( promptTokens !== undefined && promptMs !== undefined && promptMs > 0 - ? (promptTokens / promptMs) * 1000 + ? (promptTokens / promptMs) * MS_PER_SECOND : undefined ); - let promptTimeInSeconds = $derived( - promptMs !== undefined ? (promptMs / 1000).toFixed(2) : undefined + let formattedPromptTime = $derived( + promptMs !== undefined ? formatPerformanceTime(promptMs) : undefined ); let hasPromptStats = $derived( promptTokens !== undefined && promptMs !== undefined && promptTokensPerSecond !== undefined && - promptTimeInSeconds !== undefined + formattedPromptTime !== undefined ); // In live mode, generation tab is disabled until we have generation stats @@ -96,9 +103,11 @@ onclick={() => (activeView = ChatMessageStatsView.READING)} > + Reading +

Reading (prompt processing)

@@ -118,9 +127,11 @@ disabled={isGenerationDisabled} > + Generation +

{isGenerationDisabled @@ -139,16 +150,18 @@ value="{predictedTokens?.toLocaleString()} tokens" tooltipLabel="Generated tokens" /> + + {:else if hasPromptStats} @@ -158,12 +171,14 @@ value="{promptTokens} tokens" tooltipLabel="Prompt tokens" /> + + void; - onSaveEdit: () => void; - onEditKeydown: (event: KeyboardEvent) => void; - onEditedContentChange: (content: string) => void; onCopy: () => void; onEdit: () => void; onDelete: () => void; @@ -36,15 +33,9 @@ let { class: className = '', message, - isEditing, - editedContent, siblingInfo = null, showDeleteDialog, deletionInfo, - onCancelEdit, - onSaveEdit, - onEditKeydown, - onEditedContentChange, onCopy, onEdit, onDelete, @@ -54,10 +45,25 @@ textareaElement = $bindable() }: Props = $props(); + const editCtx = getMessageEditContext(); + + function handleEditKeydown(event: KeyboardEvent) { + if (event.key === KeyboardKey.ENTER && !event.shiftKey && !isIMEComposing(event)) { + event.preventDefault(); + + editCtx.save(); + } else if (event.key === KeyboardKey.ESCAPE) { + event.preventDefault(); + + editCtx.cancel(); + } + } + let isMultiline = $state(false); let messageElement: HTMLElement | undefined = $state(); let isExpanded = $state(false); let contentHeight = $state(0); + const MAX_HEIGHT = 200; // pixels const currentConfig = config(); @@ -97,26 +103,33 @@ class="group flex flex-col items-end gap-3 md:gap-2 {className}" role="group" > - {#if isEditing} + {#if editCtx.isEditing}

- -
@@ -131,12 +144,12 @@ type="button" >
{#if currentConfig.renderUserContentAsMarkdown}
- +
{:else}
+
@@ -208,7 +225,7 @@ {onShowDeleteDialogChange} {siblingInfo} {showDeleteDialog} - role="user" + role={MessageRole.USER} />
{/if} diff --git a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageUser.svelte b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageUser.svelte index 041c6bd251..05a02e2728 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageUser.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageUser.svelte @@ -1,67 +1,48 @@ -
- {#each displayMessages as { message, siblingInfo } (message.id)} +
+ {#each displayMessages as { message, isLastAssistantMessage, siblingInfo } (message.id)} {/each}
diff --git a/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte b/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte index 27439551a1..ceecf03e54 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte @@ -1,7 +1,7 @@ @@ -388,11 +363,8 @@ class="mb-16 md:mb-24" messages={activeMessages()} onUserAction={() => { - if (!disableAutoScroll) { - userScrolledUp = false; - autoScrollEnabled = true; - scrollChatToBottom(); - } + autoScroll.enable(); + autoScroll.scrollToBottom(); }} /> @@ -426,13 +398,15 @@ {/if}
- chatStore.stopGeneration()} + onSystemPromptAdd={handleSystemPromptAdd} showHelperText={false} bind:uploadedFiles /> @@ -454,7 +428,7 @@ >
-

llama.cpp

+

llama.cpp

{serverStore.props?.modalities?.audio @@ -484,13 +458,15 @@ {/if}

- chatStore.stopGeneration()} + onSystemPromptAdd={handleSystemPromptAdd} showHelperText={true} bind:uploadedFiles /> @@ -595,7 +571,7 @@ contextInfo={activeErrorDialog?.contextInfo} onOpenChange={handleErrorDialogOpenChange} open={Boolean(activeErrorDialog)} - type={activeErrorDialog?.type ?? 'server'} + type={activeErrorDialog?.type ?? ErrorDialogType.SERVER} /> diff --git a/tools/server/webui/src/lib/components/app/misc/SyntaxHighlightedCode.svelte b/tools/server/webui/src/lib/components/app/content/SyntaxHighlightedCode.svelte similarity index 90% rename from tools/server/webui/src/lib/components/app/misc/SyntaxHighlightedCode.svelte rename to tools/server/webui/src/lib/components/app/content/SyntaxHighlightedCode.svelte index bc42f9dd1e..625fdc7b1b 100644 --- a/tools/server/webui/src/lib/components/app/misc/SyntaxHighlightedCode.svelte +++ b/tools/server/webui/src/lib/components/app/content/SyntaxHighlightedCode.svelte @@ -71,13 +71,11 @@
-
{@html highlightedHtml}
+
{@html highlightedHtml}