diff --git a/.devops/rocm.Dockerfile b/.devops/rocm.Dockerfile index 14936f8e9c..830fe19e3e 100644 --- a/.devops/rocm.Dockerfile +++ b/.devops/rocm.Dockerfile @@ -1,8 +1,8 @@ ARG UBUNTU_VERSION=24.04 # This needs to generally match the container host's environment. -ARG ROCM_VERSION=7.0 -ARG AMDGPU_VERSION=7.0 +ARG ROCM_VERSION=7.2 +ARG AMDGPU_VERSION=7.2 # Target the ROCm build image ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-complete @@ -11,13 +11,12 @@ ARG BASE_ROCM_DEV_CONTAINER=rocm/dev-ubuntu-${UBUNTU_VERSION}:${ROCM_VERSION}-co FROM ${BASE_ROCM_DEV_CONTAINER} AS build # Unless otherwise specified, we make a fat build. -# List from https://github.com/ggml-org/llama.cpp/pull/1087#issuecomment-1682807878 # This is mostly tied to rocBLAS supported archs. -# gfx803, gfx900, gfx906, gfx1032, gfx1101, gfx1102,not officialy supported -# check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-6.4.1/reference/system-requirements.html +# check https://rocm.docs.amd.com/projects/install-on-linux/en/docs-7.2.0/reference/system-requirements.html +# check https://rocm.docs.amd.com/projects/radeon-ryzen/en/latest/docs/compatibility/compatibilityrad/native_linux/native_linux_compatibility.html +# check https://rocm.docs.amd.com/projects/radeon-ryzen/en/latest/docs/compatibility/compatibilityryz/native_linux/native_linux_compatibility.html -ARG ROCM_DOCKER_ARCH='gfx803;gfx900;gfx906;gfx908;gfx90a;gfx942;gfx1010;gfx1030;gfx1032;gfx1100;gfx1101;gfx1102;gfx1200;gfx1201;gfx1151' -#ARG ROCM_DOCKER_ARCH='gfx1151' +ARG ROCM_DOCKER_ARCH='gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1151;gfx1150;gfx1200;gfx1201' # Set ROCm architectures ENV AMDGPU_TARGETS=${ROCM_DOCKER_ARCH} diff --git a/.github/actions/windows-setup-rocm/action.yml b/.github/actions/windows-setup-rocm/action.yml index b83e6e295b..fd9f8e5a41 100644 --- a/.github/actions/windows-setup-rocm/action.yml +++ b/.github/actions/windows-setup-rocm/action.yml @@ -11,5 +11,5 @@ runs: - name: Setup ROCm uses: ./.github/actions/install-exe with: - url: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-${{ inputs.version }}-WinSvr2022-For-HIP.exe + url: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-${{ inputs.version }}-Win11-For-HIP.exe args: -install diff --git a/.github/workflows/build-cache.yml b/.github/workflows/build-cache.yml index 3de0be9fad..18a6515117 100644 --- a/.github/workflows/build-cache.yml +++ b/.github/workflows/build-cache.yml @@ -68,7 +68,7 @@ jobs: env: # Make sure this is in sync with build.yml - HIPSDK_INSTALLER_VERSION: "25.Q3" + HIPSDK_INSTALLER_VERSION: "26.Q1" steps: - name: Clone diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 6c7ab71143..30365a3613 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1175,10 +1175,8 @@ jobs: runs-on: windows-2022 env: - # The ROCm version must correspond to the version used in the HIP SDK. - ROCM_VERSION: "6.4.2" # Make sure this is in sync with build-cache.yml - HIPSDK_INSTALLER_VERSION: "25.Q3" + HIPSDK_INSTALLER_VERSION: "26.Q1" steps: - name: Clone @@ -1188,7 +1186,7 @@ jobs: - name: Grab rocWMMA package id: grab_rocwmma run: | - curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/${{ env.ROCM_VERSION }}/pool/main/r/rocwmma-dev/rocwmma-dev_1.7.0.60402-120~24.04_amd64.deb" + curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/7.2/pool/main/r/rocwmma-dev/rocwmma-dev_2.2.0.70200-43~24.04_amd64.deb" 7z x rocwmma.deb 7z x data.tar @@ -1231,7 +1229,7 @@ jobs: cmake -G "Unix Makefiles" -B build -S . ` -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" ` -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" ` - -DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-${{ env.ROCM_VERSION }}/include/" ` + -DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-7.2.0/include/" ` -DCMAKE_BUILD_TYPE=Release ` -DLLAMA_BUILD_BORINGSSL=ON ` -DROCM_DIR="${env:HIP_PATH}" ` diff --git a/.github/workflows/gguf-publish.yml b/.github/workflows/gguf-publish.yml index 0e95766459..5bdab0f157 100644 --- a/.github/workflows/gguf-publish.yml +++ b/.github/workflows/gguf-publish.yml @@ -21,7 +21,7 @@ on: jobs: deploy: - runs-on: ubuntu-slim + runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 1914c08489..1f79a83815 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -516,17 +516,113 @@ jobs: path: llama-bin-win-sycl-x64.zip name: llama-bin-win-sycl-x64.zip + ubuntu-22-rocm: + runs-on: ubuntu-22.04 + + strategy: + matrix: + include: + - ROCM_VERSION: "7.2" + gpu_targets: "gfx908;gfx90a;gfx942;gfx1030;gfx1100;gfx1101;gfx1151;gfx1150;gfx1200;gfx1201" + build: 'x64' + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v6 + with: + fetch-depth: 0 + + - name: ccache + uses: ggml-org/ccache-action@v1.2.16 + with: + key: ubuntu-rocm-cmake-${{ matrix.ROCM_VERSION }}-${{ matrix.build }} + evict-old-files: 1d + + - name: Dependencies + id: depends + run: | + sudo apt install -y build-essential git cmake wget + + - name: Setup Legacy ROCm + if: matrix.ROCM_VERSION == '7.2' + id: legacy_env + run: | + sudo mkdir --parents --mode=0755 /etc/apt/keyrings + wget https://repo.radeon.com/rocm/rocm.gpg.key -O - | \ + gpg --dearmor | sudo tee /etc/apt/keyrings/rocm.gpg > /dev/null + + sudo tee /etc/apt/sources.list.d/rocm.list << EOF + deb [arch=amd64 signed-by=/etc/apt/keyrings/rocm.gpg] https://repo.radeon.com/rocm/apt/${{ matrix.ROCM_VERSION }} jammy main + EOF + + sudo tee /etc/apt/preferences.d/rocm-pin-600 << EOF + Package: * + Pin: release o=repo.radeon.com + Pin-Priority: 600 + EOF + + sudo apt update + sudo apt-get install -y libssl-dev rocm-hip-sdk + + - name: Setup TheRock + if: matrix.ROCM_VERSION != '7.2' + id: therock_env + run: | + wget https://repo.amd.com/rocm/tarball/therock-dist-linux-gfx1151-${{ matrix.ROCM_VERSION }}.tar.gz + mkdir install + tar -xf *.tar.gz -C install + export ROCM_PATH=$(pwd)/install + echo ROCM_PATH=$ROCM_PATH >> $GITHUB_ENV + echo PATH=$PATH:$ROCM_PATH/bin >> $GITHUB_ENV + echo LD_LIBRARY_PATH=$ROCM_PATH/lib:$ROCM_PATH/llvm/lib:$ROCM_PATH/lib/rocprofiler-systems >> $GITHUB_ENV + + - name: Build with native CMake HIP support + id: cmake_build + run: | + cmake -B build -S . \ + -DCMAKE_HIP_COMPILER="$(hipconfig -l)/clang" \ + -DCMAKE_HIP_FLAGS="-mllvm --amdgpu-unroll-threshold-local=600" \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_BACKEND_DL=ON \ + -DGGML_NATIVE=OFF \ + -DCMAKE_INSTALL_RPATH='$ORIGIN' \ + -DCMAKE_BUILD_WITH_INSTALL_RPATH=ON \ + -DGGML_CPU_ALL_VARIANTS=ON \ + -DGPU_TARGETS="${{ matrix.gpu_targets }}" \ + -DGGML_HIP=ON \ + -DHIP_PLATFORM=amd \ + -DGGML_HIP_ROCWMMA_FATTN=ON \ + ${{ env.CMAKE_ARGS }} + cmake --build build --config Release -j $(nproc) + + - name: Determine tag name + id: tag + uses: ./.github/actions/get-tag-name + + - name: Pack artifacts + id: pack_artifacts + run: | + cp LICENSE ./build/bin/ + tar -czvf llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}.tar.gz --transform "s,./,llama-${{ steps.tag.outputs.name }}/," -C ./build/bin . + + - name: Upload artifacts + uses: actions/upload-artifact@v6 + with: + path: llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}.tar.gz + name: llama-bin-ubuntu-rocm-${{ matrix.ROCM_VERSION }}-${{ matrix.build }}.tar.gz + windows-hip: runs-on: windows-2022 env: - HIPSDK_INSTALLER_VERSION: "25.Q3" + HIPSDK_INSTALLER_VERSION: "26.Q1" strategy: matrix: include: - name: "radeon" - gpu_targets: "gfx1151;gfx1200;gfx1201;gfx1100;gfx1101;gfx1102;gfx1030;gfx1031;gfx1032" + gpu_targets: "gfx1150;gfx1151;gfx1200;gfx1201;gfx1100;gfx1101;gfx1102;gfx1030;gfx1031;gfx1032" steps: - name: Clone @@ -536,7 +632,7 @@ jobs: - name: Grab rocWMMA package id: grab_rocwmma run: | - curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/7.0.1/pool/main/r/rocwmma-dev/rocwmma-dev_2.0.0.70001-42~24.04_amd64.deb" + curl -o rocwmma.deb "https://repo.radeon.com/rocm/apt/7.2/pool/main/r/rocwmma-dev/rocwmma-dev_2.2.0.70200-43~24.04_amd64.deb" 7z x rocwmma.deb 7z x data.tar @@ -559,7 +655,7 @@ jobs: run: | $ErrorActionPreference = "Stop" write-host "Downloading AMD HIP SDK Installer" - Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-${{ env.HIPSDK_INSTALLER_VERSION }}-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe" + Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-${{ env.HIPSDK_INSTALLER_VERSION }}-Win11-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe" write-host "Installing AMD HIP SDK" $proc = Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -PassThru $completed = $proc.WaitForExit(600000) @@ -593,20 +689,20 @@ jobs: cmake -G "Unix Makefiles" -B build -S . ` -DCMAKE_C_COMPILER="${env:HIP_PATH}\bin\clang.exe" ` -DCMAKE_CXX_COMPILER="${env:HIP_PATH}\bin\clang++.exe" ` - -DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-7.0.1/include/ -Wno-ignored-attributes -Wno-nested-anon-types" ` + -DCMAKE_CXX_FLAGS="-I$($PWD.Path.Replace('\', '/'))/opt/rocm-7.2.0/include/ -Wno-ignored-attributes -Wno-nested-anon-types" ` -DCMAKE_BUILD_TYPE=Release ` -DGGML_BACKEND_DL=ON ` -DGGML_NATIVE=OFF ` -DGGML_CPU=OFF ` - -DAMDGPU_TARGETS="${{ matrix.gpu_targets }}" ` + -DGPU_TARGETS="${{ matrix.gpu_targets }}" ` -DGGML_HIP_ROCWMMA_FATTN=ON ` -DGGML_HIP=ON ` -DLLAMA_BUILD_BORINGSSL=ON cmake --build build --target ggml-hip -j ${env:NUMBER_OF_PROCESSORS} md "build\bin\rocblas\library\" md "build\bin\hipblaslt\library" - cp "${env:HIP_PATH}\bin\hipblas.dll" "build\bin\" - cp "${env:HIP_PATH}\bin\hipblaslt.dll" "build\bin\" + cp "${env:HIP_PATH}\bin\libhipblas.dll" "build\bin\" + cp "${env:HIP_PATH}\bin\libhipblaslt.dll" "build\bin\" cp "${env:HIP_PATH}\bin\rocblas.dll" "build\bin\" cp "${env:HIP_PATH}\bin\rocblas\library\*" "build\bin\rocblas\library\" cp "${env:HIP_PATH}\bin\hipblaslt\library\*" "build\bin\hipblaslt\library\" @@ -784,6 +880,7 @@ jobs: - windows-cuda - windows-sycl - windows-hip + - ubuntu-22-rocm - ubuntu-22-cpu - ubuntu-22-vulkan - macOS-arm64 @@ -868,6 +965,7 @@ jobs: **Linux:** - [Ubuntu x64 (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-x64.tar.gz) - [Ubuntu x64 (Vulkan)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-vulkan-x64.tar.gz) + - [Ubuntu x64 (ROCm 7.2)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-rocm-7.2-x64.tar.gz) - [Ubuntu s390x (CPU)](https://github.com/ggml-org/llama.cpp/releases/download/${{ steps.tag.outputs.name }}/llama-${{ steps.tag.outputs.name }}-bin-ubuntu-s390x.tar.gz) **Windows:** diff --git a/CMakeLists.txt b/CMakeLists.txt index 32542ecd27..69da97dc1e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories. +cmake_minimum_required(VERSION 3.14...3.28) # for add_link_options and implicit target directories. project("llama.cpp" C CXX) include(CheckIncludeFileCXX) diff --git a/common/arg.cpp b/common/arg.cpp index 18f953a38e..05f4a5244e 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1578,7 +1578,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_sparam()); add_opt(common_arg( - {"--temp"}, "N", + {"--temp", "--temperature"}, "N", string_format("temperature (default: %.2f)", (double)params.sampling.temp), [](common_params & params, const std::string & value) { params.sampling.temp = std::stof(value); @@ -1611,7 +1611,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_sparam()); add_opt(common_arg( - {"--top-nsigma"}, "N", + {"--top-nsigma", "--top-n-sigma"}, "N", string_format("top-n-sigma sampling (default: %.2f, -1.0 = disabled)", params.sampling.top_n_sigma), [](common_params & params, const std::string & value) { params.sampling.top_n_sigma = std::stof(value); @@ -1634,7 +1634,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } ).set_sparam()); add_opt(common_arg( - {"--typical"}, "N", + {"--typical", "--typical-p"}, "N", string_format("locally typical sampling, parameter p (default: %.2f, 1.0 = disabled)", (double)params.sampling.typ_p), [](common_params & params, const std::string & value) { params.sampling.typ_p = std::stof(value); @@ -2520,11 +2520,28 @@ common_params_context common_params_parser_init(common_params & params, llama_ex )); add_opt(common_arg( {"-a", "--alias"}, "STRING", - "set alias for model name (to be used by REST API)", + "set model name aliases, comma-separated (to be used by API)", [](common_params & params, const std::string & value) { - params.model_alias = value; + for (auto & alias : string_split(value, ',')) { + alias = string_strip(alias); + if (!alias.empty()) { + params.model_alias.insert(alias); + } + } } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ALIAS")); + add_opt(common_arg( + {"--tags"}, "STRING", + "set model tags, comma-separated (informational, not used for routing)", + [](common_params & params, const std::string & value) { + for (auto & tag : string_split(value, ',')) { + tag = string_strip(tag); + if (!tag.empty()) { + params.model_tags.insert(tag); + } + } + } + ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_TAGS")); add_opt(common_arg( {"-m", "--model"}, "FNAME", ex == LLAMA_EXAMPLE_EXPORT_LORA diff --git a/common/chat-parser-xml-toolcall.cpp b/common/chat-parser-xml-toolcall.cpp index a80900ff8d..ba359fdbf4 100644 --- a/common/chat-parser-xml-toolcall.cpp +++ b/common/chat-parser-xml-toolcall.cpp @@ -803,7 +803,7 @@ inline void parse_msg_with_xml_tool_calls(common_chat_msg_parser & builder, cons } // remove potential partial suffix - if (builder.pos() == builder.input().size()) { + if (builder.pos() == builder.input().size() && builder.is_partial()) { if (unclosed_reasoning_content.empty()) { rstrip(content); trim_potential_partial_word(content); diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp index 29819e48d3..060578f0b7 100644 --- a/common/chat-parser.cpp +++ b/common/chat-parser.cpp @@ -893,23 +893,6 @@ static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) { builder.consume_reasoning_with_xml_tool_calls(form, "", ""); } -static void common_chat_parse_qwen3_coder_xml(common_chat_msg_parser & builder) { - static const xml_tool_call_format form = ([]() { - xml_tool_call_format form {}; - form.scope_start = ""; - form.tool_start = "") != std::string::npos); + // Handle thinking tags appropriately based on inputs.enable_thinking - if (string_ends_with(data.prompt, "\n")) { + if (supports_reasoning && string_ends_with(data.prompt, "\n")) { if (!inputs.enable_thinking) { data.prompt += ""; } else { @@ -1538,19 +1540,21 @@ static common_chat_params common_chat_params_init_nemotron_v3(const common_chat_ } data.preserved_tokens = { - "", - "", "", "", }; + if (supports_reasoning) { + data.preserved_tokens.insert(data.preserved_tokens.end(), {"", ""}); + } + auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); auto extract_reasoning = inputs.reasoning_format != COMMON_REASONING_FORMAT_NONE; auto include_grammar = true; auto parser = build_chat_peg_constructed_parser([&](auto & p) { auto reasoning = p.eps(); - if (inputs.enable_thinking && extract_reasoning) { + if (supports_reasoning && inputs.enable_thinking && extract_reasoning) { auto reasoning_content = p.reasoning(p.until("")) + ("" | p.end()); if (data.thinking_forced_open) { reasoning = reasoning_content; @@ -1888,38 +1892,6 @@ static common_chat_params common_chat_params_init_minimax_m2(const common_chat_t return data; } -static common_chat_params common_chat_params_init_qwen3_coder_xml(const common_chat_template & tmpl, const struct templates_params & params) { - common_chat_params data; - data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; - - data.prompt = apply(tmpl, params); - data.format = COMMON_CHAT_FORMAT_QWEN3_CODER_XML; - - data.preserved_tokens = { - "", - "", - "", - "", - }; - - // build grammar for tool call - static const xml_tool_call_format form { - /* form.scope_start = */ "\n", - /* form.tool_start = */ "\n", - /* form.key_start = */ "\n", - /* form.val_end = */ "\n\n", - /* form.tool_end = */ "\n", - /* form.scope_end = */ "", - }; - build_grammar_xml_tool_call(data, params.tools, form); - - return data; -} - static common_chat_params common_chat_params_init_kimi_k2(const common_chat_template & tmpl, const struct templates_params & params) { common_chat_params data; data.grammar_lazy = params.tools.is_array() && !params.tools.empty() && params.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; @@ -3147,13 +3119,7 @@ static common_chat_params common_chat_templates_apply_jinja( src.find("") != std::string::npos) { - return common_chat_params_init_nemotron_v3(tmpl, params); - } - return common_chat_params_init_qwen3_coder_xml(tmpl, params); + return common_chat_params_init_qwen3_coder(tmpl, params); } // Xiaomi MiMo format detection (must come before Hermes 2 Pro) diff --git a/common/chat.h b/common/chat.h index 1bf43f7261..6f0b9409ec 100644 --- a/common/chat.h +++ b/common/chat.h @@ -128,7 +128,6 @@ enum common_chat_format { COMMON_CHAT_FORMAT_GLM_4_5, COMMON_CHAT_FORMAT_MINIMAX_M2, COMMON_CHAT_FORMAT_KIMI_K2, - COMMON_CHAT_FORMAT_QWEN3_CODER_XML, COMMON_CHAT_FORMAT_APRIEL_1_5, COMMON_CHAT_FORMAT_XIAOMI_MIMO, COMMON_CHAT_FORMAT_SOLAR_OPEN, diff --git a/common/common.cpp b/common/common.cpp index 75116ed6f3..53bddc4ef2 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1760,3 +1760,65 @@ float lr_opt::get_lr(float epoch) const { LOG_INF("epoch %.2g lr=%.2g\n", epoch, r); return r; } + +bool common_replay_last_token(struct llama_context * ctx, llama_token last_token, int32_t pos) { + llama_batch batch = llama_batch_get_one(&last_token, 1); + batch.pos = &pos; + if (llama_decode(ctx, batch)) { + LOG_ERR("%s: failed to replay last token\n", __func__); + return false; + } + return true; +} + +bool common_prompt_batch_decode( + struct llama_context * ctx, + const std::vector & tokens, + int & n_past, + int n_batch, + std::string_view state_path, + bool save_state) { + const int n_eval = tokens.size(); + if (n_eval == 0) { + return true; + } + + if (save_state && n_eval > 1) { + const int n_tokens_before_last = n_eval - 1; + + GGML_ASSERT(n_eval <= n_batch); + + // Decode all but the last token so we can save the memory state before decoding the last token. + // This is done so we can restore the session state later and replay the last token. + // Memory implementations in recurrent/hybrid models don't support removing tokens from their + // memory, so we can't just remove the last token from the memory and replay the last token which + // is the reason for this logic. + if (llama_decode(ctx, llama_batch_get_one(const_cast(tokens.data()), n_tokens_before_last))) { + LOG_ERR("%s : failed to eval\n", __func__); + return false; + } + n_past += n_tokens_before_last; + + llama_state_save_file(ctx, state_path.data(), tokens.data(), n_tokens_before_last); + LOG_INF("saved session before last token to %s, n_tokens = %d\n", state_path.data(), n_tokens_before_last); + + llama_token last_token = tokens.back(); + llama_batch batch = llama_batch_get_one(&last_token, 1); + int32_t pos = n_past; + batch.pos = &pos; + + if (llama_decode(ctx, batch)) { + LOG_ERR("%s : failed to eval last token\n", __func__); + return false; + } + n_past++; + } else { + if (llama_decode(ctx, llama_batch_get_one(const_cast(tokens.data()), n_eval))) { + LOG_ERR("%s : failed to eval\n", __func__); + return false; + } + n_past += n_eval; + } + + return true; +} diff --git a/common/common.h b/common/common.h index a4c431172d..c5a8037571 100644 --- a/common/common.h +++ b/common/common.h @@ -410,7 +410,8 @@ struct common_params { struct common_params_model model; - std::string model_alias = ""; // model alias // NOLINT + std::set model_alias; // model aliases // NOLINT + std::set model_tags; // model tags (informational, not used for routing) // NOLINT std::string hf_token = ""; // HF token // NOLINT std::string prompt = ""; // NOLINT std::string system_prompt = ""; // NOLINT @@ -804,6 +805,23 @@ void common_batch_add( const std::vector & seq_ids, bool logits); +// decodes a single batch of tokens for a prompt and manages session tokens +// +// Note: We save state before the last token so that we can replay it to ensure +// compatibility with all memory types. Recurrent/hybrid models cannot remove +// tokens from memory, so this approach works across all model architectures. +bool common_prompt_batch_decode( + struct llama_context * ctx, + const std::vector & embd, + int & n_past, + int n_batch, + std::string_view state_path, + bool save_state); + +// replays the last token after loading state to regenerate logits +// used after loading session state to ensure the sampling context has valid logits +bool common_replay_last_token(struct llama_context * ctx, llama_token last_token, int32_t pos); + // // Vocab utils // diff --git a/common/jinja/runtime.cpp b/common/jinja/runtime.cpp index cc012c892f..5757c76b7a 100644 --- a/common/jinja/runtime.cpp +++ b/common/jinja/runtime.cpp @@ -85,7 +85,7 @@ value identifier::execute_impl(context & ctx) { auto builtins = global_builtins(); if (!it->is_undefined()) { if (ctx.is_get_stats) { - it->stats.used = true; + value_t::stats_t::mark_used(it); } JJ_DEBUG("Identifier '%s' found, type = %s", val.c_str(), it->type().c_str()); return it; @@ -277,7 +277,7 @@ value binary_expression::execute_impl(context & ctx) { static value try_builtin_func(context & ctx, const std::string & name, value & input, bool undef_on_missing = false) { JJ_DEBUG("Trying built-in function '%s' for type %s", name.c_str(), input->type().c_str()); if (ctx.is_get_stats) { - input->stats.used = true; + value_t::stats_t::mark_used(input); input->stats.ops.insert(name); } auto builtins = input->get_builtins(); @@ -448,7 +448,7 @@ value for_statement::execute_impl(context & ctx) { // mark the variable being iterated as used for stats if (ctx.is_get_stats) { - iterable_val->stats.used = true; + value_t::stats_t::mark_used(iterable_val); iterable_val->stats.ops.insert("array_access"); } @@ -470,7 +470,7 @@ value for_statement::execute_impl(context & ctx) { items.push_back(std::move(tuple)); } if (ctx.is_get_stats) { - iterable_val->stats.used = true; + value_t::stats_t::mark_used(iterable_val); iterable_val->stats.ops.insert("object_access"); } } else { @@ -480,7 +480,7 @@ value for_statement::execute_impl(context & ctx) { items.push_back(item); } if (ctx.is_get_stats) { - iterable_val->stats.used = true; + value_t::stats_t::mark_used(iterable_val); iterable_val->stats.ops.insert("array_access"); } } @@ -721,6 +721,8 @@ value member_expression::execute_impl(context & ctx) { int64_t arr_size = 0; if (is_val(object)) { arr_size = object->as_array().size(); + } else if (is_val(object)) { + arr_size = object->as_string().length(); } if (is_stmt(this->property)) { @@ -817,8 +819,9 @@ value member_expression::execute_impl(context & ctx) { } if (ctx.is_get_stats && val && object && property) { - val->stats.used = true; - object->stats.used = true; + value_t::stats_t::mark_used(val); + value_t::stats_t::mark_used(object); + value_t::stats_t::mark_used(property); if (is_val(property)) { object->stats.ops.insert("array_access"); } else if (is_val(property)) { diff --git a/common/jinja/value.cpp b/common/jinja/value.cpp index 9987836d18..749113124b 100644 --- a/common/jinja/value.cpp +++ b/common/jinja/value.cpp @@ -161,6 +161,11 @@ static value tojson(const func_args & args) { value val_separators = args.get_kwarg_or_pos("separators", 3); value val_sort = args.get_kwarg_or_pos("sort_keys", 4); int indent = -1; + if (args.ctx.is_get_stats) { + // mark as used (recursively) for stats + auto val_input = args.get_pos(0); + value_t::stats_t::mark_used(const_cast(val_input), true); + } if (is_val(val_indent)) { indent = static_cast(val_indent->as_int()); } @@ -891,6 +896,11 @@ const func_builtins & value_array_t::get_builtins() const { }}, {"string", [](const func_args & args) -> value { args.ensure_vals(); + if (args.ctx.is_get_stats) { + // mark as used (recursively) for stats + auto val_input = args.get_pos(0); + value_t::stats_t::mark_used(const_cast(val_input), true); + } return mk_val(args.get_pos(0)->as_string()); }}, {"tojson", tojson}, @@ -1046,6 +1056,11 @@ const func_builtins & value_object_t::get_builtins() const { {"tojson", tojson}, {"string", [](const func_args & args) -> value { args.ensure_vals(); + if (args.ctx.is_get_stats) { + // mark as used (recursively) for stats + auto val_input = args.get_pos(0); + value_t::stats_t::mark_used(const_cast(val_input), true); + } return mk_val(args.get_pos(0)->as_string()); }}, {"length", [](const func_args & args) -> value { @@ -1358,4 +1373,21 @@ std::string value_to_string_repr(const value & val) { } } +// stats utility +void value_t::stats_t::mark_used(value & val, bool deep) { + val->stats.used = true; + if (deep) { + if (is_val(val)) { + for (auto & item : val->val_arr) { + mark_used(item, deep); + } + } else if (is_val(val)) { + for (auto & pair : val->val_obj) { + mark_used(pair.first, deep); + mark_used(pair.second, deep); + } + } + } +} + } // namespace jinja diff --git a/common/jinja/value.h b/common/jinja/value.h index 1c04760a08..07e447ff69 100644 --- a/common/jinja/value.h +++ b/common/jinja/value.h @@ -118,6 +118,8 @@ struct value_t { bool used = false; // ops can be builtin calls or operators: "array_access", "object_access" std::set ops; + // utility to recursively mark value and its children as used + static void mark_used(value & val, bool deep = false); } stats; value_t() = default; diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 31acd5bb48..0954417398 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -116,7 +116,8 @@ class ModelBase: split_max_tensors: int = 0, split_max_size: int = 0, dry_run: bool = False, small_first_shard: bool = False, hparams: dict[str, Any] | None = None, remote_hf_model_id: str | None = None, disable_mistral_community_chat_template: bool = False, - sentence_transformers_dense_modules: bool = False): + sentence_transformers_dense_modules: bool = False, + fuse_gate_up_exps: bool = False): if type(self) is ModelBase or \ type(self) is TextModel or \ type(self) is MmprojModel: @@ -135,6 +136,9 @@ class ModelBase: self.dry_run = dry_run self.remote_hf_model_id = remote_hf_model_id self.sentence_transformers_dense_modules = sentence_transformers_dense_modules + self.fuse_gate_up_exps = fuse_gate_up_exps + self._gate_exp_buffer: dict[int, Tensor] = {} + self._up_exp_buffer: dict[int, Tensor] = {} self.hparams = ModelBase.load_hparams(self.dir_model, self.is_mistral_format) if hparams is None else hparams self.model_tensors = self.index_tensors(remote_hf_model_id=remote_hf_model_id) self.metadata_override = metadata_override @@ -512,8 +516,31 @@ class ModelBase: raise NotImplementedError("set_gguf_parameters() must be implemented in subclasses") def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: - del bid # unused - return [(self.map_tensor_name(name), data_torch)] + new_name = self.map_tensor_name(name) + + # Handle gate/up expert tensor fusion if enabled + if self.fuse_gate_up_exps and bid is not None: + if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_GATE_EXP, bid): + self._gate_exp_buffer[bid] = data_torch + elif self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_UP_EXP, bid): + self._up_exp_buffer[bid] = data_torch + + # Check if both gate and up are buffered for this layer + if bid in self._gate_exp_buffer and bid in self._up_exp_buffer: + gate_data = self._gate_exp_buffer.pop(bid) + up_data = self._up_exp_buffer.pop(bid) + # gate/up shape: (n_expert, n_ff, n_embd), concatenate to (n_expert, n_ff*2, n_embd) + fused_data = torch.cat([gate_data, up_data], dim=1) + fused_name = self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_UP_EXP, bid) + logger.info(f"Fused gate_exps and up_exps for layer {bid}") + return [(fused_name, fused_data)] + + # If we buffered a gate/up tensor, wait for the other + if self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_GATE_EXP, bid) or \ + self.match_model_tensor_name(new_name, gguf.MODEL_TENSOR.FFN_UP_EXP, bid): + return [] + + return [(new_name, data_torch)] def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool: del name, new_name, bid, n_dims # unused @@ -1148,6 +1175,9 @@ class TextModel(ModelBase): if chkhsh == "27949a2493fc4a9f53f5b9b029c82689cfbe5d3a1929bb25e043089e28466de6": # ref: https://huggingface.co/jinaai/jina-embeddings-v2-base-de res = "jina-v2-de" + if chkhsh == "a023e9fdc5a11f034d3ef515b92350e56fb2af1f66c6b6811a4444ea9bf8763d": + # ref: https://huggingface.co/jinaai/jina-embeddings-v5-text-nano + res = "jina-v5-nano" if chkhsh == "c136ed14d01c2745d4f60a9596ae66800e2b61fa45643e72436041855ad4089d": # ref: https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct res = "smaug-bpe" @@ -1274,6 +1304,9 @@ class TextModel(ModelBase): if chkhsh == "b4b8ca1f9769494fbd956ebc4c249de6131fb277a4a3345a7a92c7dd7a55808d": # ref: https://huggingface.co/jdopensource/JoyAI-LLM-Flash res = "joyai-llm" + if chkhsh == "e4d54df1ebc1f2b91acd986c5b51aa50837d5faf7c7398e73c1f9e9ee5d19869": + # ref: https://huggingface.co/kakaocorp/kanana-2-30b-a3b-instruct-2601 + res = "kanana2" if res is None: logger.warning("\n") @@ -6122,6 +6155,32 @@ class NeoBert(BertModel): yield from super().modify_tensors(data_torch, name, bid) +@ModelBase.register("EuroBertModel", "JinaEmbeddingsV5Model") +class EuroBertModel(TextModel): + model_arch = gguf.MODEL_ARCH.EUROBERT + + def set_vocab(self): + self.gguf_writer.add_add_bos_token(False) + self._set_vocab_gpt2() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + # EuroBert is bidirectional (encoder) + self.gguf_writer.add_causal_attention(False) + + self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) + + self._try_set_pooling_type() + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # Strip "model." prefix from tensor names + if name.startswith("model."): + name = name[6:] + + yield from super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("XLMRobertaModel", "XLMRobertaForSequenceClassification") class XLMRobertaModel(BertModel): model_arch = gguf.MODEL_ARCH.BERT @@ -11910,6 +11969,11 @@ def parse_args() -> argparse.Namespace: "Default these modules are not included.") ) + parser.add_argument( + "--fuse-gate-up-exps", action="store_true", + help="Fuse gate_exps and up_exps tensors into a single gate_up_exps tensor for MoE models.", + ) + args = parser.parse_args() if not args.print_supported_models and args.model is None: parser.error("the following arguments are required: model") @@ -12047,7 +12111,8 @@ def main() -> None: split_max_size=split_str_to_n_bytes(args.split_max_size), dry_run=args.dry_run, small_first_shard=args.no_tensor_first_split, remote_hf_model_id=hf_repo_id, disable_mistral_community_chat_template=disable_mistral_community_chat_template, - sentence_transformers_dense_modules=args.sentence_transformers_dense_modules + sentence_transformers_dense_modules=args.sentence_transformers_dense_modules, + fuse_gate_up_exps=args.fuse_gate_up_exps ) if args.vocab_only: diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 8f7443d1b5..b31ddcca77 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -107,6 +107,7 @@ models = [ {"name": "jina-v2-en", "tokt": TOKENIZER_TYPE.WPM, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-en", }, # WPM! {"name": "jina-v2-es", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-es", }, {"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-de", }, + {"name": "jina-v5-nano", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v5-text-nano", }, {"name": "smaug-bpe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/abacusai/Smaug-Llama-3-70B-Instruct", }, {"name": "poro-chat", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LumiOpen/Poro-34B-chat", }, {"name": "jina-v2-code", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jinaai/jina-embeddings-v2-base-code", }, @@ -152,6 +153,7 @@ models = [ {"name": "exaone-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/K-EXAONE-236B-A23B", }, {"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-9B-Instruct", }, {"name": "joyai-llm", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/jdopensource/JoyAI-LLM-Flash", }, + {"name": "kanana2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/kakaocorp/kanana-2-30b-a3b-instruct-2601", }, ] # some models are known to be broken upstream, so we will skip them as exceptions diff --git a/docs/backend/VirtGPU.md b/docs/backend/VirtGPU.md index c81468da13..72f376dea7 100644 --- a/docs/backend/VirtGPU.md +++ b/docs/backend/VirtGPU.md @@ -152,7 +152,9 @@ Commands and data are serialized using a custom binary protocol with: - **VM-specific**: Only works in virtual machines with virtio-gpu support - **Host dependency**: Requires properly configured host-side backend - **Latency**: Small overhead from VM escaping for each operation - +- **Shared-memory size**: with the `libkrun` hypervisor, the RAM + VRAM + addressable memory is limited to 64 GB. So the maximum GPU memory + will be `64GB - RAM`, regardless of the hardware VRAM size. * This work is pending upstream changes in the VirglRenderer project. diff --git a/docs/backend/ZenDNN.md b/docs/backend/ZenDNN.md index b57fd97b69..3b1f8242ff 100644 --- a/docs/backend/ZenDNN.md +++ b/docs/backend/ZenDNN.md @@ -22,7 +22,7 @@ **Llama.cpp + ZenDNN** -The llama.cpp ZenDNN backend leverages AMD's optimized matrix multiplication primitives to accelerate inference on AMD CPUs. It utilizes ZenDNN's **LowOHA (Low Overhead Hardware Accelerated)** MatMul operator for efficient GEMM operations with minimal execution overhead, built-in weight caching, and direct access to backend libraries (AOCL BLIS, LibXSMM, OneDNN). +The llama.cpp ZenDNN backend leverages AMD's optimized matrix multiplication primitives to accelerate inference on AMD CPUs. It utilizes ZenDNN's **LowOHA (Low Overhead Hardware Accelerated)** MatMul operator for efficient GEMM operations with minimal execution overhead, built-in weight caching, and direct access to backend libraries (AOCL DLP, LibXSMM, OneDNN). For more information about ZenDNN, visit: https://www.amd.com/en/developer/zendnn.html @@ -32,7 +32,7 @@ For more information about ZenDNN, visit: https://www.amd.com/en/developer/zendn |:-------:|:-------:|:----------------------------------------------:| | Linux | Support | Ubuntu 20.04, 22.04, 24.04 | -For the latest list of supported operating systems, see the [ZenDNN Supported OS](https://github.com/amd/ZenDNN/blob/zendnnl/README.md#15-supported-os). +For the latest list of supported operating systems, see the [ZenDNN Supported OS](https://github.com/amd/ZenDNN/blob/a18adf8c605fb5f5e52cefd7eda08a7b18febbaf/README.md#15-supported-os). ## Hardware @@ -44,9 +44,9 @@ ZenDNN is optimized for AMD EPYC™ processors and AMD Ryzen™ processors based | CPU Family | Status | Notes | |:-----------------------------:|:-------:|:----------------------------------:| -| AMD EPYC™ 9005 Series (Turin)| Support | 5th Gen - Zen 5 architecture | -| AMD EPYC™ 9004 Series (Genoa)| Support | 4th Gen - Zen 4 architecture | -| AMD EPYC™ 7003 Series (Milan)| Support | 3rd Gen - Zen 3 architecture | +| AMD EPYC™ 9005 Series (Turin) | Support | 5th Gen - Zen 5 architecture | +| AMD EPYC™ 9004 Series (Genoa) | Support | 4th Gen - Zen 4 architecture | +| AMD EPYC™ 7003 Series (Milan) | Support | 3rd Gen - Zen 3 architecture | | AMD Ryzen™ AI MAX (Strix Halo)| Support | High-performance mobile processors | *Notes:* @@ -61,7 +61,7 @@ The ZenDNN backend currently accelerates **matrix multiplication (MUL_MAT)** ope | Operation | Status | Notes | |:-------------|:-------:|:----------------------------------------------:| -| MUL_MAT | ✓ | Accelerated via ZenDNN LowOHA MatMul | +| MUL_MAT | Support | Accelerated via ZenDNN LowOHA MatMul | *Note:* Since only MUL_MAT is accelerated, models will benefit most from ZenDNN when matrix multiplications dominate the computational workload (which is typical for transformer-based LLMs). @@ -104,7 +104,6 @@ If you want to build ZenDNN yourself or use a specific version: # Clone ZenDNN repository git clone https://github.com/amd/ZenDNN.git cd ZenDNN -git checkout zendnnl # Build and install (requires CMake >= 3.25) mkdir build && cd build @@ -114,7 +113,7 @@ cmake --build . --target all Default installation path: `ZenDNN/build/install` -**For detailed build instructions**, refer to the [ZenDNN README](https://github.com/amd/ZenDNN/blob/zendnnl/README.md). +**For detailed build instructions**, refer to the [ZenDNN README](https://github.com/amd/ZenDNN/blob/a18adf8c605fb5f5e52cefd7eda08a7b18febbaf/README.md). **Step 2: Build llama.cpp with custom ZenDNN path** @@ -146,8 +145,7 @@ Run llama.cpp server with ZenDNN acceleration: ```sh # Set optimal configuration -export OMP_NUM_THREADS=64 # Adjust to your CPU core count -export ZENDNNL_MATMUL_ALGO=2 # Blocked AOCL BLIS for best performance +export ZENDNNL_MATMUL_ALGO=1 # Blocked AOCL DLP algo for best performance # Start server ./build/bin/llama-server \ @@ -160,62 +158,26 @@ export ZENDNNL_MATMUL_ALGO=2 # Blocked AOCL BLIS for best performance Access the server at `http://localhost:8080`. **Performance tips**: -- Set `OMP_NUM_THREADS` to match your physical core count -- Use `ZENDNNL_MATMUL_ALGO=2` for optimal performance +- Use `ZENDNNL_MATMUL_ALGO=1` for optimal performance - For NUMA systems: `numactl --cpunodebind=0 --membind=0 ./build/bin/llama-server ...` ## Environment Variable -### Build Time +For environment variables related to ZenDNN, refer to the [ZenDNN Environment Variables Documentation](https://github.com/amd/ZenDNN/blob/a18adf8c605fb5f5e52cefd7eda08a7b18febbaf/docs/runtime_env.md). -| Name | Value | Function | -|--------------------|---------------------------------------|---------------------------------------------| -| GGML_ZENDNN | ON/OFF | Enable ZenDNN backend support | -| ZENDNN_ROOT | Path to ZenDNN installation | Set ZenDNN installation directory | -| GGML_OPENMP | ON/OFF (recommended: ON) | Enable OpenMP for multi-threading | +### Performance Optimization -### Runtime - -| Name | Value | Function | -|-------------------------|--------------------------|-------------------------------------------------------------------| -| OMP_NUM_THREADS | Number (e.g., 64) | Set number of OpenMP threads (recommended: physical core count) | -| ZENDNNL_MATMUL_ALGO | 0-5 | Select MatMul backend algorithm (see Performance Optimization) | -| ZENDNNL_PROFILE_LOG_LEVEL | 0-4 | Profiling log level (0=disabled, 4=verbose) | -| ZENDNNL_ENABLE_PROFILER | 0 or 1 | Enable detailed profiling (1=enabled) | -| ZENDNNL_API_LOG_LEVEL | 0-4 | API log level (0=disabled, 4=verbose) | - -**Example**: +ZenDNN's LowOHA MatMul supports multiple backend algorithms. For **best performance**, use the **Blocked AOCL DLP** algorithm: ```sh -export OMP_NUM_THREADS=64 -export ZENDNNL_MATMUL_ALGO=2 # Use Blocked AOCL BLIS for best performance -./build/bin/llama-cli -m models/llama-2-7b.Q4_0.gguf -p "Test" -n 100 +export ZENDNNL_MATMUL_ALGO=1 # Blocked AOCL DLP algo (recommended) ``` -## Performance Optimization - -### MatMul Algorithm Selection - -ZenDNN's LowOHA MatMul supports multiple backend algorithms. For **best performance**, use the **Blocked AOCL BLIS** algorithm: - -```sh -export ZENDNNL_MATMUL_ALGO=2 # Blocked AOCL BLIS (recommended) -``` - -**Available algorithms**: - -| Value | Algorithm | Description | -|:-----:|:-----------------------|:----------------------------------------------| -| 0 | Dynamic Dispatch | Automatic backend selection (default) | -| 1 | AOCL BLIS | AOCL BLIS backend | -| 2 | AOCL BLIS Blocked | **Blocked AOCL BLIS (recommended)** | -| 3 | OneDNN | OneDNN backend | -| 4 | OneDNN Blocked | Blocked OneDNN | -| 5 | LibXSMM | LibXSMM backend | +For more details on available algorithms, see the [ZenDNN MatMul Algorithm Documentation](https://github.com/amd/ZenDNN/blob/a18adf8c605fb5f5e52cefd7eda08a7b18febbaf/docs/runtime_env.md#algorithm-details). ### Profiling and Debugging -For detailed profiling and logging options, refer to the [ZenDNN Logging Documentation](https://github.com/amd/ZenDNN/blob/zendnnl/docs/logging.md). +For detailed profiling and logging options, refer to the [ZenDNN Logging Documentation](https://github.com/amd/ZenDNN/blob/a18adf8c605fb5f5e52cefd7eda08a7b18febbaf/docs/logging.md). ## Known Issues @@ -245,10 +207,9 @@ A: Currently, ZenDNN primarily supports FP32 and BF16 data types. Quantized mode A: Ensure: 1. You're using an AMD EPYC or Ryzen processor (Zen 2 or newer) -2. `OMP_NUM_THREADS` is set appropriately (physical core count) -3. `ZENDNNL_MATMUL_ALGO=2` is set for best performance (Blocked AOCL BLIS) -4. You're using a sufficiently large model (small models may not benefit as much) -5. Enable profiling to verify ZenDNN MatMul is being called +2. `ZENDNNL_MATMUL_ALGO=1` is set for best performance (Blocked AOCL DLP) +3. You're using a sufficiently large model (small models may not benefit as much) +4. Enable profiling to verify ZenDNN MatMul is being called ### **GitHub Contribution**: Please add the **[ZenDNN]** prefix/tag in issues/PRs titles to help the ZenDNN-team check/address them without delay. diff --git a/examples/model-conversion/Makefile b/examples/model-conversion/Makefile index 342de63bd0..9356aaf854 100644 --- a/examples/model-conversion/Makefile +++ b/examples/model-conversion/Makefile @@ -77,7 +77,10 @@ causal-verify-embeddings: causal-run-original-embeddings causal-run-converted-em @./scripts/causal/compare-embeddings-logits.sh causal-inspect-original-model: - @./scripts/utils/inspect-org-model.py + @./scripts/utils/inspect-org-model.py --list-all -s + +causal-list-original-model-tensors: + @./scripts/utils/inspect-org-model.py --list-all-short -s causal-inspect-converted-model: @./scripts/utils/inspect-converted-model.sh @@ -153,7 +156,7 @@ embedding-verify-logits-st: embedding-run-original-model-st embedding-run-conver embedding-inspect-original-model: $(call validate_embedding_model_path,embedding-inspect-original-model) - @EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" ./scripts/utils/inspect-org-model.py -m ${EMBEDDING_MODEL_PATH} + @EMBEDDING_MODEL_PATH="$(EMBEDDING_MODEL_PATH)" ./scripts/utils/inspect-org-model.py -m ${EMBEDDING_MODEL_PATH} --list-all -s embedding-inspect-converted-model: @CONVERTED_EMBEDDING_MODEL="$(CONVERTED_EMBEDDING_MODEL)" ./scripts/utils/inspect-converted-model.sh ${CONVERTED_EMBEDDING_MODEL} diff --git a/examples/model-conversion/scripts/utils/inspect-org-model.py b/examples/model-conversion/scripts/utils/inspect-org-model.py index bc6f45a5fb..5c3674af71 100755 --- a/examples/model-conversion/scripts/utils/inspect-org-model.py +++ b/examples/model-conversion/scripts/utils/inspect-org-model.py @@ -1,67 +1,290 @@ #!/usr/bin/env python3 import argparse -import os import json +import os +import re +import struct +import sys +from pathlib import Path +from typing import Optional from safetensors import safe_open -from collections import defaultdict -parser = argparse.ArgumentParser(description='Process model with specified path') -parser.add_argument('--model-path', '-m', help='Path to the model') -args = parser.parse_args() -model_path = os.environ.get('MODEL_PATH', args.model_path) -if model_path is None: - parser.error("Model path must be specified either via --model-path argument or MODEL_PATH environment variable") +MODEL_SAFETENSORS_FILE = "model.safetensors" +MODEL_SAFETENSORS_INDEX = "model.safetensors.index.json" -# Check if there's an index file (multi-file model) -index_path = os.path.join(model_path, "model.safetensors.index.json") -single_file_path = os.path.join(model_path, "model.safetensors") +DTYPE_SIZES = { + "F64": 8, "I64": 8, "U64": 8, + "F32": 4, "I32": 4, "U32": 4, + "F16": 2, "BF16": 2, "I16": 2, "U16": 2, + "I8": 1, "U8": 1, "BOOL": 1, + "F8_E4M3": 1, "F8_E5M2": 1, +} -if os.path.exists(index_path): - # Multi-file model - print("Multi-file model detected") +SIZE_UNITS = ['B', 'KB', 'MB', 'GB', 'TB'] - with open(index_path, 'r') as f: - index_data = json.load(f) - # Get the weight map (tensor_name -> file_name) - weight_map = index_data.get("weight_map", {}) +def get_weight_map(model_path: Path) -> Optional[dict[str, str]]: + index_file = model_path / MODEL_SAFETENSORS_INDEX - # Group tensors by file for efficient processing - file_tensors = defaultdict(list) - for tensor_name, file_name in weight_map.items(): - file_tensors[file_name].append(tensor_name) + if index_file.exists(): + with open(index_file, 'r') as f: + index = json.load(f) + return index.get("weight_map", {}) - print("Tensors in model:") + return None - # Process each shard file - for file_name, tensor_names in file_tensors.items(): - file_path = os.path.join(model_path, file_name) - print(f"\n--- From {file_name} ---") - with safe_open(file_path, framework="pt") as f: - for tensor_name in sorted(tensor_names): - tensor = f.get_tensor(tensor_name) - print(f"- {tensor_name} : shape = {tensor.shape}, dtype = {tensor.dtype}") +def get_all_tensor_names(model_path: Path) -> list[str]: + weight_map = get_weight_map(model_path) -elif os.path.exists(single_file_path): - # Single file model (original behavior) - print("Single-file model detected") + if weight_map is not None: + return list(weight_map.keys()) - with safe_open(single_file_path, framework="pt") as f: - keys = f.keys() - print("Tensors in model:") - for key in sorted(keys): - tensor = f.get_tensor(key) - print(f"- {key} : shape = {tensor.shape}, dtype = {tensor.dtype}") + single_file = model_path / MODEL_SAFETENSORS_FILE + if single_file.exists(): + try: + with safe_open(single_file, framework="pt", device="cpu") as f: + return list(f.keys()) + except Exception as e: + print(f"Error reading {single_file}: {e}") + sys.exit(1) -else: - print(f"Error: Neither 'model.safetensors.index.json' nor 'model.safetensors' found in {model_path}") - print("Available files:") - if os.path.exists(model_path): - for item in sorted(os.listdir(model_path)): - print(f" {item}") + print(f"Error: No safetensors files found in {model_path}") + sys.exit(1) + + +def find_tensor_file(model_path: Path, tensor_name: str) -> Optional[str]: + weight_map = get_weight_map(model_path) + + if weight_map is not None: + return weight_map.get(tensor_name) + + single_file = model_path / MODEL_SAFETENSORS_FILE + if single_file.exists(): + return single_file.name + + return None + + +def read_safetensors_header(file_path: Path) -> dict: + with open(file_path, 'rb') as f: + header_size = struct.unpack(' int: + offsets = tensor_meta.get("data_offsets") + if offsets and len(offsets) == 2: + return offsets[1] - offsets[0] + n_elements = 1 + for d in tensor_meta.get("shape", []): + n_elements *= d + return n_elements * DTYPE_SIZES.get(tensor_meta.get("dtype", "F32"), 4) + + +def format_size(size_bytes: int) -> str: + val = float(size_bytes) + for unit in SIZE_UNITS[:-1]: + if val < 1024.0: + return f"{val:.2f} {unit}" + val /= 1024.0 + return f"{val:.2f} {SIZE_UNITS[-1]}" + + +def get_all_tensor_metadata(model_path: Path) -> dict[str, dict]: + weight_map = get_weight_map(model_path) + + if weight_map is not None: + file_to_tensors: dict[str, list[str]] = {} + for tensor_name, file_name in weight_map.items(): + file_to_tensors.setdefault(file_name, []).append(tensor_name) + + all_metadata: dict[str, dict] = {} + for file_name, tensor_names in file_to_tensors.items(): + try: + header = read_safetensors_header(model_path / file_name) + for tensor_name in tensor_names: + if tensor_name in header: + all_metadata[tensor_name] = header[tensor_name] + except Exception as e: + print(f"Warning: Could not read header from {file_name}: {e}", file=sys.stderr) + return all_metadata + + single_file = model_path / MODEL_SAFETENSORS_FILE + if single_file.exists(): + try: + header = read_safetensors_header(single_file) + return {k: v for k, v in header.items() if k != "__metadata__"} + except Exception as e: + print(f"Error reading {single_file}: {e}") + sys.exit(1) + + print(f"Error: No safetensors files found in {model_path}") + sys.exit(1) + + +def normalize_tensor_name(tensor_name: str) -> str: + normalized = re.sub(r'\.\d+\.', '.#.', tensor_name) + normalized = re.sub(r'\.\d+$', '.#', normalized) + return normalized + + +def list_all_tensors( + model_path: Path, + short: bool = False, + show_sizes: bool = False, +): + tensor_names = get_all_tensor_names(model_path) + + metadata: Optional[dict[str, dict]] = None + if show_sizes: + metadata = get_all_tensor_metadata(model_path) + + total_bytes = 0 + + if short: + seen: dict[str, str] = {} + for tensor_name in sorted(tensor_names): + normalized = normalize_tensor_name(tensor_name) + if normalized not in seen: + seen[normalized] = tensor_name + display_pairs = list(sorted(seen.items())) + name_width = max((len(n) for n, _ in display_pairs), default=0) + for normalized, first_name in display_pairs: + if metadata and first_name in metadata: + m = metadata[first_name] + size = get_tensor_size_bytes(m) + total_bytes += size + print(f"{normalized:{name_width}} {m.get('dtype', '?'):6s} {str(m.get('shape', '')):30s} {format_size(size)}") + else: + print(normalized) else: - print(f" Directory {model_path} does not exist") - exit(1) + name_width = max((len(n) for n in tensor_names), default=0) + for tensor_name in sorted(tensor_names): + if metadata and tensor_name in metadata: + m = metadata[tensor_name] + size = get_tensor_size_bytes(m) + total_bytes += size + print(f"{tensor_name:{name_width}} {m.get('dtype', '?'):6s} {str(m.get('shape', '')):30s} {format_size(size)}") + else: + print(tensor_name) + + if show_sizes: + print(f"\nTotal: {format_size(total_bytes)}") + + +def print_tensor_info(model_path: Path, tensor_name: str, num_values: Optional[int] = None): + tensor_file = find_tensor_file(model_path, tensor_name) + + if tensor_file is None: + print(f"Error: Could not find tensor '{tensor_name}' in model index") + print(f"Model path: {model_path}") + sys.exit(1) + + file_path = model_path / tensor_file + + try: + header = read_safetensors_header(file_path) + tensor_meta = header.get(tensor_name, {}) + dtype_str = tensor_meta.get("dtype") + + with safe_open(file_path, framework="pt", device="cpu") as f: + if tensor_name in f.keys(): + tensor_slice = f.get_slice(tensor_name) + shape = tensor_slice.get_shape() + print(f"Tensor: {tensor_name}") + print(f"File: {tensor_file}") + print(f"Shape: {shape}") + if dtype_str: + print(f"Dtype: {dtype_str}") + if tensor_meta: + print(f"Size: {format_size(get_tensor_size_bytes(tensor_meta))}") + if num_values is not None: + tensor = f.get_tensor(tensor_name) + if not dtype_str: + print(f"Dtype: {tensor.dtype}") + flat = tensor.flatten() + n = min(num_values, flat.numel()) + print(f"Values: {flat[:n].tolist()}") + else: + print(f"Error: Tensor '{tensor_name}' not found in {tensor_file}") + sys.exit(1) + + except FileNotFoundError: + print(f"Error: The file '{file_path}' was not found.") + sys.exit(1) + except Exception as e: + print(f"An error occurred: {e}") + sys.exit(1) + + +def main(): + parser = argparse.ArgumentParser( + description="Print tensor information from a safetensors model" + ) + parser.add_argument( + "tensor_name", + nargs="?", + help="Name of the tensor to inspect" + ) + parser.add_argument( + "-m", "--model-path", + type=Path, + help="Path to the model directory (default: MODEL_PATH environment variable)" + ) + parser.add_argument( + "-l", "--list-all-short", + action="store_true", + help="List unique tensor patterns (layer numbers replaced with #)" + ) + parser.add_argument( + "-la", "--list-all", + action="store_true", + help="List all tensor names with actual layer numbers" + ) + parser.add_argument( + "-n", "--num-values", + nargs="?", + const=10, + default=None, + type=int, + metavar="N", + help="Print the first N values of the tensor flattened (default: 10 if flag is given without a number)" + ) + parser.add_argument( + "-s", "--sizes", + action="store_true", + help="Show dtype, shape, and size for each tensor when listing" + ) + + args = parser.parse_args() + + model_path = args.model_path + if model_path is None: + model_path_str = os.environ.get("MODEL_PATH") + if model_path_str is None: + print("Error: --model-path not provided and MODEL_PATH environment variable not set") + sys.exit(1) + model_path = Path(model_path_str) + + if not model_path.exists(): + print(f"Error: Model path does not exist: {model_path}") + sys.exit(1) + + if not model_path.is_dir(): + print(f"Error: Model path is not a directory: {model_path}") + sys.exit(1) + + if args.list_all_short or args.list_all: + list_all_tensors(model_path, short=args.list_all_short, show_sizes=args.sizes) + else: + if args.tensor_name is None: + print("Error: tensor_name is required when not using --list-all-short or --list-all") + sys.exit(1) + print_tensor_info(model_path, args.tensor_name, args.num_values) + + +if __name__ == "__main__": + main() diff --git a/examples/model-conversion/scripts/utils/tensor-info.py b/examples/model-conversion/scripts/utils/tensor-info.py deleted file mode 100755 index 1bb9e0564c..0000000000 --- a/examples/model-conversion/scripts/utils/tensor-info.py +++ /dev/null @@ -1,174 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import json -import os -import re -import sys -from pathlib import Path -from typing import Optional -from safetensors import safe_open - - -MODEL_SAFETENSORS_FILE = "model.safetensors" -MODEL_SAFETENSORS_INDEX = "model.safetensors.index.json" - - -def get_weight_map(model_path: Path) -> Optional[dict[str, str]]: - index_file = model_path / MODEL_SAFETENSORS_INDEX - - if index_file.exists(): - with open(index_file, 'r') as f: - index = json.load(f) - return index.get("weight_map", {}) - - return None - - -def get_all_tensor_names(model_path: Path) -> list[str]: - weight_map = get_weight_map(model_path) - - if weight_map is not None: - return list(weight_map.keys()) - - single_file = model_path / MODEL_SAFETENSORS_FILE - if single_file.exists(): - try: - with safe_open(single_file, framework="pt", device="cpu") as f: - return list(f.keys()) - except Exception as e: - print(f"Error reading {single_file}: {e}") - sys.exit(1) - - print(f"Error: No safetensors files found in {model_path}") - sys.exit(1) - - -def find_tensor_file(model_path: Path, tensor_name: str) -> Optional[str]: - weight_map = get_weight_map(model_path) - - if weight_map is not None: - return weight_map.get(tensor_name) - - single_file = model_path / MODEL_SAFETENSORS_FILE - if single_file.exists(): - return single_file.name - - return None - - -def normalize_tensor_name(tensor_name: str) -> str: - normalized = re.sub(r'\.\d+\.', '.#.', tensor_name) - normalized = re.sub(r'\.\d+$', '.#', normalized) - return normalized - - -def list_all_tensors(model_path: Path, unique: bool = False): - tensor_names = get_all_tensor_names(model_path) - - if unique: - seen = set() - for tensor_name in sorted(tensor_names): - normalized = normalize_tensor_name(tensor_name) - if normalized not in seen: - seen.add(normalized) - print(normalized) - else: - for tensor_name in sorted(tensor_names): - print(tensor_name) - - -def print_tensor_info(model_path: Path, tensor_name: str, num_values: Optional[int] = None): - tensor_file = find_tensor_file(model_path, tensor_name) - - if tensor_file is None: - print(f"Error: Could not find tensor '{tensor_name}' in model index") - print(f"Model path: {model_path}") - sys.exit(1) - - file_path = model_path / tensor_file - - try: - with safe_open(file_path, framework="pt", device="cpu") as f: - if tensor_name in f.keys(): - tensor_slice = f.get_slice(tensor_name) - shape = tensor_slice.get_shape() - print(f"Tensor: {tensor_name}") - print(f"File: {tensor_file}") - print(f"Shape: {shape}") - if num_values is not None: - tensor = f.get_tensor(tensor_name) - print(f"Dtype: {tensor.dtype}") - flat = tensor.flatten() - n = min(num_values, flat.numel()) - print(f"Values: {flat[:n].tolist()}") - else: - print(f"Error: Tensor '{tensor_name}' not found in {tensor_file}") - sys.exit(1) - - except FileNotFoundError: - print(f"Error: The file '{file_path}' was not found.") - sys.exit(1) - except Exception as e: - print(f"An error occurred: {e}") - sys.exit(1) - - -def main(): - parser = argparse.ArgumentParser( - description="Print tensor information from a safetensors model" - ) - parser.add_argument( - "tensor_name", - nargs="?", # optional (if --list is used for example) - help="Name of the tensor to inspect" - ) - parser.add_argument( - "-m", "--model-path", - type=Path, - help="Path to the model directory (default: MODEL_PATH environment variable)" - ) - parser.add_argument( - "-l", "--list", - action="store_true", - help="List unique tensor patterns in the model (layer numbers replaced with #)" - ) - parser.add_argument( - "-n", "--num-values", - nargs="?", - const=10, - default=None, - type=int, - metavar="N", - help="Print the first N values of the tensor flattened (default: 10 if flag is given without a number)" - ) - - args = parser.parse_args() - - model_path = args.model_path - if model_path is None: - model_path_str = os.environ.get("MODEL_PATH") - if model_path_str is None: - print("Error: --model-path not provided and MODEL_PATH environment variable not set") - sys.exit(1) - model_path = Path(model_path_str) - - if not model_path.exists(): - print(f"Error: Model path does not exist: {model_path}") - sys.exit(1) - - if not model_path.is_dir(): - print(f"Error: Model path is not a directory: {model_path}") - sys.exit(1) - - if args.list: - list_all_tensors(model_path, unique=True) - else: - if args.tensor_name is None: - print("Error: tensor_name is required when not using --list") - sys.exit(1) - print_tensor_info(model_path, args.tensor_name, args.num_values) - - -if __name__ == "__main__": - main() diff --git a/examples/save-load-state/save-load-state.cpp b/examples/save-load-state/save-load-state.cpp index 39d4464663..5e35dcd603 100644 --- a/examples/save-load-state/save-load-state.cpp +++ b/examples/save-load-state/save-load-state.cpp @@ -5,12 +5,15 @@ #include #include + int main(int argc, char ** argv) { common_params params; params.prompt = "The quick brown fox"; params.sampling.seed = 1234; + const std::string_view state_file = "dump_state.bin"; + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_COMMON)) { return 1; } @@ -53,35 +56,16 @@ int main(int argc, char ** argv) { // tokenize prompt auto tokens = common_tokenize(ctx, params.prompt, true); - // prepare the batch - llama_batch batch = llama_batch_init(tokens.size(), 0, 1); - for (size_t i = 0; i < tokens.size(); i++) { - common_batch_add(batch, tokens[i], i, {0}, false); + const bool save_state = true; + if (!common_prompt_batch_decode(ctx, tokens, n_past, params.n_batch, state_file, save_state)) { + return 1; } - batch.logits[batch.n_tokens - 1] = true; // generate next token - - // evaluate prompt - llama_decode(ctx, batch); - n_past += batch.n_tokens; - - // save state (rng, logits, embedding and kv_cache) to file - { - std::vector state_mem(llama_state_get_size(ctx)); - const size_t written = llama_state_get_data(ctx, state_mem.data(), state_mem.size()); - - FILE *fp_write = fopen("dump_state.bin", "wb"); - fwrite(state_mem.data(), 1, written, fp_write); - fclose(fp_write); - - fprintf(stderr, "%s : serialized state into %zd out of a maximum of %zd bytes\n", __func__, written, state_mem.size()); - } - - // save state (last tokens) - const auto n_past_saved = n_past; // first run printf("\nfirst run: %s", params.prompt.c_str()); + llama_batch batch = llama_batch_init(1, 0, 1); + for (auto i = 0; i < params.n_predict; i++) { auto next_token = llama_sampler_sample(smpl, ctx, -1); auto next_token_str = common_token_to_piece(ctx, next_token); @@ -111,27 +95,23 @@ int main(int argc, char ** argv) { printf("\nsecond run: %s", params.prompt.c_str()); - // load state (rng, logits, embedding and kv_cache) from file - { - std::vector state_mem; + // load state from file + std::vector unused_sts(tokens.size()); // unused session tokens. + size_t n_token_count_out = 0; - FILE * fp_read = fopen("dump_state.bin", "rb"); - fseek(fp_read, 0, SEEK_END); - state_mem.resize(ftell(fp_read)); - fseek(fp_read, 0, SEEK_SET); - const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read); - fclose(fp_read); - - if (read != llama_state_set_data(ctx2, state_mem.data(), state_mem.size())) { - fprintf(stderr, "\n%s : failed to read state\n", __func__); - return 1; - } - - fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size()); + if (!llama_state_load_file(ctx2, state_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) { + fprintf(stderr, "\n%s : failed to load state\n", __func__); + return 1; } + fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out); + // restore state (last tokens) - n_past = n_past_saved; + n_past = n_token_count_out; + if (!common_replay_last_token(ctx2, tokens.back(), n_past)) { + return 1; + } + ++n_past; // second run for (auto i = 0; i < params.n_predict; i++) { @@ -160,7 +140,9 @@ int main(int argc, char ** argv) { } // make new context - llama_context * ctx3 = llama_init_from_model(model, common_context_params_to_llama(params)); + auto params_ctx3 = common_context_params_to_llama(params); + params_ctx3.n_seq_max = 2; + llama_context * ctx3 = llama_init_from_model(model, params_ctx3); llama_sampler * smpl3 = llama_sampler_chain_init(sparams); @@ -169,26 +151,21 @@ int main(int argc, char ** argv) { printf("\nsingle seq run: %s", params.prompt.c_str()); // load state (rng, logits, embedding and kv_cache) from file - { - std::vector state_mem; + n_token_count_out = 0; - FILE * fp_read = fopen("dump_state.bin", "rb"); - fseek(fp_read, 0, SEEK_END); - state_mem.resize(ftell(fp_read)); - fseek(fp_read, 0, SEEK_SET); - const size_t read = fread(state_mem.data(), 1, state_mem.size(), fp_read); - fclose(fp_read); - - if (read != llama_state_set_data(ctx3, state_mem.data(), state_mem.size())) { - fprintf(stderr, "\n%s : failed to read state\n", __func__); - return 1; - } - - fprintf(stderr, "%s : deserialized state from %zd out of a maximum of %zd bytes\n", __func__, read, state_mem.size()); + if (!llama_state_load_file(ctx3, state_file.data(), unused_sts.data(), unused_sts.size(), &n_token_count_out)) { + fprintf(stderr, "\n%s : failed to load state\n", __func__); + return 1; } + fprintf(stderr, "%s : loaded state with %zu tokens\n", __func__, n_token_count_out); + // restore state (last tokens) - n_past = n_past_saved; + n_past = n_token_count_out; + if (!common_replay_last_token(ctx3, tokens.back(), n_past)) { + return 1; + } + ++n_past; // save seq 0 and load into seq 1 { diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 77af0e7fb6..fcc51f1f71 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -730,10 +730,6 @@ extern "C" { GGML_API size_t ggml_type_size(enum ggml_type type); // size in bytes for all elements in a block GGML_API size_t ggml_row_size (enum ggml_type type, int64_t ne); // size in bytes for all elements in a row - GGML_DEPRECATED( - GGML_API double ggml_type_sizef(enum ggml_type type), // ggml_type_size()/ggml_blck_size() as float - "use ggml_row_size() instead"); - GGML_API const char * ggml_type_name(enum ggml_type type); GGML_API const char * ggml_op_name (enum ggml_op op); GGML_API const char * ggml_op_symbol(enum ggml_op op); diff --git a/ggml/src/ggml-cpu/amx/amx.cpp b/ggml/src/ggml-cpu/amx/amx.cpp index 895a571375..9baf3e025e 100644 --- a/ggml/src/ggml-cpu/amx/amx.cpp +++ b/ggml/src/ggml-cpu/amx/amx.cpp @@ -141,27 +141,50 @@ static size_t ggml_backend_amx_buffer_type_get_alignment(ggml_backend_buffer_typ namespace ggml::cpu::amx { class extra_buffer_type : ggml::cpu::extra_buffer_type { bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { - // handle only 2d gemm for now - auto is_contiguous_2d = [](const struct ggml_tensor * t) { - return ggml_is_contiguous(t) && t->ne[3] == 1 && t->ne[2] == 1; - }; - - if (op->op == GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) && // src0 must be contiguous - is_contiguous_2d(op->src[1]) && // src1 must be contiguous - op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_amx_buffer_type() && - op->src[0]->ne[0] % (TILE_K * 2 * 32) == 0 && // TODO: not sure if correct (https://github.com/ggml-org/llama.cpp/pull/16315) - op->ne[0] % (TILE_N * 2) == 0 && // out_features is 32x - (qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == GGML_TYPE_F16))) { - // src1 must be host buffer - if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { - return false; - } - // src1 must be float32 - if (op->src[1]->type == GGML_TYPE_F32) { - return true; - } + if (op->op != GGML_OP_MUL_MAT) { + return false; } - return false; + auto * src0 = op->src[0]; + auto * src1 = op->src[1]; + + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + return false; + } + if (!src0->buffer || src0->buffer->buft != ggml_backend_amx_buffer_type()) { + return false; + } + if (src1->buffer && !ggml_backend_buft_is_host(src1->buffer->buft)) { + return false; + } + if (op->ne[0] % (TILE_N * 2)) { + return false; + } + int alignment; + switch (src0->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q8_0: + alignment = TILE_K; + break; + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ4_XS: + alignment = 256; // QK_K + break; + case GGML_TYPE_F16: + alignment = 16; + break; + default: + return false; + } + if (src0->ne[0] % alignment) { + return false; + } + if (src1->type != GGML_TYPE_F32) { + return false; + } + return true; } ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override { diff --git a/ggml/src/ggml-cpu/amx/mmq.cpp b/ggml/src/ggml-cpu/amx/mmq.cpp index 47c61b8816..b5aca76633 100644 --- a/ggml/src/ggml-cpu/amx/mmq.cpp +++ b/ggml/src/ggml-cpu/amx/mmq.cpp @@ -1,4 +1,3 @@ - #if defined(__GNUC__) #pragma GCC diagnostic ignored "-Wpedantic" #pragma GCC diagnostic ignored "-Wunused-local-typedefs" @@ -202,35 +201,27 @@ struct tile_config_t{ // advanced-matrix-extensions-intrinsics-functions.html // -#define TC_CONFIG_TILE(i, r, cb) tc.rows[i] = r; tc.colsb[i] = cb -void ggml_tile_config_init(void) { - static thread_local bool is_first_time = true; +inline void ggml_tile_config_init(void) { + static thread_local bool done = false; - if (!is_first_time) { + if (done) { return; } - static thread_local tile_config_t tc; - tile_config_t current_tc; - _tile_storeconfig(¤t_tc); + alignas(64) tile_config_t tc = {}; + tc.palette_id = 1; + tc.start_row = 0; + tc.rows[0] = 8; tc.colsb[0] = 64; + tc.rows[1] = 8; tc.colsb[1] = 64; + tc.rows[2] = 16; tc.colsb[2] = 32; + tc.rows[3] = 16; tc.colsb[3] = 32; + tc.rows[4] = 16; tc.colsb[4] = 64; + tc.rows[5] = 16; tc.colsb[5] = 64; + tc.rows[6] = 16; tc.colsb[6] = 64; + tc.rows[7] = 16; tc.colsb[7] = 64; - // load only when config changes - if (tc.palette_id == 0 || (memcmp(¤t_tc.colsb, &tc.colsb, sizeof(uint16_t) * 8) != 0 && - memcmp(¤t_tc.rows, &tc.rows, sizeof(uint8_t) * 8) != 0)) { - tc.palette_id = 1; - tc.start_row = 0; - TC_CONFIG_TILE(TMM0, 8, 64); - TC_CONFIG_TILE(TMM1, 8, 64); - TC_CONFIG_TILE(TMM2, 16, 32); - TC_CONFIG_TILE(TMM3, 16, 32); - TC_CONFIG_TILE(TMM4, 16, 64); - TC_CONFIG_TILE(TMM5, 16, 64); - TC_CONFIG_TILE(TMM6, 16, 64); - TC_CONFIG_TILE(TMM7, 16, 64); - _tile_loadconfig(&tc); - } - - is_first_time = false; + _tile_loadconfig(&tc); + done = true; } // we need an extra 16 * 4B (TILE_N * int32_t) for each NB/KB block for compensation. @@ -268,33 +259,6 @@ int get_row_size(int K) { return row_size; } -// vectorized dtype conversion -inline float FP16_TO_FP32(ggml_half val) { - __m256i v = _mm256_setr_epi16( - val, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0); - __m512 o = _mm512_cvtph_ps(v); - return _mm512_cvtss_f32(o); -} - -inline __m512 FP16_TO_FP32_VEC(ggml_half val) { - __m256i v = _mm256_set1_epi16(val); - return _mm512_cvtph_ps(v); -} - -// horizontal reduce -inline float _mm512_reduce_max_ps(const __m512 x) { - __m512 v = x; - __m512 v1 = _mm512_shuffle_f32x4(v, v, 0x4E); - v = _mm512_max_ps(v, v1); - v1 = _mm512_shuffle_f32x4(v, v, 0xB1); - v = _mm512_max_ps(v, v1); - v1 = _mm512_shuffle_ps(v, v, 0x4E); - v = _mm512_max_ps(v, v1); - v1 = _mm512_shuffle_ps(v, v, 0xB1); - v = _mm512_max_ps(v, v1); - return _mm512_cvtss_f32(v); -} - // transpose utils #define SHUFFLE_EPI32(a, b, mask) \ _mm256_castps_si256(_mm256_shuffle_ps(_mm256_castsi256_ps(a), _mm256_castsi256_ps(b), mask)) @@ -1370,9 +1334,9 @@ struct tinygemm_kernel_avx #define LAUNCH_TINYGEMM_KERNEL_AVX(MB_SIZE, NB_SIZE) \ tinygemm_kernel_avx::apply( \ - K, (const float *)src1->data + mb_start * K, \ - (const type *)src0->data + nb_start * K, \ - (float *)dst->data + mb_start * ldc + nb_start, ldc); + K, (const float *)src1->data + src1_offset + mb_start * K, \ + (const type *)src0->data + src0_offset + nb_start * K, \ + (float *)dst->data + dst_offset + mb_start * ldc + nb_start, ldc) // re-organize in the format {NB, KB, TILE_SIZE}: @@ -2019,11 +1983,11 @@ struct tinygemm_kernel_vnni::apply( \ - KB, (const char *)wdata + 0 * row_size_A, \ - (const char *)src0->data + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), \ - (float *) dst->data + 0 * N + nb_start, ldc) +#define LAUNCH_TINYGEMM_KERNEL_VNNI(NB_SIZE) \ + tinygemm_kernel_vnni::apply( \ + KB, wdata_batch, \ + (const char *)src0->data + src0_offset + PACKED_INDEX(nb * kTilesN, 0, KB, TILE_SIZE), \ + (float *) dst->data + dst_offset + nb_start, ldc) template ::value, int>::type = 0> @@ -2079,7 +2043,7 @@ void tinygemm_kernel_amx(int M, int N, int KB, const void * RESTRICT _A, const v _tile_stored(TMM5, Tile5(C_pre), TILE_N * sizeof(int32_t)); if (need_unpack) { - unpack_B(Tile1, B_blk0); + unpack_B(Tile1, B_blk1); _tile_loadd(TMM1, Tile1, TILE_N * VNNI_BLK); } else { _tile_loadd(TMM1, B_blk1, TILE_N * VNNI_BLK); @@ -2336,6 +2300,13 @@ void ggml_backend_amx_convert_weight(struct ggml_tensor * tensor, const void * d }); } +// ne2 is passed explicitly to help compiler optimize repeated calls +inline int64_t ggml_batch_offset(const ggml_tensor * t, int64_t batch_idx, int64_t ne2) { + const int64_t i2 = batch_idx % ne2; + const int64_t i3 = batch_idx / ne2; + return i3 * t->nb[3] + i2 * t->nb[2]; +} + size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) { struct ggml_tensor * src0 = dst->src[0]; @@ -2348,12 +2319,13 @@ size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) { const int M = dst->ne[1]; const int K = src0->ne[0]; + const int64_t n_batch = dst->ne[2] * dst->ne[3]; size_t desired_wsize = 0; GGML_DISPATCH_QTYPES(TYPE, [&] { const size_t row_size_A = K / blck_size * sizeof(vec_dot_type); - desired_wsize = M * row_size_A; + desired_wsize = n_batch * M * row_size_A; }); return desired_wsize; @@ -2365,7 +2337,7 @@ size_t ggml_backend_amx_desired_wsize(const struct ggml_tensor * dst) { // src1: input in shape of {M, K}, float32 // dst: output in shape of {M, N}, float32 // -// the function performs: dst = src1 @ src0.T +// the function performs: dst = src1 @ src0.T for each batch // void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_tensor * dst) { struct ggml_tensor * src0 = dst->src[0]; @@ -2382,17 +2354,26 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te const int K = src0->ne[0]; const int ldc = dst->nb[1] / dst->nb[0]; + const int64_t ne2 = dst->ne[2]; + const int64_t n_batch = ne2 * dst->ne[3]; + if (is_floating_type) { constexpr int BLOCK_M = 4; constexpr int BLOCK_N = 6; const int MB = div_up(M, BLOCK_M); const int NB = div_up(N, BLOCK_N); - parallel_for_ggml(params, MB * NB, [&](int begin, int end) { + parallel_for_ggml(params, n_batch * MB * NB, [&](int begin, int end) { GGML_DISPATCH_FLOATING_TYPES(TYPE, [&] { for (int i = begin; i < end; ++i) { - int mb = i / NB; - int nb = i % NB; + int batch_idx = i / (MB * NB); + int remaining = i % (MB * NB); + int mb = remaining / NB; + int nb = remaining % NB; + + int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2); + int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2); + int64_t dst_offset = ggml_batch_offset(dst, batch_idx, ne2); int mb_start = mb * BLOCK_M; int mb_size = std::min(BLOCK_M, M - mb_start); @@ -2424,10 +2405,10 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te void * wdata = params->wdata; //TODO: performance improvement: merge quant A - if (params->ith == 0) { + // if (params->ith == 0) { GGML_DISPATCH_QTYPES(TYPE, [&] { const size_t row_size_A = K / blck_size * sizeof(vec_dot_type); - const size_t desired_wsize = M * row_size_A; + const size_t desired_wsize = n_batch * M * row_size_A; if (params->wsize < desired_wsize) { GGML_ABORT("insufficient work space size"); } @@ -2436,12 +2417,19 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te // Q4_K, Q5_K, Q6_K, IQ4_XS handles 8 TILE_K per blck_size GGML_ASSERT(TILE_K == blck_size || TILE_K * 8 == blck_size); - const float * A_data = static_cast(src1->data); - for (int m = 0; m < M; ++m) { - from_float(A_data + m * K, (char *)wdata + m * row_size_A, K); - } + parallel_for_ggml(params, n_batch, [&](int begin, int end) { + for (int batch_idx = begin; batch_idx < end; ++batch_idx) { + int64_t src1_offset = ggml_batch_offset(src1, batch_idx, ne2); + const float * A_data = (const float *)((const char *)src1->data + src1_offset); + char * wdata_batch = (char *)wdata + batch_idx * M * row_size_A; + + for (int m = 0; m < M; ++m) { + from_float(A_data + m * K, wdata_batch + m * row_size_A, K); + } + } + }); }); - } + // } ggml_barrier(params->threadpool); @@ -2451,13 +2439,19 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te constexpr int BLOCK_N = TILE_N * kTilesN; const int NB = div_up(N, BLOCK_N); - parallel_for_ggml(params, NB, [&](int begin, int end) { + parallel_for_ggml(params, n_batch * NB, [&](int begin, int end) { GGML_DISPATCH_QTYPES(TYPE, [&] { const int KB = K / blck_size; const int TILE_SIZE = get_tile_size(); const int row_size_A = KB * sizeof(vec_dot_type); for (int i = begin; i < end; ++i) { - int nb = i; + int batch_idx = i / NB; + int nb = i % NB; + + int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2); + int64_t dst_offset = ggml_batch_offset(dst, batch_idx, ne2); + const char * wdata_batch = (const char *)wdata + batch_idx * row_size_A; + int nb_start = nb * BLOCK_N; int nb_size = std::min(BLOCK_N, N - nb_start); // 32, 64, 96 @@ -2481,7 +2475,7 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te const int MB = div_up(M, BLOCK_M); const int NB = div_up(N, BLOCK_N); - parallel_for_ggml(params, MB * NB, [&](int begin, int end) { + parallel_for_ggml(params, n_batch * MB * NB, [&](int begin, int end) { // init tile config for each thread ggml_tile_config_init(); @@ -2491,8 +2485,14 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te const int row_size_A = KB * sizeof(vec_dot_type); for (int i = begin; i < end; ++i) { - int mb = i / NB; - int nb = i % NB; + int batch_idx = i / (MB * NB); + int remaining = i % (MB * NB); + int mb = remaining / NB; + int nb = remaining % NB; + + int64_t src0_offset = ggml_batch_offset(src0, batch_idx, ne2); + int64_t dst_offset = ggml_batch_offset(dst, batch_idx, ne2); + const char * wdata_batch = (const char *)wdata + batch_idx * M * row_size_A; int mb_start = mb * BLOCK_M; int mb_size = std::min(BLOCK_M, M - mb_start); @@ -2501,9 +2501,9 @@ void ggml_backend_amx_mul_mat(const ggml_compute_params * params, struct ggml_te tinygemm_kernel_amx( mb_size, nb_size, KB, - (const char *)wdata + mb_start * row_size_A, - (const char *)src0->data + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE), - (float *) dst->data + mb_start * N + nb_start, ldc); + wdata_batch + mb_start * row_size_A, + (const char *)src0->data + src0_offset + PACKED_INDEX(nb * 2, 0, KB, TILE_SIZE), + (float *) dst->data + dst_offset + mb_start * N + nb_start, ldc); } }); }); diff --git a/ggml/src/ggml-cpu/arch-fallback.h b/ggml/src/ggml-cpu/arch-fallback.h index 55526e6fb3..ebbd4b47e0 100644 --- a/ggml/src/ggml-cpu/arch-fallback.h +++ b/ggml/src/ggml-cpu/arch-fallback.h @@ -42,11 +42,14 @@ #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K +#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 +#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0 +#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 @@ -55,11 +58,14 @@ #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K +#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K -#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K +#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 +#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0 +#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64) @@ -67,8 +73,10 @@ #define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4 #define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 +#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 +#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64) // repack.cpp @@ -77,19 +85,23 @@ #define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0 #define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0 #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K +#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 +#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 #define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0 #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K +#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 +#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__POWERPC__) || defined(__powerpc__) @@ -110,11 +122,14 @@ #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K +#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 +#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0 +#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 @@ -123,11 +138,14 @@ #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K +#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 +#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0 +#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__loongarch64) @@ -148,11 +166,14 @@ #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K +#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 +#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0 +#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 @@ -161,11 +182,14 @@ #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K +#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 +#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0 +#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__riscv) @@ -187,11 +211,14 @@ #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K +#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 +#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0 +#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 @@ -199,11 +226,14 @@ #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K +#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 +#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0 +#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__s390x__) @@ -230,11 +260,14 @@ #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K +#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 +#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0 +#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 @@ -243,11 +276,14 @@ #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K +#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 +#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0 +#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #elif defined(__wasm__) @@ -276,11 +312,14 @@ #define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K #define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K #define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K +#define ggml_gemv_q5_K_8x4_q8_K_generic ggml_gemv_q5_K_8x4_q8_K #define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K #define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K #define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K #define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0 #define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0 +#define ggml_gemv_mxfp4_4x4_q8_0_generic ggml_gemv_mxfp4_4x4_q8_0 +#define ggml_gemv_mxfp4_8x8_q8_0_generic ggml_gemv_mxfp4_8x8_q8_0 #define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0 #define ggml_gemv_q8_0_4x8_q8_0_generic ggml_gemv_q8_0_4x8_q8_0 #define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0 @@ -289,11 +328,14 @@ #define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K #define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K #define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K +#define ggml_gemm_q5_K_8x4_q8_K_generic ggml_gemm_q5_K_8x4_q8_K #define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K #define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K #define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K #define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0 #define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0 +#define ggml_gemm_mxfp4_4x4_q8_0_generic ggml_gemm_mxfp4_4x4_q8_0 +#define ggml_gemm_mxfp4_8x8_q8_0_generic ggml_gemm_mxfp4_8x8_q8_0 #define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0 #define ggml_gemm_q8_0_4x8_q8_0_generic ggml_gemm_q8_0_4x8_q8_0 #endif diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index 3a3b32efb2..3eed0105bf 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -498,6 +498,81 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + const int8x16_t kvalues = vld1q_s8(kvalues_mxfp4); + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + float * res_ptr = s; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb); + + float32x4_t sumf = vdupq_n_f32(0); + for (int l = 0; l < nb; l++) { + uint8x16_t b_0 = vld1q_u8(b_ptr[l].qs + 0); + uint8x16_t b_1 = vld1q_u8(b_ptr[l].qs + 16); + uint8x16_t b_2 = vld1q_u8(b_ptr[l].qs + 32); + uint8x16_t b_3 = vld1q_u8(b_ptr[l].qs + 48); + + int8x16_t b_0_hi = vqtbl1q_s8(kvalues, b_0 >> 4); + int8x16_t b_0_lo = vqtbl1q_s8(kvalues, b_0 & 0x0F); + int8x16_t b_1_hi = vqtbl1q_s8(kvalues, b_1 >> 4); + int8x16_t b_1_lo = vqtbl1q_s8(kvalues, b_1 & 0x0F); + int8x16_t b_2_hi = vqtbl1q_s8(kvalues, b_2 >> 4); + int8x16_t b_2_lo = vqtbl1q_s8(kvalues, b_2 & 0x0F); + int8x16_t b_3_hi = vqtbl1q_s8(kvalues, b_3 >> 4); + int8x16_t b_3_lo = vqtbl1q_s8(kvalues, b_3 & 0x0F); + + int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 0); + int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16); + + int32x4_t sumi = vdupq_n_s32(0); + sumi = vdotq_laneq_s32(sumi, b_0_lo, a_0, 0); + sumi = vdotq_laneq_s32(sumi, b_0_hi, a_1, 0); + sumi = vdotq_laneq_s32(sumi, b_1_lo, a_0, 1); + sumi = vdotq_laneq_s32(sumi, b_1_hi, a_1, 1); + sumi = vdotq_laneq_s32(sumi, b_2_lo, a_0, 2); + sumi = vdotq_laneq_s32(sumi, b_2_hi, a_1, 2); + sumi = vdotq_laneq_s32(sumi, b_3_lo, a_0, 3); + sumi = vdotq_laneq_s32(sumi, b_3_hi, a_1, 3); + + float32x4_t a_d = vcvt_f32_f16(vld1_dup_f16((const float16_t *)&a_ptr[l].d)); + float32x4_t b_d = { + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[0]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[1]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[2]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[3]), + }; + float32x4_t d = a_d * b_d; + + sumf = vmlaq_f32(sumf, d, vcvtq_f32_s32(sumi)); + } + + vst1q_f32(res_ptr + x * 4, sumf); + } + return; +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) + ggml_gemv_mxfp4_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { constexpr int qk = QK_K; const int nb = n / qk; @@ -785,6 +860,165 @@ void ggml_gemv_q4_K_8x8_q8_K(int n, ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemv_q5_K_8x4_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 4; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + constexpr int col_groups = ncols_interleaved / 4; // 0123 and 4567 + const uint8x16_t m4b = vdupq_n_u8(0x0f); + const uint8x16_t mone = vdupq_n_u8(1); + const uint8x16_t mtwo = vdupq_n_u8(2); + + // 1x8 tile = 2 x 4 + float32x4_t acc_f32[col_groups]; + + const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy; + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb); + + for (int i = 0; i < col_groups; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + float32x4_t q5_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d)); // d0 d1 d2 d3 + float32x4_t q5_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4)); // d4 d5 d6 d7 + float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d); + float32x4_t sb_scale_0123 = vmulq_f32(q5_d_0, q8_d); + float32x4_t sb_scale_4567 = vmulq_f32(q5_d_1, q8_d); + float32x4_t q5_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin)); // dmin 0..3 + float32x4_t q5_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4)); // dmin 4..7 + float32x4_t sb_min_0123 = vmulq_f32(q5_dmin_0, q8_d); + float32x4_t sb_min_4567 = vmulq_f32(q5_dmin_1, q8_d); + + // interleaved bias_acc: [0]->r0 0123, [1]->r0 4567 + int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) }; + int32x4_t acc_lo[col_groups]; + int32x4_t acc_hi[col_groups]; + + // Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block + const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8)); + int16_t bsums_arr[8]; + vst1q_s16(bsums_arr, bsums); + + uint8x16_t qh[col_groups][8]; + for (int c = 0; c < col_groups; c++) { + for (int i = 0; i < 8; i++) { + qh[c][i] = vld1q_u8(q5_ptr[b].qh + i * 32 + 16 * c); + } + } + + for (int sb = 0; sb < QK_K / 64; sb++) { + for (int i = 0; i < col_groups; i++) { + acc_lo[i] = vdupq_n_s32(0); + acc_hi[i] = vdupq_n_s32(0); + } + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + int16x8_t q5sb_mins[2]; + int16x8_t q5sb_scales[2]; + for (int i = 0; i < 2; i++) { + int8_t aux_q5sb[8]; + const int offset = sb * 24 + i * 12; + decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb); + q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb)); + } + + int8x16_t q8_qs[4]; + for (int i = 0; i < 4; i++) { + q8_qs[i] = vld1q_s8(q8_ptr[b].qs + sb * 64 + i * 16); + } + + for (int c = 0; c < col_groups; c++) { + uint8x16_t q5_cols[8]; + uint8x16_t hbit_lo[8]; + uint8x16_t hbit_hi[8]; + int8x16_t q5_lo[8]; + int8x16_t q5_hi[8]; + + for (int i = 0; i < 8; i++) { + q5_cols[i] = vld1q_u8(q5_ptr[b].qs + sb * QK_K + i * 32 + 16 * c); + hbit_lo[i] = vandq_u8(qh[c][i], mone); + hbit_hi[i] = vshlq_n_u8(vandq_u8(qh[c][i], mtwo), 3); + qh[c][i] = vshrq_n_u8(qh[c][i], 2); + q5_lo[i] = vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_cols[i], m4b), hbit_lo[i], 4)); + q5_hi[i] = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_cols[i], 4), hbit_hi[i])); + } + + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[0], q8_qs[0], 0); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[1], q8_qs[0], 1); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[2], q8_qs[0], 2); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[3], q8_qs[0], 3); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[4], q8_qs[1], 0); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[5], q8_qs[1], 1); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[6], q8_qs[1], 2); + acc_lo[c] = vdotq_laneq_s32(acc_lo[c], q5_lo[7], q8_qs[1], 3); + + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[0], q8_qs[2], 0); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[1], q8_qs[2], 1); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[2], q8_qs[2], 2); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[3], q8_qs[2], 3); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[4], q8_qs[3], 0); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[5], q8_qs[3], 1); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[6], q8_qs[3], 2); + acc_hi[c] = vdotq_laneq_s32(acc_hi[c], q5_hi[7], q8_qs[3], 3); + } + + // Scales + // row c0123 blk0 and blk1 + const int16x4_t sc_0123_lo = vget_low_s16(q5sb_scales[0]); + const int16x4_t sc_0123_hi = vget_low_s16(q5sb_scales[1]); + const float32x4_t sumf_0123 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[0]), + vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[0]))); + acc_f32[0] = vfmaq_f32(acc_f32[0], sb_scale_0123, sumf_0123); + // row c4567 blk0 and blk1 + const int16x4_t sc_4567_lo = vget_high_s16(q5sb_scales[0]); + const int16x4_t sc_4567_hi = vget_high_s16(q5sb_scales[1]); + const float32x4_t sumf_4567 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[1]), + vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[1]))); + acc_f32[1] = vfmaq_f32(acc_f32[1], sb_scale_4567, sumf_4567); + + // Bias Correction + const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]); + const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]); + + bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q5sb_mins[0])); + bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q5sb_mins[1])); + bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q5sb_mins[0])); + bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q5sb_mins[1])); + } // for sb + + acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0123); + acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_4567); + } // for b + + int base = x * ncols_interleaved; + vst1q_f32(s + base, acc_f32[0]); + vst1q_f32(s + base + 4, acc_f32[1]); + } // for x + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemv_q5_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, @@ -3005,6 +3239,87 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert (n % qk == 0); + assert (nr % 4 == 0); + assert (nc % ncols_interleaved == 0); + + UNUSED(s); + UNUSED(bs); + UNUSED(vx); + UNUSED(vy); + UNUSED(nr); + UNUSED(nc); + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + const int8x16_t kvalues = vld1q_s8(kvalues_mxfp4); + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb); + + float32x4_t sumf[4]; + for (int m = 0; m < 4; m++) { + sumf[m] = vdupq_n_f32(0); + } + + for (int l = 0; l < nb; l++) { + float32x4_t a_d = vcvt_f32_f16(vld1_f16((const float16_t *)a_ptr[l].d)); + float32x4_t b_d = { + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[0]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[1]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[2]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[3]), + }; + + int32x4_t sumi_0 = vdupq_n_s32(0); + int32x4_t sumi_1 = vdupq_n_s32(0); + int32x4_t sumi_2 = vdupq_n_s32(0); + int32x4_t sumi_3 = vdupq_n_s32(0); + + for (int k = 0; k < 4; k++) { + int8x16_t a_0 = vld1q_s8(a_ptr[l].qs + 16 * k + 0); + int8x16_t a_1 = vld1q_s8(a_ptr[l].qs + 16 * k + 64); + + uint8x16_t b = vld1q_u8(b_ptr[l].qs + 16 * k); + int8x16_t b_hi = vqtbl1q_s8(kvalues, b >> 4); + int8x16_t b_lo = vqtbl1q_s8(kvalues, b & 0xF); + + sumi_0 = vdotq_laneq_s32(sumi_0, b_lo, a_0, 0); + sumi_1 = vdotq_laneq_s32(sumi_1, b_lo, a_0, 1); + sumi_2 = vdotq_laneq_s32(sumi_2, b_lo, a_0, 2); + sumi_3 = vdotq_laneq_s32(sumi_3, b_lo, a_0, 3); + sumi_0 = vdotq_laneq_s32(sumi_0, b_hi, a_1, 0); + sumi_1 = vdotq_laneq_s32(sumi_1, b_hi, a_1, 1); + sumi_2 = vdotq_laneq_s32(sumi_2, b_hi, a_1, 2); + sumi_3 = vdotq_laneq_s32(sumi_3, b_hi, a_1, 3); + } + + sumf[0] = vmlaq_f32(sumf[0], vmulq_laneq_f32(b_d, a_d, 0), vcvtq_f32_s32(sumi_0)); + sumf[1] = vmlaq_f32(sumf[1], vmulq_laneq_f32(b_d, a_d, 1), vcvtq_f32_s32(sumi_1)); + sumf[2] = vmlaq_f32(sumf[2], vmulq_laneq_f32(b_d, a_d, 2), vcvtq_f32_s32(sumi_2)); + sumf[3] = vmlaq_f32(sumf[3], vmulq_laneq_f32(b_d, a_d, 3), vcvtq_f32_s32(sumi_3)); + } + + for (int m = 0; m < 4; m++) { + vst1q_f32(s + (y * 4 + m) * bs + x * 4, sumf[m]); + } + } + } + return; +#endif // #if ! ((defined(_MSC_VER)) && ! defined(__clang__)) && defined(__aarch64__) && defined(__ARM_NEON) + ggml_gemm_mxfp4_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { constexpr int qk = QK_K; const int nb = n / qk; @@ -3205,6 +3520,235 @@ void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemm_q5_K_8x4_q8_K(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int qk = QK_K; + const int nb = n / qk; + + constexpr int ncols_interleaved = 8; + constexpr int blocklen = 4; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(nb); + UNUSED(ncols_interleaved); + UNUSED(blocklen); + +#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + constexpr int q8_k_blocklen = 4; + constexpr int acc_size = 2 * 4; // 2 row pairs, 4 col pairs + constexpr int col_groups = ncols_interleaved / 4; + const uint8x16_t m4b = vdupq_n_u8(0x0f); + const uint8x16_t mone = vdupq_n_u8(1); + const uint8x16_t mtwo = vdupq_n_u8(2); + + // 8 accumulators: 2 row pairs, 4 col pairs + float32x4_t acc_f32[acc_size]; + + for (int y = 0; y < nr / q8_k_blocklen; y++) { + const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx8 * GGML_RESTRICT q5_ptr = (const block_q5_Kx8 *) vx + (x * nb); + + for (int i = 0; i < acc_size; i++) { + acc_f32[i] = vdupq_n_f32(0); + } + + for (int b = 0; b < nb; b++) { + // d5 0 1 2 3, 4 5 6 7 + float32x4_t q5_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d)); + float32x4_t q5_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].d + 4)); + // d8 0 1 2 3 + float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d); + // mins + float32x4_t q5_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin)); + float32x4_t q5_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q5_ptr[b].dmin + 4)); + + // Precomputation of scales and mins + float32x4_t sbd_scale_0123[q8_k_blocklen]; + float32x4_t sbd_scale_4567[q8_k_blocklen]; + float32x4_t sbd_min_0123[q8_k_blocklen]; + float32x4_t sbd_min_4567[q8_k_blocklen]; + + sbd_scale_0123[0] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 0); + sbd_scale_4567[0] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 0); + sbd_min_0123[0] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 0); + sbd_min_4567[0] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 0); + + sbd_scale_0123[1] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 1); + sbd_scale_4567[1] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 1); + sbd_min_0123[1] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 1); + sbd_min_4567[1] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 1); + + sbd_scale_0123[2] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 2); + sbd_scale_4567[2] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 2); + sbd_min_0123[2] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 2); + sbd_min_4567[2] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 2); + + sbd_scale_0123[3] = vmulq_laneq_f32(q5_d_0123, q8_d_0123, 3); + sbd_scale_4567[3] = vmulq_laneq_f32(q5_d_4567, q8_d_0123, 3); + sbd_min_0123[3] = vmulq_laneq_f32(q5_dmin_0123, q8_d_0123, 3); + sbd_min_4567[3] = vmulq_laneq_f32(q5_dmin_4567, q8_d_0123, 3); + + // Precomputation of bsums, each vpaddq calcs all the bsums for each row + const int16x8_t bsums[q8_k_blocklen] = { + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)), + }; + int16_t bsums_arr[QK_K / 64][8]; + for (int q8_row = 0; q8_row < 4; q8_row++) { + vst1q_s16(bsums_arr[q8_row], bsums[q8_row]); + } + + // interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 .. + int32x4_t bias_acc[acc_size]; + for (int i = 0; i < acc_size; i++) { + bias_acc[i] = vdupq_n_s32(0); + } + + uint8x16_t qh[col_groups][8]; + for (int c = 0; c < col_groups; c++) { + for (int i = 0; i < 8; i++) { + qh[c][i] = vld1q_u8(q5_ptr[b].qh + i * 32 + 16 * c); + } + } + + for (int sb = 0; sb < QK_K / 64; sb++) { + // Int accumulators for qs vecdot (4 row * 2 col quartets) + int32x4_t acc_lo[acc_size]; + int32x4_t acc_hi[acc_size]; + for (int i = 0; i < acc_size; i++) { + acc_lo[i] = vdupq_n_s32(0); + acc_hi[i] = vdupq_n_s32(0); + } + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + int16x8_t q5sb_scales[2]; + int16x8_t q5sb_mins[2]; + for (int i = 0; i < 2; i++) { + int8_t aux_q5sb[8]; + const int offset = sb * 24 + i * 12; + decode_q_Kx8_6bit_scales(&q5_ptr[b].scales[offset], &q5sb_mins[i], aux_q5sb); + q5sb_scales[i] = vmovl_s8(vld1_s8(aux_q5sb)); + } + + constexpr int reads_per_sb = 8; // 8 * 16 bytes each => 32 qs * 4 rows + for (int k = 0; k < reads_per_sb; k++) { + const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k); + const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128); + + // 0..3 & 32..35 + const uint8x16_t q5_0123 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k); + const uint8x16_t q5_4567 = vld1q_u8(q5_ptr[b].qs + sb * QK_K + 32 * k + 16); + + // NOTE: This is the only difference with q4_K + const uint8x16_t hbit_lo_0123 = vandq_u8(qh[0][k], mone); + const uint8x16_t hbit_hi_0123 = vshlq_n_u8(vandq_u8(qh[0][k], mtwo), 3); + qh[0][k] = vshrq_n_u8(qh[0][k], 2); + const uint8x16_t hbit_lo_4567 = vandq_u8(qh[1][k], mone); + const uint8x16_t hbit_hi_4567 = vshlq_n_u8(vandq_u8(qh[1][k], mtwo), 3); + qh[1][k] = vshrq_n_u8(qh[1][k], 2); + // From here, same as q4_K + + const int8x16_t q5_0123_lo = + vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_0123, m4b), hbit_lo_0123, 4)); + const int8x16_t q5_0123_hi = + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_0123, 4), hbit_hi_0123)); + + acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q5_0123_lo, q8_blk0, 0); // 0..3 r0 c0123 + acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q5_0123_lo, q8_blk0, 1); // 0..3 r1 c0123 + acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q5_0123_lo, q8_blk0, 2); // 0..3 r2 c0123 + acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q5_0123_lo, q8_blk0, 3); // 0..3 r3 c0123 + + acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q5_0123_hi, q8_blk1, 0); // 32..35 r0 c0123 + acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q5_0123_hi, q8_blk1, 1); // 32..35 r1 c0123 + acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q5_0123_hi, q8_blk1, 2); // 32..35 r2 c0123 + acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q5_0123_hi, q8_blk1, 3); // 32..35 r3 c0123 + + const int8x16_t q5_4567_lo = + vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q5_4567, m4b), hbit_lo_4567, 4)); + const int8x16_t q5_4567_hi = + vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q5_4567, 4), hbit_hi_4567)); + + acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q5_4567_lo, q8_blk0, 0); // 0..3 r0 c4567 + acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q5_4567_lo, q8_blk0, 1); // 0..3 r1 c4567 + acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q5_4567_lo, q8_blk0, 2); // 0..3 r2 c4567 + acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q5_4567_lo, q8_blk0, 3); // 0..3 r3 c4567 + + acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q5_4567_hi, q8_blk1, 0); // 32..35 r0 c4567 + acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q5_4567_hi, q8_blk1, 1); // 32..35 r1 c4567 + acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q5_4567_hi, q8_blk1, 2); // 32..35 r2 c4567 + acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q5_4567_hi, q8_blk1, 3); // 32..35 r3 c4567 + } + + // Scale and bias application + // acc is stored interleaved to match output layout + const int16x4_t sc_0123_lo = vget_low_s16(q5sb_scales[0]); + const int16x4_t sc_4567_lo = vget_high_s16(q5sb_scales[0]); + const int16x4_t sc_0123_hi = vget_low_s16(q5sb_scales[1]); + const int16x4_t sc_4567_hi = vget_high_s16(q5sb_scales[1]); + for (int row = 0; row < q8_k_blocklen; row++) { + // Bias correction + // row c0123 blk0 and blk1 + const float32x4_t sumf_0123 = + vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]), + vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row]))); + acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123); + + // row c4567 blk0 and blk1 + const float32x4_t sumf_4567 = + vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]), + vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4]))); + acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567); + + // Bias + const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]); + const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]); + + // row c0123 blk0 and blk1 + bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q5sb_mins[0])); + bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q5sb_mins[1])); + + // row c4567 blk0 and blk1 + bias_acc[2 * row + 1] = + vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q5sb_mins[0])); + bias_acc[2 * row + 1] = + vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q5sb_mins[1])); + } + } // for sb + + for (int row = 0; row < q8_k_blocklen; row++) { + acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]); + acc_f32[2 * row + 1] = + vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]); + } + } // for b + + for (int i = 0; i < q8_k_blocklen; i++) { + int row = y * q8_k_blocklen + i; + for (int j = 0; j < 2; j++) { + int col = x * ncols_interleaved + j * 4; + int offset = row * bs + col; + vst1q_f32(s + offset, acc_f32[2 * i + j]); + } + } + } // for x + } // for y + return; +#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD) + ggml_gemm_q5_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, diff --git a/ggml/src/ggml-cpu/arch/x86/repack.cpp b/ggml/src/ggml-cpu/arch/x86/repack.cpp index 7dda9eea0c..bd6906c415 100644 --- a/ggml/src/ggml-cpu/arch/x86/repack.cpp +++ b/ggml/src/ggml-cpu/arch/x86/repack.cpp @@ -522,7 +522,8 @@ template static void gemv_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, __m256i signextendlut) { static_assert( std::is_same_v || - std::is_same_v, + std::is_same_v || + std::is_same_v, "Unsupported block type"); const int qk = QK8_0; @@ -580,6 +581,18 @@ static void gemv_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t std::is_same_v || std::is_same_v) { col_scale_f32 = GGML_F32Cx8_REARRANGE_LOAD(b_ptr[b].d, changemask); + } else if constexpr (std::is_same_v) { + // Load 8 E8M0 exponents and convert to float via LUT + // Rearranged to match changemask order: 0,4,1,5,2,6,3,7 + col_scale_f32 = _mm256_set_ps( + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[7]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[3]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[6]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[2]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[5]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[1]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[4]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[0])); } // Load and convert to FP32 scale from block_q8_0 @@ -628,7 +641,8 @@ template static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc, __m256i signextendlut) { static_assert( std::is_same_v || - std::is_same_v, + std::is_same_v || + std::is_same_v, "Unsupported block type"); const int qk = QK8_0; @@ -749,6 +763,25 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t std::is_same_v || std::is_same_v) { col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d); + } else if constexpr (std::is_same_v) { + //TODO: simd-ify + col_scale_f32 = _mm512_set_ps( + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[7]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[6]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[5]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[4]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[3]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[2]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[1]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[0]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[7]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[6]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[5]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[4]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[3]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[2]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[1]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[0])); } // Process LHS in pairs of rows @@ -941,6 +974,25 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t std::is_same_v || std::is_same_v) { col_scale_f32 = GGML_F32Cx8x2_LOAD(b_ptr_0[b].d, b_ptr_1[b].d); + } else if constexpr (std::is_same_v) { + //TODO: simd-ify + col_scale_f32 = _mm512_set_ps( + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[7]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[6]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[5]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[4]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[3]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[2]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[1]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_1[b].e[0]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[7]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[6]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[5]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[4]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[3]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[2]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[1]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr_0[b].e[0])); } // Load the four blocks of quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 @@ -1123,6 +1175,16 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t std::is_same_v || std::is_same_v) { col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d); + } else if constexpr (std::is_same_v) { + col_scale_f32 = _mm256_set_ps( + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[7]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[6]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[5]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[4]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[3]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[2]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[1]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[0])); } // Process LHS in groups of four @@ -1283,6 +1345,16 @@ static void gemm_q4_b32_8x8_q8_0_lut_avx(int n, float * GGML_RESTRICT s, size_t std::is_same_v || std::is_same_v) { col_scale_f32 = GGML_F32Cx8_LOAD(b_ptr[b].d); + } else if constexpr (std::is_same_v) { + col_scale_f32 = _mm256_set_ps( + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[7]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[6]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[5]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[4]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[3]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[2]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[1]), + GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[b].e[0])); } // Load the four blocks of quantized values interleaved with each other in chunks of eight - A0,A1,A2,A3 @@ -1625,6 +1697,19 @@ void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const ggml_gemv_iq4_nl_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); } +void ggml_gemv_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined(__AVX2__) + __m256i signextendlut = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i*)kvalues_mxfp4)); + signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0); + + gemv_q4_b32_8x8_q8_0_lut_avx(n, s, bs, vx, vy, nr, nc, signextendlut); + + return; +#endif + + ggml_gemv_mxfp4_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; @@ -3423,6 +3508,21 @@ void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const ggml_gemm_iq4_nl_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } +void ggml_gemm_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { +#if defined(__AVX2__) || defined(__AVX512F__) + { + __m256i signextendlut = _mm256_castsi128_si256(_mm_loadu_si128((const __m128i*)kvalues_mxfp4)); + signextendlut = _mm256_permute2f128_si256(signextendlut, signextendlut, 0); + + gemm_q4_b32_8x8_q8_0_lut_avx(n, s, bs, vx, vy, nr, nc, signextendlut); + + return; + } +#endif // defined(__AVX2__) || defined(__AVX512F__) + + ggml_gemm_mxfp4_8x8_q8_0_generic(n, s, bs, vx, vy, nr, nc); +} + void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { const int qk = QK_K; const int nb = n / qk; diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index f94426ddd7..5edba4212f 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -450,6 +450,208 @@ static void ggml_gemm_q6_K_NxM_q8_K_generic_impl(int n, } } +template +static void ggml_gemv_q5_K_NxM_q8_K_generic_impl(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int blocklen = M; + constexpr int ncols_interleaved = N; + const int qk = QK_K; + const int nb = n / qk; + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[ncols_interleaved]; + float sum_minf[ncols_interleaved]; + uint32_t utmp[32]; + int sumi1; + int sumi2; + int sumi; + + const block_q8_K * a_ptr = (const block_q8_K *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) { + sumf[j] = 0.0; + sum_minf[j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int sb = 0; sb < 8; sb++) { + memcpy(utmp + sb * 4, b_ptr[l].scales + sb * K_SCALE_SIZE, K_SCALE_SIZE); + utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); + const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; + utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); + utmp[sb * 4 + 2] = uaux_0; + utmp[sb * 4 + 0] &= kmask1; + } + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + constexpr int scale_stride = 32; + uint8_t * scales_0 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride; + uint8_t * scales_1 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride + 16; + + const int qh_shift = (k / (32 / blocklen)) * 2; + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i; + + const int qh_idx = (k * blocklen + i) % 32; + const int qh_chunk = qh_idx / blocklen; + const int qh_pos = qh_idx % blocklen; + const int b_qh_offset = qh_chunk * (blocklen * ncols_interleaved) + j * blocklen + qh_pos; + + const uint8_t qh_val = b_ptr[l].qh[b_qh_offset]; + const uint8_t h0 = (qh_val >> qh_shift) & 1; + const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1; + + const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4)); + const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4)); + + const int q8_offset = (k / (32 / blocklen)) * 64 + (k % (32 / blocklen)) * blocklen + i; + + sumi1 = (v0 * a_ptr[l].qs[q8_offset]); + sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + } + sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; + } + } + for (int sb = 0; sb < 8; sb++) { + uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; + for (int j = 0; j < ncols_interleaved; j++) { + sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * + GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; + } + } + } + for (int j = 0; j < ncols_interleaved; j++) { + s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; + } + } +} + +template +static void ggml_gemm_q5_K_NxM_q8_K_generic_impl(int n, + float * GGML_RESTRICT s, + size_t bs, + const void * GGML_RESTRICT vx, + const void * GGML_RESTRICT vy, + int nr, + int nc) { + constexpr int blocklen = M; + constexpr int ncols_interleaved = N; + const int qk = QK_K; + const int nb = n / qk; + static const uint32_t kmask1 = 0x3f3f3f3f; + static const uint32_t kmask2 = 0x0f0f0f0f; + static const uint32_t kmask3 = 0x03030303; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + float sumf[4][ncols_interleaved]; + float sum_minf[4][ncols_interleaved]; + uint32_t utmp[32]; + int sumi1; + int sumi2; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumf[m][j] = 0.0; + sum_minf[m][j] = 0.0; + } + } + for (int l = 0; l < nb; l++) { + for (int sb = 0; sb < 8; sb++) { + memcpy(utmp + sb * 4, b_ptr[l].scales + sb * K_SCALE_SIZE, K_SCALE_SIZE); + utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); + const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; + utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); + utmp[sb * 4 + 2] = uaux_0; + utmp[sb * 4 + 0] &= kmask1; + } + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + constexpr int scale_stride = 32; + uint8_t * scales_0 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride; + uint8_t * scales_1 = (uint8_t *) utmp + (k / (32 / blocklen)) * scale_stride + 16; + + const int qh_shift = (k / (32 / blocklen)) * 2; + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi1 = 0; + sumi2 = 0; + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i; + + const int qh_idx = (k * blocklen + i) % 32; + const int qh_chunk = qh_idx / blocklen; + const int qh_pos = qh_idx % blocklen; + const int b_qh_offset = + qh_chunk * (blocklen * ncols_interleaved) + j * blocklen + qh_pos; + + const uint8_t qh_val = b_ptr[l].qh[b_qh_offset]; + const uint8_t h0 = (qh_val >> qh_shift) & 1; + const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1; + + const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4)); + const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4)); + + const int q8_offset = (k / (32 / blocklen)) * 256 + + (k % (32 / blocklen)) * 4 * blocklen + m * blocklen + i; + + sumi1 = (v0 * a_ptr[l].qs[q8_offset]); + sumi2 = (v1 * a_ptr[l].qs[q8_offset + 128]); + sumi1 = sumi1 * scales_0[j]; + sumi2 = sumi2 * scales_1[j]; + sumi += sumi1 + sumi2; + } + sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; + } + } + } + for (int sb = 0; sb < 8; sb++) { + uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; + for (int m = 0; m < 4; m++) { + const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6); + for (int j = 0; j < ncols_interleaved; j++) { + sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * + GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; + } + } + } + } +} + extern "C" { void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -803,98 +1005,12 @@ void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, } } -void ggml_gemv_q5_K_8x8_q8_K_generic(int n, - float * GGML_RESTRICT s, - size_t bs, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - int nr, - int nc) { - const int qk = QK_K; - const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; - static const uint32_t kmask1 = 0x3f3f3f3f; - static const uint32_t kmask2 = 0x0f0f0f0f; - static const uint32_t kmask3 = 0x03030303; +void ggml_gemv_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q5_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc); +} - assert(n % qk == 0); - assert(nc % ncols_interleaved == 0); - - UNUSED(bs); - UNUSED(nr); - - float sumf[8]; - float sum_minf[8]; - uint32_t utmp[32]; - int sumi1; - int sumi2; - int sumi; - - const block_q8_K * a_ptr = (const block_q8_K *) vy; - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb); - - for (int j = 0; j < ncols_interleaved; j++) { - sumf[j] = 0.0; - sum_minf[j] = 0.0; - } - for (int l = 0; l < nb; l++) { - for (int sb = 0; sb < 8; sb++) { - memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12); - utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); - const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; - utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); - utmp[sb * 4 + 2] = uaux_0; - utmp[sb * 4 + 0] &= kmask1; - } - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32; - uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16; - - const int qh_shift = (k / 4) * 2; - for (int j = 0; j < ncols_interleaved; j++) { - sumi1 = 0; - sumi2 = 0; - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i; - - const int qh_idx = (k * 8 + i) % 32; - const int qh_chunk = qh_idx / 8; - const int qh_pos = qh_idx % 8; - const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos; - - const uint8_t qh_val = b_ptr[l].qh[b_qh_offset]; - const uint8_t h0 = (qh_val >> qh_shift) & 1; - const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1; - - const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4)); - const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4)); - - const int q8_offset = (k >> 2) * 64 + (k % 4) * blocklen + i; - - sumi1 = (v0 * a_ptr[l].qs[q8_offset]); - sumi2 = (v1 * a_ptr[l].qs[q8_offset + 32]); - sumi1 = sumi1 * scales_0[j]; - sumi2 = sumi2 * scales_1[j]; - sumi += sumi1 + sumi2; - } - sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d; - } - } - for (int sb = 0; sb < 8; sb++) { - uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; - for (int j = 0; j < ncols_interleaved; j++) { - sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * - GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d; - } - } - } - for (int j = 0; j < ncols_interleaved; j++) { - s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j]; - } - } +void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemv_q5_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc); } @@ -982,6 +1098,82 @@ void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs } } +void ggml_gemv_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[4]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); + } + sumf[j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} + +void ggml_gemv_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert(nr == 1); + assert(n % qk == 0); + assert(nc % ncols_interleaved == 0); + + UNUSED(bs); + UNUSED(nr); + + float sumf[8]; + int sumi; + + const block_q8_0 * a_ptr = (const block_q8_0 *) vy; + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_mxfp4x8 * b_ptr = (const block_mxfp4x8 *) vx + (x * nb); + + for (int j = 0; j < ncols_interleaved; j++) sumf[j] = 0.0; + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * blocklen + i]) + (v1 * a_ptr[l].qs[k * blocklen + i + qk / 2])); + } + sumf[j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d); + } + } + } + for (int j = 0; j < ncols_interleaved; j++) s[x * ncols_interleaved + j] = sumf[j]; + } +} + void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, @@ -1494,107 +1686,12 @@ void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, } } -void ggml_gemm_q5_K_8x8_q8_K_generic(int n, - float * GGML_RESTRICT s, - size_t bs, - const void * GGML_RESTRICT vx, - const void * GGML_RESTRICT vy, - int nr, - int nc) { - const int qk = QK_K; - const int nb = n / qk; - const int ncols_interleaved = 8; - const int blocklen = 8; +void ggml_gemm_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q5_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc); +} - constexpr uint32_t kmask1 = 0x3f3f3f3f; - constexpr uint32_t kmask2 = 0x0f0f0f0f; - constexpr uint32_t kmask3 = 0x03030303; - - assert(n % qk == 0); - assert(nr % 4 == 0); - assert(nc % ncols_interleaved == 0); - - float sumf[4][8]; - float sum_minf[4][8]; - uint32_t utmp[32]; - int sumi1; - int sumi2; - int sumi; - - for (int y = 0; y < nr / 4; y++) { - const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb); - for (int x = 0; x < nc / ncols_interleaved; x++) { - const block_q5_Kx8 * b_ptr = (const block_q5_Kx8 *) vx + (x * nb); - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumf[m][j] = 0.0; - sum_minf[m][j] = 0.0; - } - } - for (int l = 0; l < nb; l++) { - for (int sb = 0; sb < 8; sb++) { - memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12); - utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4); - const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1; - utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4); - utmp[sb * 4 + 2] = uaux_0; - utmp[sb * 4 + 0] &= kmask1; - } - for (int k = 0; k < (qk / (2 * blocklen)); k++) { - uint8_t * scales_0 = (uint8_t *) utmp + (k / 4) * 32; - uint8_t * scales_1 = (uint8_t *) utmp + (k / 4) * 32 + 16; - - const int qh_shift = (k / 4) * 2; - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - sumi1 = 0; - sumi2 = 0; - sumi = 0; - for (int i = 0; i < blocklen; ++i) { - const int b_qs_offset = k * ncols_interleaved * blocklen + j * blocklen + i; - - const int qh_idx = (k * 8 + i) % 32; - const int qh_chunk = qh_idx / 8; - const int qh_pos = qh_idx % 8; - const int b_qh_offset = qh_chunk * 64 + j * 8 + qh_pos; - - const uint8_t qh_val = b_ptr[l].qh[b_qh_offset]; - const uint8_t h0 = (qh_val >> qh_shift) & 1; - const uint8_t h1 = (qh_val >> (qh_shift + 1)) & 1; - - const int v0 = (int8_t) ((b_ptr[l].qs[b_qs_offset] & 0xF) | (h0 << 4)); - const int v1 = (int8_t) ((b_ptr[l].qs[b_qs_offset] >> 4) | (h1 << 4)); - - const int q8_offset = (k >> 2) * 256 + (k % 4) * 4 * blocklen + m * blocklen + i; - - sumi1 = (v0 * a_ptr[l].qs[q8_offset]); - sumi2 = (v1 * a_ptr[l].qs[q8_offset + 128]); - sumi1 = sumi1 * scales_0[j]; - sumi2 = sumi2 * scales_1[j]; - sumi += sumi1 + sumi2; - } - sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m]; - } - } - } - for (int sb = 0; sb < 8; sb++) { - uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16; - for (int m = 0; m < 4; m++) { - const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6); - for (int j = 0; j < ncols_interleaved; j++) { - sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * - GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m]; - } - } - } - } - for (int m = 0; m < 4; m++) { - for (int j = 0; j < ncols_interleaved; j++) { - s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j]; - } - } - } - } +void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + ggml_gemm_q5_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc); } void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { @@ -1705,6 +1802,94 @@ void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs } } +void ggml_gemm_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 4; + const int blocklen = 4; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + float sumf[4][4]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_mxfp4x4 * b_ptr = (const block_mxfp4x4 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])); + } + sumf[m][j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +} + +void ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) { + const int qk = QK8_0; + const int nb = n / qk; + const int ncols_interleaved = 8; + const int blocklen = 8; + + assert(n % qk == 0); + assert(nr % 4 == 0); + assert(nc % ncols_interleaved == 0); + + float sumf[4][8]; + int sumi; + + for (int y = 0; y < nr / 4; y++) { + const block_q8_0x4 * a_ptr = (const block_q8_0x4 *) vy + (y * nb); + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_mxfp4x8 * b_ptr = (const block_mxfp4x8 *) vx + (x * nb); + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) sumf[m][j] = 0.0; + } + for (int l = 0; l < nb; l++) { + for (int k = 0; k < (qk / (2 * blocklen)); k++) { + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) { + sumi = 0; + for (int i = 0; i < blocklen; ++i) { + const int v0 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0x0F]; + const int v1 = kvalues_mxfp4[b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4]; + sumi += ((v0 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i]) + + (v1 * a_ptr[l].qs[k * 4 * blocklen + m * blocklen + i + qk / 2 * 4])); + } + sumf[m][j] += sumi * GGML_CPU_E8M0_TO_FP32_HALF(b_ptr[l].e[j]) * GGML_CPU_FP16_TO_FP32(a_ptr[l].d[m]); + } + } + } + } + for (int m = 0; m < 4; m++) { + for (int j = 0; j < ncols_interleaved; j++) + s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j]; + } + } + } +} + void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, @@ -2029,18 +2214,16 @@ static block_q5_Kx8 make_block_q5_Kx8(block_q5_K * in, unsigned int blck_size_in const int end = QK_K * 4 / blck_size_interleave; - // Interleave Q5_K quants by taking 8 bytes at a time + // Interleave Q5_K quants by taking blck_size_interleave bytes at a time for (int i = 0; i < end; ++i) { int src_id = i % 8; int src_offset = (i / 8) * blck_size_interleave; int dst_offset = i * blck_size_interleave; - uint64_t elems; - memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); - memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); + memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], blck_size_interleave); } - // Repeat for low bits 8 bytes at a time as well, since + // Repeat for high bits with the same chunk size, since // the high bits are interleaved in Q5_K and the index is // qh_idx = (qs_idx % 32); // qh_val = qh[qh_idx] >> (qs_idx / 32); @@ -2049,9 +2232,7 @@ static block_q5_Kx8 make_block_q5_Kx8(block_q5_K * in, unsigned int blck_size_in int src_offset = (i / 8) * blck_size_interleave; int dst_offset = i * blck_size_interleave; - uint64_t elems; - memcpy(&elems, &in[src_id].qh[src_offset], sizeof(uint64_t)); - memcpy(&out.qh[dst_offset], &elems, sizeof(uint64_t)); + memcpy(&out.qh[dst_offset], &in[src_id].qh[src_offset], blck_size_interleave); } // The below logic is copied over from Q4_K @@ -2249,7 +2430,7 @@ static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t, const void * GGML_RESTRICT data, size_t data_size) { GGML_ASSERT(t->type == GGML_TYPE_Q5_K); - GGML_ASSERT(interleave_block == 8); + GGML_ASSERT(interleave_block == 4 || interleave_block == 8); constexpr int nrows_interleaved = 8; block_q5_Kx8 * dst = (block_q5_Kx8 *) t->data; @@ -2493,6 +2674,121 @@ static int repack_iq4_nl_to_iq4_nl_8_bl(struct ggml_tensor * t, int interleave_b GGML_UNUSED(data_size); } + +static block_mxfp4x4 make_block_mxfp4x4(block_mxfp4 * in, unsigned int blck_size_interleave) { + block_mxfp4x4 out; + + for (int i = 0; i < 4; i++) { + out.e[i] = in[i].e; + } + + const int end = QK_MXFP4 * 2 / blck_size_interleave; + + if (blck_size_interleave == 4) { + for (int i = 0; i < end; ++i) { + int src_id = i % 4; + int src_offset = (i / 4) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint32_t)); + } + } else { + GGML_ASSERT(false); + } + + return out; +} + +static int repack_mxfp4_to_mxfp4_4_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_MXFP4); + GGML_ASSERT(interleave_block == 4); + + const block_mxfp4 * src = (const block_mxfp4 *)data; + block_mxfp4x4 * dst = ( block_mxfp4x4 *)t->data; + + block_mxfp4 dst_tmp[4]; + + int nrow = ggml_nrows(t); + int nrows_interleaved = 4; + int nblocks = t->ne[0] / QK_MXFP4; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_mxfp4)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % 8 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_mxfp4x4(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static block_mxfp4x8 make_block_mxfp4x8(block_mxfp4 * in, unsigned int blck_size_interleave) { + block_mxfp4x8 out; + + for (int i = 0; i < 8; i++) { + out.e[i] = in[i].e; + } + + const int end = QK_MXFP4 * 4 / blck_size_interleave; + + if (blck_size_interleave == 8) { + for (int i = 0; i < end; ++i) { + int src_id = i % 8; + int src_offset = (i / 8) * blck_size_interleave; + int dst_offset = i * blck_size_interleave; + + memcpy(&out.qs[dst_offset], &in[src_id].qs[src_offset], sizeof(uint64_t)); + } + } else { + GGML_ASSERT(false); + } + + return out; +} + +static int repack_mxfp4_to_mxfp4_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_MXFP4); + GGML_ASSERT(interleave_block == 8); + + const block_mxfp4 * src = (const block_mxfp4 *)data; + block_mxfp4x8 * dst = ( block_mxfp4x8 *)t->data; + + block_mxfp4 dst_tmp[8]; + + int nrow = ggml_nrows(t); + int nrows_interleaved = 8; + int nblocks = t->ne[0] / QK_MXFP4; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_mxfp4)); + + if (t->ne[1] % nrows_interleaved != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_mxfp4x8(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + namespace ggml::cpu::repack { // repack template @@ -2523,6 +2819,10 @@ template <> int repack(struct ggml_tensor * t, const void * da return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size); } +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q5_K_to_q5_K_8_bl(t, 4, data, data_size); +} + template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_q5_K_to_q5_K_8_bl(t, 8, data, data_size); } @@ -2548,6 +2848,14 @@ template <> int repack(struct ggml_tensor * t, const void * return repack_iq4_nl_to_iq4_nl_8_bl(t, 8, data, data_size); } +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_mxfp4_to_mxfp4_4_bl(t, 4, data, data_size); +} + +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_mxfp4_to_mxfp4_8_bl(t, 8, data, data_size); +} + template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { return repack_q8_0_to_q8_0_4_bl(t, 4, data, data_size); } @@ -2591,6 +2899,10 @@ template <> void gemv(int n, float * s, size_t ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_q5_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); +} + template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } @@ -2611,6 +2923,14 @@ template <> void gemv(int n, float * s, size ggml_gemv_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_mxfp4_4x4_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemv_mxfp4_8x8_q8_0(n, s, bs, vx, vy, nr, nc); +} + template <> void gemv(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemv_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } @@ -2654,6 +2974,10 @@ template <> void gemm(int n, float * s, size_t ggml_gemm_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_q5_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc); +} + template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc); } @@ -2674,6 +2998,14 @@ template <> void gemm(int n, float * s, size ggml_gemm_iq4_nl_8x8_q8_0(n, s, bs, vx, vy, nr, nc); } +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_mxfp4_4x4_q8_0(n, s, bs, vx, vy, nr, nc); +} + +template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { + ggml_gemm_mxfp4_8x8_q8_0(n, s, bs, vx, vy, nr, nc); +} + template <> void gemm(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) { ggml_gemm_q8_0_4x4_q8_0(n, s, bs, vx, vy, nr, nc); } @@ -3068,6 +3400,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits q4_K_8x8_q8_K; // instance for Q5_K + static const ggml::cpu::repack::tensor_traits q5_K_8x4_q8_K; static const ggml::cpu::repack::tensor_traits q5_K_8x8_q8_K; // instance for Q6_K @@ -3081,6 +3414,10 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons static const ggml::cpu::repack::tensor_traits iq4_nl_4x4_q8_0; static const ggml::cpu::repack::tensor_traits iq4_nl_8x8_q8_0; + // instance for MXFP4 + static const ggml::cpu::repack::tensor_traits mxfp4_4x4_q8_0; + static const ggml::cpu::repack::tensor_traits mxfp4_8x8_q8_0; + // instance for Q8_0 static const ggml::cpu::repack::tensor_traits q8_0_4x4_q8_0; static const ggml::cpu::repack::tensor_traits q8_0_4x8_q8_0; @@ -3130,6 +3467,11 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &q5_K_8x8_q8_K; } } + if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { + if (cur->ne[1] % 8 == 0) { + return &q5_K_8x4_q8_K; + } + } } else if (cur->type == GGML_TYPE_Q6_K) { if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { if (cur->ne[1] % 8 == 0) { @@ -3152,6 +3494,17 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons return &iq4_nl_4x4_q8_0; } } + } else if (cur->type == GGML_TYPE_MXFP4) { + if (ggml_cpu_has_avx2()) { + if (cur->ne[1] % 8 == 0) { + return &mxfp4_8x8_q8_0; + } + } + if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) { + if (cur->ne[1] % 4 == 0) { + return &mxfp4_4x4_q8_0; + } + } } else if (cur->type == GGML_TYPE_Q8_0) { if (ggml_cpu_has_neon() && ggml_cpu_has_matmul_int8()) { if (cur->ne[1] % 4 == 0) { diff --git a/ggml/src/ggml-cpu/repack.h b/ggml/src/ggml-cpu/repack.h index 39b6b48238..b9f821630c 100644 --- a/ggml/src/ggml-cpu/repack.h +++ b/ggml/src/ggml-cpu/repack.h @@ -97,6 +97,19 @@ struct block_iq4_nlx8 { static_assert(sizeof(block_iq4_nlx8) == 8 * sizeof(ggml_half) + QK4_NL * 4, "wrong iq4_nlx8 block size/padding"); +struct block_mxfp4x4 { + uint8_t e[4]; + uint8_t qs[QK_MXFP4 * 2]; +}; +static_assert(sizeof(block_mxfp4x4) == 4 + QK_MXFP4 * 2, "wrong mxfp4x4 block size/padding"); + +struct block_mxfp4x8 { + uint8_t e[8]; + uint8_t qs[QK_MXFP4 * 4]; +}; +static_assert(sizeof(block_mxfp4x8) == 8 + QK_MXFP4 * 4, "wrong mxfp4x8 block size/padding"); + + #if defined(__cplusplus) extern "C" { #endif @@ -111,22 +124,28 @@ void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q5_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_mxfp4_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_mxfp4_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); @@ -143,22 +162,28 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemv_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_q5_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_mxfp4_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); +void ggml_gemm_mxfp4_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemv_q8_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); void ggml_gemm_q8_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc); diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index a3256d59dd..36d8a3aaab 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -1149,8 +1149,7 @@ struct ggml_cuda_graph { size_t num_nodes = 0; std::vector nodes; bool disable_due_to_gpu_arch = false; - bool disable_due_to_too_many_updates = false; - int number_consecutive_updates = 0; + bool warmup_complete = false; std::vector props; // these are extra tensors (inputs) that participate in the ggml graph but are not nodes @@ -1159,21 +1158,9 @@ struct ggml_cuda_graph { // ref: https://github.com/ggml-org/llama.cpp/pull/19165 std::vector extra; - void record_update(bool use_graph, bool update_required) { - if (use_graph && update_required) { - number_consecutive_updates++; - } else { - number_consecutive_updates = 0; - } - if (number_consecutive_updates >= 4) { - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__); - disable_due_to_too_many_updates = true; - } - } - bool is_enabled() const { static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); - return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env || disable_due_to_too_many_updates); + return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env); } #endif }; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index 09b6d5db6a..b70492c7d6 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -16,27 +16,27 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __ return; } - const int64_t i01 = blockIdx.y; + for (int64_t i01 = blockIdx.y; i01 < ne01; i01 += gridDim.y) { + for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) { + const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02); + const int64_t i02 = dm.y; + const int64_t i03 = dm.x; - for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) { - const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02); - const int64_t i02 = dm.y; - const int64_t i03 = dm.x; + const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01; - const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01; + const int64_t ib = ibx0 + i00/qk; // block index + const int64_t iqs = (i00%qk)/qr; // quant index + const int64_t iybs = i00 - i00%qk; // y block start index + const int64_t y_offset = qr == 1 ? 1 : qk/2; - const int64_t ib = ibx0 + i00/qk; // block index - const int64_t iqs = (i00%qk)/qr; // quant index - const int64_t iybs = i00 - i00%qk; // y block start index - const int64_t y_offset = qr == 1 ? 1 : qk/2; + // dequantize + float2 v; + dequantize_kernel(vx, ib, iqs, v); - // dequantize - float2 v; - dequantize_kernel(vx, ib, iqs, v); - - const int64_t iy0 = (i0203*ne01 + i01)*ne00 + iybs + iqs; - y[iy0 + 0] = ggml_cuda_cast(v.x); - y[iy0 + y_offset] = ggml_cuda_cast(v.y); + const int64_t iy0 = (i0203*ne01 + i01)*ne00 + iybs + iqs; + y[iy0 + 0] = ggml_cuda_cast(v.x); + y[iy0 + y_offset] = ggml_cuda_cast(v.y); + } } } @@ -492,7 +492,7 @@ static void dequantize_block_cuda(const void * vx, dst_t * y, const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) { const int64_t ne0203 = ne02*ne03; const uint3 ne02_fdv = init_fastdiv_values(ne02); - const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, (int)std::min(ne0203, (int64_t)65535)); + const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), (int)std::min(ne01, (int64_t)65535), (int)std::min(ne0203, (int64_t)65535)); dequantize_block<<>> (vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03); } @@ -628,18 +628,18 @@ static __global__ void convert_unary( return; } - const int64_t i01 = blockIdx.y; - const src_t * x = (const src_t *) vx; - for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) { - const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02); - const int64_t i02 = dm.y; - const int64_t i03 = dm.x; + for (int64_t i01 = blockIdx.y; i01 < ne01; i01 += gridDim.y) { + for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) { + const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02); + const int64_t i02 = dm.y; + const int64_t i03 = dm.x; - const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00; - const int64_t iy = (i0203*ne01 + i01)*ne00 + i00; - y[iy] = ggml_cuda_cast(x[ix]); + const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00; + const int64_t iy = (i0203*ne01 + i01)*ne00 + i00; + y[iy] = ggml_cuda_cast(x[ix]); + } } } @@ -649,7 +649,7 @@ static void convert_unary_cuda(const void * vx, dst_t * y, const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) { const int64_t ne0203 = ne02*ne03; const uint3 ne02_fdv = init_fastdiv_values(ne02); - const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, (int)std::min(ne0203, (int64_t)65535)); + const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, (int)std::min(ne01, (int64_t)65535), (int)std::min(ne0203, (int64_t)65535)); convert_unary<<>> (vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03); } diff --git a/ggml/src/ggml-cuda/fattn-mma-f16.cuh b/ggml/src/ggml-cuda/fattn-mma-f16.cuh index 0b8ef90794..beb7e32e4f 100644 --- a/ggml/src/ggml-cuda/fattn-mma-f16.cuh +++ b/ggml/src/ggml-cuda/fattn-mma-f16.cuh @@ -111,6 +111,44 @@ static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_co return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); } +static constexpr __host__ __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config_cdna(const int DKQ, const int DV, const int ncols) { + // Conservative configs for CDNA (MI100+): 64KB LDS, wavefront64, nstages=1 (no cp.async). + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 8, 128, 2, 128, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 16, 128, 2, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 32, 128, 2, 64, 32, 32, 32, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 32, 32, 32, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 8, 128, 2, 128, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 16, 128, 2, 64, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 32, 128, 2, 64, 40, 40, 40, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 80, 80, 64, 256, 2, 64, 40, 40, 40, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 8, 128, 2, 128, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 16, 128, 2, 64, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 32, 128, 2, 64, 48, 48, 48, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE( 96, 96, 64, 256, 2, 64, 48, 48, 48, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 8, 128, 2, 128, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 16, 128, 2, 64, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 32, 128, 2, 64, 56, 56, 56, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(112, 112, 64, 256, 2, 64, 56, 56, 56, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 8, 128, 2, 128, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 16, 128, 2, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 32, 128, 2, 64, 64, 64, 64, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(128, 128, 64, 256, 2, 64, 64, 64, 64, 1, true); + + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 8, 64, 4, 64, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 16, 64, 4, 32, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 32, 128, 2, 32, 128, 128, 128, 1, true); + GGML_CUDA_FATTN_MMA_CONFIG_CASE(256, 256, 64, 256, 2, 32, 128, 128, 128, 1, true); + + // Fallback for unsupported DKQ values (e.g. 576). Must return non-zero values to satisfy + // compile-time static_asserts even though the kernel guard prevents runtime execution. + // nthreads=256 gives nwarps=4 (warp_size=64) or 8 (warp_size=32), nbatch_fa=128 satisfies np*16 divisibility. + return fattn_mma_config(256, 1, 128, 4, 4, 4, 1, false); +} + static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, const int DV, const int ncols, const int cc) { if (ampere_mma_available(cc)) { return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); @@ -118,6 +156,9 @@ static __host__ fattn_mma_config ggml_cuda_fattn_mma_get_config(const int DKQ, c if (turing_mma_available(cc)) { return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols); } + if (amd_mfma_available(cc)) { + return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols); + } if (amd_wmma_available(cc)) { return ggml_cuda_fattn_mma_get_config_rdna(DKQ, DV, ncols); } @@ -130,6 +171,8 @@ static constexpr __device__ fattn_mma_config ggml_cuda_fattn_mma_get_config(cons return ggml_cuda_fattn_mma_get_config_ampere(DKQ, DV, ncols); #elif defined(TURING_MMA_AVAILABLE) return ggml_cuda_fattn_mma_get_config_turing(DKQ, DV, ncols); +#elif defined(AMD_MFMA_AVAILABLE) + return ggml_cuda_fattn_mma_get_config_cdna(DKQ, DV, ncols); #elif defined(VOLTA_MMA_AVAILABLE) return ggml_cuda_fattn_mma_get_config_volta(DKQ, DV, ncols); #elif defined(AMD_WMMA_AVAILABLE) @@ -205,15 +248,15 @@ static constexpr __device__ bool ggml_cuda_fattn_mma_get_Q_in_reg(const int DKQ, } static constexpr __device__ int get_cols_per_thread() { -#if defined(AMD_WMMA_AVAILABLE) - return 1; // RDNA has a single column. +#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + return 1; // AMD has a single column per thread. #else return 2; // This is specifically KQ columns, Volta only has a single VKQ column. -#endif // defined(AMD_WMMA_AVAILABLE) +#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } static __host__ int get_cols_per_warp(const int cc) { - if (turing_mma_available(cc) || amd_wmma_available(cc)) { + if (turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc)) { return 16; } else { // Volta @@ -241,6 +284,7 @@ static constexpr __device__ int ggml_cuda_fattn_mma_get_nstages(const int DKQ, c template static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int D2, const int stride_KV, const int i_sup) { + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); // K/V data is loaded with decreasing granularity for D for better memory bandwidth. // The minimum granularity with cp.async is 16 bytes, with synchronous data loading it's 4 bytes. if constexpr (use_cp_async) { @@ -252,10 +296,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( const unsigned int tile_KV_32 = ggml_cuda_cvta_generic_to_shared(tile_KV); auto load = [&] __device__ (auto n) { - const int stride_k = WARP_SIZE >> n; - const int k0_start = stride_k == WARP_SIZE ? 0 : chunks_per_row - chunks_per_row % (2*stride_k); + const int stride_k = warp_size >> n; + const int k0_start = stride_k == warp_size ? 0 : chunks_per_row - chunks_per_row % (2*stride_k); const int k0_stop = chunks_per_row - chunks_per_row % (1*stride_k); - const int stride_i = WARP_SIZE / stride_k; + const int stride_i = warp_size / stride_k; if (k0_start == k0_stop) { return; @@ -263,7 +307,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( #pragma unroll for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { - const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k); if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { break; @@ -271,7 +315,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k); cp_async_cg_16(tile_KV_32 + i*(stride_tile*sizeof(half2)) + k*16, KV + i*stride_KV + k*h2_per_chunk); } @@ -287,10 +331,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( } else { // TODO use ggml_cuda_memcpy_1 auto load = [&] __device__ (const int n) { - const int stride_k = WARP_SIZE >> n; - const int k0_start = stride_k == WARP_SIZE ? 0 : D2 - D2 % (2*stride_k); + const int stride_k = warp_size >> n; + const int k0_start = stride_k == warp_size ? 0 : D2 - D2 % (2*stride_k); const int k0_stop = D2 - D2 % (1*stride_k); - const int stride_i = WARP_SIZE / stride_k; + const int stride_i = warp_size / stride_k; if (k0_start == k0_stop) { return; @@ -298,7 +342,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( #pragma unroll for (int i0 = 0; i0 < nbatch_fa; i0 += nwarps*stride_i) { - const int i = i0 + threadIdx.y*stride_i + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + const int i = i0 + threadIdx.y*stride_i + (stride_k == warp_size ? 0 : threadIdx.x / stride_k); if (i0 + nwarps*stride_i > nbatch_fa && i >= nbatch_fa) { break; @@ -306,7 +350,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_tile( #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k); tile_KV[i*stride_tile + k] = !oob_check || i < i_sup ? KV[i*stride_KV + k] : make_half2(0.0f, 0.0f); } @@ -324,18 +368,19 @@ template= 32 ? nbatch_fa * sizeof(half) : 64; - constexpr int cols_per_warp = 8*WARP_SIZE/nbatch_fa; + constexpr int cols_per_warp = 8*warp_size/nbatch_fa; constexpr int stride_j = nwarps * cols_per_warp; const unsigned int tile_mask_32 = ggml_cuda_cvta_generic_to_shared(tile_mask); #pragma unroll for (int j1 = 0; j1 < ncols1; j1 += stride_j) { - const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp); + const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp); const int j_vram = fastmodulo(j0 + j_sram, ne01); if (j1 + stride_j > ncols1 && j_sram >= ncols1) { @@ -357,25 +402,25 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( } #pragma unroll - for (int i0 = 0; i0 < nbatch_fa; i0 += WARP_SIZE) { + for (int i0 = 0; i0 < nbatch_fa; i0 += warp_size) { const int i = i0 + threadIdx.x; tile_mask[j_sram*(nbatch_fa + 8) + i] = i < i_sup ? mask_h[j_vram*stride_mask + i] : half(0.0f); } } - } else if constexpr (nbatch_fa < 2*WARP_SIZE) { - constexpr int cols_per_warp = 2*WARP_SIZE/nbatch_fa; + } else if constexpr (nbatch_fa < 2*warp_size) { + constexpr int cols_per_warp = 2*warp_size/nbatch_fa; constexpr int stride_j = nwarps * cols_per_warp; #pragma unroll for (int j1 = 0; j1 < ncols1; j1 += stride_j) { - const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (WARP_SIZE/cols_per_warp); + const int j_sram = j1 + threadIdx.y*cols_per_warp + threadIdx.x / (warp_size/cols_per_warp); const int j_vram = fastmodulo(j0 + j_sram, ne01); if (j1 + stride_j > ncols1 && j_sram >= ncols1) { break; } - const int i = threadIdx.x % (WARP_SIZE/cols_per_warp); + const int i = threadIdx.x % (warp_size/cols_per_warp); ggml_cuda_memcpy_1(tile_mask + j_sram*(nbatch_fa + 8) + 2*i, mask_h + j_vram*stride_mask + 2*i); } @@ -390,7 +435,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_load_mask( } #pragma unroll - for (int i0 = 0; i0 < nbatch_fa; i0 += 2*WARP_SIZE) { + for (int i0 = 0; i0 < nbatch_fa; i0 += 2*warp_size) { const int i = i0 + 2*threadIdx.x; ggml_cuda_memcpy_1(tile_mask + j_sram*(nbatch_fa + 8) + i, mask_h + j_vram*stride_mask + i); @@ -428,7 +473,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int jt, const int kb0, const int k_VKQ_sup) { -#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE) + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int ncols = ncols1 * ncols2; constexpr int cols_per_warp = T_B_KQ::I; constexpr int cols_per_thread = get_cols_per_thread(); @@ -447,7 +493,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( const int k_VKQ_0 = kb0 * nbatch_fa; #if defined(TURING_MMA_AVAILABLE) T_C_KQ KQ_C[nbatch_fa/(np*(cols_per_warp == 8 ? T_C_KQ::I : T_C_KQ::J))]; -#elif defined(AMD_WMMA_AVAILABLE) +#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)]; #else // Volta T_C_KQ KQ_C[nbatch_fa/(np*T_C_KQ::J)]; @@ -500,13 +546,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]); } else { // Wide version of KQ_C is column-major -#if defined(AMD_WMMA_AVAILABLE) - // RDNA matrix C is column-major. +#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + // AMD matrix C is column-major. mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[k_KQ_0/T_A_KQ::J]); #else // swap A and B for CUDA. mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[k_KQ_0/T_A_KQ::J], K_A); -#endif // defined(AMD_WMMA_AVAILABLE) +#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } } } @@ -526,13 +572,13 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]); } else { // Wide version of KQ_C is column-major -#if defined(AMD_WMMA_AVAILABLE) - // RDNA matrix C is column-major. +#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + // AMD matrix C is column-major. mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], K_A, Q_B[0]); #else // swap A and B for CUDA. mma(KQ_C[i_KQ_00/(np*T_A_KQ::I)], Q_B[0], K_A); -#endif // defined(AMD_WMMA_AVAILABLE) +#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } } } @@ -585,12 +631,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #pragma unroll for (int l = 0; l < T_C_KQ::ne; ++l) { if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) { -#if defined(AMD_WMMA_AVAILABLE) +#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) constexpr int KQ_idx = 0; #else // Turing + Volta: const int KQ_idx = l % 2; -#endif // defined(AMD_WMMA_AVAILABLE) +#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET); } } @@ -601,7 +647,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( for (int col = 0; col < cols_per_thread; ++col) { #pragma unroll for (int offset = 16; offset >= 4; offset >>= 1) { - KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); + KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size)); } } @@ -611,12 +657,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #pragma unroll for (int l = 0; l < T_C_KQ::ne; ++l) { if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) { -#if defined(AMD_WMMA_AVAILABLE) +#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) constexpr int KQ_idx = 0; #else // Turing + Volta: const int KQ_idx = l % 2; -#endif // defined(AMD_WMMA_AVAILABLE) +#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) KQ_C[k0/(np*T_C_KQ::I)].x[l] = expf(KQ_C[k0/(np*T_C_KQ::I)].x[l] - KQ_max_new[KQ_idx]); KQ_rowsum_add[KQ_idx] += KQ_C[k0/(np*T_C_KQ::I)].x[l]; } else { @@ -649,12 +695,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #pragma unroll for (int l = 0; l < T_C_KQ::ne; ++l) { if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) { -#if defined(AMD_WMMA_AVAILABLE) +#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) constexpr int KQ_idx = 0; #else // Turing + Volta: const int KQ_idx = (l/2) % 2; -#endif // defined(AMD_WMMA_AVAILABLE) +#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) KQ_max_new[KQ_idx] = fmaxf(KQ_max_new[KQ_idx], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET); } } @@ -666,6 +712,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( // Values per KQ column are spread across 4 threads: constexpr int offset_first = 2; constexpr int offset_last = 1; +#elif defined(AMD_MFMA_AVAILABLE) + // MFMA: 4 threads per Q column (threadIdx.x % 16 == col, spaced by 16). + constexpr int offset_first = 32; + constexpr int offset_last = 16; #elif defined(AMD_WMMA_AVAILABLE) // Values per KQ column are spread across 2 threads: constexpr int offset_first = 16; @@ -677,7 +727,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #endif // defined(TURING_MMA_AVAILABLE) #pragma unroll for (int offset = offset_first; offset >= offset_last; offset >>= 1) { - KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, WARP_SIZE)); + KQ_max_new[col] = fmaxf(KQ_max_new[col], __shfl_xor_sync(0xFFFFFFFF, KQ_max_new[col], offset, warp_size)); } } @@ -687,12 +737,12 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( #pragma unroll for (int l = 0; l < T_C_KQ::ne; ++l) { if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) { -#if defined(AMD_WMMA_AVAILABLE) +#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) constexpr int KQ_idx = 0; #else // Turing + Volta: const int KQ_idx = (l/2) % 2; -#endif // defined(AMD_WMMA_AVAILABLE) +#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) KQ_C[(k0/(np*T_C_KQ::J))].x[l] = expf(KQ_C[(k0/(np*T_C_KQ::J))].x[l] - KQ_max_new[KQ_idx]); KQ_rowsum_add[KQ_idx] += KQ_C[(k0/(np*T_C_KQ::J))].x[l]; } else { @@ -739,7 +789,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } } } -#elif defined(AMD_WMMA_AVAILABLE) +#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) const half2 KQ_max_scale_h2 = make_half2( KQ_max_scale[0], KQ_max_scale[0]); #pragma unroll @@ -818,7 +868,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( } const half2 * tile_V_i = !V_is_K_view || i0_stop > 2*nbatch_K2 ? tile_V : tile_V + i0_start/2; -#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) +#if defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) constexpr int i0_stride = cols_per_warp == 8 ? T_C_VKQ::I : 2*T_C_VKQ::J; #pragma unroll for (int i_VKQ_0 = i0_start; i_VKQ_0 < i0_stop; i_VKQ_0 += i0_stride) { @@ -830,24 +880,38 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( T_A_VKQ A; // Transposed in SRAM but not in registers, gets transposed on load. #if defined(LDMATRIX_TRANS_AVAILABLE) load_ldmatrix_trans(A, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); +#elif defined(AMD_MFMA_AVAILABLE) + // MFMA A register layout: A_mat[i=lane%16][k=4*(lane/16)+reg]. + // Normal load gives A_mat[seq][dv] but we need A_mat[dv][seq] = V^T. + // Load with transposed addressing: 4 strided half loads. + { + const half2 * xs0 = tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2; + const half * xs0_h = (const half *) xs0; + const int stride_h = stride_tile_V * 2; // stride in half units + half * A_h = (half *) A.x; +#pragma unroll + for (int l = 0; l < 4; ++l) { + A_h[l] = xs0_h[(4*(threadIdx.x / 16) + l) * stride_h + threadIdx.x % 16]; + } + } #else // TODO: Try to transpose tile_V when loading gmem to smem. // Use mma to transpose T_A_VKQ for RDNA. T_A_VKQ A_trans; load_ldmatrix(A_trans, tile_V_i + 2*k0*stride_tile_V + (i_VKQ_0 - i0_start)/2, stride_tile_V); mma(A, A_trans, A_identity); -#endif // defined(TURING_MMA_AVAILABLE) +#endif // defined(LDMATRIX_TRANS_AVAILABLE) if constexpr (T_B_KQ::I == 8) { mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]); } else { // Wide version of VKQ_C is column-major. -#if defined(AMD_WMMA_AVAILABLE) - // RDNA matrix C is column-major. +#if defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) + // AMD matrix C is column-major. mma(VKQ_C[i_VKQ_0/i0_stride], A, B[k00/(np*T_A_VKQ::J)]); #else // swap A and B for CUDA. mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::J)], A); -#endif // defined(AMD_WMMA_AVAILABLE) +#endif // defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) } } } @@ -866,7 +930,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( mma(VKQ_C[i_VKQ_0/i0_stride], B[k00/(np*T_A_VKQ::I)], A); } } -#endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) +#endif // defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) if constexpr (nstages <= 1) { __syncthreads(); // Only needed if tile_K == tile_V. @@ -879,7 +943,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter( tile_Q, tile_K, tile_V, tile_mask, Q_B, VKQ_C, KQ_max, KQ_rowsum, kb0); NO_DEVICE_CODE; -#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE) } #if defined(TURING_MMA_AVAILABLE) @@ -899,7 +963,7 @@ template<> struct mma_tile_sizes<8> { using T_B_VKQ = tile< 8, 8, half2>; // column-major using T_C_VKQ = tile<16, 4, half2>; // row-major }; -#elif defined(AMD_WMMA_AVAILABLE) +#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) template struct mma_tile_sizes { using T_A_KQ = tile<16, 8, half2>; // row-major using T_B_KQ = tile<16, 8, half2>; // column-major @@ -944,9 +1008,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int zt_gqa, const int kb0_start, const int kb0_stop) { -#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) +#if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE) //In this kernel Q, K, V are matrices while i, j, k are matrix indices. + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int ncols = ncols1 * ncols2; using T_A_KQ = typename mma_tile_sizes::T_A_KQ; using T_B_KQ = typename mma_tile_sizes::T_B_KQ; @@ -986,7 +1051,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( T_B_KQ Q_B[(Q_in_reg ? DKQ/(2*T_B_KQ::J) : 1)]; #if defined(TURING_MMA_AVAILABLE) T_C_VKQ VKQ_C[cols_per_warp == 8 ? DV/T_C_VKQ::I : DV/(2*T_C_VKQ::J)]; -#elif defined(AMD_WMMA_AVAILABLE) +#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)]; #else // Volta T_C_VKQ VKQ_C[ DV/(2*T_C_VKQ::J)]; @@ -1004,10 +1069,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // The loading is done with decreasing granularity for D for better memory bandwidth. const half2 scale_h2 = make_half2(scale, scale); #pragma unroll - for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { - const int k0_start = stride_k == WARP_SIZE ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k); + for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) { + const int k0_start = stride_k == warp_size ? 0 : DKQ/2 - (DKQ/2) % (2*stride_k); const int k0_stop = DKQ/2 - (DKQ/2) % (1*stride_k); - const int stride_jc = WARP_SIZE / stride_k; + const int stride_jc = warp_size / stride_k; if (k0_start == k0_stop) { continue; @@ -1015,7 +1080,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #pragma unroll for (int jc0 = 0; jc0 < ncols; jc0 += nwarps*stride_jc) { - const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + const int jc = jc0 + threadIdx.y*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k); if (jc0 + nwarps*stride_jc > ncols && jc >= ncols) { break; @@ -1027,7 +1092,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( if ((ncols1 == 1 || jt*ncols1 + j < int(ne01.z)) && (ncols2 == 1 || zt_gqa*ncols2 + c < gqa_ratio)) { #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k); const float2 tmp = Q_f2[(jt*ncols1 + j)*stride_Q1 + c*stride_Q2 + k]; tile_Q[jc*stride_tile_Q + k] = scale_h2 * make_half2(tmp.x, tmp.y); @@ -1035,7 +1100,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } else { #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k); tile_Q[jc*stride_tile_Q + k] = make_half2(0.0f, 0.0f); } @@ -1127,6 +1192,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // The partial sums are spread across 8/4 threads. constexpr int offset_first = cols_per_warp == 8 ? 16 : 2; constexpr int offset_last = cols_per_warp == 8 ? 4 : 1; +#elif defined(AMD_MFMA_AVAILABLE) + // The partial sums are spread across 4 threads (wavefront64, 16 cols). + constexpr int offset_first = 32; + constexpr int offset_last = 16; #elif defined(AMD_WMMA_AVAILABLE) // The partial sums are spread across 2 threads. constexpr int offset_first = 16; @@ -1140,7 +1209,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( for (int col = 0; col < cols_per_thread; ++col) { #pragma unroll for (int offset = offset_first; offset >= offset_last; offset >>= 1) { - KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, WARP_SIZE); + KQ_rowsum[col] += __shfl_xor_sync(0xFFFFFFFF, KQ_rowsum[col], offset, warp_size); } } } @@ -1189,7 +1258,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } } } -#elif defined(AMD_WMMA_AVAILABLE) +#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale[0], KQ_max_scale[0]); #pragma unroll for (int i = 0; i < (DV/2)/T_C_VKQ::J; ++i) { @@ -1249,7 +1318,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(threadIdx.x % 4); const float2 KQ_cmr = make_float2(KQ_max[threadIdx.x % cols_per_thread], KQ_rowsum[threadIdx.x % cols_per_thread]); const bool thread_should_write = threadIdx.x % 4 < cols_per_thread; -#elif defined(AMD_WMMA_AVAILABLE) +#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) const int jc_cwm = threadIdx.y*cols_per_warp + T_C_VKQ::get_i(0); const float2 KQ_cmr = make_float2(KQ_max[0], KQ_rowsum[0]); const bool thread_should_write = threadIdx.x / 16 < cols_per_thread; @@ -1283,14 +1352,14 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // Warps with threadIdx.y % np != 0 must NOT return early. // All threads must return simultaneously to avoid race conditions with work on the next tile. - constexpr int nmeta = np*cols_per_warp >= WARP_SIZE ? np*cols_per_warp/WARP_SIZE : 1; + constexpr int nmeta = np*cols_per_warp >= warp_size ? np*cols_per_warp/warp_size : 1; - const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < WARP_SIZE ? threadIdx.x % (np*cols_per_warp) : threadIdx.x); + const int jc_meta = threadIdx.y*cols_per_warp + (np*cols_per_warp < warp_size ? threadIdx.x % (np*cols_per_warp) : threadIdx.x); float2 * const meta_ptr = ((float2 *) tile_Q) + jc_meta*(tile_stride/2) + nbatch_combine/2; float2 meta[nmeta]; #pragma unroll for (int imeta = 0; imeta < nmeta; ++imeta) { - meta[imeta] = meta_ptr[imeta * WARP_SIZE * tile_stride/2]; + meta[imeta] = meta_ptr[imeta * warp_size * tile_stride/2]; } float KQ_cmn = meta[0].x; // KQ combine max new, max between all parallel warps. @@ -1300,8 +1369,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } #pragma unroll for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { - if (offset < WARP_SIZE) { - KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, WARP_SIZE)); + if (offset < warp_size) { + KQ_cmn = fmaxf(KQ_cmn, __shfl_xor_sync(0xFFFFFFFF, KQ_cmn, offset, warp_size)); } } @@ -1318,8 +1387,8 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( } #pragma unroll for (int offset = np*cols_per_warp/2; offset >= cols_per_warp; offset >>= 1) { - if (offset < WARP_SIZE) { - KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, WARP_SIZE); + if (offset < warp_size) { + KQ_crs += __shfl_xor_sync(0xFFFFFFFF, KQ_crs, offset, warp_size); } } @@ -1328,19 +1397,19 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( // Write back combined meta data: #pragma unroll for (int imeta = 0; imeta < nmeta; ++imeta) { - if (np*cols_per_warp >= WARP_SIZE || threadIdx.x < np*cols_per_warp) { + if (np*cols_per_warp >= warp_size || threadIdx.x < np*cols_per_warp) { // Combined KQ max scale + rowsum. - meta_ptr[imeta * WARP_SIZE * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs); + meta_ptr[imeta * warp_size * tile_stride/2] = make_float2(KQ_cms[imeta], KQ_crs); } } // Combined KQ max + rowsum. - static_assert(cols_per_warp <= WARP_SIZE); - if (needs_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) { + static_assert(cols_per_warp <= warp_size); + if (needs_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) { float2 * dstk_fixup_meta = dstk_fixup + blockIdx.x*ncols; dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); } - if (is_fixup && (cols_per_warp == WARP_SIZE || threadIdx.x < cols_per_warp)) { + if (is_fixup && (cols_per_warp == warp_size || threadIdx.x < cols_per_warp)) { float2 * dstk_fixup_meta = dstk_fixup + (gridDim.x + blockIdx.x)*ncols; dstk_fixup_meta[(threadIdx.y/np)*cols_per_warp + threadIdx.x] = make_float2(KQ_cmn, KQ_crs); } @@ -1388,10 +1457,10 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( float2 * dstk_fixup_data = dstk_fixup + gridDim.x*(2*ncols) + blockIdx.x*(ncols*(DV/2)); #pragma unroll - for (int stride_k : {WARP_SIZE, WARP_SIZE/2, WARP_SIZE/4}) { - const int k0_start = stride_k == WARP_SIZE ? 0 : nbatch_combine - nbatch_combine % (2*stride_k); + for (int stride_k : {warp_size, warp_size/2, warp_size/4, warp_size/8}) { + const int k0_start = stride_k == warp_size ? 0 : nbatch_combine - nbatch_combine % (2*stride_k); const int k0_stop = nbatch_combine - nbatch_combine % (1*stride_k); - const int stride_jc = WARP_SIZE / stride_k; + const int stride_jc = warp_size / stride_k; if (k0_start == k0_stop) { continue; @@ -1399,7 +1468,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( #pragma unroll for (int jc0_dst = 0; jc0_dst < ncols; jc0_dst += (nwarps/np)*stride_jc) { - const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == WARP_SIZE ? 0 : threadIdx.x / stride_k); + const int jc_dst = jc0_dst + (threadIdx.y/np)*stride_jc + (stride_k == warp_size ? 0 : threadIdx.x / stride_k); if (jc0_dst + (nwarps/np)*stride_jc > ncols && jc_dst >= ncols) { break; @@ -1417,7 +1486,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( const float * meta_j = (const float *) tile_Q + jc_tile_K*tile_stride + nbatch_combine; #pragma unroll for (int k0 = k0_start; k0 < k0_stop; k0 += stride_k) { - const int k = k0 + (stride_k == WARP_SIZE ? threadIdx.x : threadIdx.x % stride_k); + const int k = k0 + (stride_k == warp_size ? threadIdx.x : threadIdx.x % stride_k); float2 dstk_val = make_float2(0.0f, 0.0f); #pragma unroll @@ -1453,7 +1522,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_process_tile( stride_Q1, stride_Q2, stride_K, stride_V, stride_mask, jt, kb0_start, kb0_stop); NO_DEVICE_CODE; -#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) +#endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE) } template @@ -1480,7 +1549,7 @@ static __global__ void flash_attn_ext_f16( const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))) +#if defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(DKQ == 128 || DKQ == 256)) { @@ -1508,10 +1577,18 @@ static __global__ void flash_attn_ext_f16( } #endif // defined(AMD_WMMA_AVAILABLE) +#if defined(AMD_MFMA_AVAILABLE) + if (DKQ != 64 && DKQ != 80 && DKQ != 96 && DKQ != 112 && DKQ != 128) { + NO_DEVICE_CODE; + return; + } +#endif // defined(AMD_MFMA_AVAILABLE) + + constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int ncols = ncols1 * ncols2; constexpr int nbatch_fa = ggml_cuda_fattn_mma_get_nbatch_fa(DKQ, DV, ncols); constexpr int nthreads = ggml_cuda_fattn_mma_get_nthreads(DKQ, DV, ncols); - constexpr int nwarps = nthreads / WARP_SIZE; + constexpr int nwarps = nthreads / warp_size; const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. @@ -1624,7 +1701,7 @@ static __global__ void flash_attn_ext_f16( ne31, ne32, ne33, nb31, nb32, nb33); NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4))) +#endif // defined(FLASH_ATTN_AVAILABLE) && (defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || (defined(AMD_WMMA_AVAILABLE) && defined(RDNA4)) || defined(AMD_MFMA_AVAILABLE)) } template @@ -1644,7 +1721,8 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml const int nstages = ggml_cuda_fattn_mma_get_nstages (DKQ, DV, ncols1, ncols2, cc); const int cols_per_warp = std::min(ncols, get_cols_per_warp(cc)); - const int nwarps = nthreads / WARP_SIZE; + const int warp_size_host = ggml_cuda_info().devices[ctx.device].warp_size; + const int nwarps = nthreads / warp_size_host; constexpr bool V_is_K_view = DKQ == 576; // Guaranteed by the kernel selection logic in fattn.cu @@ -1694,7 +1772,7 @@ void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ctx, ggml } launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true); + (ctx, dst, fattn_kernel, nwarps, nbytes_shared_total, nbatch_fa, true, true, true, warp_size_host); } diff --git a/ggml/src/ggml-cuda/fattn.cu b/ggml/src/ggml-cuda/fattn.cu index 721edd9994..85c177f496 100644 --- a/ggml/src/ggml-cuda/fattn.cu +++ b/ggml/src/ggml-cuda/fattn.cu @@ -440,6 +440,18 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_MMA_F16; } + // Use MFMA flash attention for CDNA (MI100+): + if (amd_mfma_available(cc) && Q->ne[0] != 40 && Q->ne[0] != 72 && Q->ne[0] != 256 && Q->ne[0] != 576) { + const int64_t eff_nq = Q->ne[1] * (gqa_opt_applies ? gqa_ratio : 1); + // MMA vs tile crossover benchmarked on MI300X @ d32768: + // hsk=64 (gqa=4): MMA wins at eff >= 128 (+11%) + // hsk=128 (gqa=4): MMA wins at eff >= 128 (+4%) + if (eff_nq >= (GGML_CUDA_CC_IS_CDNA1(cc) && Q->ne[0] == 64 ? 64 : 128)) { + return BEST_FATTN_KERNEL_MMA_F16; + } + // Fall through to tile kernel for small effective batch sizes. + } + // If there are no tensor cores available, use the generic tile kernel: if (can_use_vector_kernel) { if (!ggml_is_quantized(K->type) && !ggml_is_quantized(V->type)) { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index ffa35eeb65..7e6d330354 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2979,10 +2979,6 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx const void * graph_key = ggml_cuda_graph_get_key(cgraph); ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); - if (graph->instance == nullptr) { - res = true; - } - // Check if the graph size has changed if (graph->props.size() != (size_t)cgraph->n_nodes) { res = true; @@ -3931,14 +3927,35 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, #ifdef USE_CUDA_GRAPH graph_key = ggml_cuda_graph_get_key(cgraph); - use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx, graph_key); + ggml_cuda_graph_set_enabled(cuda_ctx, graph_key); ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); if (graph->is_enabled()) { - cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph); - use_cuda_graph = ggml_cuda_graph_check_compability(cgraph); + const bool graph_compatible = ggml_cuda_graph_check_compability(cgraph); + if (graph_compatible) { + const bool properties_changed = ggml_cuda_graph_update_required(cuda_ctx, cgraph); - graph->record_update(use_cuda_graph, cuda_graph_update_required); + if (!graph->warmup_complete) { + // Warmup: need at least 2 calls with no property change on the 2nd call + if (!properties_changed) { + graph->warmup_complete = true; + GGML_LOG_DEBUG("%s: CUDA graph warmup complete\n", __func__); + use_cuda_graph = true; + cuda_graph_update_required = true; + } + // else: properties changed or first call - execute directly (use_cuda_graph stays false) + } else { + // Post-warmup: normal CUDA graph operation + if (properties_changed) { + // Properties changed - reset warmup, execute directly until stable again + graph->warmup_complete = false; + GGML_LOG_DEBUG("%s: CUDA graph warmup reset\n", __func__); + } else { + use_cuda_graph = true; + cuda_graph_update_required = graph->instance == nullptr; + } + } + } } #endif // USE_CUDA_GRAPH diff --git a/ggml/src/ggml-cuda/mma.cuh b/ggml/src/ggml-cuda/mma.cuh index dd45d6c78f..5d1dadd3e4 100644 --- a/ggml/src/ggml-cuda/mma.cuh +++ b/ggml/src/ggml-cuda/mma.cuh @@ -668,7 +668,7 @@ namespace ggml_cuda_mma { return ret; } -#elif defined(AMD_WMMA_AVAILABLE) +#elif defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE) template static __device__ __forceinline__ tile get_half2(const tile & tile_float) { tile ret; @@ -964,6 +964,34 @@ namespace ggml_cuda_mma { GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; #endif // defined(RDNA4) +#elif defined(AMD_MFMA_AVAILABLE) + // MFMA: FP16 input, FP32 accumulate, convert back to half2. + using halfx4_t = __attribute__((ext_vector_type(4))) _Float16; + using floatx4_t = __attribute__((ext_vector_type(4))) float; + + // Convert existing half2 accumulator to float for MFMA: + floatx4_t acc_f32; + { + const halfx4_t acc_h = reinterpret_cast(D.x[0]); +#pragma unroll + for (int i = 0; i < 4; ++i) { + acc_f32[i] = (float)acc_h[i]; + } + } + + const halfx4_t& a_frag = reinterpret_cast(A.x[0]); + const halfx4_t& b_frag = reinterpret_cast(B.x[0]); + acc_f32 = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_f32, 0, 0, 0); + + // Convert back to half2: + { + halfx4_t result_h; +#pragma unroll + for (int i = 0; i < 4; ++i) { + result_h[i] = (_Float16)acc_f32[i]; + } + reinterpret_cast(D.x[0]) = result_h; + } #else GGML_UNUSED_VARS(D, A, B); NO_DEVICE_CODE; diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 54f9986498..7a44443a8a 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -1749,23 +1749,6 @@ static inline bool ggml_backend_buffer_is_hexagon_repack(const struct ggml_backe return b->buft->iface.alloc_buffer == ggml_backend_hexagon_repack_buffer_type_alloc_buffer; } -static bool hex_supported_dims2(const struct ggml_tensor * x, const struct ggml_tensor * y) { - if (x->ne[0] != y->ne[0]) { - return false; - } - if (x->ne[1] != y->ne[1]) { - return false; - } - if (x->ne[2] != y->ne[2]) { - return false; - } - if (x->ne[3] != y->ne[3]) { - return false; - } - - return true; -} - static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { const struct ggml_tensor * src0 = op->src[0]; const struct ggml_tensor * src1 = op->src[1]; @@ -1797,43 +1780,6 @@ static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_sess return opt_experimental; } -static bool hex_supported_src0_type(ggml_type t) { - return t == GGML_TYPE_F32; -} - -static bool hex_supported_src1_type(ggml_type t) { - return t == GGML_TYPE_F32; -} - -static bool hex_supported_src2_type(ggml_type t) { - return t == GGML_TYPE_F32; -} - -static bool hex_supported_src1_type2(ggml_type t) { - return t == GGML_TYPE_F16; -} - -static bool hex_supported_src1_type3(ggml_type t) { - return t == GGML_TYPE_I32; -} - -static bool hex_supported_dst_type(ggml_type t) { - return t == GGML_TYPE_F32; -} - -static bool hex_supported_dims(const struct ggml_tensor * x, const struct ggml_tensor * y) { - // TODO: support broadcast for ne[2 and 3] - if (x->ne[0] != y->ne[0]) { - return false; - } - if (x->ne[2] != y->ne[2]) { - return false; - } - if (x->ne[3] != y->ne[3]) { - return false; - } - return true; -} static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * sess, const struct ggml_tensor * dst) { const struct ggml_tensor * src0 = dst->src[0]; @@ -1919,19 +1865,19 @@ static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * se const struct ggml_tensor * src1 = op->src[1]; const struct ggml_tensor * dst = op; - if (!hex_supported_src0_type(src0->type)) { + if (src0->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_src1_type(src1->type)) { + if (src1->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_dst_type(dst->type)) { + if (dst->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_dims2(src0, dst)) { + if (!ggml_are_same_shape(src0, dst)) { return false; } - if (!ggml_can_repeat(src1, src0)) { + if (!ggml_can_repeat(src1, src0) || ggml_is_permuted(src1)) { return false; } @@ -1943,16 +1889,16 @@ static bool ggml_hexagon_supported_add_id(const struct ggml_hexagon_session * se const struct ggml_tensor * src1 = op->src[1]; const struct ggml_tensor * dst = op; - if (!hex_supported_src0_type(src0->type)) { + if (src0->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_src1_type(src1->type)) { + if (src1->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_dst_type(dst->type)) { + if (dst->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_dims2(src0, dst)) { + if (!ggml_are_same_shape(src0, dst)) { return false; } @@ -1968,13 +1914,13 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses const struct ggml_tensor * src0 = op->src[0]; const struct ggml_tensor * dst = op; - if (!hex_supported_src0_type(src0->type)) { + if (src0->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_dst_type(dst->type)) { + if (dst->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_dims2(src0, dst)) { + if (!ggml_are_same_shape(src0, dst)) { return false; } @@ -1990,10 +1936,10 @@ static bool ggml_hexagon_supported_sum_rows(const struct ggml_hexagon_session * const struct ggml_tensor * src0 = op->src[0]; const struct ggml_tensor * dst = op; - if (!hex_supported_src0_type(src0->type)) { + if (src0->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_dst_type(dst->type)) { + if (dst->type != GGML_TYPE_F32) { return false; } @@ -2011,10 +1957,10 @@ static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session const struct ggml_tensor * src1 = op->src[1]; const struct ggml_tensor * dst = op; - if (!hex_supported_src0_type(src0->type)) { + if (src0->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_dst_type(dst->type)) { + if (dst->type != GGML_TYPE_F32) { return false; } @@ -2023,10 +1969,10 @@ static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session } if (src1) { - if (!hex_supported_src1_type(src1->type)) { + if (src1->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_dims2(src0, src1)) { + if (!ggml_are_same_shape(src0, src1)) { return false; } if (!ggml_is_contiguous(src1)) { @@ -2047,15 +1993,15 @@ static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * s return false; // FIXME: add support for sinks } - if (!hex_supported_src0_type(src0->type)) { + if (src0->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_dst_type(dst->type)) { + if (dst->type != GGML_TYPE_F32) { return false; } if (src1) { - if (!hex_supported_src1_type(src1->type) && !hex_supported_src1_type2(src1->type)) { + if (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) { return false; } if (src0->ne[0] != src1->ne[0]) { @@ -2162,17 +2108,17 @@ static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess const struct ggml_tensor * src2 = op->src[2]; const struct ggml_tensor * dst = op; - if (!hex_supported_src0_type(src0->type)) { + if (src0->type != GGML_TYPE_F32) { return false; // FIXME: add support for GGML_TYPE_F16 for src0 } - if (!hex_supported_dst_type(dst->type)) { + if (dst->type != GGML_TYPE_F32) { return false; } - if (!hex_supported_src1_type3(src1->type)) { + if (src1->type != GGML_TYPE_I32) { return false; } if (src2) { - if (!hex_supported_src2_type(src2->type)) { + if (src2->type != GGML_TYPE_F32) { return false; } int n_dims = op_params[1]; diff --git a/ggml/src/ggml-hexagon/htp/act-ops.c b/ggml/src/ggml-hexagon/htp/act-ops.c index 950d836ad3..21bd4050a1 100644 --- a/ggml/src/ggml-hexagon/htp/act-ops.c +++ b/ggml/src/ggml-hexagon/htp/act-ops.c @@ -69,27 +69,45 @@ const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; -static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0, - const struct htp_tensor * src1, - struct htp_tensor * dst, - const int32_t * op_params, - struct htp_spad * src0_spad, - struct htp_spad * src1_spad, - struct htp_spad * dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { +struct htp_act_context { + struct htp_ops_context * octx; + + // Precomputed values + const uint8_t * data_src0; + const uint8_t * data_src1; + uint8_t * data_dst; + + size_t src0_row_size; + size_t src1_row_size; + size_t dst_row_size; + + size_t src0_row_size_aligned; + size_t src1_row_size_aligned; + size_t dst_row_size_aligned; + + size_t src0_spad_half_size; + size_t src1_spad_half_size; + size_t dst_spad_half_size; + + uint32_t block; + uint32_t src0_nrows; + uint32_t src0_nrows_per_thread; + int nc; +}; + +static void glu_swiglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { + struct htp_act_context * actx = (struct htp_act_context *) data; + const struct htp_tensor * src0 = &actx->octx->src0; + const struct htp_tensor * src1 = &actx->octx->src1; + const struct htp_tensor * dst = &actx->octx->dst; htp_act_preamble3; - size_t src0_row_size = nb01; - size_t src1_row_size = nb11; - size_t dst_row_size = nb1; - - - - const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + size_t src0_row_size = actx->src0_row_size; + size_t src1_row_size = actx->src1_row_size; + size_t dst_row_size = actx->dst_row_size; + const uint32_t src0_nrows = actx->src0_nrows; + const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); @@ -101,43 +119,34 @@ static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0, uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); - const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; - const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; - uint8_t * restrict data_dst = (uint8_t *) dst->data; + const uint8_t * restrict data_src0 = actx->data_src0; + const uint8_t * restrict data_src1 = actx->data_src1; + uint8_t * restrict data_dst = actx->data_dst; - const bool src1_valid = src1->ne[0]; - const int nc = (src1_valid) ? ne00 : ne00 / 2; - if (!src1_valid) { - const int32_t swapped = op_params[1]; - data_src1 = data_src0; - src1_row_size = src0_row_size; + const int nc = actx->nc; - const size_t nc_in_bytes = nc * SIZEOF_FP32; - data_src0 += swapped ? nc_in_bytes : 0; - data_src1 += swapped ? 0 : nc_in_bytes; - } + const size_t src0_row_size_aligned = actx->src0_row_size_aligned; + const size_t src1_row_size_aligned = actx->src1_row_size_aligned; + const size_t dst_row_size_aligned = actx->dst_row_size_aligned; - const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); - const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); - const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread); + uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread); + uint8_t * restrict dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread); - uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); - uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread); - uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); + size_t src0_spad_half_size = actx->src0_spad_half_size; + size_t src1_spad_half_size = actx->src1_spad_half_size; + size_t dst_spad_half_size = actx->dst_spad_half_size; - // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 - size_t src0_spad_half_size = src0_spad->size_per_thread / 2; - size_t src1_spad_half_size = src1_spad->size_per_thread / 2; - size_t dst_spad_half_size = dst_spad->size_per_thread / 2; - - const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block + const int BLOCK = actx->block; if (BLOCK == 0) { FARF(ERROR, "swiglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", - src0_spad->size_per_thread, src0_row_size_aligned); + actx->octx->src0_spad.size_per_thread, src0_row_size_aligned); return; } + dma_queue * dma_queue = actx->octx->ctx->dma[ith]; + // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); @@ -196,27 +205,22 @@ static void glu_swiglu_f32_per_thread(const struct htp_tensor * src0, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0, - const struct htp_tensor * src1, - struct htp_tensor * dst, - const int32_t * op_params, - struct htp_spad * src0_spad, - struct htp_spad * src1_spad, - struct htp_spad * dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { +static void glu_swiglu_oai_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { + struct htp_act_context * actx = (struct htp_act_context *) data; + const struct htp_tensor * src0 = &actx->octx->src0; + const struct htp_tensor * src1 = &actx->octx->src1; + const struct htp_tensor * dst = &actx->octx->dst; htp_act_preamble3; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); - size_t src0_row_size = nb01; - size_t src1_row_size = nb11; - size_t dst_row_size = nb1; + size_t src0_row_size = actx->src0_row_size; + size_t src1_row_size = actx->src1_row_size; + size_t dst_row_size = actx->dst_row_size; - const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + const uint32_t src0_nrows = actx->src0_nrows; + const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); @@ -226,45 +230,36 @@ static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0, return; } - const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; - const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; - uint8_t * restrict data_dst = (uint8_t *) dst->data; + const uint8_t * restrict data_src0 = actx->data_src0; + const uint8_t * restrict data_src1 = actx->data_src1; + uint8_t * restrict data_dst = actx->data_dst; - const bool src1_valid = src1->ne[0]; - const int nc = (src1_valid) ? ne00 : ne00 / 2; - if (!src1_valid) { - const int32_t swapped = op_params[1]; - data_src1 = data_src0; - src1_row_size = src0_row_size; + const int nc = actx->nc; - const size_t nc_in_bytes = nc * SIZEOF_FP32; - data_src0 += swapped ? nc_in_bytes : 0; - data_src1 += swapped ? 0 : nc_in_bytes; - } + const size_t src0_row_size_aligned = actx->src0_row_size_aligned; + const size_t src1_row_size_aligned = actx->src1_row_size_aligned; + const size_t dst_row_size_aligned = actx->dst_row_size_aligned; - const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); - const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); - const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread); + uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread); + uint8_t * restrict dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread); - uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); - uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread); - uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); + size_t src0_spad_half_size = actx->src0_spad_half_size; + size_t src1_spad_half_size = actx->src1_spad_half_size; + size_t dst_spad_half_size = actx->dst_spad_half_size; - // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 - size_t src0_spad_half_size = src0_spad->size_per_thread / 2; - size_t src1_spad_half_size = src1_spad->size_per_thread / 2; - size_t dst_spad_half_size = dst_spad->size_per_thread / 2; - - const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block + const int BLOCK = actx->block; if (BLOCK == 0) { FARF(ERROR, "swiglu-oai-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least " "%zu\n", - src0_spad->size_per_thread, src0_row_size_aligned); + actx->octx->src0_spad.size_per_thread, src0_row_size_aligned); return; } - const float alpha = ((const float *) (op_params))[2]; - const float limit = ((const float *) (op_params))[3]; + const float alpha = ((const float *) (actx->octx->op_params))[2]; + const float limit = ((const float *) (actx->octx->op_params))[3]; + + dma_queue * dma_queue = actx->octx->ctx->dma[ith]; // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { @@ -335,26 +330,22 @@ static void glu_swiglu_oai_f32_per_thread(const struct htp_tensor * src0, } -static void unary_gelu_f32_per_thread(const struct htp_tensor * src0, - struct htp_tensor * dst, - const int32_t * op_params, - struct htp_spad * src0_spad, - struct htp_spad * dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { +static void unary_gelu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { + struct htp_act_context * actx = (struct htp_act_context *) data; + const struct htp_tensor * src0 = &actx->octx->src0; + const struct htp_tensor * dst = &actx->octx->dst; htp_act_preamble2; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); - const size_t src0_row_size = nb01; - const size_t dst_row_size = nb1; - const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); - const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + const size_t src0_row_size = actx->src0_row_size; + const size_t dst_row_size = actx->dst_row_size; + const size_t src0_row_size_aligned = actx->src0_row_size_aligned; + const size_t dst_row_size_aligned = actx->dst_row_size_aligned; - const uint32_t src0_nrows = ne01 * ne02 * ne03; + const uint32_t src0_nrows = actx->src0_nrows; + const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); @@ -364,25 +355,29 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0, return; } - const uint8_t * data_src0 = (const uint8_t *) src0->data; - uint8_t * data_dst = (uint8_t *) dst->data; + const uint8_t * data_src0 = actx->data_src0; + uint8_t * data_dst = actx->data_dst; - uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); - uint8_t * dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); + // nc/ne0 matches. + const int ne0_val = actx->nc; // == dst->ne[0] - // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 - size_t src0_spad_half_size = src0_spad->size_per_thread / 2; - size_t dst_spad_half_size = dst_spad->size_per_thread / 2; + uint8_t * src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread); + uint8_t * dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread); + + size_t src0_spad_half_size = actx->src0_spad_half_size; + size_t dst_spad_half_size = actx->dst_spad_half_size; // In gelu = x*sigmoid(x*1.702) - const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block + const int BLOCK = actx->block; if (BLOCK == 0) { FARF(ERROR, "gelu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", - src0_spad->size_per_thread, src0_row_size_aligned); + actx->octx->src0_spad.size_per_thread, src0_row_size_aligned); return; } + dma_queue * dma_queue = actx->octx->ctx->dma[ith]; + // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); @@ -408,9 +403,9 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0, float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float)); // gelu = x * sigmoid(1.702 * x) // current implementation - hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0); - hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0); - hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0); + hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0_val); + hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val); + hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val); } dma_queue_push_vtcm_to_ddr(dma_queue, @@ -435,34 +430,23 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0, ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void unary_gelu_f32(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = (struct htp_ops_context *) data; - unary_gelu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); -} - - -static void unary_silu_f32_per_thread(const struct htp_tensor * src0, - struct htp_tensor * dst, - const int32_t * op_params, - struct htp_spad * src0_spad, - struct htp_spad * dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { +static void unary_silu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { + struct htp_act_context * actx = (struct htp_act_context *) data; + const struct htp_tensor * src0 = &actx->octx->src0; + const struct htp_tensor * dst = &actx->octx->dst; htp_act_preamble2; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); - const size_t src0_row_size = nb01; - const size_t dst_row_size = nb1; - const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); - const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + const size_t src0_row_size = actx->src0_row_size; + const size_t dst_row_size = actx->dst_row_size; + const size_t src0_row_size_aligned = actx->src0_row_size_aligned; + const size_t dst_row_size_aligned = actx->dst_row_size_aligned; - const uint32_t src0_nrows = ne01 * ne02 * ne03; + const uint32_t src0_nrows = actx->src0_nrows; + const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); @@ -472,24 +456,27 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0, return; } - const uint8_t * data_src0 = (const uint8_t *) src0->data; - uint8_t * data_dst = (uint8_t *) dst->data; + const uint8_t * data_src0 = actx->data_src0; + uint8_t * data_dst = actx->data_dst; - uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); - uint8_t * dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); + const int ne0_val = actx->nc; // == dst->ne[0] - // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 - size_t src0_spad_half_size = src0_spad->size_per_thread / 2; - size_t dst_spad_half_size = dst_spad->size_per_thread / 2; + uint8_t * src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread); + uint8_t * dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread); - const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block + size_t src0_spad_half_size = actx->src0_spad_half_size; + size_t dst_spad_half_size = actx->dst_spad_half_size; + + const int BLOCK = actx->block; if (BLOCK == 0) { FARF(ERROR, "silu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", - src0_spad->size_per_thread, src0_row_size_aligned); + actx->octx->src0_spad.size_per_thread, src0_row_size_aligned); return; } + dma_queue * dma_queue = actx->octx->ctx->dma[ith]; + // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); @@ -515,8 +502,8 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0, float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float)); // silu = x * sigmoid(x) - hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0); - hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0); + hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0_val); + hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0_val); } dma_queue_push_vtcm_to_ddr(dma_queue, @@ -544,27 +531,22 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0, static const float GELU_COEF_A = 0.044715f; static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; -static void glu_geglu_f32_per_thread(const struct htp_tensor * src0, - const struct htp_tensor * src1, - struct htp_tensor * dst, - const int32_t * op_params, - struct htp_spad * src0_spad, - struct htp_spad * src1_spad, - struct htp_spad * dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { +static void glu_geglu_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { + struct htp_act_context * actx = (struct htp_act_context *) data; + const struct htp_tensor * src0 = &actx->octx->src0; + const struct htp_tensor * src1 = &actx->octx->src1; + const struct htp_tensor * dst = &actx->octx->dst; htp_act_preamble3; - size_t src0_row_size = nb01; - size_t src1_row_size = nb11; - size_t dst_row_size = nb1; + size_t src0_row_size = actx->src0_row_size; + size_t src1_row_size = actx->src1_row_size; + size_t dst_row_size = actx->dst_row_size; uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); - const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + const uint32_t src0_nrows = actx->src0_nrows; + const uint32_t src0_nrows_per_thread = actx->src0_nrows_per_thread; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); @@ -574,43 +556,34 @@ static void glu_geglu_f32_per_thread(const struct htp_tensor * src0, return; } - const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; - const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; - uint8_t * restrict data_dst = (uint8_t *) dst->data; + const uint8_t * restrict data_src0 = actx->data_src0; + const uint8_t * restrict data_src1 = actx->data_src1; + uint8_t * restrict data_dst = actx->data_dst; - const bool src1_valid = src1->ne[0]; - const int nc = (src1_valid) ? ne00 : ne00 / 2; - if (!src1_valid) { - const int32_t swapped = op_params[1]; - data_src1 = data_src0; - src1_row_size = src0_row_size; + const int nc = actx->nc; - const size_t nc_in_bytes = nc * SIZEOF_FP32; - data_src0 += swapped ? nc_in_bytes : 0; - data_src1 += swapped ? 0 : nc_in_bytes; - } + const size_t src0_row_size_aligned = actx->src0_row_size_aligned; + const size_t src1_row_size_aligned = actx->src1_row_size_aligned; + const size_t dst_row_size_aligned = actx->dst_row_size_aligned; - const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); - const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN); - const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + uint8_t * restrict src0_spad_data = actx->octx->src0_spad.data + (ith * actx->octx->src0_spad.size_per_thread); + uint8_t * restrict src1_spad_data = actx->octx->src1_spad.data + (ith * actx->octx->src1_spad.size_per_thread); + uint8_t * restrict dst_spad_data = actx->octx->dst_spad.data + (ith * actx->octx->dst_spad.size_per_thread); - uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); - uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread); - uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); + size_t src0_spad_half_size = actx->src0_spad_half_size; + size_t src1_spad_half_size = actx->src1_spad_half_size; + size_t dst_spad_half_size = actx->dst_spad_half_size; - // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 - size_t src0_spad_half_size = src0_spad->size_per_thread / 2; - size_t src1_spad_half_size = src1_spad->size_per_thread / 2; - size_t dst_spad_half_size = dst_spad->size_per_thread / 2; - - const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block + const int BLOCK = actx->block; if (BLOCK == 0) { FARF(ERROR, "geglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", - src0_spad->size_per_thread, src0_row_size_aligned); + actx->octx->src0_spad.size_per_thread, src0_row_size_aligned); return; } + dma_queue * dma_queue = actx->octx->ctx->dma[ith]; + // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); @@ -678,33 +651,7 @@ static void glu_geglu_f32_per_thread(const struct htp_tensor * src0, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void unary_silu_f32(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = (struct htp_ops_context *) data; - unary_silu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); -} - -static void glu_swiglu_f32(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = (struct htp_ops_context *) data; - glu_swiglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad, - &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); -} - -static void glu_swiglu_oai_f32(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = (struct htp_ops_context *) data; - glu_swiglu_oai_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad, - &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); -} - -static void glu_geglu_f32(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = (struct htp_ops_context *) data; - glu_geglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad, - &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); -} - static int execute_op_activations_f32(struct htp_ops_context * octx) { - int err = HTP_STATUS_OK; - const struct htp_tensor * src0 = &octx->src0; const struct htp_tensor * src1 = &octx->src1; struct htp_tensor * dst = &octx->dst; @@ -719,26 +666,26 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) { switch (octx->op) { case HTP_OP_UNARY_SILU: - act_op_func = unary_silu_f32; + act_op_func = (worker_callback_t)unary_silu_f32_per_thread; op_type = "silu-f32"; break; case HTP_OP_GLU_SWIGLU: - act_op_func = glu_swiglu_f32; + act_op_func = (worker_callback_t)glu_swiglu_f32_per_thread; op_type = "swiglu-f32"; break; case HTP_OP_GLU_SWIGLU_OAI: - act_op_func = glu_swiglu_oai_f32; + act_op_func = (worker_callback_t)glu_swiglu_oai_f32_per_thread; op_type = "swiglu-oai-f32"; break; case HTP_OP_UNARY_GELU: - act_op_func = unary_gelu_f32; + act_op_func = (worker_callback_t)unary_gelu_f32_per_thread; op_type = "gelu-f32"; break; case HTP_OP_GLU_GEGLU: - act_op_func = glu_geglu_f32; + act_op_func = (worker_callback_t)glu_geglu_f32_per_thread; op_type = "geglu-f32"; break; default: @@ -797,13 +744,58 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) { octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); } - if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - uint32_t n_jobs = MIN(n_threads, src0_nrows); - octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; - worker_pool_run_func(octx->ctx->worker_pool, act_op_func, octx, n_jobs); + if ((octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { + return HTP_STATUS_OK; } - return err; + uint32_t n_jobs = MIN(n_threads, src0_nrows); + + // Prepare context + struct htp_act_context actx; + actx.octx = octx; + + actx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + + actx.src0_row_size = src0_row_size; + actx.src1_row_size = src1_row_size; + actx.dst_row_size = dst_row_size; + + actx.src0_row_size_aligned = src0_row_size_aligned; + actx.src1_row_size_aligned = src1_row_size_aligned; + actx.dst_row_size_aligned = dst_row_size_aligned; + + actx.src0_spad_half_size = octx->src0_spad.size_per_thread / 2; + actx.src1_spad_half_size = octx->src1_spad.size_per_thread / 2; + actx.dst_spad_half_size = octx->dst_spad.size_per_thread / 2; + + actx.block = actx.src0_spad_half_size / actx.src0_row_size_aligned; + actx.src0_nrows = src0_nrows; + + actx.nc = dst->ne[0]; + + // Pointers and GLU logic + const uint8_t * data_src0 = (const uint8_t *) src0->data; + const uint8_t * data_src1 = (const uint8_t *) src1->data; + + if (!src1_valid && (octx->op == HTP_OP_GLU_SWIGLU || octx->op == HTP_OP_GLU_SWIGLU_OAI || octx->op == HTP_OP_GLU_GEGLU)) { + const int32_t swapped = octx->op_params[1]; + data_src1 = data_src0; + actx.src1_row_size = actx.src0_row_size; + + size_t nc_in_bytes = actx.nc * SIZEOF_FP32; + if (swapped) { + data_src0 += nc_in_bytes; + } else { + data_src1 += nc_in_bytes; + } + } + + actx.data_src0 = data_src0; + actx.data_src1 = data_src1; + actx.data_dst = (uint8_t *) dst->data; + + worker_pool_run_func(octx->ctx->worker_pool, act_op_func, &actx, n_jobs); + return HTP_STATUS_OK; } int op_activations(struct htp_ops_context * octx) { diff --git a/ggml/src/ggml-hexagon/htp/get-rows-ops.c b/ggml/src/ggml-hexagon/htp/get-rows-ops.c index a657cd2dcf..bf24bbda70 100644 --- a/ggml/src/ggml-hexagon/htp/get-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/get-rows-ops.c @@ -15,6 +15,13 @@ #include "htp-ops.h" #include "hvx-utils.h" +struct get_rows_context { + struct htp_ops_context * octx; + uint32_t src1_nrows_per_thread; + struct fastdiv_values get_rows_div_ne10; + struct fastdiv_values get_rows_div_ne10_ne11; +}; + #define get_rows_preamble \ const uint32_t ne00 = octx->src0.ne[0]; \ const uint32_t ne01 = octx->src0.ne[1]; \ @@ -39,20 +46,22 @@ \ const uint32_t nr = ne10 * ne11 * ne12; -static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) { +static void get_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) { + struct get_rows_context * grctx = (struct get_rows_context *)data; + struct htp_ops_context * octx = grctx->octx; get_rows_preamble; // parallelize by src1 elements (which correspond to dst rows) - const uint32_t dr = octx->src1_nrows_per_thread; + const uint32_t dr = grctx->src1_nrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; const bool is_i32 = (octx->src1.type == HTP_TYPE_I32); for (uint32_t i = ir0; i < ir1; ++i) { - const uint32_t i12 = fastdiv(i, &octx->get_rows_div_ne10_ne11); + const uint32_t i12 = fastdiv(i, &grctx->get_rows_div_ne10_ne11); const uint32_t rem = i - i12 * ne11 * ne10; - const uint32_t i11 = fastdiv(rem, &octx->get_rows_div_ne10); + const uint32_t i11 = fastdiv(rem, &grctx->get_rows_div_ne10); const uint32_t i10 = rem - i11 * ne10; const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; @@ -68,12 +77,6 @@ static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const uintptr_t dst_ptr = octx->dst.data + i10*nb1 + i11*nb2 + i12*nb3; hvx_copy_f32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00); } - - return HTP_STATUS_OK; -} - -static void get_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) { - get_rows_thread_f32_f32((struct htp_ops_context *) data, n, i); } int op_get_rows(struct htp_ops_context * octx) { @@ -95,12 +98,14 @@ int op_get_rows(struct htp_ops_context * octx) { return HTP_STATUS_OK; } - octx->get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]); - octx->get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]); + struct get_rows_context grctx; + grctx.octx = octx; + grctx.get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]); + grctx.get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]); const uint32_t n_jobs = MIN(nr, octx->n_threads); - octx->src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; + grctx.src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; - worker_pool_run_func(octx->ctx->worker_pool, get_rows_work_f32_f32, octx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, get_rows_thread_f32_f32, &grctx, n_jobs); return HTP_STATUS_OK; } diff --git a/ggml/src/ggml-hexagon/htp/hex-dma.h b/ggml/src/ggml-hexagon/htp/hex-dma.h index d1ddb0ecbf..350ab9d966 100644 --- a/ggml/src/ggml-hexagon/htp/hex-dma.h +++ b/ggml/src/ggml-hexagon/htp/hex-dma.h @@ -102,7 +102,7 @@ static inline bool dma_queue_push(dma_queue * q, dmlink(q->tail, desc); q->tail = desc; - // FARF(ERROR, "dma-push: i %u len %u dst %p src %p\n", q->push_idx, len, dst, src); + // FARF(ERROR, "dma-push: i %u width %u nrows %d dst %p src %p\n", q->push_idx, width, nrows, dptr.dst, dptr.src); q->push_idx = (q->push_idx + 1) & q->idx_mask; return true; } @@ -144,11 +144,37 @@ static inline dma_ptr dma_queue_pop(dma_queue * q) { dptr = q->dptr[q->pop_idx]; - // FARF(ERROR, "dma-pop: i %u dst %p\n", q->pop_idx, dst); + // FARF(ERROR, "dma-pop: i %u dst %p src %p\n", q->pop_idx, dptr.dst, dptr.src); q->pop_idx = (q->pop_idx + 1) & q->idx_mask; return dptr; } +static inline dma_ptr dma_queue_pop_nowait(dma_queue * q) { + dma_ptr dptr = { NULL }; + + if (q->push_idx == q->pop_idx) { + return dptr; + } + + dptr = q->dptr[q->pop_idx]; + + // FARF(ERROR, "dma-pop-nowait: i %u dst %p src %p\n", q->pop_idx, dptr.dst, dptr.src); + q->pop_idx = (q->pop_idx + 1) & q->idx_mask; + return dptr; +} + +static inline bool dma_queue_empty(dma_queue * q) { + return q->push_idx == q->pop_idx; +} + +static inline uint32_t dma_queue_depth(dma_queue * q) { + return (q->push_idx - q->pop_idx) & q->idx_mask; +} + +static inline uint32_t dma_queue_capacity(dma_queue * q) { + return q->capacity; +} + #ifdef __cplusplus } // extern "C" #endif diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index f1ad24dbfa..127ab1d665 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -44,32 +44,6 @@ struct htp_ops_context { uint32_t src0_nrows_per_thread; uint32_t src1_nrows_per_thread; - struct fastdiv_values src0_div1; // fastdiv values for ne1 - struct fastdiv_values src0_div2; // fastdiv values for ne2 - struct fastdiv_values src0_div3; // fastdiv values for ne3 - struct fastdiv_values src0_div21; // fastdiv values for ne2 * ne1 - - struct fastdiv_values src1_div1; // fastdiv values for ne1 - struct fastdiv_values src1_div2; // fastdiv values for ne2 - struct fastdiv_values src1_div3; // fastdiv values for ne3 - struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1 - - struct fastdiv_values src3_div1; // fastdiv values for ne1 - struct fastdiv_values src3_div2; // fastdiv values for ne2 - struct fastdiv_values src3_div3; // fastdiv values for ne3 - struct fastdiv_values src3_div21; // fastdiv values for ne2 * ne1 - - struct fastdiv_values broadcast_rk2; - struct fastdiv_values broadcast_rk3; - struct fastdiv_values broadcast_rv2; - struct fastdiv_values broadcast_rv3; - - struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12 - struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11 - - struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10 - struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11 - uint32_t flags; }; diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index c360abe8da..6f6f51f01f 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -49,62 +49,6 @@ struct htp_matmul_context { struct fastdiv_values mm_div_r3; }; -// vdelta control to replicate first 4x fp32 values across lanes -static const uint8_t __attribute__((aligned(128))) repl_4x_f32[128] = { - 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, - 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, - 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, - 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, - 0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, - 0x04, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, - 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, -}; - -// vdelta control to replicate and interleave first 8x fp32 values across lanes -static const uint8_t __attribute__((aligned(128))) repl_interleave_8x_f32[128] = { - 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x00, 0x00, 0x00, - 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, - 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, - 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, - 0x44, 0x44, 0x44, 0x44, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x44, 0x44, 0x44, - 0x44, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, - 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, -}; - -// vdelta control to replicate first fp32 value across all elements -static const uint8_t __attribute__((aligned(128))) repl_1x_f32[128] = { - 0x00, 0x00, 0x00, 0x00, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, - 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, - 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, - 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x40, 0x40, 0x40, 0x40, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, - 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, - 0x04, 0x20, 0x20, 0x20, 0x20, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, 0x10, 0x10, - 0x10, 0x10, 0x04, 0x04, 0x04, 0x04, 0x08, 0x08, 0x08, 0x08, 0x04, 0x04, 0x04, 0x04, -}; - -// vdelta control to replicate first fp16 value across all elements -static const uint8_t __attribute__((aligned(128))) repl_1x_f16[128] = { - 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, - 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, - 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, - 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x40, 0x40, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, - 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, - 0x02, 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x10, 0x10, - 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, -}; - -// vdelta control to replicate first fp16 value across all elements -static const uint8_t __attribute__((aligned(128))) repl_2x_f16[128] = { - 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, - 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, - 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, - 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, - 0x00, 0x00, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, - 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, - 0x20, 0x20, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, - 0x10, 0x10, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, 0x08, 0x08, 0x02, 0x02, 0x04, 0x04, 0x02, 0x02, -}; - // vdelta control to expand first 32 e8m0 values into 32 uint32 elements static const uint8_t __attribute__((aligned(128))) expand_x32_e8m0[128] = { 0x00, 0x00, 0x00, 0x00, 0x01, 0x04, 0x00, 0x00, 0x02, 0x00, 0x08, 0x08, 0x01, 0x02, 0x00, 0x04, 0x04, 0x00, 0x00, @@ -2067,10 +2011,10 @@ static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restric HVX_Vector vx3_qf = Q6_Vqf32_vsub_VsfVsf(vx[3], zero); // 32 elements // Convert to QF32 - HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); - HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); - HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); - HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); + HVX_Vector vmax0_qf = Q6_Vqf32_vsub_VsfVsf(vmax0_sf, zero); // replicated over all lanes + HVX_Vector vmax1_qf = Q6_Vqf32_vsub_VsfVsf(vmax1_sf, zero); // replicated over all lanes + HVX_Vector vmax2_qf = Q6_Vqf32_vsub_VsfVsf(vmax2_sf, zero); // replicated over all lanes + HVX_Vector vmax3_qf = Q6_Vqf32_vsub_VsfVsf(vmax3_sf, zero); // replicated over all lanes // Combine and convert to fp16 HVX_Vector vmax01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vmax1_qf, vmax0_qf))); @@ -2080,11 +2024,6 @@ static inline void quantize_block_f32_q8x1(float * restrict x, uint8_t * restric HVX_Vector vx01_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx1_qf, vx0_qf))); HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); - // Replicate first fp16 scale across all lanes - HVX_Vector ctrl = *(const HVX_Vector *) repl_2x_f16; - vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl); - vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl); - HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 HVX_Vector vd01_hf = Q6_Vhf_equals_Vqf16(vd01_qf16); @@ -2130,13 +2069,8 @@ static inline void quantize_block_f32_q8x2(float * restrict x, uint8_t * restric HVX_Vector vx23_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vx3_qf, vx2_qf))); // Compute max and scale - HVX_Vector vmax01_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf)); - HVX_Vector vmax23_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx23_hf)); - - // Replicate first fp16 scale across all lanes - HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_f16; - vmax01_hf = Q6_V_vdelta_VV(vmax01_hf, ctrl); - vmax23_hf = Q6_V_vdelta_VV(vmax23_hf, ctrl); + HVX_Vector vmax01_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf)); // replicated over all lanes + HVX_Vector vmax23_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx23_hf)); // replicated over all lanes HVX_Vector vd01_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax01_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 HVX_Vector vd23_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax23_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 @@ -2179,11 +2113,7 @@ static inline void quantize_block_f32_q8x4(float * restrict x, uint8_t * restric // Compute max and scale HVX_Vector vmax_hf = hvx_vec_reduce_max_f16(hvx_vec_abs_f16(vx01_hf)); - vmax_hf = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf); - - // Replicate first fp16 scale across all lanes - HVX_Vector ctrl = *(const HVX_Vector *) repl_1x_f16; - vmax_hf = Q6_V_vdelta_VV(vmax_hf, ctrl); + vmax_hf = hvx_vec_reduce_max2_f16(hvx_vec_abs_f16(vx23_hf), vmax_hf); // replicated over all lanes HVX_Vector vd_qf16 = Q6_Vqf16_vmpy_VhfVhf(vmax_hf, Q6_Vh_vsplat_R(0x2008)); // 1.0 / 127.0 HVX_Vector vd_hf = Q6_Vhf_equals_Vqf16(vd_qf16); diff --git a/ggml/src/ggml-hexagon/htp/rope-ops.c b/ggml/src/ggml-hexagon/htp/rope-ops.c index 943ca5c952..aa6a6c9008 100644 --- a/ggml/src/ggml-hexagon/htp/rope-ops.c +++ b/ggml/src/ggml-hexagon/htp/rope-ops.c @@ -10,6 +10,7 @@ #include "hex-dma.h" #include "hvx-utils.h" +#include "hex-fastdiv.h" #define GGML_COMMON_DECL_C #include "ggml-common.h" @@ -21,6 +22,9 @@ #define HTP_ROPE_TYPE_NORMAL 0 #define HTP_ROPE_TYPE_NEOX 2 +#define HTP_ROPE_SPAD_NROWS 16 +#define HTP_ROPE_SPAD_BLOCK (HTP_ROPE_SPAD_NROWS/2) + #define htp_rope_preamble \ const uint32_t ne00 = src0->ne[0]; \ const uint32_t ne01 = src0->ne[1]; \ @@ -42,7 +46,7 @@ const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; -struct rope_th_ctx { +struct htp_rope_context { int32_t n_dims; int32_t mode; int32_t n_ctx_orig; @@ -57,7 +61,19 @@ struct rope_th_ctx { float theta_scale; float corr_dims[2]; + uint32_t src0_nrows_per_thread; + size_t spad_stride; + struct htp_ops_context * octx; + + size_t src0_row_size; + size_t dst_row_size; + size_t src0_row_size_aligned; + size_t dst_row_size_aligned; + size_t theta_cache_offset; + uint32_t src0_nrows; + + uint64_t t_start; }; static float rope_yarn_ramp(const float low, const float high, const int i0) { @@ -117,64 +133,23 @@ static void rope_corr_dims(int n_dims, dims[1] = MIN(n_dims - 1, end); } -static void init_rope_ctx(struct rope_th_ctx * rope_ctx, struct htp_ops_context * octx) { - memset(rope_ctx, 0, sizeof(struct rope_th_ctx)); +static inline void hvx_rope_neox_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) { + const HVX_Vector * restrict vsrc = (const HVX_Vector *) src0; + const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache; + HVX_Vector * restrict vdst = (HVX_Vector *) dst; - const int32_t * op_params = &octx->op_params[0]; + uint32_t nvec = (ne / (VLEN_FP32 * 2) * 2); // 2 vecs per loop, step of 2 - rope_ctx->n_dims = ((const int32_t *) op_params)[1]; - rope_ctx->mode = ((const int32_t *) op_params)[2]; - rope_ctx->n_ctx_orig = ((const int32_t *) op_params)[4]; + uint32_t he = ne / 2; // half_dims offset in elements + uint32_t hv = he / VLEN_FP32; // half_dims offset in vectors - memcpy(&rope_ctx->freq_base, (int32_t *) op_params + 5, sizeof(float)); - memcpy(&rope_ctx->freq_scale, (int32_t *) op_params + 6, sizeof(float)); - memcpy(&rope_ctx->ext_factor, (int32_t *) op_params + 7, sizeof(float)); - memcpy(&rope_ctx->attn_factor, (int32_t *) op_params + 8, sizeof(float)); - memcpy(&rope_ctx->beta_fast, (int32_t *) op_params + 9, sizeof(float)); - memcpy(&rope_ctx->beta_slow, (int32_t *) op_params + 10, sizeof(float)); - memcpy(&rope_ctx->sections, (int32_t *) op_params + 11, sizeof(int) * 4); + #pragma unroll(2) + for (uint32_t i = 0; i < nvec; i += 2) { + HVX_Vector v0 = vsrc[i/2+0]; + HVX_Vector v1 = vsrc[i/2+hv]; - rope_ctx->theta_scale = powf(rope_ctx->freq_base, -2.0f / rope_ctx->n_dims); - - rope_corr_dims(rope_ctx->n_dims, rope_ctx->n_ctx_orig, rope_ctx->freq_base, rope_ctx->beta_fast, - rope_ctx->beta_slow, rope_ctx->corr_dims); - - rope_ctx->octx = octx; - FARF(HIGH, "rope-f32 n_dims:%d, ext_factor:%.6f, theta_scale:%.6f, attn_factor:%.6f\n", rope_ctx->n_dims, - rope_ctx->ext_factor, rope_ctx->theta_scale, rope_ctx->attn_factor); -} - -static void hvx_calc_rope_neox_f32(const float * restrict src0, - float * restrict dst, - const int num_elems, - const float * restrict theta_cache) { - // for (int i = 0; i < num_elems; i += 2) { - //const float cos_theta = theta_cache[i + 0]; - //const float sin_theta = theta_cache[i + 1]; - - //const float x0 = src[0]; - //const float x1 = src[num_elems/2]; - - //dst[0] = x0*cos_theta - x1*sin_theta; - //dst[num_elems/2] = x0*sin_theta + x1*cos_theta; - - //src += 1; - //dst += 1; - // } - - const uint8_t * restrict src0_curr = (const uint8_t *) src0; - const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache; - uint8_t * restrict dst_curr = (uint8_t *) dst; - - int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once - int half_size = (sizeof(float) * (num_elems / 2)); - - for (int i = 0; i < step_of_1; i++) { - HVX_Vector v0 = *(HVX_Vector *) src0_curr; - HVX_Vector v1 = *(HVX_Vector *) (src0_curr + half_size); - - HVX_Vector v2 = *(HVX_Vector *) theta_curr; - HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN); + HVX_Vector v2 = vtheta[i+0]; + HVX_Vector v3 = vtheta[i+1]; HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta @@ -186,45 +161,34 @@ static void hvx_calc_rope_neox_f32(const float * restrict src0, HVX_Vector v4 = Q6_Vqf32_vsub_Vqf32Vqf32(vx0_c, vx1_s); HVX_Vector v5 = Q6_Vqf32_vadd_Vqf32Vqf32(vx0_s, vx1_c); - *(HVX_Vector *) dst_curr = Q6_Vsf_equals_Vqf32(v4); - *(HVX_Vector *) (dst_curr + half_size) = Q6_Vsf_equals_Vqf32(v5); + vdst[i/2+0] = Q6_Vsf_equals_Vqf32(v4); + vdst[i/2+hv] = Q6_Vsf_equals_Vqf32(v5); + } - src0_curr += VLEN; - theta_curr += 2 * VLEN; - dst_curr += VLEN; + for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) { + const float cos_theta = theta_cache[i+0]; + const float sin_theta = theta_cache[i+1]; + float x0 = src0[i/2]; + float x1 = src0[i/2 + he]; + dst[i/2] = x0 * cos_theta - x1 * sin_theta; + dst[i/2 + he] = x0 * sin_theta + x1 * cos_theta; } } -static void hvx_calc_rope_f32(const float * restrict src0, - float * restrict dst, - const int num_elems, - const float * restrict theta_cache) { - // for (int i = 0; i < num_elems; i += 2) { - //const float cos_theta = theta_cache[i + 0]; - //const float sin_theta = theta_cache[i + 1]; +static inline void hvx_rope_f32_aa(float * restrict dst, const float * restrict src0, uint32_t ne, const float * restrict theta_cache) { + const HVX_Vector * restrict vsrc = (const HVX_Vector *) src0; + const HVX_Vector * restrict vtheta = (const HVX_Vector *) theta_cache; + HVX_Vector * restrict vdst = (HVX_Vector *) dst; - //const float x0 = src[0]; - //const float x1 = src[1]; + uint32_t nvec = (ne / (VLEN_FP32 * 2)) * 2; // 2 vecs per loop, step of two - //dst[0] = x0*cos_theta - x1*sin_theta; - //dst[1] = x0*sin_theta + x1*cos_theta; + #pragma unroll(2) + for (uint32_t i = 0; i < nvec; i+=2) { + HVX_Vector v0 = vsrc[i+0]; + HVX_Vector v1 = vsrc[i+1]; - //src += 2; - //dst += 2; - // } - - const uint8_t * restrict src0_curr = (const uint8_t *) src0; - const uint8_t * restrict theta_curr = (const uint8_t *) theta_cache; - uint8_t * restrict dst_curr = (uint8_t *) dst; - - int step_of_1 = num_elems >> 6; // 6 because we process two vectors at once - - for (int i = 0; i < step_of_1; i++) { - HVX_Vector v0 = *(HVX_Vector *) src0_curr; - HVX_Vector v1 = *(HVX_Vector *) (src0_curr + VLEN); - - HVX_Vector v2 = *(HVX_Vector *) theta_curr; - HVX_Vector v3 = *(HVX_Vector *) (theta_curr + VLEN); + HVX_Vector v2 = vtheta[i+0]; + HVX_Vector v3 = vtheta[i+1]; HVX_VectorPair vx0_x1 = Q6_W_vdeal_VVR(v1, v0, -4); // vx0_x1[0] = x0, vx0_x1[1] = x1 HVX_VectorPair vcos_sin = Q6_W_vdeal_VVR(v3, v2, -4); // vcos_sin[0] = cos_theta, vcos_sin[1] = sin_theta @@ -239,116 +203,65 @@ static void hvx_calc_rope_f32(const float * restrict src0, HVX_VectorPair vstore = Q6_W_vshuff_VVR(Q6_Vsf_equals_Vqf32(v5), Q6_Vsf_equals_Vqf32(v4), -4); - *(HVX_Vector *) dst_curr = Q6_V_lo_W(vstore); - *(HVX_Vector *) (dst_curr + VLEN) = Q6_V_hi_W(vstore); + vdst[i+0] = Q6_V_lo_W(vstore); + vdst[i+1] = Q6_V_hi_W(vstore); + } - src0_curr += 2 * VLEN; - theta_curr += 2 * VLEN; - dst_curr += 2 * VLEN; + for (uint32_t i = nvec * VLEN_FP32; i < ne; i += 2) { + const float cos_theta = theta_cache[i+0]; + const float sin_theta = theta_cache[i+1]; + float x0 = src0[i+0]; + float x1 = src0[i+1]; + dst[i+0] = x0 * cos_theta - x1 * sin_theta; + dst[i+1] = x0 * sin_theta + x1 * cos_theta; } } -static void rope_hex_f32(struct rope_th_ctx * rope_ctx, - const uint32_t ir0, - const uint32_t ir1, - int nth, - int ith, - const int opt_path) { - struct htp_ops_context * octx = rope_ctx->octx; +static void inline rope_basic_f32(struct htp_rope_context * rctx, uint8_t * restrict dst, uint8_t * restrict src, + uint32_t nr, uint32_t ne0, const float * restrict theta_cache) { + #pragma unroll(4) + for (uint32_t i = 0; i < nr; i++) { + float * d = (float *) (dst + i * rctx->dst_row_size_aligned); + float * s = (float *) (src + i * rctx->src0_row_size_aligned); + + hvx_rope_f32_aa(d, s, rctx->n_dims, theta_cache); + + // fill the remain channels with data from src tensor + if (rctx->n_dims < ne0) { + hvx_copy_f32_uu((uint8_t *)(d + rctx->n_dims), (uint8_t *)(s + rctx->n_dims), ne0 - rctx->n_dims); + } + } +} + +static void inline rope_neox_f32(struct htp_rope_context * rctx, uint8_t * restrict dst, uint8_t * restrict src, + uint32_t nr, uint32_t ne0, const float * restrict theta_cache) { + #pragma unroll(4) + for (uint32_t i = 0; i < nr; i++) { + float * d = (float *) (dst + i * rctx->dst_row_size_aligned); + float * s = (float *) (src + i * rctx->src0_row_size_aligned); + + hvx_rope_neox_f32_aa(d, s, rctx->n_dims, theta_cache); + + // fill the remain channels with data from src tensor + if (rctx->n_dims < ne0) { + hvx_copy_f32_uu((uint8_t *)(d + rctx->n_dims), (uint8_t *)(s + rctx->n_dims), ne0 - rctx->n_dims); + } + } +} + +static void rope_job_f32(unsigned int nth, unsigned int ith, void * data) { + struct htp_rope_context * rctx = (struct htp_rope_context *) data; + struct htp_ops_context * octx = rctx->octx; const struct htp_tensor * src0 = &octx->src0; const struct htp_tensor * src1 = &octx->src1; const struct htp_tensor * src2 = &octx->src2; struct htp_tensor * dst = &octx->dst; - const int32_t mode = rope_ctx->mode; - const bool is_neox = mode & HTP_ROPE_TYPE_NEOX; - htp_rope_preamble; - const int32_t * pos = (const int32_t *) src1->data; - - float * wp0 = (float *) (octx->src0_spad.data + (ith * nb01)); - - const float * freq_factors = NULL; - if (src2 != NULL) { - freq_factors = (const float *) src2->data; - } - - const uint32_t i1_end = MIN(ir1, ne1); - const int32_t half_dims = rope_ctx->n_dims / 2; - const size_t remain_bytes = (ne0 - rope_ctx->n_dims) * sizeof(float); - for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch - for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len - const int32_t p = pos[i2]; - - rope_cache_init(p, rope_ctx->freq_scale, freq_factors, rope_ctx->corr_dims, ne0, rope_ctx->ext_factor, - rope_ctx->attn_factor, wp0, rope_ctx->theta_scale); - - for (uint32_t i1 = ir0; i1 < i1_end; i1++) { // attn-heads - const float * src = (float *) ((char *) src0->data + i3 * nb03 + i2 * nb02 + i1 * nb01); - float * dst_data = (float *) ((char *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1); - - const float * src_loc = src; - float * dst_data_loc = dst_data; - - if (1 == opt_path) { - if (is_neox) { - hvx_calc_rope_neox_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0); - } else { - hvx_calc_rope_f32(src_loc, dst_data_loc, rope_ctx->n_dims, wp0); - } - - src_loc += rope_ctx->n_dims; - dst_data_loc += rope_ctx->n_dims; - } else { - for (uint32_t i0 = 0; i0 < rope_ctx->n_dims; i0 += 2) { - const float cos_theta = wp0[i0 + 0]; - const float sin_theta = wp0[i0 + 1]; - - if (is_neox) { - const float x0 = src_loc[0]; - const float x1 = src_loc[half_dims]; - - dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta; - dst_data_loc[half_dims] = x0 * sin_theta + x1 * cos_theta; - - src_loc += 1; - dst_data_loc += 1; - } else { - const float x0 = src_loc[0]; - const float x1 = src_loc[1]; - - dst_data_loc[0] = x0 * cos_theta - x1 * sin_theta; - dst_data_loc[1] = x0 * sin_theta + x1 * cos_theta; - - src_loc += 2; - dst_data_loc += 2; - } - } - - src_loc += (is_neox ? half_dims : 0); - dst_data_loc += (is_neox ? half_dims : 0); - } - - // TODO: use simd to speed up the remaining elements copy - memcpy(dst_data_loc, src_loc, remain_bytes); - } - } - } -} - -static void rope_job_f32_per_thread(struct rope_th_ctx * rope_ctx, int nth, int ith) { - struct htp_ops_context * octx = rope_ctx->octx; - - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - struct htp_tensor * dst = &octx->dst; - - htp_rope_preamble; - - const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows - const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; + const uint32_t src0_nrows = rctx->src0_nrows; + const uint32_t src0_nrows_per_thread = rctx->src0_nrows_per_thread; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); @@ -358,32 +271,114 @@ static void rope_job_f32_per_thread(struct rope_th_ctx * rope_ctx, int nth, int return; } - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); + uint64_t tt = HAP_perf_get_qtimer_count(); - int is_aligned = 1; - int opt_path = 0; - if ((0 == hex_is_aligned((void *) src0->data, VLEN)) || (0 == hex_is_aligned((void *) src1->data, VLEN)) || - (0 == hex_is_aligned((void *) dst->data, VLEN))) { - FARF(HIGH, "rope-f32: unaligned addresses in rope op, possibly slower execution\n"); - is_aligned = 0; - } - if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { - opt_path = 1; + const int32_t mode = rctx->mode; + const bool is_neox = mode & HTP_ROPE_TYPE_NEOX; + + // VTCM setup + uint8_t * src0_spad_base = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + float * theta_cache = (float *) (src0_spad_base); + src0_spad_base = src0_spad_base + rctx->theta_cache_offset; + uint8_t * dst_spad_base = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); + + dma_queue * dma_queue = octx->ctx->dma[ith]; + const int32_t * pos = (const int32_t *) src1->data; + const float * freq_factors = src2->data ? (const float *) src2->data : NULL; + + uint32_t ir = 0; + uint32_t prev_i2 = (uint32_t) -1; + + for (uint32_t i3 = 0; i3 < ne3; i3++) { // batch + for (uint32_t i2 = 0; i2 < ne2; i2++) { // seq-len + for (uint32_t i1 = 0; i1 < ne1; ) { // attn-heads + if (ir < src0_start_row) { ir++; i1++; continue; } + if (ir >= src0_end_row) goto done; + + // Rows in this block + const uint32_t nrows = MIN(src0_end_row - ir, ne1 - i1); + + // Depth before prefetch + uint32_t dma_depth = dma_queue_depth(dma_queue); + + // FARF(HIGH, "rope-block %u: ir %u n-rows %u dma-depth %u : usec %u", ith, ir, nrows, dma_depth, + // (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start)); + + // Prefetch loop + for (uint32_t pnr = 0, pr = 0; pr < nrows && pr < HTP_ROPE_SPAD_NROWS; pr += pnr) { + pnr = MIN(nrows - pr, HTP_ROPE_SPAD_BLOCK); + + uint32_t pi1 = i1 + pr; + uint32_t pir = ir + pr; + + // Dummy DMA transaction for sequencing (interleaving dst,src,dst,...) + dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr((void *) dst->data, dst_spad_base + pr * rctx->dst_row_size_aligned), 0, 0, 0); + + const uint8_t * src_addr = (const uint8_t *) src0->data + i3 * nb03 + i2 * nb02 + pi1 * nb01; + uint8_t * src_spad = src0_spad_base + pr * rctx->src0_row_size_aligned; + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src_spad, src_addr), + rctx->src0_row_size_aligned, rctx->src0_row_size, pnr); + + // FARF(HIGH, "rope-prefetch %u: pr %u i1 %u i2 %u i3 %u src-spad %p src-addr %p pnr %u", ith, pir, pi1, i2, i3, src_spad, src_addr, pnr); + } + + // Update theta cache + if (i2 != prev_i2) { + prev_i2 = i2; + + const int32_t p = pos[i2]; + rope_cache_init(p, rctx->freq_scale, freq_factors, rctx->corr_dims, ne0, rctx->ext_factor, rctx->attn_factor, theta_cache, rctx->theta_scale); + + // FARF(HIGH, "rope-theta %u: ir %u i1 %u i2 %u i3 %u cache %p : usec %u", ith, ir, i1, i2, i3, theta_cache, + // (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start)); + } + + // Skip DMA transactions from prev block (if any) + // No need to wait for these since the DMA is setup for in-order processing + for (uint32_t d=0; d < dma_depth; d++) { dma_queue_pop_nowait(dma_queue); } + + // Compute loop + for (uint32_t cnr = 0, cr = 0; cr < nrows; cr += cnr, ir += cnr, i1 += cnr) { + // Number of rows to compute + cnr = MIN(nrows - cr, HTP_ROPE_SPAD_BLOCK); + + uint8_t * dst_spad = (uint8_t *) dma_queue_pop(dma_queue).src; + uint8_t * src_spad = (uint8_t *) dma_queue_pop(dma_queue).dst; + + // FARF(HIGH, "rope-compute %u: ir %u i1 %u i2 %u i3 %u src-spad %p cnr %u : usec %u", ith, ir, i1, i2, i3, src_spad, cnr, + // (unsigned) HAP_perf_qtimer_count_to_us(HAP_perf_get_qtimer_count() - rctx->t_start)); + + if (is_neox) { + rope_neox_f32(rctx, dst_spad, src_spad, cnr, ne0, theta_cache); + } else { + rope_basic_f32(rctx, dst_spad, src_spad, cnr, ne0, theta_cache); + } + + uint8_t * dst_addr = (uint8_t *) dst->data + i3 * nb3 + i2 * nb2 + i1 * nb1; + dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(dst_addr, dst_spad), rctx->dst_row_size, rctx->dst_row_size_aligned, cnr); + + // Prefetch more rows (if any) + if ((cr + HTP_ROPE_SPAD_NROWS) < nrows) { + uint32_t pnr = MIN(nrows - (cr + HTP_ROPE_SPAD_NROWS), HTP_ROPE_SPAD_BLOCK); + uint32_t pi1 = i1 + HTP_ROPE_SPAD_NROWS; + uint32_t pir = ir + HTP_ROPE_SPAD_NROWS; + + const uint8_t * src_addr = (const uint8_t *) src0->data + i3 * nb03 + i2 * nb02 + pi1 * nb01; + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src_spad, src_addr), + rctx->src0_row_size_aligned, rctx->src0_row_size, pnr); + + // FARF(HIGH, "rope-prefetch %u: pr %u i1 %u i2 %u i3 %u src-spad %p src-addr %p pnr %u", ith, pir, pi1, i2, i3, src_spad, src_addr, pnr); + } + } + } + } } - rope_hex_f32(rope_ctx, src0_start_row, src0_end_row, nth, ith, opt_path); +done: + dma_queue_flush(dma_queue); + tt = HAP_perf_get_qtimer_count() - tt; - t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "rope-f32: %d/%d/%d: (%u:%u) usec %u\n", ith, nth, opt_path, src0_start_row, src0_end_row, - (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); -} - -static void rope_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) { - struct rope_th_ctx * rope_ctx = (struct rope_th_ctx *) data; - - rope_job_f32_per_thread(rope_ctx, n, i); + FARF(HIGH, "rope-f32: %d/%d: (%u:%u) usec %u\n", ith, nth, src0_start_row, src0_end_row, (unsigned) HAP_perf_qtimer_count_to_us(tt)); } static int execute_op_rope_f32(struct htp_ops_context * octx) { @@ -394,17 +389,10 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) { const struct htp_tensor * src2 = &octx->src2; struct htp_tensor * dst = &octx->dst; - worker_callback_t op_func; - const char * op_type = NULL; - - struct rope_th_ctx rope_ctx; + const char * op_type = "rope-f32"; switch (octx->op) { case HTP_OP_ROPE: - op_func = rope_job_dispatcher_f32; - op_type = "rope-f32"; - - init_rope_ctx(&rope_ctx, octx); break; default: @@ -415,49 +403,79 @@ static int execute_op_rope_f32(struct htp_ops_context * octx) { const uint32_t n_threads = octx->n_threads; const size_t src0_row_size = src0->nb[1]; - const size_t src1_row_size = src0_row_size; const size_t dst_row_size = dst->nb[1]; - // VTCM scratchpads for all tensors - // N rows per thread, padded to HVX vector size - octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads; - octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads; - octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads; + // Aligned row sizes for VTCM + const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); + const size_t theta_cache_size_aligned = hex_round_up(src0->ne[0] * sizeof(float), 128); - size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; + // Calculate spad sizes per thread + size_t src0_spad_per_thread = theta_cache_size_aligned + HTP_ROPE_SPAD_NROWS * src0_row_size_aligned; + size_t dst_spad_per_thread = HTP_ROPE_SPAD_NROWS * dst_row_size_aligned; + size_t spad_per_thread = src0_spad_per_thread + dst_spad_per_thread; - if (src2->ne[0]) { - FARF(HIGH, - "%s: %ux%ux%ux%u (x %ux%ux%ux%u x %ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u " - "dst-spad-size %u\n", - op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], - src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], - dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); - } else { - FARF(HIGH, - "%s: %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", - op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], - src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, - octx->dst_spad.size); - } - - // Make sure the reserved vtcm size is sufficient - if (octx->ctx->vtcm_size < spad_size) { - FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, - spad_size); + // Check if we fit in VTCM + size_t total_vtcm_needed = spad_per_thread * n_threads; + if (octx->ctx->vtcm_size < total_vtcm_needed) { + FARF(ERROR, "%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, total_vtcm_needed); return HTP_STATUS_VTCM_TOO_SMALL; } - octx->src0_spad.data = octx->ctx->vtcm_base; - octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; - octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; + // Assign sizes + octx->src0_spad.size_per_thread = src0_spad_per_thread; + octx->dst_spad.size_per_thread = dst_spad_per_thread; + octx->src0_spad.size = n_threads * src0_spad_per_thread; + octx->dst_spad.size = n_threads * dst_spad_per_thread; + octx->src1_spad.size = 0; + // Assign pointers + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src1_spad.data = NULL; + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; + + // Fill context + struct htp_rope_context rctx; + memset(&rctx, 0, sizeof(struct htp_rope_context)); + + rctx.t_start = HAP_perf_get_qtimer_count(); + + rctx.octx = octx; + + const int32_t * op_params = &octx->op_params[0]; + rctx.n_dims = ((const int32_t *) op_params)[1]; + rctx.mode = ((const int32_t *) op_params)[2]; + rctx.n_ctx_orig = ((const int32_t *) op_params)[4]; + + memcpy(&rctx.freq_base, (int32_t *) op_params + 5, sizeof(float)); + memcpy(&rctx.freq_scale, (int32_t *) op_params + 6, sizeof(float)); + memcpy(&rctx.ext_factor, (int32_t *) op_params + 7, sizeof(float)); + memcpy(&rctx.attn_factor, (int32_t *) op_params + 8, sizeof(float)); + memcpy(&rctx.beta_fast, (int32_t *) op_params + 9, sizeof(float)); + memcpy(&rctx.beta_slow, (int32_t *) op_params + 10, sizeof(float)); + memcpy(&rctx.sections, (int32_t *) op_params + 11, sizeof(int) * 4); + + rctx.theta_scale = powf(rctx.freq_base, -2.0f / rctx.n_dims); + + rope_corr_dims(rctx.n_dims, rctx.n_ctx_orig, rctx.freq_base, rctx.beta_fast, rctx.beta_slow, rctx.corr_dims); + + rctx.src0_row_size = src0_row_size; + rctx.dst_row_size = dst_row_size; + rctx.src0_row_size_aligned = src0_row_size_aligned; + rctx.dst_row_size_aligned = dst_row_size_aligned; + rctx.theta_cache_offset = theta_cache_size_aligned; + + uint32_t ne0 = dst->ne[0]; uint32_t src0_nrows = src0->ne[1] * src0->ne[2] * src0->ne[3]; + rctx.src0_nrows = src0_nrows; + + FARF(HIGH, "rope-f32 n-rows %u n-dims %d ne0 %u ext-factor %.6f theta-scale %.6f attn-factor %.6f\n", rctx.src0_nrows, rctx.n_dims, ne0, + rctx.ext_factor, rctx.theta_scale, rctx.attn_factor); if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - uint32_t n_jobs = MIN(n_threads, src0_nrows); - octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; - worker_pool_run_func(octx->ctx->worker_pool, op_func, &rope_ctx, n_jobs); + uint32_t n_jobs = MIN(n_threads, src0_nrows); + rctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + worker_pool_run_func(octx->ctx->worker_pool, rope_job_f32, &rctx, n_jobs); } return err; diff --git a/ggml/src/ggml-hexagon/htp/set-rows-ops.c b/ggml/src/ggml-hexagon/htp/set-rows-ops.c index 904484da9d..2fd6c90772 100644 --- a/ggml/src/ggml-hexagon/htp/set-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/set-rows-ops.c @@ -43,11 +43,21 @@ \ const uint32_t nr = ne01; -static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) { +struct htp_set_rows_context { + struct htp_ops_context * octx; + struct fastdiv_values div_ne12; + struct fastdiv_values div_ne11; + uint32_t src0_nrows_per_thread; +}; + +static void set_rows_thread_f32_f32(unsigned int nth, unsigned int ith, void *data) { + struct htp_set_rows_context * srctx = (struct htp_set_rows_context *)data; + struct htp_ops_context * octx = srctx->octx; + set_rows_preamble; // parallelize by rows of src0 - const uint32_t dr = octx->src0_nrows_per_thread; + const uint32_t dr = srctx->src0_nrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; @@ -56,8 +66,8 @@ static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, for (uint32_t i03 = 0; i03 < ne03; ++i03) { for (uint32_t i02 = 0; i02 < ne02; ++i02) { for (uint32_t i = ir0; i < ir1; ++i) { - const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12); - const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11); + const uint32_t i12 = fastmodulo(i03, ne12, &srctx->div_ne12); + const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11); const uint32_t i10 = i; const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; @@ -76,15 +86,16 @@ static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, } } } - - return HTP_STATUS_OK; } -static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, const int ith) { +static void set_rows_thread_f16_f32(unsigned int nth, unsigned int ith, void *data) { + struct htp_set_rows_context * srctx = (struct htp_set_rows_context *)data; + struct htp_ops_context * octx = srctx->octx; + set_rows_preamble; // parallelize by rows of src0 - const uint32_t dr = octx->src0_nrows_per_thread; + const uint32_t dr = srctx->src0_nrows_per_thread; const uint32_t ir0 = dr * ith; const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; @@ -93,8 +104,8 @@ static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, for (uint32_t i03 = 0; i03 < ne03; ++i03) { for (uint32_t i02 = 0; i02 < ne02; ++i02) { for (uint32_t i = ir0; i < ir1; ++i) { - const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12); - const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11); + const uint32_t i12 = fastmodulo(i03, ne12, &srctx->div_ne12); + const uint32_t i11 = fastmodulo(i02, ne11, &srctx->div_ne11); const uint32_t i10 = i; const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; @@ -112,16 +123,6 @@ static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, } } } - - return HTP_STATUS_OK; -} - -static void set_rows_work_f16_f32(unsigned int n, unsigned int i, void *data) { - set_rows_thread_f16_f32((struct htp_ops_context *) data, n, i); -} - -static void set_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) { - set_rows_thread_f32_f32((struct htp_ops_context *) data, n, i); } int op_set_rows(struct htp_ops_context * octx) { @@ -143,18 +144,20 @@ int op_set_rows(struct htp_ops_context * octx) { return HTP_STATUS_OK; } - octx->set_rows_div_ne12 = init_fastdiv_values(ne12); - octx->set_rows_div_ne11 = init_fastdiv_values(ne11); + struct htp_set_rows_context srctx; + srctx.octx = octx; + srctx.div_ne12 = init_fastdiv_values(ne12); + srctx.div_ne11 = init_fastdiv_values(ne11); const uint32_t n_jobs = MIN(nr, octx->n_threads); - octx->src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; + srctx.src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; switch(octx->dst.type) { case HTP_TYPE_F32: - worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f32_f32, octx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f32_f32, &srctx, n_jobs); break; case HTP_TYPE_F16: - worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f16_f32, octx, n_jobs); + worker_pool_run_func(octx->ctx->worker_pool, set_rows_thread_f16_f32, &srctx, n_jobs); break; default: return HTP_STATUS_NO_SUPPORT; diff --git a/ggml/src/ggml-hexagon/htp/softmax-ops.c b/ggml/src/ggml-hexagon/htp/softmax-ops.c index e91a16d947..6e22eb6a63 100644 --- a/ggml/src/ggml-hexagon/htp/softmax-ops.c +++ b/ggml/src/ggml-hexagon/htp/softmax-ops.c @@ -10,6 +10,7 @@ #include "hex-dma.h" #include "hvx-utils.h" +#include "hex-fastdiv.h" #define GGML_COMMON_DECL_C #include "ggml-common.h" @@ -48,7 +49,7 @@ const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; -struct softmax_th_ctx { +struct htp_softmax_context { bool use_f16; bool use_src1; uint32_t n_head; @@ -59,28 +60,48 @@ struct softmax_th_ctx { float m0; float m1; + uint32_t src0_nrows_per_thread; + struct fastdiv_values fastdiv_ne01; + struct fastdiv_values fastdiv_ne02; + struct fastdiv_values fastdiv_ne12; // For mask broadcasting + struct fastdiv_values fastdiv_ne13; // For mask broadcasting + size_t spad_stride; + struct htp_ops_context * octx; }; -static void init_softmax_ctx(struct softmax_th_ctx * softmax_ctx, struct htp_ops_context * octx) { +static void init_softmax_ctx(struct htp_softmax_context * smctx, struct htp_ops_context * octx) { const struct htp_tensor * src0 = &octx->src0; const struct htp_tensor * src1 = &octx->src1; - memset(softmax_ctx, 0, sizeof(struct softmax_th_ctx)); + memset(smctx, 0, sizeof(struct htp_softmax_context)); - memcpy(&softmax_ctx->scale, (float *) octx->op_params, sizeof(float)); - memcpy(&softmax_ctx->max_bias, (float *) octx->op_params + 1, sizeof(float)); + memcpy(&smctx->scale, (float *) octx->op_params, sizeof(float)); + memcpy(&smctx->max_bias, (float *) octx->op_params + 1, sizeof(float)); - softmax_ctx->n_head = src0->ne[2]; - softmax_ctx->n_head_log2 = 1u << (uint32_t) floor(log2(softmax_ctx->n_head)); + smctx->n_head = src0->ne[2]; + smctx->n_head_log2 = 1u << (uint32_t) floor(log2(smctx->n_head)); - softmax_ctx->m0 = powf(2.0f, -(softmax_ctx->max_bias) / softmax_ctx->n_head_log2); - softmax_ctx->m1 = powf(2.0f, -(softmax_ctx->max_bias / 2.0f) / softmax_ctx->n_head_log2); + smctx->m0 = powf(2.0f, -(smctx->max_bias) / smctx->n_head_log2); + smctx->m1 = powf(2.0f, -(smctx->max_bias / 2.0f) / smctx->n_head_log2); - softmax_ctx->use_src1 = (src1->ne[0] != 0); - softmax_ctx->use_f16 = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16); + smctx->use_src1 = (src1->ne[0] != 0); + smctx->use_f16 = (src1->ne[0] != 0) && (src1->type == HTP_TYPE_F16); - softmax_ctx->octx = octx; + smctx->octx = octx; + + // Initialize fastdiv values + const uint32_t ne01 = src0->ne[1]; + const uint32_t ne02 = src0->ne[2]; + + if (ne01 > 0) smctx->fastdiv_ne01 = init_fastdiv_values(ne01); + if (ne02 > 0) smctx->fastdiv_ne02 = init_fastdiv_values(ne02); + + const uint32_t ne12 = (src1->ne[0]) ? src1->ne[2] : 1; + const uint32_t ne13 = (src1->ne[0]) ? src1->ne[3] : 1; + + if (ne12 > 0) smctx->fastdiv_ne12 = init_fastdiv_values(ne12); + if (ne13 > 0) smctx->fastdiv_ne13 = init_fastdiv_values(ne13); } static void hvx_fast_softmax_prep_f32(const uint8_t * restrict src, @@ -139,8 +160,7 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src, max_vec = Q6_Vsf_vmax_VsfVsf(max_vec, v1); } - HVX_Vector v = hvx_vec_reduce_max_f32(max_vec); - max_vec = hvx_vec_repl4(v); + max_vec = hvx_vec_reduce_max_f32(max_vec); // replicated over all lanes #pragma unroll(4) for (int i = 0; i < step_of_1; i++) { @@ -154,8 +174,7 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src, v_pad[i] = v3; } - v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_vec)); - sum_vec = hvx_vec_repl4(v); + sum_vec = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_vec)); // replicated over all lanes HVX_VectorPred pos_sum = Q6_Q_vcmp_gt_VwVw(sum_vec, zero_v); HVX_Vector v4 = hvx_vec_inverse_f32(sum_vec); @@ -183,83 +202,9 @@ static float hvx_softmax_f32(const uint8_t * restrict src, return sum; } -static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ctx, int opt_path) { - struct htp_ops_context * octx = softmax_ctx->octx; - - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - const struct htp_tensor * dst = &octx->dst; - - htp_softmax_preamble3; - - uint8_t * src0_spad_data = octx->src0_spad.data + (ith * nb01); - uint8_t * src1_spad_data = octx->src1_spad.data + (ith * nb01); - uint8_t * dst_spad_data = octx->dst_spad.data + (ith * nb1); - - float * wp0 = (float *) src0_spad_data; - float * wp1 = (float *) src1_spad_data; - float * wp2 = (float *) dst_spad_data; - - for (uint32_t i03 = 0; i03 < ne03; i03++) { - for (uint32_t i02 = 0; i02 < ne02; i02++) { - for (uint32_t i01 = ith; i01 < ne01; i01 += nth) { - const uint32_t i11 = i01; - const uint32_t i12 = i02 % ne12; - const uint32_t i13 = i03 % ne13; - - // ALiBi - const uint32_t h = i02; // head - - const float slope = (softmax_ctx->max_bias > 0.0f) ? - h < softmax_ctx->n_head_log2 ? - powf(softmax_ctx->m0, h + 1) : - powf(softmax_ctx->m1, 2 * (h - softmax_ctx->n_head_log2) + 1) : - 1.0f; - - float * sp = (float *) ((char *) octx->src0.data + i01 * nb01 + i02 * nb02 + i03 * nb03); - float * dp = (float *) ((char *) octx->dst.data + i01 * nb1 + i02 * nb2 + i03 * nb3); - - // broadcast the mask across rows - __fp16 * mp_f16 = (softmax_ctx->use_src1) ? - (__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) : - NULL; - float * mp_f32 = (softmax_ctx->use_src1) ? - (float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) : - NULL; - - if ((1 == opt_path) && (mp_f32) && !(softmax_ctx->use_f16)) { - hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale, - (const uint8_t *) mp_f32, slope); - } else { - hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, softmax_ctx->scale); - if (mp_f32) { - if (softmax_ctx->use_f16) { - for (int i = 0; i < ne00; ++i) { - wp0[i] += slope * (float) mp_f16[i]; - } - } else { - for (int i = 0; i < ne00; ++i) { - wp0[i] += slope * mp_f32[i]; - } - } - } - } - - if (1 == opt_path) { - hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00); - } else { - float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00); - float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max); - sum = sum > 0.0 ? (1.0 / sum) : 1; - hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum); - } - } - } - } -} - -static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int nth, int ith) { - struct htp_ops_context * octx = softmax_ctx->octx; +static void softmax_job_f32(unsigned int nth, unsigned int ith, void * data) { + struct htp_softmax_context * smctx = (struct htp_softmax_context *) data; + struct htp_ops_context * octx = smctx->octx; const struct htp_tensor * src0 = &octx->src0; const struct htp_tensor * src1 = &octx->src1; @@ -268,7 +213,7 @@ static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int htp_softmax_preamble3; const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows - const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; + const uint32_t src0_nrows_per_thread = smctx->src0_nrows_per_thread; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); @@ -291,20 +236,103 @@ static void softmax_job_f32_per_thread(struct softmax_th_ctx * softmax_ctx, int opt_path = 1; } - softmax_htp_f32(nth, ith, softmax_ctx, opt_path); + uint8_t * src0_spad_data = octx->src0_spad.data + (ith * smctx->spad_stride); + uint8_t * src1_spad_data = octx->src1_spad.data + (ith * smctx->spad_stride); + uint8_t * dst_spad_data = octx->dst_spad.data + (ith * smctx->spad_stride); + + float * wp0 = (float *) src0_spad_data; + float * wp1 = (float *) src1_spad_data; + float * wp2 = (float *) dst_spad_data; + + uint32_t prev_i2 = (uint32_t)-1; + float slope = 1.0f; + + for (uint32_t r = src0_start_row; r < src0_end_row; ++r) { + uint32_t i1 = fastmodulo(r, ne01, &smctx->fastdiv_ne01); + uint32_t r_div_ne01 = fastdiv(r, &smctx->fastdiv_ne01); + uint32_t i2 = fastmodulo(r_div_ne01, ne02, &smctx->fastdiv_ne02); + uint32_t i3 = fastdiv(r_div_ne01, &smctx->fastdiv_ne02); + + // Map to original logic indices + // i01 = i1 + // i02 = i2 + // i03 = i3 + + const uint32_t i11 = i1; + // const uint32_t i12 = i2 % ne12; + // const uint32_t i13 = i3 % ne13; + + uint32_t i12, i13; + if (ne12 == ne02) { + i12 = i2; + } else { + i12 = fastmodulo(i2, ne12, &smctx->fastdiv_ne12); + } + + if (ne13 == ne03) { + i13 = i3; + } else { + i13 = fastmodulo(i3, ne13, &smctx->fastdiv_ne13); + } + + // ALiBi + if (i2 != prev_i2) { + const uint32_t h = i2; // head + + slope = (smctx->max_bias > 0.0f) ? + h < smctx->n_head_log2 ? + powf(smctx->m0, h + 1) : + powf(smctx->m1, 2 * (h - smctx->n_head_log2) + 1) : + 1.0f; + prev_i2 = i2; + } + + float * sp = (float *) ((char *) octx->src0.data + i1 * nb01 + i2 * nb02 + i3 * nb03); + float * dp = (float *) ((char *) octx->dst.data + i1 * nb1 + i2 * nb2 + i3 * nb3); + + // broadcast the mask across rows + __fp16 * mp_f16 = (smctx->use_src1) ? + (__fp16 *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) : + NULL; + float * mp_f32 = (smctx->use_src1) ? + (float *) ((char *) octx->src1.data + i11 * nb11 + i12 * nb12 + i13 * nb13) : + NULL; + + if ((1 == opt_path) && (mp_f32) && !(smctx->use_f16)) { + hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, smctx->scale, + (const uint8_t *) mp_f32, slope); + } else { + hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, smctx->scale); + if (mp_f32) { + if (smctx->use_f16) { + for (int i = 0; i < ne00; ++i) { + wp0[i] += slope * (float) mp_f16[i]; + } + } else { + for (int i = 0; i < ne00; ++i) { + wp0[i] += slope * mp_f32[i]; + } + } + } + } + + if (1 == opt_path) { + hvx_fast_softmax_f32((const uint8_t *) wp0, (uint8_t *) dp, (uint8_t *) wp1, ne00); + } else { + float max = hvx_reduce_max_f32((const uint8_t *) wp0, ne00); + float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max); + sum = sum > 0.0 ? (1.0 / sum) : 1; + hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum); + } + } t2 = HAP_perf_get_qtimer_count(); FARF(HIGH, "softmax-f32 %d/%d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, - softmax_ctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, + smctx->use_f16, opt_path, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void softmax_job_dispatcher_f32(unsigned int n, unsigned int i, void * p_data) { - struct softmax_th_ctx * p_softmax_ctx = (struct softmax_th_ctx *) p_data; - softmax_job_f32_per_thread(p_softmax_ctx, n, i); -} - static int execute_op_softmax_f32(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; @@ -312,17 +340,12 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) { const struct htp_tensor * src1 = &octx->src1; struct htp_tensor * dst = &octx->dst; - worker_callback_t op_func; - const char * op_type = NULL; - - struct softmax_th_ctx softmax_ctx; + struct htp_softmax_context smctx; + const char * op_type = "softmax-f32"; switch (octx->op) { case HTP_OP_SOFTMAX: - op_func = softmax_job_dispatcher_f32; - op_type = "softmax-f32"; - - init_softmax_ctx(&softmax_ctx, octx); + init_softmax_ctx(&smctx, octx); break; default: @@ -342,6 +365,9 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) { octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads; octx->src1_spad.size = hex_round_up(src1_row_size, 128) * n_threads; + // Use stride for calculating offset + smctx.spad_stride = hex_round_up(src0_row_size, 128); + size_t spad_size = octx->src0_spad.size + octx->src1_spad.size + octx->dst_spad.size; if (src1->ne[0]) { @@ -371,8 +397,8 @@ static int execute_op_softmax_f32(struct htp_ops_context * octx) { if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { uint32_t n_jobs = MIN(n_threads, src0_nrows); - octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; - worker_pool_run_func(octx->ctx->worker_pool, op_func, &softmax_ctx, n_jobs); + smctx.src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + worker_pool_run_func(octx->ctx->worker_pool, softmax_job_f32, &smctx, n_jobs); } return err; diff --git a/ggml/src/ggml-hexagon/htp/sum-rows-ops.c b/ggml/src/ggml-hexagon/htp/sum-rows-ops.c index 62e45da2b3..04fa72182a 100644 --- a/ggml/src/ggml-hexagon/htp/sum-rows-ops.c +++ b/ggml/src/ggml-hexagon/htp/sum-rows-ops.c @@ -17,7 +17,6 @@ #include "htp-msg.h" #include "htp-ops.h" - #define sum_rows_preamble \ struct htp_tensor *src0 = &octx->src0;\ struct htp_tensor *dst = &octx->dst; \ @@ -42,53 +41,54 @@ const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; \ -static int sum_rows_thread_f32(struct htp_ops_context * octx, const int nth, const int ith) { - sum_rows_preamble; +struct sum_rows_context { + const uint8_t * src_data; + uint8_t * dst_data; + uint32_t ne00; + size_t src_stride; + size_t dst_stride; + uint32_t rows_per_thread; + uint32_t total_rows; + bool opt_path; +}; - const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; - const size_t src0_row_size = nb01; - const size_t dst_row_size = nb1; +static void sum_rows_thread_f32(unsigned int nth, unsigned int ith, void *data) { + const struct sum_rows_context * smctx = (const struct sum_rows_context *) data; - const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + const uint32_t rows_per_thread = smctx->rows_per_thread; + const uint32_t total_rows = smctx->total_rows; - const uint32_t src0_start_row = src0_nrows_per_thread * ith; - const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + const uint32_t start_row = rows_per_thread * ith; + const uint32_t end_row = MIN(start_row + rows_per_thread, total_rows); - // no work for this thread - if (src0_start_row >= src0_end_row) { - return HTP_STATUS_OK; + if (start_row >= end_row) { + return; } - int opt_path = 0; - if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) { - opt_path = 1; - } + const size_t src_stride = smctx->src_stride; + const size_t dst_stride = smctx->dst_stride; + const uint32_t ne00 = smctx->ne00; + const bool opt_path = smctx->opt_path; - const uint8_t * restrict data_src = (const uint8_t *) src0->data; - uint8_t * restrict data_dst = (uint8_t *) dst->data; + const float * restrict src_th = (const float *) (smctx->src_data + (start_row * src_stride)); + float * restrict dst_th = (float *) (smctx->dst_data + (start_row * dst_stride)); - const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size)); - float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size)); + // Calculate actual number of rows for this thread + const uint32_t n_rows = end_row - start_row; - for (uint32_t ir = 0; ir < src0_nrows_per_thread; ir++) { - const float * restrict src_local = src_th + (ir * ne00); + for (uint32_t ir = 0; ir < n_rows; ir++) { + const float * restrict src_local = src_th + (ir * (src_stride / sizeof(float))); - if (ir + 1 < src0_nrows_per_thread) { - hex_l2fetch(src_local + ne00, src0_row_size, src0_row_size, 1); + if (ir + 1 < n_rows) { + hex_l2fetch(src_local + (src_stride / sizeof(float)), src_stride, src_stride, 1); } - if (1 == opt_path) { + if (opt_path) { dst_th[ir] = hvx_reduce_sum_f32_a((const uint8_t *) src_local, ne00); } else { dst_th[ir] = hvx_reduce_sum_f32((const uint8_t *) src_local, ne00); } } - - return HTP_STATUS_OK; -} - -static void sum_rows_work_f32(unsigned int n, unsigned int i, void *data) { - sum_rows_thread_f32((struct htp_ops_context *) data, n, i); } int op_sum_rows(struct htp_ops_context * octx) { @@ -106,10 +106,25 @@ int op_sum_rows(struct htp_ops_context * octx) { const uint32_t src0_nrows = ne01 * ne02 * ne03; uint32_t n_jobs = MIN(n_threads, src0_nrows); - octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + uint32_t rows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; - worker_pool_run_func(octx->ctx->worker_pool, sum_rows_work_f32, octx, n_jobs); + bool opt_path = false; + if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) { + opt_path = true; + } + + struct sum_rows_context smctx = { + .src_data = (const uint8_t *) src0->data, + .dst_data = (uint8_t *) dst->data, + .ne00 = ne00, + .src_stride = nb01, + .dst_stride = nb1, + .rows_per_thread = rows_per_thread, + .total_rows = src0_nrows, + .opt_path = opt_path, + }; + + worker_pool_run_func(octx->ctx->worker_pool, sum_rows_thread_f32, &smctx, n_jobs); return HTP_STATUS_OK; } - diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index ce879bf037..98135c50ab 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -17,6 +17,28 @@ #include "htp-msg.h" #include "htp-ops.h" +struct htp_unary_context { + struct htp_ops_context * octx; + + // Precomputed values + const uint8_t * data_src0; + uint8_t * data_dst; + + size_t src0_row_size; + size_t dst_row_size; + + size_t src0_row_size_aligned; + size_t dst_row_size_aligned; + + size_t src0_spad_half_size; + size_t dst_spad_half_size; + + uint32_t block; + uint32_t src0_nrows; + uint32_t src0_nrows_per_thread; + uint32_t nc; +}; + #define htp_unary_preamble \ const uint32_t ne00 = src->ne[0]; \ const uint32_t ne01 = src->ne[1]; \ @@ -57,8 +79,7 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src, sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2); } - HVX_Vector reduced_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); - sum_v = hvx_vec_repl4(reduced_sum); + sum_v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v)); // replicated over all lanes HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems); HVX_Vector denom_v = hvx_vec_inverse_f32(t_v); @@ -75,128 +96,95 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src, } } -static void scale_htp_f32(const float * restrict src, - float * restrict dst, - uint8_t * restrict spad, - const uint32_t num_rows, - const uint32_t row_elems, - const size_t row_size, - int32_t * op_params, - int opt_path) { +static void scale_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { float scale = 0.f; float bias = 0.f; memcpy(&scale, &op_params[0], sizeof(float)); memcpy(&bias, &op_params[1], sizeof(float)); for (uint32_t ir = 0; ir < num_rows; ir++) { - const float * restrict src_local = src + (ir * row_elems); - float * restrict dst_local = dst + (ir * row_elems); + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); - if (ir + 1 < num_rows) { - hex_l2fetch(src_local + row_elems, row_size, row_size, 1); - } - - hvx_scale_offset_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias); + hvx_scale_offset_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias); } } -static void rms_norm_htp_f32(const float * restrict src, - float * restrict dst, - uint8_t * restrict spad, - const uint32_t num_rows, - const uint32_t row_elems, - const size_t row_size, - int32_t * op_params, - int opt_path) { +static void rms_norm_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { float epsilon = 0.f; memcpy(&epsilon, op_params, sizeof(float)); for (uint32_t ir = 0; ir < num_rows; ir++) { - const float * restrict src_local = src + (ir * row_elems); - float * restrict dst_local = dst + (ir * row_elems); + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); - if (ir + 1 < num_rows) { - hex_l2fetch(src_local + row_elems, row_size, row_size, 1); - } - - if (1 == opt_path) { - hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon); - } else { - float sum = hvx_sum_of_squares_f32((const uint8_t *) src_local, row_elems); - - const float mean = sum / row_elems; - const float scale = 1.0f / sqrtf(mean + epsilon); - - hvx_scale_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale); - } + hvx_fast_rms_norm_f32((const uint8_t *) src_local, (uint8_t *) dst_local, spad, row_elems, epsilon); } } -static void sqr_htp_f32(const float * restrict src, - float * restrict dst, - uint8_t * restrict spad, - const uint32_t num_rows, - const uint32_t row_elems, - const size_t row_size, - int32_t * op_params, - int opt_path) { +static void sqr_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { for (uint32_t ir = 0; ir < num_rows; ir++) { - const float * restrict src_local = src + (ir * row_elems); - float * restrict dst_local = dst + (ir * row_elems); + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); - if (ir + 1 < num_rows) { - hex_l2fetch(src_local + row_elems, row_size, row_size, 1); - } - - if (1 == opt_path) { - hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); - } else { - hvx_sqr_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); - } + hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); } } -static void sqrt_htp_f32(const float * restrict src, - float * restrict dst, - uint8_t * restrict spad, - const uint32_t num_rows, - const uint32_t row_elems, - const size_t row_size, - int32_t * op_params, - int opt_path) { +static void sqrt_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params) { for (uint32_t ir = 0; ir < num_rows; ir++) { - const float * restrict src_local = src + (ir * row_elems); - float * restrict dst_local = dst + (ir * row_elems); + const uint8_t * restrict src_local = (const uint8_t *)src + (ir * row_size); + uint8_t * restrict dst_local = (uint8_t *)dst + (ir * row_size); - if (ir + 1 < num_rows) { - hex_l2fetch(src_local + row_elems, row_size, row_size, 1); - } - - if (1 == opt_path) { - hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); - } else { - hvx_sqrt_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); - } + hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems); } } -static void unary_job_f32_per_thread(const struct htp_tensor * src, - struct htp_tensor * dst, - uint8_t * spad, - int htp_op, - int32_t * op_params, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread) { +static void unary_job_f32_per_thread(unsigned int nth, unsigned int ith, void * data) { + const struct htp_unary_context * uctx = (const struct htp_unary_context *) data; + struct htp_ops_context * octx = uctx->octx; + const struct htp_tensor * src = &octx->src0; + const struct htp_tensor * dst = &octx->dst; + htp_unary_preamble; - const size_t src0_row_size = nb01; - const size_t dst_row_size = nb1; + int htp_op = octx->op; + int32_t * op_params = octx->op_params; + uint32_t src0_nrows_per_thread = uctx->src0_nrows_per_thread; - const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + const size_t src0_row_size = uctx->src0_row_size; + const size_t dst_row_size = uctx->dst_row_size; + const size_t src0_row_size_aligned = uctx->src0_row_size_aligned; + const size_t dst_row_size_aligned = uctx->dst_row_size_aligned; + + const uint32_t src0_nrows = uctx->src0_nrows; const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); @@ -208,79 +196,104 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src, uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); - int is_aligned = 1; - int opt_path = 0; - if ((0 == hex_is_aligned((void *) src->data, VLEN)) || (0 == hex_is_aligned((void *) dst->data, VLEN))) { - is_aligned = 0; - } - if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { - opt_path = 1; + const uint8_t * restrict data_src = uctx->data_src0; + uint8_t * restrict data_dst = uctx->data_dst; + + uint8_t * src0_spad_data = octx->src0_spad.data + (ith * octx->src0_spad.size_per_thread); + uint8_t * dst_spad_data = octx->dst_spad.data + (ith * octx->dst_spad.size_per_thread); + + size_t src0_spad_half_size = uctx->src0_spad_half_size; + size_t dst_spad_half_size = uctx->dst_spad_half_size; + + const int BLOCK = uctx->block; + if (BLOCK == 0) { + FARF(ERROR, "unary-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", + octx->src0_spad.size_per_thread, src0_row_size_aligned); + return; } - const uint8_t * restrict data_src = (const uint8_t *) src->data; - uint8_t * restrict data_dst = (uint8_t *) dst->data; + dma_queue * dma_queue = octx->ctx->dma[ith]; - const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size)); - float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size)); - uint8_t * restrict spad_th = (uint8_t *) spad + (ith * nb01); + for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { + const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); - switch (htp_op) { - case HTP_OP_RMS_NORM: - rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); - break; - case HTP_OP_SCALE: - scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); - break; - case HTP_OP_SQR: - sqr_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); - break; - case HTP_OP_SQRT: - sqrt_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); - break; + // Dummy DMA transation for sequencing (interleaving dst,src,dst,...) + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)), + dst_row_size, dst_row_size_aligned, 0); - default: - break; + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src + (ir * src0_row_size)), + src0_row_size_aligned, src0_row_size, block_size); } + for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) { + const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); + + float * dst_spad = (float *) dma_queue_pop(dma_queue).src; + float * src0_spad = (float *) dma_queue_pop(dma_queue).dst; + + // Process block in VTCM + switch (htp_op) { + case HTP_OP_RMS_NORM: + rms_norm_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + case HTP_OP_SCALE: + scale_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + case HTP_OP_SQR: + sqr_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + case HTP_OP_SQRT: + sqrt_f32(src0_spad, dst_spad, NULL, block_size, ne0, src0_row_size_aligned, op_params); + break; + default: + break; + } + + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), + dst_row_size, dst_row_size_aligned, block_size); + + // prefetch N+2 loop iteration if any + const uint32_t pref_block = (ir + BLOCK * 2); + if (pref_block < src0_end_row) { + const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block); + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src0_spad, data_src + (pref_block * src0_row_size)), + src0_row_size_aligned, src0_row_size, pref_block_size); + } + } + + dma_queue_flush(dma_queue); + t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "unary-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, src->ne[0], + FARF(HIGH, "unary-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, src->ne[0], src->ne[1], src->ne[2], src->ne[3], src0_start_row, src0_end_row, dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void unary_job_dispatcher_f32(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = (struct htp_ops_context *) data; - - unary_job_f32_per_thread(&octx->src0, &octx->dst, octx->src0_spad.data, octx->op, octx->op_params, n, i, - octx->src0_nrows_per_thread); -} - static int execute_op_unary_f32(struct htp_ops_context * octx) { int err = HTP_STATUS_OK; const struct htp_tensor * src0 = &octx->src0; struct htp_tensor * dst = &octx->dst; - worker_callback_t unary_op_func; - const char * op_type = NULL; + const char * op_type = NULL; switch (octx->op) { case HTP_OP_RMS_NORM: - unary_op_func = unary_job_dispatcher_f32; - op_type = "rmsnorm-f32"; + op_type = "rmsnorm-f32"; break; case HTP_OP_SCALE: - unary_op_func = unary_job_dispatcher_f32; - op_type = "scale-f32"; + op_type = "scale-f32"; break; case HTP_OP_SQR: - unary_op_func = unary_job_dispatcher_f32; - op_type = "sqr-f32"; + op_type = "sqr-f32"; break; case HTP_OP_SQRT: - unary_op_func = unary_job_dispatcher_f32; - op_type = "sqrt-f32"; + op_type = "sqrt-f32"; break; default: @@ -294,32 +307,61 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { const size_t src0_row_size = src0->nb[1]; const size_t dst_row_size = dst->nb[1]; - // VTCM scratchpads for all tensors - octx->dst_spad.size = hex_round_up(dst_row_size, 128) * n_threads; - octx->src0_spad.size = hex_round_up(src0_row_size, 128) * n_threads; + const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN); + const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN); - size_t spad_size = octx->src0_spad.size + octx->dst_spad.size; + // VTCM scratchpads for all tensors + // N rows per thread, padded to HVX vector size + // Double buffering requires 2x size per buffer + + size_t spad_size_per_row = 2 * (src0_row_size_aligned + dst_row_size_aligned); + size_t vtcm_row_per_thread = (octx->ctx->vtcm_size)/ (n_threads * spad_size_per_row); + + // Make sure the reserved vtcm size is sufficient + if (vtcm_row_per_thread == 0) { + FARF(ERROR, "unary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, + spad_size_per_row * n_threads); + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.size_per_thread = src0_row_size_aligned * vtcm_row_per_thread * 2; + octx->dst_spad.size_per_thread = dst_row_size_aligned * vtcm_row_per_thread * 2; + + octx->src0_spad.size = n_threads * octx->src0_spad.size_per_thread; + octx->dst_spad.size = n_threads * octx->dst_spad.size_per_thread; + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; FARF(HIGH, "%s: (%ux%ux%ux%u) -> (%ux%ux%ux%u) : src0-spad-size %u src1-spad-size %u dst-spad-size %u\n", op_type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size); - // Make sure the reserved vtcm size is sufficient - if (octx->ctx->vtcm_size < spad_size) { - FARF(ERROR, "unary-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, octx->ctx->vtcm_size, - spad_size); - return HTP_STATUS_VTCM_TOO_SMALL; - } - - octx->src0_spad.data = octx->ctx->vtcm_base; - octx->dst_spad.data = octx->src0_spad.data + octx->src0_spad.size; - if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { uint32_t n_jobs = MIN(n_threads, src0_nrows); - octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs; + struct htp_unary_context uctx = { + .octx = octx, + .src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs, + .src0_nrows = src0_nrows, - worker_pool_run_func(octx->ctx->worker_pool, unary_op_func, octx, n_jobs); + .data_src0 = (const uint8_t *)src0->data, + .data_dst = (uint8_t *)dst->data, + + .src0_row_size = src0_row_size, + .dst_row_size = dst_row_size, + + .src0_row_size_aligned = src0_row_size_aligned, + .dst_row_size_aligned = dst_row_size_aligned, + + .src0_spad_half_size = octx->src0_spad.size_per_thread / 2, + .dst_spad_half_size = octx->dst_spad.size_per_thread / 2, + + .block = (octx->src0_spad.size_per_thread / 2) / src0_row_size_aligned, + .nc = src0->ne[0], + }; + + worker_pool_run_func(octx->ctx->worker_pool, unary_job_f32_per_thread, &uctx, n_jobs); } return err; diff --git a/ggml/src/ggml-sycl/add-id.cpp b/ggml/src/ggml-sycl/add-id.cpp index 00c073cf93..8929017a99 100644 --- a/ggml/src/ggml-sycl/add-id.cpp +++ b/ggml/src/ggml-sycl/add-id.cpp @@ -55,7 +55,11 @@ void ggml_sycl_add_id(ggml_backend_sycl_context& ctx, ggml_tensor* dst) { const int32_t* src2_d = (const int32_t*)src2->data; float* dst_d = (float*)dst->data; - int threads = std::min((int)ne00, 768); // cols + const unsigned int max_work_group_size = ggml_sycl_info().max_work_group_sizes[ctx.device]; + assert(work_group_size % (WARP_SIZE * WARP_SIZE) == 0); + + int threads = std::min((unsigned int)ne00, max_work_group_size); // cols + ctx.stream()->parallel_for( sycl::nd_range<3>( sycl::range<3>(1, ne02, ne01) * sycl::range<3>(1, 1, threads), diff --git a/ggml/src/ggml-sycl/binbcast.cpp b/ggml/src/ggml-sycl/binbcast.cpp index 0a3883ae1e..92dd18889f 100644 --- a/ggml/src/ggml-sycl/binbcast.cpp +++ b/ggml/src/ggml-sycl/binbcast.cpp @@ -11,8 +11,8 @@ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, int ne0, int ne1, int ne2, int ne3, int ne10, int ne11, int ne12, int ne13, /*int s0, */ int s1, int s2, int s3, - /*int s00,*/ int s01, int s02, int s03, - /*int s10,*/ int s11, int s12, int s13, + int s00, int s01, int s02, int s03, + int s10, int s11, int s12, int s13, const sycl::nd_item<3> &item_ct1) { const int i0s = item_ct1.get_local_range(2) * item_ct1.get_group(2) + item_ct1.get_local_id(2); @@ -44,7 +44,7 @@ static void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst, for (int i0 = i0s; i0 < ne0; i0 += item_ct1.get_local_range(2) * item_ct1.get_group_range(2)) { const int i10 = i0 % ne10; - dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); + dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0*s00] : 0.0f, (float)src1_row[i10*s10]); } } @@ -53,8 +53,8 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t int ne0, int ne1, int ne2, int ne3, int ne10, int ne11, int ne12, int ne13, /*int s0, */ int s1, int s2, int s3, - /*int s00,*/ int s01, int s02, int s03, - /*int s10,*/ int s11, int s12, int s13, + int s00, int s01, int s02, int s03, + int s10, int s11, int s12, int s13, const sycl::nd_item<3> &item_ct1) { const int i = item_ct1.get_local_range(2) * item_ct1.get_group(2) + @@ -82,7 +82,7 @@ static void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t dst_t * dst_row = dst + i_dst; const int i10 = i0 % ne10; - dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0] : 0.0f, (float)src1_row[i10]); + dst_row[i0] = (dst_t)bin_op(src0 ? (float)src0_row[i0*s00] : 0.0f, (float)src1_row[i10*s10]); } @@ -95,7 +95,8 @@ struct bin_bcast_sycl { const int64_t ne3, const size_t nb00, const size_t nb01, const size_t nb02, const size_t nb03, const size_t nb10, const size_t nb11, const size_t nb12, const size_t nb13, const size_t nb0, const size_t nb1, const size_t nb2, const size_t nb3, const bool src0_is_contiguous, - const bool src1_is_contiguous, const bool dst_is_contiguous, queue_ptr stream) { + const bool src1_is_contiguous, const bool src0_is_permuted, const bool src1_is_permuted, + queue_ptr stream) { int nr0 = ne10 / ne0; int nr1 = ne11/ne1; int nr2 = ne12/ne2; @@ -123,7 +124,7 @@ struct bin_bcast_sycl { cnb[3] *= cne[3]; }; - if (src0_is_contiguous && src1_is_contiguous && dst_is_contiguous) { + if (src0_is_contiguous && src1_is_contiguous && !src0_is_permuted && !src1_is_permuted) { for (int i = 0; i < 4; i++) { if (nr[i] != 1) { break; @@ -164,7 +165,7 @@ struct bin_bcast_sycl { size_t nb12 = cnb1[2]; size_t nb13 = cnb1[3]; - size_t s0 = nb0 / sizeof(dst_t); + // size_t s0 = nb0 / sizeof(dst_t); size_t s1 = nb1 / sizeof(dst_t); size_t s2 = nb2 / sizeof(dst_t); size_t s3 = nb3 / sizeof(dst_t); @@ -196,9 +197,6 @@ struct bin_bcast_sycl { GGML_ASSERT(nb12 % sizeof(src1_t) == 0); GGML_ASSERT(nb13 % sizeof(src1_t) == 0); - GGML_ASSERT(s0 == 1); - GGML_ASSERT(s10 == 1); - const int block_size = 128; int64_t hne0 = std::max(ne0/2LL, 1LL); @@ -232,8 +230,8 @@ struct bin_bcast_sycl { [=](sycl::nd_item<3> item_ct1) { k_bin_bcast_unravel( src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, - ne10, ne11, ne12, ne13, s1, s2, s3, s01, s02, - s03, s11, s12, s13, item_ct1); + ne10, ne11, ne12, ne13, s1, s2, s3, s00, s01, s02, + s03, s10, s11, s12, s13, item_ct1); }); } } else { @@ -251,7 +249,7 @@ struct bin_bcast_sycl { [=](sycl::nd_item<3> item_ct1) { k_bin_bcast(src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3, ne10, ne11, ne12, ne13, - s1, s2, s3, s01, s02, s03, s11, s12, s13, + s1, s2, s3, s00, s01, s02, s03, s10, s11, s12, s13, item_ct1); }); } @@ -268,24 +266,27 @@ inline void ggml_sycl_op_bin_bcast(ggml_backend_sycl_context & ctx, const ggml_t if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { op()((const float *) src0->data, (const float *) src1->data, (float *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, nb3, - ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream); + ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { op()((const sycl::half *) src0->data, (const sycl::half *) src1->data, (sycl::half *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, - nb0, nb1, nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), + nb0, nb1, nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), main_stream); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { op()((const sycl::half *) src0->data, (const float *) src1->data, (sycl::half *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, - nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream); + nb2, nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), + main_stream); } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) { op()((const int32_t *) src0->data, (const int32_t *) src1->data, (int32_t *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, - nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream); + nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), + main_stream); } else if (src0->type == GGML_TYPE_I16 && src1->type == GGML_TYPE_I16 && dst->type == GGML_TYPE_I16) { op()((const int16_t *) src0->data, (const int16_t *) src1->data, (int16_t *) dst->data, ne00, ne01, ne02, ne03, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, nb00, nb01, nb02, nb03, nb10, nb11, nb12, nb13, nb0, nb1, nb2, - nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_contiguous(dst), main_stream); + nb3, ggml_is_contiguous(src0), ggml_is_contiguous(src1), ggml_is_permuted(src0), ggml_is_permuted(src1), + main_stream); } else { fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s, src1: %s\n", __func__, ggml_type_name(dst->type), ggml_type_name(src0->type), ggml_type_name(src1->type)); diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp b/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp index cc879e51d0..03a037f1cb 100644 --- a/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched-backend.cpp @@ -7,9 +7,21 @@ #include +static uint32_t validate_graph_operation(size_t cgraph_size, uint32_t shmem_res_id, const char * operation) { + if (cgraph_size == 0) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Zero-size computation graph\n", operation); + return 1; + } + + // place-holder: validate that the size of shmem_res_id is <= cgraph_size + // need to add another method in the Virgl->APIR callback interface + GGML_UNUSED(shmem_res_id); + + return 0; // Valid +} + uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { GGML_UNUSED(ctx); - GGML_UNUSED(enc); static bool async_backend_initialized = false; static bool async_backend; @@ -34,10 +46,26 @@ uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, v size_t cgraph_size; apir_decode_size_t(dec, &cgraph_size); + if (validate_graph_operation(cgraph_size, shmem_res_id, __func__) != 0) { + apir_decoder_set_fatal(dec); + return 1; + } + apir_decoder secondary_dec = apir_new_decoder((const char *) shmem_data, cgraph_size); ggml_cgraph * cgraph = apir_decode_ggml_cgraph(&secondary_dec, cgraph_size); + if (!cgraph || apir_decoder_get_fatal(&secondary_dec)) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Failed to deserialize computation graph\n", __func__); + return 1; + } + + if (cgraph->n_nodes < 0 || cgraph->n_leafs < 0) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid negative node/leaf count: nodes=%d leafs=%d\n", __func__, + cgraph->n_nodes, cgraph->n_leafs); + return 1; + } + ggml_status status; #if APIR_BACKEND_CHECK_SUPPORTS_OP == 1 for (int idx = 0; idx < cgraph->n_nodes; idx++) { @@ -45,7 +73,8 @@ uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, v if (dev->iface.supports_op(dev, op)) { continue; } - GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Graph node %d (%s) not supported by the backend\n", idx, ggml_op_desc(op)); + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Graph node %d (%s) not supported by the backend\n", __func__, idx, + ggml_op_desc(op)); status = GGML_STATUS_ABORTED; apir_encode_ggml_status(enc, &status); @@ -53,9 +82,17 @@ uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, v return 0; } #endif + + // Check if backend is properly initialized + if (!bck) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Backend not initialized (bck is null)\n", __func__); + + return 1; + } + status = bck->iface.graph_compute(bck, cgraph); - if (async_backend) { + if (async_backend && bck->iface.synchronize) { bck->iface.synchronize(bck); } diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp b/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp index d55eec2761..c66dbaa9e8 100644 --- a/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer-type.cpp @@ -85,7 +85,19 @@ uint32_t backend_buffer_type_get_alloc_size(apir_encoder * enc, apir_decoder * d const ggml_tensor * op = apir_decode_ggml_tensor_inplace(dec); - size_t value = buft->iface.get_alloc_size(buft, op); + // Check for decode error + if (op == nullptr) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Failed to decode tensor\n", __func__); + apir_decoder_set_fatal(dec); + return 1; + } + + size_t value; + if (buft->iface.get_alloc_size) { + value = buft->iface.get_alloc_size(buft, op); + } else { + value = ggml_nbytes(op); // Default fallback + } apir_encode_size_t(enc, &value); diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp b/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp index 8cc063ff0a..3ade8d99b4 100644 --- a/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched-buffer.cpp @@ -6,11 +6,26 @@ #include +static uint32_t validate_buffer_operation(size_t offset, size_t size, const char * operation) { + // Only check for critical integer overflow - no arbitrary size limits + if (offset > SIZE_MAX - size) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Integer overflow in offset+size: %zu + %zu\n", operation, offset, size); + return 1; + } + + return 0; // Valid +} + uint32_t backend_buffer_get_base(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx) { GGML_UNUSED(ctx); ggml_backend_buffer_t buffer; buffer = apir_decode_ggml_buffer(dec); + if (!buffer || apir_decoder_get_fatal(dec)) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__); + return 1; + } + uintptr_t base = (uintptr_t) buffer->iface.get_base(buffer); apir_encode_uintptr_t(enc, &base); @@ -24,6 +39,11 @@ uint32_t backend_buffer_set_tensor(apir_encoder * enc, apir_decoder * dec, virgl ggml_backend_buffer_t buffer; buffer = apir_decode_ggml_buffer(dec); + if (!buffer || apir_decoder_get_fatal(dec)) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__); + return 1; + } + ggml_tensor * tensor; // safe to remove the const qualifier here tensor = (ggml_tensor *) (uintptr_t) apir_decode_ggml_tensor(dec); @@ -37,6 +57,10 @@ uint32_t backend_buffer_set_tensor(apir_encoder * enc, apir_decoder * dec, virgl size_t size; apir_decode_size_t(dec, &size); + if (validate_buffer_operation(offset, size, __func__) != 0) { + return 1; + } + void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id); if (!shmem_data) { @@ -56,6 +80,11 @@ uint32_t backend_buffer_get_tensor(apir_encoder * enc, apir_decoder * dec, virgl ggml_backend_buffer_t buffer; buffer = apir_decode_ggml_buffer(dec); + if (!buffer || apir_decoder_get_fatal(dec)) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__); + return 1; + } + const ggml_tensor * tensor; // safe to remove the const qualifier here tensor = apir_decode_ggml_tensor(dec); @@ -69,6 +98,10 @@ uint32_t backend_buffer_get_tensor(apir_encoder * enc, apir_decoder * dec, virgl size_t size; apir_decode_size_t(dec, &size); + if (validate_buffer_operation(offset, size, __func__) != 0) { + return 1; + } + void * shmem_data = ctx->iface->get_shmem_ptr(ctx->ctx_id, shmem_res_id); if (!shmem_data) { GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Couldn't get the shmem addr from virgl\n", __func__); @@ -86,6 +119,11 @@ uint32_t backend_buffer_cpy_tensor(apir_encoder * enc, apir_decoder * dec, virgl ggml_backend_buffer_t buffer; buffer = apir_decode_ggml_buffer(dec); + if (!buffer || apir_decoder_get_fatal(dec)) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__); + return 1; + } + const ggml_tensor * src; // safe to remove the const qualifier here src = apir_decode_ggml_tensor(dec); @@ -105,6 +143,11 @@ uint32_t backend_buffer_clear(apir_encoder * enc, apir_decoder * dec, virgl_apir ggml_backend_buffer_t buffer; buffer = apir_decode_ggml_buffer(dec); + if (!buffer || apir_decoder_get_fatal(dec)) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__); + return 1; + } + uint8_t value; apir_decode_uint8_t(dec, &value); @@ -120,6 +163,11 @@ uint32_t backend_buffer_free_buffer(apir_encoder * enc, apir_decoder * dec, virg ggml_backend_buffer_t buffer; buffer = apir_decode_ggml_buffer(dec); + if (!buffer || apir_decoder_get_fatal(dec)) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Invalid buffer handle from guest\n", __func__); + return 1; + } + if (!apir_untrack_backend_buffer(buffer)) { GGML_LOG_WARN(GGML_VIRTGPU_BCK "%s: unknown buffer %p\n", __func__, (void *) buffer); return 1; diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp b/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp index 64152eef0d..c80e4aabe1 100644 --- a/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched.cpp @@ -1,6 +1,6 @@ #include "backend-dispatched.h" -#include "backend-virgl-apir.h" +#include "backend-virgl-apir.h" #include "ggml-backend-impl.h" #include "ggml-backend.h" #include "ggml-impl.h" @@ -28,19 +28,24 @@ uint32_t backend_dispatch_initialize(void * ggml_backend_reg_fct_p) { return APIR_BACKEND_INITIALIZE_BACKEND_REG_FAILED; } - if (!reg->iface.get_device_count(reg)) { - GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: backend initialization failed: no device found\n", __func__); + size_t device_count = reg->iface.get_device_count(reg); + if (!device_count) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: no device found\n", __func__); return APIR_BACKEND_INITIALIZE_NO_DEVICE; } dev = reg->iface.get_device(reg, 0); if (!dev) { - GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: backend initialization failed: no device received\n", __func__); + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: failed to get device\n", __func__); return APIR_BACKEND_INITIALIZE_NO_DEVICE; } bck = dev->iface.init_backend(dev, NULL); + if (!bck) { + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: backend initialization failed\n", __func__); + return APIR_BACKEND_INITIALIZE_BACKEND_INIT_FAILED; + } return APIR_BACKEND_INITIALIZE_SUCCESS; } diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h b/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h index 481d7f3150..3dc334e4ce 100644 --- a/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched.gen.h @@ -32,64 +32,6 @@ uint32_t backend_buffer_free_buffer(apir_encoder * enc, apir_decoder * dec, virg /* backend */ uint32_t backend_backend_graph_compute(apir_encoder * enc, apir_decoder * dec, virgl_apir_context * ctx); -static inline const char * backend_dispatch_command_name(ApirBackendCommandType type) { - switch (type) { - /* device */ - case APIR_COMMAND_TYPE_DEVICE_GET_DEVICE_COUNT: - return "backend_device_get_device_count"; - case APIR_COMMAND_TYPE_DEVICE_GET_COUNT: - return "backend_device_get_count"; - case APIR_COMMAND_TYPE_DEVICE_GET_NAME: - return "backend_device_get_name"; - case APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION: - return "backend_device_get_description"; - case APIR_COMMAND_TYPE_DEVICE_GET_TYPE: - return "backend_device_get_type"; - case APIR_COMMAND_TYPE_DEVICE_GET_MEMORY: - return "backend_device_get_memory"; - case APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP: - return "backend_device_supports_op"; - case APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE: - return "backend_device_get_buffer_type"; - case APIR_COMMAND_TYPE_DEVICE_GET_PROPS: - return "backend_device_get_props"; - case APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR: - return "backend_device_buffer_from_ptr"; - /* buffer-type */ - case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME: - return "backend_buffer_type_get_name"; - case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT: - return "backend_buffer_type_get_alignment"; - case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE: - return "backend_buffer_type_get_max_size"; - case APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST: - return "backend_buffer_type_is_host (DEPRECATED)"; - case APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER: - return "backend_buffer_type_alloc_buffer"; - case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE: - return "backend_buffer_type_get_alloc_size"; - /* buffer */ - case APIR_COMMAND_TYPE_BUFFER_GET_BASE: - return "backend_buffer_get_base"; - case APIR_COMMAND_TYPE_BUFFER_SET_TENSOR: - return "backend_buffer_set_tensor"; - case APIR_COMMAND_TYPE_BUFFER_GET_TENSOR: - return "backend_buffer_get_tensor"; - case APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR: - return "backend_buffer_cpy_tensor"; - case APIR_COMMAND_TYPE_BUFFER_CLEAR: - return "backend_buffer_clear"; - case APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER: - return "backend_buffer_free_buffer"; - /* backend */ - case APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE: - return "backend_backend_graph_compute"; - - default: - return "unknown"; - } -} - extern "C" { static const backend_dispatch_t apir_backend_dispatch_table[APIR_BACKEND_DISPATCH_TABLE_COUNT] = { diff --git a/ggml/src/ggml-virtgpu/backend/backend-dispatched.h b/ggml/src/ggml-virtgpu/backend/backend-dispatched.h index 10311631d4..740ee9e3ff 100644 --- a/ggml/src/ggml-virtgpu/backend/backend-dispatched.h +++ b/ggml/src/ggml-virtgpu/backend/backend-dispatched.h @@ -1,5 +1,6 @@ #pragma once +// clang-format off #include #include @@ -10,6 +11,7 @@ #include "shared/apir_backend.h" #include "shared/apir_cs.h" #include "shared/apir_cs_ggml.h" +// clang-format on #define GGML_VIRTGPU_BCK "ggml-virtgpu-backend: " diff --git a/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h b/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h index 44b347f853..c65a01cdf9 100644 --- a/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h +++ b/ggml/src/ggml-virtgpu/backend/backend-virgl-apir.h @@ -19,7 +19,7 @@ struct virgl_apir_callbacks { }; extern "C" { -ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks *virgl_cbs); +ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks * virgl_cbs); void apir_backend_deinit(uint32_t virgl_ctx_id); uint32_t apir_backend_dispatcher(uint32_t virgl_ctx_id, virgl_apir_callbacks * virgl_cbs, diff --git a/ggml/src/ggml-virtgpu/backend/backend.cpp b/ggml/src/ggml-virtgpu/backend/backend.cpp index d93414a078..535a05f3e6 100644 --- a/ggml/src/ggml-virtgpu/backend/backend.cpp +++ b/ggml/src/ggml-virtgpu/backend/backend.cpp @@ -1,6 +1,5 @@ #include "backend-dispatched.h" #include "backend-virgl-apir.h" - #include "shared/api_remoting.h" #include "shared/apir_backend.h" #include "shared/apir_cs.h" @@ -17,10 +16,10 @@ #define GGML_DEFAULT_BACKEND_REG "ggml_backend_init" static void * backend_library_handle = NULL; -static FILE * apir_logfile = NULL; +static FILE * apir_logfile = NULL; static void log_to_file_callback(enum ggml_log_level level, const char * text, void * user_data) { - FILE * logfile = (FILE *)user_data; + FILE * logfile = (FILE *) user_data; fprintf(logfile, "[%d] %s", level, text); fflush(logfile); } @@ -48,9 +47,9 @@ void apir_backend_deinit(uint32_t virgl_ctx_id) { } #define APIR_GGML_LIBRARY_PATH_KEY "ggml.library.path" -#define APIR_GGML_LIBRARY_REG_KEY "ggml.library.reg" +#define APIR_GGML_LIBRARY_REG_KEY "ggml.library.reg" -ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks *virgl_cbs) { +ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct virgl_apir_callbacks * virgl_cbs) { const char * dlsym_error; const char * apir_log_to_file = getenv(APIR_LLAMA_CPP_LOG_TO_FILE_ENV); @@ -63,15 +62,13 @@ ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct } } - const char * library_name = virgl_cbs->get_config(virgl_ctx_id, APIR_GGML_LIBRARY_PATH_KEY); + const char * library_name = virgl_cbs->get_config(virgl_ctx_id, APIR_GGML_LIBRARY_PATH_KEY); const char * virgl_library_reg = virgl_cbs->get_config(virgl_ctx_id, APIR_GGML_LIBRARY_REG_KEY); - const char * library_reg = virgl_library_reg ? virgl_library_reg : GGML_DEFAULT_BACKEND_REG; + const char * library_reg = virgl_library_reg ? virgl_library_reg : GGML_DEFAULT_BACKEND_REG; if (!library_name) { - GGML_LOG_ERROR(GGML_VIRTGPU_BCK - "%s: cannot open the GGML library: env var '%s' not defined\n", - __func__, APIR_LLAMA_CPP_GGML_LIBRARY_PATH_ENV); - + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: cannot open the GGML library: env var '%s' not defined\n", __func__, + APIR_LLAMA_CPP_GGML_LIBRARY_PATH_ENV); return APIR_LOAD_LIBRARY_ENV_VAR_MISSING; } @@ -79,16 +76,14 @@ ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct backend_library_handle = dlopen(library_name, RTLD_LAZY); if (!backend_library_handle) { - GGML_LOG_ERROR(GGML_VIRTGPU_BCK - "%s: cannot open the GGML library: %s\n", __func__, dlerror()); + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: cannot open the GGML library: %s\n", __func__, dlerror()); return APIR_LOAD_LIBRARY_CANNOT_OPEN; } if (!library_reg) { - GGML_LOG_ERROR(GGML_VIRTGPU_BCK - "%s: cannot register the GGML library: env var '%s' not defined\n", - __func__, APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV); + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: cannot register the GGML library: env var '%s' not defined\n", __func__, + APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV); return APIR_LOAD_LIBRARY_ENV_VAR_MISSING; } @@ -96,11 +91,9 @@ ApirLoadLibraryReturnCode apir_backend_initialize(uint32_t virgl_ctx_id, struct void * ggml_backend_reg_fct = dlsym(backend_library_handle, library_reg); dlsym_error = dlerror(); if (dlsym_error) { - GGML_LOG_ERROR(GGML_VIRTGPU_BCK - "%s: cannot find the GGML backend registration symbol '%s' (from %s): %s\n", + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: cannot find the GGML backend registration symbol '%s' (from %s): %s\n", __func__, library_reg, APIR_LLAMA_CPP_GGML_LIBRARY_REG_ENV, dlsym_error); - return APIR_LOAD_LIBRARY_SYMBOL_MISSING; } @@ -132,13 +125,12 @@ uint32_t apir_backend_dispatcher(uint32_t virgl_ctx_id, virgl_apir_context ctx = { .ctx_id = virgl_ctx_id, - .iface = virgl_cbs, + .iface = virgl_cbs, }; if (cmd_type >= APIR_BACKEND_DISPATCH_TABLE_COUNT) { - GGML_LOG_ERROR(GGML_VIRTGPU_BCK - "%s: Received an invalid dispatch index (%d >= %d)\n", - __func__, cmd_type, APIR_BACKEND_DISPATCH_TABLE_COUNT); + GGML_LOG_ERROR(GGML_VIRTGPU_BCK "%s: Received an invalid dispatch index (%d >= %d)\n", __func__, cmd_type, + APIR_BACKEND_DISPATCH_TABLE_COUNT); return APIR_BACKEND_FORWARD_INDEX_INVALID; } diff --git a/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h b/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h index f19a5d12d1..6bf97e8a3a 100644 --- a/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h +++ b/ggml/src/ggml-virtgpu/backend/shared/api_remoting.h @@ -16,28 +16,32 @@ enum ApirCommandType { APIR_COMMAND_TYPE_LOADLIBRARY = 1, APIR_COMMAND_TYPE_FORWARD = 2, - APIR_COMMAND_TYPE_LENGTH = 3, + APIR_COMMAND_TYPE_LENGTH = 3, }; typedef uint64_t ApirCommandFlags; enum ApirLoadLibraryReturnCode { APIR_LOAD_LIBRARY_SUCCESS = 0, + // these error codes are returned by the Virglrenderer APIR component APIR_LOAD_LIBRARY_HYPERCALL_INITIALIZATION_ERROR = 1, APIR_LOAD_LIBRARY_ALREADY_LOADED = 2, APIR_LOAD_LIBRARY_ENV_VAR_MISSING = 3, APIR_LOAD_LIBRARY_CANNOT_OPEN = 4, APIR_LOAD_LIBRARY_SYMBOL_MISSING = 5, - APIR_LOAD_LIBRARY_INIT_BASE_INDEX = 6, // anything above this is a APIR backend library initialization return code + // any value greater than this is an APIR *backend library* initialization return code + APIR_LOAD_LIBRARY_INIT_BASE_INDEX = 6, }; enum ApirForwardReturnCode { - APIR_FORWARD_SUCCESS = 0, - APIR_FORWARD_NO_DISPATCH_FCT = 1, - APIR_FORWARD_TIMEOUT = 2, - - APIR_FORWARD_BASE_INDEX = 3, // anything above this is a APIR backend library forward return code -} ; + APIR_FORWARD_SUCCESS = 0, + // these error codes are returned by the Virglrenderer APIR component + APIR_FORWARD_NO_DISPATCH_FCT = 1, + APIR_FORWARD_TIMEOUT = 2, + APIR_FORWARD_FAILED_TO_SYNC_STREAMS = 3, + // any value greater than this index an APIR *backend library* forward return code + APIR_FORWARD_BASE_INDEX = 4, +}; __attribute__((unused)) static inline const char * apir_command_name(ApirCommandType type) { switch (type) { @@ -82,6 +86,7 @@ __attribute__((unused)) static const char * apir_forward_error(ApirForwardReturn APIR_FORWARD_ERROR(APIR_FORWARD_SUCCESS); APIR_FORWARD_ERROR(APIR_FORWARD_NO_DISPATCH_FCT); APIR_FORWARD_ERROR(APIR_FORWARD_TIMEOUT); + APIR_FORWARD_ERROR(APIR_FORWARD_FAILED_TO_SYNC_STREAMS); APIR_FORWARD_ERROR(APIR_FORWARD_BASE_INDEX); return "Unknown APIR_COMMAND_TYPE_FORWARD error"; diff --git a/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h b/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h index d214b6f2a9..520ac9c729 100644 --- a/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h +++ b/ggml/src/ggml-virtgpu/backend/shared/apir_backend.gen.h @@ -34,3 +34,61 @@ typedef enum ApirBackendCommandType { // last command_type index + 1 APIR_BACKEND_DISPATCH_TABLE_COUNT = 23, } ApirBackendCommandType; + +static inline const char * apir_dispatch_command_name(ApirBackendCommandType type) { + switch (type) { + /* device */ + case APIR_COMMAND_TYPE_DEVICE_GET_DEVICE_COUNT: + return "device_get_device_count"; + case APIR_COMMAND_TYPE_DEVICE_GET_COUNT: + return "device_get_count"; + case APIR_COMMAND_TYPE_DEVICE_GET_NAME: + return "device_get_name"; + case APIR_COMMAND_TYPE_DEVICE_GET_DESCRIPTION: + return "device_get_description"; + case APIR_COMMAND_TYPE_DEVICE_GET_TYPE: + return "device_get_type"; + case APIR_COMMAND_TYPE_DEVICE_GET_MEMORY: + return "device_get_memory"; + case APIR_COMMAND_TYPE_DEVICE_SUPPORTS_OP: + return "device_supports_op"; + case APIR_COMMAND_TYPE_DEVICE_GET_BUFFER_TYPE: + return "device_get_buffer_type"; + case APIR_COMMAND_TYPE_DEVICE_GET_PROPS: + return "device_get_props"; + case APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR: + return "device_buffer_from_ptr"; + /* buffer-type */ + case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_NAME: + return "buffer_type_get_name"; + case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALIGNMENT: + return "buffer_type_get_alignment"; + case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_MAX_SIZE: + return "buffer_type_get_max_size"; + case APIR_COMMAND_TYPE_BUFFER_TYPE_IS_HOST: + return "buffer_type_is_host"; + case APIR_COMMAND_TYPE_BUFFER_TYPE_ALLOC_BUFFER: + return "buffer_type_alloc_buffer"; + case APIR_COMMAND_TYPE_BUFFER_TYPE_GET_ALLOC_SIZE: + return "buffer_type_get_alloc_size"; + /* buffer */ + case APIR_COMMAND_TYPE_BUFFER_GET_BASE: + return "buffer_get_base"; + case APIR_COMMAND_TYPE_BUFFER_SET_TENSOR: + return "buffer_set_tensor"; + case APIR_COMMAND_TYPE_BUFFER_GET_TENSOR: + return "buffer_get_tensor"; + case APIR_COMMAND_TYPE_BUFFER_CPY_TENSOR: + return "buffer_cpy_tensor"; + case APIR_COMMAND_TYPE_BUFFER_CLEAR: + return "buffer_clear"; + case APIR_COMMAND_TYPE_BUFFER_FREE_BUFFER: + return "buffer_free_buffer"; + /* backend */ + case APIR_COMMAND_TYPE_BACKEND_GRAPH_COMPUTE: + return "backend_graph_compute"; + + default: + return "unknown"; + } +} diff --git a/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h b/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h index f3efa52c72..da1e21b5b2 100644 --- a/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h +++ b/ggml/src/ggml-virtgpu/backend/shared/apir_backend.h @@ -14,7 +14,7 @@ #define APIR_BACKEND_INITIALIZE_BACKEND_REG_FAILED 6 #define APIR_BACKEND_INITIALIZE_ALREADY_INITED 7 #define APIR_BACKEND_INITIALIZE_NO_DEVICE 8 - +#define APIR_BACKEND_INITIALIZE_BACKEND_INIT_FAILED 9 // new entries here need to be added to the apir_backend_initialize_error function below @@ -39,6 +39,10 @@ static const char * apir_backend_initialize_error(int code) { APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_MISSING_BACKEND_SYMBOLS); APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_MISSING_GGML_SYMBOLS); APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_BACKEND_FAILED); + APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_BACKEND_REG_FAILED); + APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_ALREADY_INITED); + APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_NO_DEVICE); + APIR_BACKEND_INITIALIZE_ERROR(APIR_BACKEND_INITIALIZE_BACKEND_INIT_FAILED); return "Unknown APIR_BACKEND_INITIALIZE error:/"; diff --git a/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h b/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h index 1bc3a5f685..64bf2ec960 100644 --- a/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h +++ b/ggml/src/ggml-virtgpu/backend/shared/apir_cs.h @@ -13,7 +13,6 @@ struct apir_encoder { const char * start; const char * end; bool fatal; - }; struct apir_decoder { @@ -28,8 +27,8 @@ struct apir_decoder { static apir_decoder apir_new_decoder(const char * ptr, size_t size) { apir_decoder dec = { - .cur = ptr, - .end = ptr + size, + .cur = ptr, + .end = ptr + size, .fatal = false, }; @@ -79,10 +78,7 @@ static inline bool apir_decoder_get_fatal(const apir_decoder * dec) { * encode peek */ -static inline bool apir_decoder_peek_internal(apir_decoder * dec, - size_t size, - void * val, - size_t val_size) { +static inline bool apir_decoder_peek_internal(apir_decoder * dec, size_t size, void * val, size_t val_size) { assert(val_size <= size); if (unlikely(size > (size_t) (dec->end - dec->cur))) { @@ -332,8 +328,7 @@ static inline void apir_decode_char_array(apir_decoder * dec, char * val, size_t static inline void * apir_decoder_alloc_array(size_t size, size_t count) { size_t alloc_size; if (unlikely(__builtin_mul_overflow(size, count, &alloc_size))) { - GGML_LOG_ERROR("%s: overflow in array allocation of %zu * %zu bytes\n", - __func__, size, count); + GGML_LOG_ERROR("%s: overflow in array allocation of %zu * %zu bytes\n", __func__, size, count); return NULL; } @@ -352,20 +347,19 @@ static inline void apir_decode_bool_t(apir_decoder * dec, bool * val) { /* apir_buffer_type_host_handle_t */ -static inline void apir_encode_apir_buffer_type_host_handle_t(apir_encoder * enc, +static inline void apir_encode_apir_buffer_type_host_handle_t(apir_encoder * enc, const apir_buffer_type_host_handle_t * val) { apir_encode(enc, sizeof(apir_buffer_type_host_handle_t), val, sizeof(apir_buffer_type_host_handle_t)); } -static inline void apir_decode_apir_buffer_type_host_handle_t(apir_decoder * dec, +static inline void apir_decode_apir_buffer_type_host_handle_t(apir_decoder * dec, apir_buffer_type_host_handle_t * val) { apir_decode(dec, sizeof(apir_buffer_type_host_handle_t), val, sizeof(apir_buffer_type_host_handle_t)); } /* apir_buffer_host_handle_t */ -static inline void apir_encode_apir_buffer_host_handle_t(apir_encoder * enc, - const apir_buffer_host_handle_t * val) { +static inline void apir_encode_apir_buffer_host_handle_t(apir_encoder * enc, const apir_buffer_host_handle_t * val) { apir_encode(enc, sizeof(apir_buffer_host_handle_t), val, sizeof(apir_buffer_host_handle_t)); } diff --git a/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h b/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h index 289f4b77d7..fabe3e401c 100644 --- a/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h +++ b/ggml/src/ggml-virtgpu/backend/shared/apir_cs_ggml.h @@ -1,11 +1,10 @@ -#include "ggml-impl.h" #include "apir_cs.h" #include "apir_cs_rpc.h" +#include "ggml-impl.h" // ggml_buffer_to_apir_host_handle(ggml_backend_buffer_t buffer); -static inline void apir_encode_ggml_buffer_host_handle(apir_encoder * enc, - const apir_buffer_host_handle_t * handle); +static inline void apir_encode_ggml_buffer_host_handle(apir_encoder * enc, const apir_buffer_host_handle_t * handle); static inline ggml_backend_buffer_t apir_decode_ggml_buffer(apir_decoder * dec); @@ -22,8 +21,7 @@ static inline apir_rpc_tensor * apir_decode_apir_rpc_tensor_inplace(apir_decoder return (apir_rpc_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, apir_rpc_tensor_size); } -static inline apir_rpc_tensor * apir_decode_apir_rpc_tensor_array_inplace(apir_decoder * dec, - uint32_t n_tensors) { +static inline apir_rpc_tensor * apir_decode_apir_rpc_tensor_array_inplace(apir_decoder * dec, uint32_t n_tensors) { size_t apir_rpc_tensor_size = sizeof(apir_rpc_tensor) * n_tensors; return (apir_rpc_tensor *) (uintptr_t) apir_decoder_use_inplace(dec, apir_rpc_tensor_size); @@ -45,9 +43,9 @@ static inline const ggml_tensor * apir_decode_ggml_tensor(apir_decoder * dec) { } ggml_init_params params{ - /*.mem_size =*/ ggml_tensor_overhead(), - /*.mem_buffer =*/ NULL, - /*.no_alloc =*/ true, + /*.mem_size =*/ggml_tensor_overhead(), + /*.mem_buffer =*/NULL, + /*.no_alloc =*/true, }; ggml_context * ctx = ggml_init(params); @@ -105,6 +103,19 @@ static inline ggml_backend_buffer_t apir_decode_ggml_buffer(apir_decoder * dec) apir_decoder_read(dec, buffer_ptr_size, &buffer, buffer_ptr_size); + // SECURITY: Validate buffer handle against tracked buffers to prevent + // guest VM from providing arbitrary host memory addresses + if (buffer) { + extern std::unordered_set backend_buffers; + if (backend_buffers.find(buffer) == backend_buffers.end()) { + GGML_LOG_WARN("ggml-virtgpu-backend: %s: Invalid buffer handle from guest: %p\n", __func__, + (void *) buffer); + // Set fatal flag to prevent further processing with invalid handle + apir_decoder_set_fatal(dec); + return NULL; + } + } + return buffer; } diff --git a/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h b/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h index f681798952..4cb2f047d1 100644 --- a/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h +++ b/ggml/src/ggml-virtgpu/backend/shared/apir_cs_rpc.h @@ -1,3 +1,6 @@ +#pragma once + +// clang-format off #include "ggml.h" #include "ggml-backend-impl.h" @@ -5,6 +8,7 @@ #include #include #include +// clang-format on // ggml_tensor is serialized into apir_rpc_tensor struct apir_rpc_tensor { diff --git a/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp b/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp index c493a8e2ae..8fa20ff43b 100644 --- a/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp +++ b/ggml/src/ggml-virtgpu/ggml-backend-buffer-type.cpp @@ -34,6 +34,7 @@ static ggml_backend_buffer_t ggml_backend_remoting_buffer_type_alloc_buffer(ggml static const char * ggml_backend_remoting_buffer_type_get_name(ggml_backend_buffer_type_t buft) { virtgpu * gpu = BUFT_TO_GPU(buft); + // Return the prefixed name that was built once during initialization return gpu->cached_buffer_type.name; } @@ -53,9 +54,8 @@ static size_t ggml_backend_remoting_buffer_type_get_alloc_size(ggml_backend_buff const ggml_tensor * tensor) { virtgpu * gpu = BUFT_TO_GPU(buft); - if (tensor->buffer == NULL - || !tensor->buffer->context - || !buft->device->iface.supports_buft(buft->device, tensor->buffer->buft)) { + if (tensor->buffer == NULL || !tensor->buffer->context || + !buft->device->iface.supports_buft(buft->device, tensor->buffer->buft)) { return ggml_nbytes(tensor); } diff --git a/ggml/src/ggml-virtgpu/ggml-backend-device.cpp b/ggml/src/ggml-virtgpu/ggml-backend-device.cpp index c7d2881058..ec8156bb86 100644 --- a/ggml/src/ggml-virtgpu/ggml-backend-device.cpp +++ b/ggml/src/ggml-virtgpu/ggml-backend-device.cpp @@ -3,6 +3,7 @@ static const char * ggml_backend_remoting_device_get_name(ggml_backend_dev_t dev) { virtgpu * gpu = DEV_TO_GPU(dev); + // Return the prefixed name that was built once during initialization return gpu->cached_device_info.name; } @@ -22,7 +23,7 @@ static enum ggml_backend_dev_type ggml_backend_remoting_device_get_type(ggml_bac static void ggml_backend_remoting_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { virtgpu * gpu = DEV_TO_GPU(dev); - *free = gpu->cached_device_info.memory_free; + *free = gpu->cached_device_info.memory_free; *total = gpu->cached_device_info.memory_total; } @@ -72,7 +73,7 @@ static void ggml_backend_remoting_device_get_props(ggml_backend_dev_t dev, ggml_ ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_type(ggml_backend_dev_t dev) { virtgpu * gpu = DEV_TO_GPU(dev); - static std::atomic initialized = false; + static std::atomic initialized = false; static ggml_backend_buffer_type buft; if (!initialized) { @@ -95,7 +96,7 @@ ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_type(ggml_bac static ggml_backend_buffer_type_t ggml_backend_remoting_device_get_buffer_from_ptr_type(ggml_backend_dev_t dev) { virtgpu * gpu = DEV_TO_GPU(dev); - static std::atomic initialized = false; + static std::atomic initialized = false; static ggml_backend_buffer_type buft; if (!initialized) { diff --git a/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp b/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp index 2d02cfec1d..a4df5956aa 100644 --- a/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp +++ b/ggml/src/ggml-virtgpu/ggml-backend-reg.cpp @@ -7,8 +7,8 @@ void ggml_virtgpu_cleanup(virtgpu * gpu); static virtgpu * apir_initialize() { - static virtgpu * gpu = NULL; - static std::atomic initialized = false; + static virtgpu * gpu = NULL; + static std::atomic initialized = false; if (initialized) { // fast track @@ -31,29 +31,53 @@ static virtgpu * apir_initialize() { } // Pre-fetch and cache all device information, it will not change - gpu->cached_device_info.description = apir_device_get_description(gpu); + gpu->cached_device_info.description = apir_device_get_description(gpu); if (!gpu->cached_device_info.description) { GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu device description", __func__); } - gpu->cached_device_info.name = apir_device_get_name(gpu); - if (!gpu->cached_device_info.name) { - GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu device name", __func__); - } gpu->cached_device_info.device_count = apir_device_get_count(gpu); gpu->cached_device_info.type = apir_device_get_type(gpu); - apir_device_get_memory(gpu, - &gpu->cached_device_info.memory_free, - &gpu->cached_device_info.memory_total); + { + // Get the remote name and create prefixed version + char * rmt_device_name = apir_device_get_name(gpu); + if (!rmt_device_name) { + GGML_ABORT(GGML_VIRTGPU "%s: failed to get the virtgpu device name", __func__); + } + + size_t device_name_len = strlen(rmt_device_name) + 11; // "[virtgpu] " + null terminator + gpu->cached_device_info.name = (char *) malloc(device_name_len); + if (!gpu->cached_device_info.name) { + free(rmt_device_name); + GGML_ABORT(GGML_VIRTGPU "%s: failed to allocate memory for prefixed device name", __func__); + } + snprintf(gpu->cached_device_info.name, device_name_len, "[virtgpu] %s", rmt_device_name); + free(rmt_device_name); + } + + apir_device_get_memory(gpu, &gpu->cached_device_info.memory_free, &gpu->cached_device_info.memory_total); apir_buffer_type_host_handle_t buft_host_handle = apir_device_get_buffer_type(gpu); gpu->cached_buffer_type.host_handle = buft_host_handle; - gpu->cached_buffer_type.name = apir_buffer_type_get_name(gpu, buft_host_handle); - if (!gpu->cached_buffer_type.name) { - GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu buffer type name", __func__); + { + // Get the remote name and create prefixed version + char * rmt_name = apir_buffer_type_get_name(gpu, buft_host_handle); + if (!rmt_name) { + GGML_ABORT(GGML_VIRTGPU "%s: failed to get the virtgpu buffer type name", __func__); + } + + size_t prefixed_len = strlen(rmt_name) + 11; // "[virtgpu] " + null terminator + gpu->cached_buffer_type.name = (char *) malloc(prefixed_len); + if (!gpu->cached_buffer_type.name) { + free(rmt_name); + GGML_ABORT(GGML_VIRTGPU "%s: failed to allocate memory for prefixed buffer type name", __func__); + } + snprintf(gpu->cached_buffer_type.name, prefixed_len, "[virtgpu] %s", rmt_name); + free(rmt_name); } - gpu->cached_buffer_type.alignment = apir_buffer_type_get_alignment(gpu, buft_host_handle); - gpu->cached_buffer_type.max_size = apir_buffer_type_get_max_size(gpu, buft_host_handle); + + gpu->cached_buffer_type.alignment = apir_buffer_type_get_alignment(gpu, buft_host_handle); + gpu->cached_buffer_type.max_size = apir_buffer_type_get_max_size(gpu, buft_host_handle); initialized = true; } @@ -98,7 +122,7 @@ static void ggml_backend_remoting_reg_init_devices(ggml_backend_reg_t reg) { static std::atomic initialized = false; if (initialized) { - return; // fast track + return; // fast track } { diff --git a/ggml/src/ggml-virtgpu/ggml-backend.cpp b/ggml/src/ggml-virtgpu/ggml-backend.cpp index 5cd6c0c060..a63ee2b9d2 100644 --- a/ggml/src/ggml-virtgpu/ggml-backend.cpp +++ b/ggml/src/ggml-virtgpu/ggml-backend.cpp @@ -1,5 +1,5 @@ -#include "ggml-remoting.h" #include "../../include/ggml-virtgpu.h" +#include "ggml-remoting.h" static const char * ggml_backend_remoting_get_name(ggml_backend_t backend) { UNUSED(backend); diff --git a/ggml/src/ggml-virtgpu/ggml-remoting.h b/ggml/src/ggml-virtgpu/ggml-remoting.h index 0876640867..4f70326bee 100644 --- a/ggml/src/ggml-virtgpu/ggml-remoting.h +++ b/ggml/src/ggml-virtgpu/ggml-remoting.h @@ -9,7 +9,7 @@ #include #define GGML_VIRTGPU_NAME "ggml-virtgpu" -#define GGML_VIRTGPU "ggml-virtgpu: " +#define GGML_VIRTGPU "ggml-virtgpu: " // USE_ALWAYS_TRUE_SUPPORTS_OP: 1 is fast, 0 avoid micro-benchmark crashes diff --git a/ggml/src/ggml-virtgpu/include/apir_hw.h b/ggml/src/ggml-virtgpu/include/apir_hw.h index 33af045ca2..7d6ea2265d 100644 --- a/ggml/src/ggml-virtgpu/include/apir_hw.h +++ b/ggml/src/ggml-virtgpu/include/apir_hw.h @@ -3,7 +3,7 @@ #include struct virgl_renderer_capset_apir { - uint32_t apir_version; - uint32_t supports_blob_resources; - uint32_t reserved[4]; // For future expansion + uint32_t apir_version; + uint32_t supports_blob_resources; + uint32_t reserved[4]; // For future expansion }; diff --git a/ggml/src/ggml-virtgpu/regenerate_remoting.py b/ggml/src/ggml-virtgpu/regenerate_remoting.py index aeb48a4087..dae75fd1c8 100755 --- a/ggml/src/ggml-virtgpu/regenerate_remoting.py +++ b/ggml/src/ggml-virtgpu/regenerate_remoting.py @@ -145,8 +145,31 @@ class RemotingCodebaseGenerator: enum_lines.append(f" APIR_BACKEND_DISPATCH_TABLE_COUNT = {total_count},") enum_lines.append("} ApirBackendCommandType;") + # Generate function name mapping + func_lines = [] + func_lines.append("static inline const char * apir_dispatch_command_name(ApirBackendCommandType type) {") + func_lines.append(" switch (type) {") + + current_group = None + for func in functions: + # Add comment for new group + if func['group_name'] != current_group: + func_lines.append(f" /* {func['group_description']} */") + current_group = func['group_name'] + + # Generate clean function name without backend_ prefix + clean_name = f"{func['group_name']}_{func['function_name']}" + func_lines.append(f" case {func['enum_name']}:") + func_lines.append(f" return \"{clean_name}\";") + + func_lines.append("") + func_lines.append(" default:") + func_lines.append(" return \"unknown\";") + func_lines.append(" }") + func_lines.append("}") + # Full header template - header_content = NL.join(enum_lines) + "\n" + header_content = NL.join(enum_lines) + "\n\n" + NL.join(func_lines) + "\n" return header_content @@ -170,19 +193,6 @@ class RemotingCodebaseGenerator: decl_lines.append(f"{signature} {func['backend_function']}({params});") - # Switch cases - switch_lines = [] - current_group = None - - for func in functions: - if func['group_name'] != current_group: - switch_lines.append(f" /* {func['group_description']} */") - current_group = func['group_name'] - - deprecated = " (DEPRECATED)" if func['deprecated'] else "" - - switch_lines.append(f" case {func['enum_name']}: return \"{func['backend_function']}{deprecated}\";") - # Dispatch table table_lines = [] current_group = None @@ -201,15 +211,6 @@ class RemotingCodebaseGenerator: {NL.join(decl_lines)} -static inline const char *backend_dispatch_command_name(ApirBackendCommandType type) -{{ - switch (type) {{ -{NL.join(switch_lines)} - - default: return "unknown"; - }} -}} - extern "C" {{ static const backend_dispatch_t apir_backend_dispatch_table[APIR_BACKEND_DISPATCH_TABLE_COUNT] = {{ {NL.join(table_lines)} diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp b/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp index 07d9a66849..4593690c63 100644 --- a/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp +++ b/ggml/src/ggml-virtgpu/virtgpu-forward-backend.cpp @@ -17,8 +17,8 @@ ggml_status apir_backend_graph_compute(virtgpu * gpu, ggml_cgraph * cgraph) { size_t cgraph_size = apir_serialize_ggml_cgraph(cgraph, cgraph_data); virtgpu_shmem temp_shmem; // Local storage for large buffers - virtgpu_shmem * shmem = &temp_shmem; - bool using_shared_shmem = false; + virtgpu_shmem * shmem = &temp_shmem; + bool using_shared_shmem = false; if (cgraph_size <= gpu->data_shmem.mmap_size) { // Lock mutex before using shared data_shmem buffer @@ -26,7 +26,7 @@ ggml_status apir_backend_graph_compute(virtgpu * gpu, ggml_cgraph * cgraph) { GGML_ABORT(GGML_VIRTGPU "%s: Failed to lock data_shmem mutex", __func__); } using_shared_shmem = true; - shmem = &gpu->data_shmem; + shmem = &gpu->data_shmem; } else if (virtgpu_shmem_create(gpu, cgraph_size, shmem)) { GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the guest-host shared buffer", __func__); } diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp b/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp index cab74fd170..38f8ec945e 100644 --- a/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp +++ b/ggml/src/ggml-virtgpu/virtgpu-forward-buffer-type.cpp @@ -62,7 +62,9 @@ size_t apir_buffer_type_get_max_size(virtgpu * gpu, apir_buffer_type_host_handle return max_size; } -apir_buffer_context_t apir_buffer_type_alloc_buffer(virtgpu * gpu, apir_buffer_type_host_handle_t host_handle, size_t size) { +apir_buffer_context_t apir_buffer_type_alloc_buffer(virtgpu * gpu, + apir_buffer_type_host_handle_t host_handle, + size_t size) { apir_encoder * encoder; apir_decoder * decoder; ApirForwardReturnCode ret; @@ -84,7 +86,9 @@ apir_buffer_context_t apir_buffer_type_alloc_buffer(virtgpu * gpu, apir_buffer_t return buffer_context; } -size_t apir_buffer_type_get_alloc_size(virtgpu * gpu, apir_buffer_type_host_handle_t host_handle, const ggml_tensor * op) { +size_t apir_buffer_type_get_alloc_size(virtgpu * gpu, + apir_buffer_type_host_handle_t host_handle, + const ggml_tensor * op) { apir_encoder * encoder; apir_decoder * decoder; ApirForwardReturnCode ret; diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp b/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp index 86eee358cf..228284f4a4 100644 --- a/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp +++ b/ggml/src/ggml-virtgpu/virtgpu-forward-buffer.cpp @@ -35,8 +35,8 @@ void apir_buffer_set_tensor(virtgpu * gpu, apir_encode_ggml_tensor(encoder, tensor); virtgpu_shmem temp_shmem; // Local storage for large buffers - virtgpu_shmem * shmem = &temp_shmem; - bool using_shared_shmem = false; + virtgpu_shmem * shmem = &temp_shmem; + bool using_shared_shmem = false; if (size <= gpu->data_shmem.mmap_size) { // Lock mutex before using shared data_shmem buffer @@ -44,7 +44,7 @@ void apir_buffer_set_tensor(virtgpu * gpu, GGML_ABORT(GGML_VIRTGPU "%s: Failed to lock data_shmem mutex", __func__); } using_shared_shmem = true; - shmem = &gpu->data_shmem; + shmem = &gpu->data_shmem; } else if (virtgpu_shmem_create(gpu, size, shmem)) { GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the guest-host shared buffer", __func__); @@ -86,8 +86,8 @@ void apir_buffer_get_tensor(virtgpu * gpu, apir_encode_ggml_tensor(encoder, tensor); virtgpu_shmem temp_shmem; // Local storage for large buffers - virtgpu_shmem * shmem = &temp_shmem; - bool using_shared_shmem = false; + virtgpu_shmem * shmem = &temp_shmem; + bool using_shared_shmem = false; if (size <= gpu->data_shmem.mmap_size) { // Lock mutex before using shared data_shmem buffer @@ -95,7 +95,7 @@ void apir_buffer_get_tensor(virtgpu * gpu, GGML_ABORT(GGML_VIRTGPU "%s: Failed to lock data_shmem mutex", __func__); } using_shared_shmem = true; - shmem = &gpu->data_shmem; + shmem = &gpu->data_shmem; } else if (virtgpu_shmem_create(gpu, size, shmem)) { GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate the guest-host shared buffer", __func__); diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp b/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp index 4b6b8f527b..9f513c138d 100644 --- a/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp +++ b/ggml/src/ggml-virtgpu/virtgpu-forward-device.cpp @@ -26,7 +26,7 @@ char * apir_device_get_name(virtgpu * gpu) { REMOTE_CALL(gpu, encoder, decoder, ret); const size_t string_size = apir_decode_array_size_unchecked(decoder); - char * string = (char *) apir_decoder_alloc_array(sizeof(char), string_size); + char * string = (char *) apir_decoder_alloc_array(sizeof(char), string_size); if (!string) { GGML_LOG_ERROR(GGML_VIRTGPU "%s: Could not allocate the device name buffer\n", __func__); return NULL; @@ -173,7 +173,7 @@ apir_buffer_context_t apir_device_buffer_from_ptr(virtgpu * gpu, size_t size, si REMOTE_CALL_PREPARE(gpu, encoder, APIR_COMMAND_TYPE_DEVICE_BUFFER_FROM_PTR); if (virtgpu_shmem_create(gpu, size, &buffer_context.shmem)) { - GGML_ABORT(GGML_VIRTGPU "Couldn't allocate the guest-host shared buffer"); + GGML_ABORT(GGML_VIRTGPU "%s: Couldn't allocate %ldb of guest-host shared buffer", __func__, size); } apir_encode_virtgpu_shmem_res_id(encoder, buffer_context.shmem.res_id); diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h b/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h index f23c75bb96..4d0b6e05c7 100644 --- a/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h +++ b/ggml/src/ggml-virtgpu/virtgpu-forward-impl.h @@ -1,29 +1,36 @@ -#include "virtgpu.h" +#pragma once +// clang-format off +#include "virtgpu.h" #include "ggml-remoting.h" #include "backend/shared/apir_backend.h" #include "backend/shared/apir_cs_ggml.h" - #include "ggml-backend-impl.h" +// clang-format on -#define REMOTE_CALL_PREPARE(gpu_dev_name, encoder_name, apir_command_type__) \ - do { \ - int32_t forward_flag = (int32_t) apir_command_type__; \ - encoder_name = remote_call_prepare(gpu_dev_name, APIR_COMMAND_TYPE_FORWARD, forward_flag); \ - if (!encoder_name) { \ - GGML_ABORT(GGML_VIRTGPU "%s: failed to prepare the remote call encoder", __func__); \ - } \ +#define REMOTE_CALL_PREPARE(gpu_dev_name, encoder_name, apir_command_type__) \ + int32_t REMOTE_CALL_PREPARE_forward_flag = (int32_t) apir_command_type__; \ + const char * REMOTE_CALL_PREPARE_command_name = apir_dispatch_command_name(apir_command_type__); \ + do { \ + encoder_name = remote_call_prepare(gpu_dev_name, APIR_COMMAND_TYPE_FORWARD, REMOTE_CALL_PREPARE_forward_flag); \ + if (!encoder_name) { \ + GGML_ABORT(GGML_VIRTGPU "%s: failed to prepare the remote call encoder", __func__); \ + } \ } while (0) -#define REMOTE_CALL(gpu_dev_name, encoder_name, decoder_name, ret_name) \ - do { \ - ret_name = (ApirForwardReturnCode) remote_call(gpu_dev_name, encoder_name, &decoder_name, 0, NULL); \ - if (!decoder_name) { \ - GGML_ABORT(GGML_VIRTGPU "%s: failed to kick the remote call", __func__); \ - } \ - if (ret_name < APIR_FORWARD_BASE_INDEX) { \ - GGML_ABORT(GGML_VIRTGPU "%s: failed to forward the API call: %s: code %d", __func__, \ - apir_forward_error(ret_name), ret_name); \ - } \ - ret_name = (ApirForwardReturnCode) (ret_name - APIR_FORWARD_BASE_INDEX); \ +#define REMOTE_CALL(gpu_dev_name, encoder_name, decoder_name, ret_name) \ + do { \ + ret_name = (ApirForwardReturnCode) remote_call(gpu_dev_name, encoder_name, &decoder_name, 0, NULL); \ + if (!decoder_name) { \ + GGML_ABORT(GGML_VIRTGPU "%s: failed to kick the remote call", __func__); \ + } \ + if (ret_name < APIR_FORWARD_BASE_INDEX) { \ + GGML_ABORT(GGML_VIRTGPU "%s: failed to forward the API call: %s: code %d", __func__, \ + apir_forward_error(ret_name), ret_name); \ + } \ + ret_name = (ApirForwardReturnCode) (ret_name - APIR_FORWARD_BASE_INDEX); \ + if (ret_name != 0) { \ + GGML_ABORT(GGML_VIRTGPU "backend function '%s' failed (return code: %d)", \ + REMOTE_CALL_PREPARE_command_name, ret_name); \ + } \ } while (0) diff --git a/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h b/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h index fe4cae2025..44b0ad1ffa 100644 --- a/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h +++ b/ggml/src/ggml-virtgpu/virtgpu-forward.gen.h @@ -20,6 +20,7 @@ apir_buffer_context_t apir_device_buffer_from_ptr(struct virtgpu * gpu, char * apir_buffer_type_get_name(struct virtgpu * gpu, apir_buffer_type_host_handle_t host_handle); size_t apir_buffer_type_get_alignment(struct virtgpu * gpu, apir_buffer_type_host_handle_t host_handle); size_t apir_buffer_type_get_max_size(struct virtgpu * gpu, apir_buffer_type_host_handle_t host_handle); +/* apir_buffer_type_is_host is deprecated. */ apir_buffer_context_t apir_buffer_type_alloc_buffer(struct virtgpu * gpu, apir_buffer_type_host_handle_t host_handle, size_t size); diff --git a/ggml/src/ggml-virtgpu/virtgpu.cpp b/ggml/src/ggml-virtgpu/virtgpu.cpp index 1e650dc65b..a84a77399d 100644 --- a/ggml/src/ggml-virtgpu/virtgpu.cpp +++ b/ggml/src/ggml-virtgpu/virtgpu.cpp @@ -53,9 +53,9 @@ static int virtgpu_handshake(virtgpu * gpu) { if (!decoder) { GGML_ABORT(GGML_VIRTGPU - "%s: failed to initiate the communication with the virglrenderer library. " - "Most likely, the wrong virglrenderer library was loaded in the hypervisor.", - __func__); + "%s: failed to initiate the communication with the virglrenderer library. " + "Most likely, the wrong virglrenderer library was loaded in the hypervisor.", + __func__); return 1; } @@ -65,8 +65,7 @@ static int virtgpu_handshake(virtgpu * gpu) { uint32_t host_minor; if (ret_magic != APIR_HANDSHAKE_MAGIC) { - GGML_ABORT(GGML_VIRTGPU - "%s: handshake with the virglrenderer failed (code=%d | %s)", __func__, ret_magic, + GGML_ABORT(GGML_VIRTGPU "%s: handshake with the virglrenderer failed (code=%d | %s)", __func__, ret_magic, apir_backend_initialize_error(ret_magic)); } else { apir_decode_uint32_t(decoder, &host_major); @@ -140,15 +139,13 @@ static ApirLoadLibraryReturnCode virtgpu_load_library(virtgpu * gpu) { "Make sure virglrenderer is correctly configured by the hypervisor. (%s) ", __func__, apir_load_library_error(ret)); } else { - GGML_ABORT(GGML_VIRTGPU - "%s: virglrenderer could not load the API Remoting backend library. (%s - code %d)", __func__, - apir_load_library_error(ret), ret); + GGML_ABORT(GGML_VIRTGPU "%s: virglrenderer could not load the API Remoting backend library. (%s - code %d)", + __func__, apir_load_library_error(ret), ret); } return ret; } - GGML_LOG_INFO(GGML_VIRTGPU - "%s: virglrenderer successfully loaded the API Remoting backend library.\n", __func__); + GGML_LOG_INFO(GGML_VIRTGPU "%s: virglrenderer successfully loaded the API Remoting backend library.\n", __func__); ApirLoadLibraryReturnCode apir_ret = (ApirLoadLibraryReturnCode) (ret - APIR_LOAD_LIBRARY_INIT_BASE_INDEX); @@ -158,10 +155,11 @@ static ApirLoadLibraryReturnCode virtgpu_load_library(virtgpu * gpu) { "Make sure virglrenderer is correctly configured by the hypervisor. (%s)", __func__, apir_load_library_error(apir_ret)); } else if (apir_ret == APIR_LOAD_LIBRARY_SYMBOL_MISSING) { - GGML_ABORT(GGML_VIRTGPU - "%s: the API Remoting backend library couldn't load the GGML backend library, some symbols are missing. " - "Make sure virglrenderer is correctly configured by the hypervisor. (%s)", - __func__, apir_load_library_error(apir_ret)); + GGML_ABORT( + GGML_VIRTGPU + "%s: the API Remoting backend library couldn't load the GGML backend library, some symbols are missing. " + "Make sure virglrenderer is correctly configured by the hypervisor. (%s)", + __func__, apir_load_library_error(apir_ret)); } else if (apir_ret < APIR_LOAD_LIBRARY_INIT_BASE_INDEX) { GGML_ABORT(GGML_VIRTGPU "%s: the API Remoting backend library couldn't load the GGML backend library: apir code=%d | %s)", @@ -169,8 +167,8 @@ static ApirLoadLibraryReturnCode virtgpu_load_library(virtgpu * gpu) { } else { uint32_t lib_ret = apir_ret - APIR_LOAD_LIBRARY_INIT_BASE_INDEX; GGML_ABORT(GGML_VIRTGPU - "%s: the API Remoting backend library initialize its backend library: apir code=%d)", __func__, - lib_ret); + "%s: the API Remoting backend library failed to initialize its backend library: apir code=%d)", + __func__, lib_ret); } return ret; } @@ -184,55 +182,49 @@ virtgpu * create_virtgpu() { // Initialize mutex to protect shared data_shmem buffer if (mtx_init(&gpu->data_shmem_mutex, mtx_plain) != thrd_success) { delete gpu; - GGML_ABORT(GGML_VIRTGPU - "%s: failed to initialize data_shmem mutex", __func__); + GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize data_shmem mutex", __func__); return NULL; } if (virtgpu_open(gpu) != APIR_SUCCESS) { - GGML_LOG_ERROR(GGML_VIRTGPU - "%s: failed to open the virtgpu device\n", __func__); + GGML_LOG_ERROR(GGML_VIRTGPU "%s: failed to open the virtgpu device\n", __func__); return NULL; } if (virtgpu_init_capset(gpu) != APIR_SUCCESS) { if (gpu->use_apir_capset) { GGML_ABORT(GGML_VIRTGPU - "%s: failed to initialize the virtgpu APIR capset. Make sure that the virglrenderer library supports it.", __func__); + "%s: failed to initialize the virtgpu APIR capset. Make sure that the virglrenderer library " + "supports it.", + __func__); } else { - GGML_ABORT(GGML_VIRTGPU - "%s: failed to initialize the virtgpu Venus capset", __func__); + GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the virtgpu Venus capset", __func__); } return NULL; } if (virtgpu_init_context(gpu) != APIR_SUCCESS) { - GGML_ABORT(GGML_VIRTGPU - "%s: failed to initialize the GPU context", __func__); + GGML_ABORT(GGML_VIRTGPU "%s: failed to initialize the GPU context", __func__); return NULL; } if (virtgpu_shmem_create(gpu, SHMEM_REPLY_SIZE, &gpu->reply_shmem)) { - GGML_ABORT(GGML_VIRTGPU - "%s: failed to create the shared reply memory pages", __func__); + GGML_ABORT(GGML_VIRTGPU "%s: failed to create the shared reply memory pages", __func__); return NULL; } if (virtgpu_shmem_create(gpu, SHMEM_DATA_SIZE, &gpu->data_shmem)) { - GGML_ABORT(GGML_VIRTGPU - "%s: failed to create the shared data memory pages", __func__); + GGML_ABORT(GGML_VIRTGPU "%s: failed to create the shared data memory pages", __func__); return NULL; } if (virtgpu_handshake(gpu)) { - GGML_ABORT(GGML_VIRTGPU - "%s: failed to handshake with the virglrenderer library", __func__); + GGML_ABORT(GGML_VIRTGPU "%s: failed to handshake with the virglrenderer library", __func__); return NULL; } if (virtgpu_load_library(gpu) != APIR_LOAD_LIBRARY_SUCCESS) { - GGML_ABORT(GGML_VIRTGPU - "%s: failed to load the backend library", __func__); + GGML_ABORT(GGML_VIRTGPU "%s: failed to load the backend library", __func__); return NULL; } @@ -243,8 +235,7 @@ static virt_gpu_result_t virtgpu_open(virtgpu * gpu) { drmDevicePtr devs[8]; int count = drmGetDevices2(0, devs, ARRAY_SIZE(devs)); if (count < 0) { - GGML_LOG_ERROR(GGML_VIRTGPU - "%s: failed to enumerate DRM devices\n", __func__); + GGML_LOG_ERROR(GGML_VIRTGPU "%s: failed to enumerate DRM devices\n", __func__); return APIR_ERROR_INITIALIZATION_FAILED; } @@ -266,19 +257,17 @@ static virt_gpu_result_t virtgpu_open_device(virtgpu * gpu, const drmDevicePtr d int fd = open(node_path, O_RDWR | O_CLOEXEC); if (fd < 0) { - GGML_ABORT(GGML_VIRTGPU - "%s: failed to open %s", __func__, node_path); + GGML_ABORT(GGML_VIRTGPU "%s: failed to open %s", __func__, node_path); return APIR_ERROR_INITIALIZATION_FAILED; } drmVersionPtr version = drmGetVersion(fd); if (!version || strcmp(version->name, "virtio_gpu") || version->version_major != 0) { if (version) { - GGML_LOG_ERROR(GGML_VIRTGPU - "%s: unknown DRM driver %s version %d\n", __func__, version->name, version->version_major); + GGML_LOG_ERROR(GGML_VIRTGPU "%s: unknown DRM driver %s version %d\n", __func__, version->name, + version->version_major); } else { - GGML_LOG_ERROR(GGML_VIRTGPU - "%s: failed to get DRM driver version\n", __func__); + GGML_LOG_ERROR(GGML_VIRTGPU "%s: failed to get DRM driver version\n", __func__); } if (version) { @@ -322,9 +311,8 @@ static virt_gpu_result_t virtgpu_init_capset(virtgpu * gpu) { virtgpu_ioctl_get_caps(gpu, gpu->capset.id, gpu->capset.version, &gpu->capset.data, sizeof(gpu->capset.data)); if (ret) { - GGML_LOG_ERROR(GGML_VIRTGPU - "%s: failed to get APIR v%d capset: %s\n", - __func__, gpu->capset.version, strerror(errno)); + GGML_LOG_ERROR(GGML_VIRTGPU "%s: failed to get APIR v%d capset: %s\n", __func__, gpu->capset.version, + strerror(errno)); return APIR_ERROR_INITIALIZATION_FAILED; } @@ -547,13 +535,10 @@ static void log_call_duration(long long call_duration_ns, const char * name) { double call_duration_s = (double) call_duration_ns / 1e9; // 1 second = 1e9 nanoseconds if (call_duration_s > 1) { - GGML_LOG_INFO(GGML_VIRTGPU - "waited %.2fs for the %s host reply...\n", call_duration_s, name); + GGML_LOG_INFO(GGML_VIRTGPU "waited %.2fs for the %s host reply...\n", call_duration_s, name); } else if (call_duration_ms > 1) { - GGML_LOG_INFO(GGML_VIRTGPU - "waited %.2fms for the %s host reply...\n", call_duration_ms, name); + GGML_LOG_INFO(GGML_VIRTGPU "waited %.2fms for the %s host reply...\n", call_duration_ms, name); } else { - GGML_LOG_INFO(GGML_VIRTGPU - "waited %lldns for the %s host reply...\n", call_duration_ns, name); + GGML_LOG_INFO(GGML_VIRTGPU "waited %lldns for the %s host reply...\n", call_duration_ns, name); } } diff --git a/ggml/src/ggml-virtgpu/virtgpu.h b/ggml/src/ggml-virtgpu/virtgpu.h index 68e0f3a376..f82d8fb50b 100644 --- a/ggml/src/ggml-virtgpu/virtgpu.h +++ b/ggml/src/ggml-virtgpu/virtgpu.h @@ -1,5 +1,6 @@ #pragma once +// clang-format off #include "virtgpu-utils.h" #include "virtgpu-shm.h" #include "virtgpu-apir.h" @@ -23,20 +24,21 @@ #include "apir_hw.h" #include #include "venus_hw.h" +// clang-format on #ifndef VIRTGPU_DRM_CAPSET_APIR // Will be defined include/drm/virtgpu_drm.h when // https://gitlab.freedesktop.org/virgl/virglrenderer/-/merge_requests/1590/diffs // is merged -#define VIRTGPU_DRM_CAPSET_APIR 10 +# define VIRTGPU_DRM_CAPSET_APIR 10 #endif // Mesa/Virlgrenderer Venus internal. Only necessary during the // Venus->APIR transition in Virglrenderer #define VENUS_COMMAND_TYPE_LENGTH 331 -#ifndef VIRTGPU_DRM_CAPSET_VENUS // only available with Linux >= v6.16 -#define VIRTGPU_DRM_CAPSET_VENUS 4 +#ifndef VIRTGPU_DRM_CAPSET_VENUS // only available with Linux >= v6.16 +# define VIRTGPU_DRM_CAPSET_VENUS 4 #endif typedef uint32_t virgl_renderer_capset; diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index a8840a0773..0fae68628b 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -403,19 +403,20 @@ enum FaCodePath { }; struct vk_fa_pipeline_state { - vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, bool small_cache, FaCodePath path, bool aligned, bool f32acc, uint32_t flags) - : HSK(HSK), HSV(HSV), small_rows(small_rows), small_cache(small_cache), path(path), aligned(aligned), f32acc(f32acc), flags(flags) {} - uint32_t HSK, HSV; - bool small_rows, small_cache; + uint32_t Br, Bc; + uint32_t D_split, row_split; + bool shmem_staging; FaCodePath path; + uint32_t workgroup_size, subgroup_size; bool aligned; bool f32acc; uint32_t flags; + uint32_t limit_occupancy_shmem; bool operator<(const vk_fa_pipeline_state &b) const { - return std::tie(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags) < - std::tie(b.HSK, b.HSV, b.small_rows, b.small_cache, b.path, b.aligned, b.f32acc, b.flags); + return std::tie(HSK, HSV, Br, Bc, D_split, row_split, shmem_staging, path, workgroup_size, subgroup_size, aligned, f32acc, flags, limit_occupancy_shmem) < + std::tie(b.HSK, b.HSV, b.Br, b.Bc, b.D_split, b.row_split, b.shmem_staging, b.path, b.workgroup_size, b.subgroup_size, b.aligned, b.f32acc, b.flags, b.limit_occupancy_shmem); } }; @@ -1656,6 +1657,7 @@ static bool vk_perf_logger_concurrent = false; static bool vk_enable_sync_logger = false; // number of calls between perf logger prints static uint32_t vk_perf_logger_frequency = 1; +static std::string vk_pipeline_stats_filter; class vk_perf_logger { public: @@ -2172,7 +2174,32 @@ static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipelin executableInfo.pipeline = pipeline->pipeline; auto statistics = device->device.getPipelineExecutableStatisticsKHR(executableInfo); + + bool print_stats = !vk_pipeline_stats_filter.empty() && + pipeline->name.find(vk_pipeline_stats_filter) != std::string::npos; + if (print_stats) { + std::cerr << "ggml_vulkan: pipeline stats for " << pipeline->name << ":" << std::endl; + } + for (auto & s : statistics) { + if (print_stats) { + std::cerr << "ggml_vulkan: " << s.name.data() << ": "; + switch (s.format) { + case vk::PipelineExecutableStatisticFormatKHR::eBool32: + std::cerr << (s.value.b32 ? "true" : "false"); + break; + case vk::PipelineExecutableStatisticFormatKHR::eInt64: + std::cerr << s.value.i64; + break; + case vk::PipelineExecutableStatisticFormatKHR::eUint64: + std::cerr << s.value.u64; + break; + case vk::PipelineExecutableStatisticFormatKHR::eFloat64: + std::cerr << s.value.f64; + break; + } + std::cerr << std::endl; + } // "Register Count" is reported by NVIDIA drivers. if (strcmp(s.name, "Register Count") == 0) { VK_LOG_DEBUG(pipeline->name << " " << s.name << ": " << s.value.u64 << " registers"); @@ -2755,78 +2782,218 @@ static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events ); } -// number of rows/cols for flash attention shader -static constexpr uint32_t flash_attention_num_small_rows = 32; -static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; +struct vk_fa_tuning_params { + FaCodePath path; + uint32_t workgroup_size; + uint32_t subgroup_size; + uint32_t block_rows; + uint32_t block_cols; + uint32_t d_split; + uint32_t row_split; + bool shmem_staging; + bool disable_subgroups; + uint32_t limit_occupancy_shmem; -static uint32_t get_fa_scalar_num_large_rows(uint32_t hsk, uint32_t hsv, bool small_cache) { - if (hsv >= 192) { - return 2; - } else if ((hsv | hsk) & 8 || small_cache) { - return 4; - } else { - return 8; + void print() const { + std::cerr << "path=" << path << " workgroup_size=" << workgroup_size << " subgroup_size=" << subgroup_size << + " block_rows=" << block_rows << " block_cols=" << block_cols << " d_split=" << d_split << + " row_split=" << row_split << " shmem_staging=" << shmem_staging << " disable_subgroups=" << disable_subgroups << + " limit_occupancy_shmem=" << limit_occupancy_shmem << std::endl; } -} +}; -// The FA coopmat1 shader assumes 16x16x16 matrix multiply support. -// 128 threads split into four subgroups, each subgroup does 1/4 -// of the Bc dimension. -static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16; -static constexpr uint32_t scalar_flash_attention_Bc = 64; -static constexpr uint32_t scalar_flash_attention_workgroup_size = 128; +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc); +static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc); -static uint32_t get_fa_num_small_rows(FaCodePath path) { - if (path == FA_COOPMAT2) { - return flash_attention_num_small_rows; +static vk_fa_tuning_params get_fa_tuning_params_scalar(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) { + GGML_UNUSED(kv_type); + + vk_fa_tuning_params result{}; + result.path = FA_SCALAR; + + if (device->vendor_id == VK_VENDOR_ID_INTEL) { + // Disable subgroup use due to performance issues when enforcing subgroup sizes + result.subgroup_size = 32; + result.disable_subgroups = true; + } else if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN) { + result.subgroup_size = n_rows < 4 ? 32 : device->subgroup_size; } else { - return scalar_flash_attention_num_small_rows; + result.subgroup_size = device->subgroup_size; } -} -static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) { - GGML_UNUSED(clamp); + // Row split splits the workgroup so that synchronization only has to happen within subgroups, which avoids barriers + uint32_t row_split_max_hsk = 64; + if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != AMD_GCN && !device->uma) { + row_split_max_hsk = n_rows <= 8 ? 64 : 128; + } + result.row_split = (n_rows < 4 || hsk <= row_split_max_hsk) ? 1 : 4; - if (path == FA_SCALAR) { - if (small_rows) { - return {scalar_flash_attention_num_small_rows, 64}; + if (result.subgroup_size > 32 && (n_rows < 4 || hsk < (result.row_split == 1 ? 128 : 64))) { + result.workgroup_size = result.subgroup_size * 2; + } else { + result.workgroup_size = result.subgroup_size * 4; + } + + const uint32_t D = hsk | hsv; + + const bool reduce_block_rows = D & 8 || n_kv < 1024 || device->vendor_id == VK_VENDOR_ID_INTEL; + + if (n_rows == 1) { + result.block_rows = 1; + result.block_cols = 64; + } else { + // row_split 1 means higher register use per row, so block size has to be adjusted + if (result.row_split == 1) { + result.block_rows = n_rows == 2 ? 2 : ((n_rows <= 4 || reduce_block_rows) ? 4 : 8); } else { - if ((hsv | hsk) & 8) { - // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter - // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not. - return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 64}; - } else { - return {get_fa_scalar_num_large_rows(hsk, hsv, small_cache), 32}; - } + result.block_rows = n_rows <= 4 ? 4 : ((n_rows <= 8 || reduce_block_rows) ? 8 : 16); } + + result.block_cols = (D & 8) ? 64 : 32; + } + + const uint32_t D_lsb = D ^ (D & (D-1)); // extract lowest set bit + + result.d_split = std::min(std::min(result.subgroup_size, 8u), D_lsb / 4); + + result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0; + + if (!reduce_block_rows && !ggml_vk_flash_attn_scalar_shmem_support(device, result, hsk, hsv, f32acc)) { + result.block_rows /= 2; + } + + // On AMD RDNA, for small head sizes and big batch size the shader uses few registers, so too many subgroups get scheduled + // at once and end up thrashing the cache. Fix this by setting a large (unused) shmem buffer that reduces occupancy. + // This targets an occupancy of 4 subgroups per SIMD. + if (device->vendor_id == VK_VENDOR_ID_AMD && device->properties.limits.maxComputeSharedMemorySize == 65536) { + if (device->architecture != AMD_GCN && n_rows >= 64 && hsk <= 128) { + // 30kb target for hsk > 64, 26kb for <= 64 due to smaller workgroup size + // Values are guessed, tested on RDNA2 + result.limit_occupancy_shmem = (hsk <= 64 ? 26 : 30) * 1024 / 4 / 4; + } else if (device->architecture == AMD_GCN && n_rows <= 8 && hsk >= 256) { + // Same thing for GCN, with an occupancy target of 2 subgroups per SIMD. + // Here low-batch FA with large head size is affected. + // n_rows < 4 switch because workgroup size switches from 128 to 256 there. + result.limit_occupancy_shmem = (n_rows < 4 ? 14 : 26) * 1024 / 4 / 4; + } + } + + return result; +} + +static vk_fa_tuning_params get_fa_tuning_params_coopmat1(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) { + GGML_UNUSED(n_rows); + GGML_UNUSED(n_kv); + GGML_UNUSED(kv_type); + GGML_UNUSED(f32acc); + + vk_fa_tuning_params result{}; + result.path = FA_COOPMAT1; + + const uint32_t D = hsk | hsv; + + const uint32_t coopmat_block_rows = 16; + const uint32_t coopmat_block_cols = 16; + + const uint32_t num_subgroups = 4; + + result.block_rows = coopmat_block_rows; + result.block_cols = coopmat_block_cols * num_subgroups; + result.row_split = num_subgroups; + result.subgroup_size = device->subgroup_size; + result.workgroup_size = num_subgroups * result.subgroup_size; + + const uint32_t D_lsb = D ^ (D & (D-1)); // extract lowest set bit + result.d_split = std::min(std::min(result.subgroup_size, 8u), D_lsb / 4); + + result.shmem_staging = (device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 && hsv < 256) ? 1 : 0; + + return result; +} + +static vk_fa_tuning_params get_fa_tuning_params_coopmat2(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) { + GGML_UNUSED(n_kv); + GGML_UNUSED(f32acc); + + vk_fa_tuning_params result{}; + result.path = FA_COOPMAT2; + + const uint32_t D = hsk | hsv; + + const bool small_rows = n_rows < 32; + + if (small_rows) { + result.block_rows = 32; + result.block_cols = 32; + } else if (ggml_is_quantized(kv_type) || hsk >= 256 || hsv >= 256) { + result.block_rows = (hsk >= 512 || hsv >= 512) ? 32 : 64; + result.block_cols = 32; + } else { + result.block_rows = 64; + result.block_cols = 64; + } + + result.subgroup_size = device->subgroup_size; + result.workgroup_size = (small_rows && (D % 32) == 0) ? 256 : 128; + + return result; +} + +static vk_fa_tuning_params get_fa_tuning_params(const vk_device& device, uint32_t hsk, uint32_t hsv, uint32_t n_rows, uint32_t n_kv, ggml_type kv_type, bool f32acc) { + FaCodePath path = device->coopmat2 ? FA_COOPMAT2 : + device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR; + + if (path == FA_COOPMAT1 && device->architecture == vk_device_architecture::NVIDIA_TURING) { + // Nvidia compiler bug, see https://github.com/ggml-org/llama.cpp/pull/19075#issuecomment-3820716090 + path = FA_SCALAR; } if (path == FA_COOPMAT1) { - if (small_rows) { - return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc}; - } else { - return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc}; + bool shape_ok = (f32acc && device->coopmat_support_16x16x16_f32acc) || + (!f32acc && device->coopmat_support_16x16x16_f16acc); + const vk_fa_tuning_params params = get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc); + bool shmem_ok = ggml_vk_flash_attn_coopmat_shmem_support(device, params, hsk, hsv, f32acc); + + if (!shape_ok || !shmem_ok) { + path = FA_SCALAR; } } - // small rows, large cols - if (small_rows) { - return {get_fa_num_small_rows(FA_COOPMAT2), 32}; + // scalar is faster than coopmat when N==1 + if (n_rows == 1 && (path == FA_COOPMAT1 || path == FA_COOPMAT2)) { + path = FA_SCALAR; } - // small cols to reduce register count - if (ggml_is_quantized(type) || hsk >= 256 || hsv >= 256) { - if (hsk >= 512 || hsv >= 512) { - return {32, 32}; - } else { - return {64, 32}; - } + switch (path) { + case FA_SCALAR: + return get_fa_tuning_params_scalar(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc); + case FA_COOPMAT1: + return get_fa_tuning_params_coopmat1(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc); + case FA_COOPMAT2: + return get_fa_tuning_params_coopmat2(device, hsk, hsv, n_rows, n_kv, kv_type, f32acc); + default: + throw std::runtime_error("unsupported FaCodePath"); } - return {64, 64}; } -static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows, bool small_cache) { - return fa_rows_cols(path, hsk, hsv, 0, type, small_rows, small_cache)[1]; +static vk_fa_pipeline_state get_fa_pipeline_state(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool aligned, bool f32acc, + bool use_mask, bool use_mask_opt, bool use_logit_softcap) { + const bool old_amd_windows = device->vendor_id == VK_VENDOR_ID_AMD && device->driver_id == vk::DriverId::eAmdProprietary && + (device->architecture == AMD_GCN || device->architecture == AMD_RDNA1 || device->architecture == AMD_RDNA2); + + uint32_t flags = (use_mask_opt ? 1 : 0) | + (use_mask ? 2 : 0) | + (use_logit_softcap ? 4 : 0) | + (old_amd_windows ? 8 : 0); + + const uint32_t subgroup_size = params.disable_subgroups ? 0 : params.subgroup_size; + + return vk_fa_pipeline_state{hsk, hsv, params.block_rows, params.block_cols, params.d_split, params.row_split, params.shmem_staging, params.path, params.workgroup_size, subgroup_size, aligned, f32acc, flags, params.limit_occupancy_shmem}; +} + +static std::vector get_fa_spec_constants(const vk_fa_pipeline_state& state) { + return {state.workgroup_size, state.Br, state.Bc, state.HSK, state.HSV, !state.aligned, state.D_split, + state.row_split, state.subgroup_size, state.shmem_staging ? 1u : 0u, state.flags, state.limit_occupancy_shmem}; } static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) { @@ -3193,76 +3360,43 @@ static void ggml_vk_load_shaders(vk_device& device) { align, disable_robustness, require_full_subgroups, required_subgroup_size); }; - auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache) -> std::array { - return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache)[0], 1, 1}; - }; - - auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows, bool small_cache, uint32_t flags) -> std::vector { - // For large number of rows, 128 invocations seems to work best. - // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we - // can't use 256 for D==80. - // For scalar, use 128 (arbitrary) - // The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs. - const uint32_t D = (hsk|hsv); - auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows, small_cache); - - uint32_t wg_size; - switch (path) { - case FA_COOPMAT2: - wg_size = ((small_rows && (D % 32) == 0) ? 256 : 128); - break; - case FA_COOPMAT1: - wg_size = (rows_cols[1] / 16) * device->subgroup_size; // enough subgroups for Bc/MatBc - break; - default: - wg_size = scalar_flash_attention_workgroup_size; - break; - } - - // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. - // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. - const uint32_t D_lsb = D ^ (D & (D-1)); - uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4); - - // Nvidia prefers shared memory use to load large tiles of K. - // Switch to loading from global memory when it would use too much shared memory. - // AMD prefers loading K directly from global memory - const uint32_t k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256 ? 1 : 0; - - return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split, device->subgroup_size, k_load_shmem, flags}; - }; - #define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \ - uint32_t HSK = fa.first.HSK; \ - uint32_t HSV = fa.first.HSV; \ - bool small_rows = fa.first.small_rows; \ - bool small_cache = fa.first.small_cache; \ FaCodePath path = fa.first.path; \ + uint32_t Br = fa.first.Br; \ + uint32_t Bc = fa.first.Bc; \ bool aligned = fa.first.aligned; \ bool f32acc = fa.first.f32acc; \ - uint32_t flags = fa.first.flags; \ + uint32_t fa_sgs = fa.first.subgroup_size; \ + bool fa_ds = fa.first.subgroup_size == 0; \ if (path == FAPATH) { \ if (aligned) { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows,small_cache,flags), fa_align(FAPATH,HSK,HSV,TYPE,small_rows,small_cache), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), Bc, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ } \ } else { \ if (f32acc) { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ } else { \ - ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows,small_cache,flags), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? device->subgroup_size : 0)); \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 7, sizeof(vk_flash_attn_push_constants), {Br, 1, 1}, get_fa_spec_constants(fa.first), 1, true, (!fa_ds && (FAPATH!=FA_COOPMAT2)), ((!fa_ds && (FAPATH!=FA_COOPMAT2)) ? fa_sgs : 0)); \ } \ } \ } \ } - CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) - CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) + if (device->fp16) { + CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) + } else { + CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, _fp32) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, _fp32) + } #if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) if (device->coopmat1_fa_support) { CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1) @@ -3780,10 +3914,12 @@ static void ggml_vk_load_shaders(vk_device& device) { && !device->coopmat_bf16_support #endif ) { + const uint32_t s_warptile_wm = device->subgroup_size == 8 ? 8 : 32; + // use scalar tile sizes l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 }; m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, 4, 2, 1, subgroup_size_8 }; - s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, 1, subgroup_size_8 }; + s_warptile = { subgroup_size_32, 32, 32, 16, s_warptile_wm, 32, 2, 2, 2, 1, subgroup_size_8 }; l_wg_denoms = {128, 128, 1 }; m_wg_denoms = { 64, 64, 1 }; @@ -4533,6 +4669,7 @@ static void ggml_vk_load_shaders(vk_device& device) { } static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch); +static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev); static vk_device ggml_vk_get_device(size_t idx) { VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")"); @@ -4749,6 +4886,8 @@ static vk_device ggml_vk_get_device(size_t idx) { device->shader_core_count = sm_props.shaderSMCount; } else if (amd_shader_core_properties2) { device->shader_core_count = amd_shader_core_properties2_props.activeComputeUnitCount; + } else if (device->vendor_id == VK_VENDOR_ID_INTEL) { + device->shader_core_count = ggml_vk_intel_shader_core_count(device->physical_device); } else { device->shader_core_count = 0; } @@ -4968,11 +5107,7 @@ static vk_device ggml_vk_get_device(size_t idx) { #if defined(VK_KHR_cooperative_matrix) device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix; - - // coopmat1 fa shader currently assumes 32 invocations per subgroup - device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support && - device->subgroup_size_control && device->subgroup_min_size <= 32 && - device->subgroup_max_size >= 32; + device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support; #endif if (coopmat2_support) { @@ -5540,6 +5675,10 @@ static void ggml_vk_instance_init() { vk_perf_logger_concurrent = getenv("GGML_VK_PERF_LOGGER_CONCURRENT") != nullptr; vk_enable_sync_logger = getenv("GGML_VK_SYNC_LOGGER") != nullptr; vk_memory_logger_enabled = getenv("GGML_VK_MEMORY_LOGGER") != nullptr; + const char* GGML_VK_PIPELINE_STATS = getenv("GGML_VK_PIPELINE_STATS"); + if (GGML_VK_PIPELINE_STATS != nullptr) { + vk_pipeline_stats_filter = GGML_VK_PIPELINE_STATS; + } const char* GGML_VK_PERF_LOGGER_FREQUENCY = getenv("GGML_VK_PERF_LOGGER_FREQUENCY"); if (GGML_VK_PERF_LOGGER_FREQUENCY != nullptr) { @@ -8419,21 +8558,27 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx } } -static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool small_cache) { +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) { + GGML_UNUSED(f32acc); // Needs to be kept up to date on shader changes - GGML_UNUSED(hsv); - const uint32_t wg_size = scalar_flash_attention_workgroup_size; - const uint32_t Br = get_fa_scalar_num_large_rows(hsk, hsv, small_cache); - const uint32_t Bc = scalar_flash_attention_Bc; + const uint32_t wg_size = params.workgroup_size; + const uint32_t Br = params.block_rows; + const uint32_t Bc = params.block_cols; + const uint32_t float_type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); + + // tmpsh is overestimated slightly const uint32_t tmpsh = wg_size * sizeof(float); - const uint32_t tmpshv4 = wg_size * 4 * sizeof(float); + const uint32_t tmpshv4 = wg_size * 4 * float_type_size; - const uint32_t masksh = Bc * Br * sizeof(float); + const uint32_t masksh = Bc * (Br + 1) * float_type_size; - const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float); + const uint32_t Qf = Br * (hsk / 4 + 1) * 4 * float_type_size; - const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf; + const uint32_t D = std::max(hsk, hsv); + const uint32_t kvsh = params.shmem_staging ? Bc * (D / 4 + 1) * 4 * float_type_size : 4 * float_type_size; + + const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf + kvsh; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; VK_LOG_DEBUG("ggml_vk_flash_attn_scalar_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported); @@ -8441,18 +8586,17 @@ static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, con return supported; } -static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc, ggml_type kv_type) { +static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const vk_fa_tuning_params& params, uint32_t hsk, uint32_t hsv, bool f32acc) { // Needs to be kept up to date on shader changes - GGML_UNUSED(hsv); - const auto rows_cols = fa_rows_cols(FA_COOPMAT1, hsk, hsv, 0, kv_type, false, false); - const uint32_t Br = rows_cols[0]; - const uint32_t Bc = rows_cols[1]; + const uint32_t Br = params.block_rows; + const uint32_t Bc = params.block_cols; const uint32_t MatBr = 16, MatBc = 16; const uint32_t row_split = Bc / MatBc; const uint32_t hsk_pad = ROUNDUP_POW2(hsk, 16); + const uint32_t hsv_pad = ROUNDUP_POW2(hsv, 16); const uint32_t acctype = f32acc ? 4 : 2; const uint32_t f16vec4 = 8; @@ -8468,17 +8612,19 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br; const uint32_t sfsh = Bc * sfshstride * acctype; - const bool k_load_shmem = device->vendor_id == VK_VENDOR_ID_NVIDIA && hsk < 256; - const uint32_t kshstride = (k_load_shmem ? hsk_pad : MatBr) / 4 + 2; + const uint32_t kvshstride = (params.shmem_staging ? std::max(hsk_pad, hsv_pad) : MatBr) / 4 + 2; const uint32_t vsh_stride = MatBc / 4 * row_split; - const uint32_t ksh = ((kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)) * f16vec4; + const uint32_t ksh = ((kvshstride >= vsh_stride) ? (Bc * kvshstride) : (Bc * vsh_stride)) * f16vec4; + + const uint32_t osh_stride = params.row_split * MatBr / 4; + const uint32_t pvsh = MatBc * osh_stride * f16vec4; const uint32_t slope = Br * acctype; - const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + slope; + const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + pvsh + slope; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; - VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", kv_type=" << kv_type << ", total_size=" << total_size << ", supported=" << supported); + VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported); return supported; } @@ -8536,48 +8682,18 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx assert(q->type == GGML_TYPE_F32); assert(k->type == v->type); - FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 : - ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR; - - if (path == FA_COOPMAT1 && ctx->device->architecture == vk_device_architecture::NVIDIA_TURING) { - // Nvidia compiler bug, see https://github.com/ggml-org/llama.cpp/pull/19075#issuecomment-3820716090 - path = FA_SCALAR; - } - - if (path == FA_COOPMAT1) { - const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) || - (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc); - - const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32, k->type); - - if (!coopmat_shape_supported || !coopmat_shmem_supported) { - path = FA_SCALAR; - } - } - uint32_t gqa_ratio = 1; uint32_t qk_ratio = neq2 / nek2; uint32_t workgroups_x = (uint32_t)neq1; uint32_t workgroups_y = (uint32_t)neq2; uint32_t workgroups_z = (uint32_t)neq3; - const bool small_cache = nek1 < 1024; + const bool f32acc = !ctx->device->fp16 || dst->op_params[3] == GGML_PREC_F32; // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga. // For coopmat2 FA, we always use the small size (which is still pretty large for gqa). - uint32_t max_gqa; - switch (path) { - case FA_SCALAR: - case FA_COOPMAT1: - // We may switch from coopmat1 to scalar, so use the scalar limit for both - max_gqa = get_fa_scalar_num_large_rows(HSK, HSV, small_cache); - break; - case FA_COOPMAT2: - max_gqa = get_fa_num_small_rows(FA_COOPMAT2); - break; - default: - GGML_ASSERT(0); - } + vk_fa_tuning_params tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, 512, KV, k->type, f32acc); + const uint32_t max_gqa = std::min(tuning_params.block_rows, 32u); if (N <= 8 && qk_ratio > 1 && qk_ratio <= max_gqa && qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) { @@ -8589,24 +8705,7 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx workgroups_y /= gqa_ratio; } - bool small_rows = N <= get_fa_num_small_rows(path); - - // coopmat1 does not actually support "small rows" (it needs 16 rows). - // So use scalar instead. - if (small_rows && path == FA_COOPMAT1) { - path = FA_SCALAR; - } - - // scalar is faster than coopmat2 when N==1 - if (N == 1 && path == FA_COOPMAT2) { - path = FA_SCALAR; - } - - // with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory - if (path == FA_SCALAR && - !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV, small_cache)) { - small_rows = true; - } + tuning_params = get_fa_tuning_params(ctx->device, HSK, HSV, N, KV, k->type, f32acc); const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type)); uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type)); @@ -8620,18 +8719,16 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx v_stride /= 4; } - uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows, small_cache); + const uint32_t alignment = tuning_params.block_cols; bool aligned = (KV % alignment) == 0 && // the "aligned" shader variant will forcibly align strides, for performance (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0; // Need to use the coopmat2 variant that clamps loads when HSK/HSV aren't sufficiently aligned. - if (((HSK | HSV) % 16) != 0 && path == FA_COOPMAT2) { + if (((HSK | HSV) % 16) != 0 && tuning_params.path == FA_COOPMAT2) { aligned = false; } - bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; - float scale = 1.0f; float max_bias = 0.0f; float logit_softcap = 0.0f; @@ -8646,12 +8743,8 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx // Only use mask opt when the mask is fairly large. This hasn't been tuned extensively. bool use_mask_opt = mask && nem1 >= 32 && nem0 * nem1 > 32768; - - uint32_t flags = (use_mask_opt ? 1 : 0) | - (mask != nullptr ? 2 : 0) | - (logit_softcap != 0 ? 4 : 0); - - vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, small_cache, path, aligned, f32acc, flags); + vk_fa_pipeline_state fa_pipeline_state = get_fa_pipeline_state(ctx->device, tuning_params, HSK, HSV, aligned, f32acc, + mask != nullptr, use_mask_opt, logit_softcap != 0); vk_pipeline pipeline = nullptr; @@ -8673,22 +8766,35 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx uint32_t split_kv = KV; uint32_t split_k = 1; + // Intel Alchemist prefers more workgroups + const uint32_t shader_core_count_multiplier = (ctx->device->vendor_id == VK_VENDOR_ID_INTEL && ctx->device->architecture != INTEL_XE2) ? 2 : 1; + // Use a placeholder core count if one isn't available. split_k is a big help for perf. - const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16; + const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count * shader_core_count_multiplier : 16; + + const uint32_t Br = fa_pipeline_state.Br; + const uint32_t Bc = fa_pipeline_state.Bc; + + GGML_ASSERT(Br == pipeline->wg_denoms[0]); + const uint32_t Tr = CEIL_DIV(N, Br); // Try to use split_k when KV is large enough to be worth the overhead. - // Must either be a single batch or be using gqa, we can't mix the two. - if (workgroups_x <= pipeline->wg_denoms[0] && (workgroups_x == 1 || gqa_ratio > 1)) { - // Try to run two workgroups per SM. + if (gqa_ratio > 1 && workgroups_x <= Br) { split_k = shader_core_count * 2 / (workgroups_x * workgroups_y * workgroups_z); - if (split_k > 1) { - // Try to evenly split KV into split_k chunks, but it needs to be a multiple - // of "align", so recompute split_k based on that. - split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment); - split_k = CEIL_DIV(KV, split_kv); + } else if (gqa_ratio <= 1) { + uint32_t total_wgs_no_split = Tr * workgroups_y * workgroups_z; + if (total_wgs_no_split < shader_core_count * 2) { + split_k = shader_core_count * 2 / total_wgs_no_split; } } + if (split_k > 1) { + // Try to evenly split KV into split_k chunks, but it needs to be a multiple + // of "align", so recompute split_k based on that. + split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment); + split_k = CEIL_DIV(KV, split_kv); + } + // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1) // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows. // For matrices, the order is (inner to outer) [HSV, ne1, k, ne2, ne3]. @@ -8702,10 +8808,6 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx ggml_vk_preallocate_buffers(ctx, subctx); } - auto rows_cols = fa_rows_cols(path, HSK, HSV, !aligned, k->type, small_rows, small_cache); - const uint32_t Br = rows_cols[0]; - const uint32_t Bc = rows_cols[1]; - const uint32_t mask_opt_num_dwords = CEIL_DIV(nem0, 16 * Bc); const uint64_t mask_opt_size = sizeof(uint32_t) * mask_opt_num_dwords * CEIL_DIV(nem1, Br) * nem2 * nem3; @@ -8785,15 +8887,21 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx if (ctx->prealloc_split_k_need_sync) { ggml_vk_sync_buffers(ctx, subctx); } - workgroups_x *= pipeline->wg_denoms[0]; + + // We reuse workgroups_x to mean the number of splits, so we need to + // cancel out the divide by wg_denoms[0]. + uint32_t dispatch_x; + if (gqa_ratio > 1) { + workgroups_x *= pipeline->wg_denoms[0]; + dispatch_x = split_k * workgroups_x; + } else { + dispatch_x = Tr * split_k * pipeline->wg_denoms[0]; + } + vk_subbuffer split_k_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {q_buf, k_buf, v_buf, mask_buf, sinks_buf, split_k_buf, mask_opt_buf}, - // We only use split_k when group query attention is enabled, which means - // there's no more than one tile of rows (i.e. workgroups_x would have been - // one). We reuse workgroups_x to mean the number of splits, so we need to - // cancel out the divide by wg_denoms[0]. - pc, { split_k * workgroups_x, workgroups_y, workgroups_z }); + pc, { dispatch_x, workgroups_y, workgroups_z }); ggml_vk_sync_buffers(ctx, subctx); const vk_op_flash_attn_split_k_reduce_push_constants pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, split_k, (sinks != nullptr) }; @@ -13710,12 +13818,11 @@ static bool ggml_vk_can_fuse_rope_set_rows(ggml_backend_vk_context * ctx, const return true; } -// Check whether the tensors overlap in memory but are not equal. -// Fusions can potenitally overwrite src tensors in ways that are not prevented -// by ggml-alloc. If the fusion is entirely elementwise, then it's OK for them -// to overlap if they are exactly equal. -// XXX TODO this check is probably missing from several fusion optimizations. -static bool ggml_vk_tensors_overlap_but_not_equal(const ggml_tensor * a, const ggml_tensor * b) { +// Check whether the tensors overlap in memory. +// Fusions can potentially overwrite src tensors in ways that are not prevented +// by ggml-alloc. If the fusion src is being applied in a way that's elementwise +// with the destination, then it's OK for them to overlap if they are exactly equal. +static bool ggml_vk_tensors_overlap(const ggml_tensor * a, const ggml_tensor * b, bool elementwise) { ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)a->buffer->context; vk_buffer a_buf = a_buf_ctx->dev_buffer; ggml_backend_vk_buffer_context * b_buf_ctx = (ggml_backend_vk_buffer_context *)b->buffer->context; @@ -13726,7 +13833,7 @@ static bool ggml_vk_tensors_overlap_but_not_equal(const ggml_tensor * a, const g auto b_base = vk_tensor_offset(b) + b->view_offs; auto b_size = ggml_nbytes(b); - if (a_base == b_base && a_size == b_size) { + if (elementwise && a_base == b_base && a_size == b_size) { return false; } @@ -13764,13 +13871,6 @@ static bool ggml_vk_can_fuse_rms_norm_mul_rope(ggml_backend_vk_context * ctx, co return false; } - // must not overwrite srcs in a way that's not elementwise - ggml_tensor *other_src = mul->src[0] == rms ? mul->src[1] : mul->src[0]; - if (ggml_vk_tensors_overlap_but_not_equal(rms->src[0], rope) || - ggml_vk_tensors_overlap_but_not_equal(other_src, rope)) { - return false; - } - // conditions for pipeline creation if (!(ctx->device->float_controls_rte_fp16 && sizeof(vk_op_rms_norm_mul_rope_push_constants) <= ctx->device->properties.limits.maxPushConstantsSize)) { @@ -13832,6 +13932,18 @@ static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const stru return num_adds; } +static int32_t find_first_set(uint32_t x) { + int32_t ret = 0; + if (!x) { + return -1; + } + while (!(x & 1)) { + x >>= 1; + ret++; + } + return ret; +} + static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)"); ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; @@ -13930,6 +14042,12 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg total_mul_mat_bytes += bytes; } + // op_srcs_fused_elementwise indicates whether an op's srcs all contribute to + // the fused result in an elementwise-way. This affects whether the memory for + // the src is allowed to overlap the memory for the destination. + // The array is sized to handle the largest fusion (asserted later). + bool op_srcs_fused_elementwise[12]; + ctx->fused_topk_moe_mode = TOPK_MOE_COUNT; ctx->fused_topk_moe_scale = false; const char *fusion_string {}; @@ -13938,39 +14056,68 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg if (num_adds) { ctx->num_additional_fused_ops = num_adds - 1; fusion_string = "MULTI_ADD"; + std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, true); } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_ADD })) { ctx->num_additional_fused_ops = 2; fusion_string = "MUL_MAT_ADD_ADD"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; + op_srcs_fused_elementwise[2] = true; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT, GGML_OP_ADD })) { ctx->num_additional_fused_ops = 1; fusion_string = "MUL_MAT_ADD"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL })) { ctx->num_additional_fused_ops = 2; fusion_string = "MUL_MAT_ID_ADD_ID_MUL"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; + op_srcs_fused_elementwise[2] = true; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID })) { ctx->num_additional_fused_ops = 1; fusion_string = "MUL_MAT_ID_ADD_ID"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_MUL_MAT_ID, GGML_OP_MUL })) { ctx->num_additional_fused_ops = 1; fusion_string = "MUL_MAT_ID_MUL"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 4 }) && ggml_check_edges(cgraph, i, rms_norm_mul_rope_view_set_rows_edges) && ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i) && ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i + 2)) { ctx->num_additional_fused_ops = 4; fusion_string = "RMS_NORM_MUL_ROPE_VIEW_SET_ROWS"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = false; + op_srcs_fused_elementwise[2] = false; + op_srcs_fused_elementwise[3] = false; + op_srcs_fused_elementwise[4] = false; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL, GGML_OP_ROPE })&& ggml_vk_can_fuse_rms_norm_mul_rope(ctx, cgraph, i)) { ctx->num_additional_fused_ops = 2; fusion_string = "RMS_NORM_MUL_ROPE"; + // rope is approximately elementwise - whole rows are done by a single workgroup and it's row-wise + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = true; + op_srcs_fused_elementwise[2] = true; } else if (ggml_vk_can_fuse(ctx, cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { ctx->num_additional_fused_ops = 1; fusion_string = "RMS_NORM_MUL"; + // rms_norm is not elementwise, but whole rows must be consumed and the scale factor computed before + // they are overwritten, and one workgroup per row. So close enough. + op_srcs_fused_elementwise[0] = true; + op_srcs_fused_elementwise[1] = true; } else if (ggml_can_fuse_subgraph(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, { i + 2 }) && ggml_check_edges(cgraph, i, rope_view_set_rows_edges) && ggml_vk_can_fuse_rope_set_rows(ctx, cgraph, i)) { ctx->num_additional_fused_ops = 2; fusion_string = "ROPE_VIEW_SET_ROWS"; + op_srcs_fused_elementwise[0] = false; + op_srcs_fused_elementwise[1] = false; + op_srcs_fused_elementwise[2] = false; } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax_norm, { i + 3, i + 9 }) && ggml_check_edges(cgraph, i, topk_moe_early_softmax_norm_edges) && ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX_NORM)) { @@ -13979,6 +14126,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->fused_ops_write_mask |= 1 << 3; ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX_NORM; fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM"; + std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false); } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_sigmoid_norm_bias, { i + 4, i + 10 }) && ggml_check_edges(cgraph, i, topk_moe_sigmoid_norm_bias_edges) && ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_SIGMOID_NORM_BIAS)) { @@ -13987,6 +14135,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->fused_ops_write_mask |= 1 << 4; ctx->fused_topk_moe_mode = TOPK_MOE_SIGMOID_NORM_BIAS; fusion_string = "TOPK_MOE_SIGMOID_NORM_BIAS"; + std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false); } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) && ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) && ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) { @@ -13995,6 +14144,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->fused_ops_write_mask |= 1 << 3; ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX; fusion_string = "TOPK_MOE_EARLY_SOFTMAX"; + std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false); } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) && ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) && ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_LATE_SOFTMAX)) { @@ -14003,6 +14153,7 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ctx->fused_ops_write_mask |= 1 << 1; ctx->fused_topk_moe_mode = TOPK_MOE_LATE_SOFTMAX; fusion_string = "TOPK_MOE_LATE_SOFTMAX"; + std::fill_n(op_srcs_fused_elementwise, ctx->num_additional_fused_ops + 1, false); } if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) { // Look for an additional scale op to fuse - occurs in deepseek2 and nemotron3 nano. @@ -14010,11 +14161,73 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops, { GGML_OP_GET_ROWS, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 })) { ctx->fused_topk_moe_scale = true; ctx->num_additional_fused_ops++; + op_srcs_fused_elementwise[ctx->num_additional_fused_ops] = false; } } } + GGML_ASSERT(ctx->num_additional_fused_ops < (int)(sizeof(op_srcs_fused_elementwise) / sizeof(op_srcs_fused_elementwise[0]))); ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops; + // Check whether fusion would overwrite src operands while they're still in use. + // If so, disable fusion. + if (ctx->num_additional_fused_ops) { + // There are up to two output nodes - topk_moe has two. + uint32_t bits = ctx->fused_ops_write_mask & ~(1 << ctx->num_additional_fused_ops); + ggml_tensor *output_nodes[2] {}; + output_nodes[0] = cgraph->nodes[i + ctx->num_additional_fused_ops]; + if (bits) { + int output_idx = find_first_set(bits); + GGML_ASSERT(bits == (1u << output_idx)); + output_nodes[1] = cgraph->nodes[i + output_idx]; + } + + bool need_disable = false; + + // topk_moe often overwrites the source, but for a given row all the src values are + // loaded before anything is stored. If there's only one row, this is safe, so treat + // this as a special case. + bool is_topk_moe_single_row = ctx->fused_topk_moe_mode != TOPK_MOE_COUNT && + ggml_nrows(cgraph->nodes[i]->src[0]) == 1; + + if (!is_topk_moe_single_row) { + for (int j = 0; j < 2; ++j) { + ggml_tensor *dst = output_nodes[j]; + if (!dst) { + continue; + } + // Loop over all srcs of all nodes in the fusion. If the src overlaps + // the destination and the src is not an intermediate node that's being + // elided, then disable fusion. + for (int k = 0; k <= ctx->num_additional_fused_ops; ++k) { + for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) { + ggml_tensor *src = cgraph->nodes[i + k]->src[s]; + if (!src || src->op == GGML_OP_NONE) { + continue; + } + if (ggml_vk_tensors_overlap(src, dst, op_srcs_fused_elementwise[k])) { + bool found = false; + for (int n = 0; n < k; ++n) { + if (cgraph->nodes[i + n] == src) { + found = true; + break; + } + } + if (!found) { + need_disable = true; + } + } + } + } + } + } + if (need_disable) { + ctx->num_additional_fused_ops = 0; + ctx->fused_ops_write_mask = 1; + ctx->fused_topk_moe_mode = TOPK_MOE_COUNT; + ctx->fused_topk_moe_scale = false; + } + } + // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining) bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5; bool submit = (submitted_nodes >= nodes_per_submit) || @@ -15418,6 +15631,46 @@ static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDevicePrope } } +static uint32_t ggml_vk_intel_shader_core_count(const vk::PhysicalDevice& vkdev) { + VkPhysicalDeviceProperties2 props = vkdev.getProperties2(); + + if (props.properties.vendorID != VK_VENDOR_ID_INTEL) { + return 0; + } + + const uint32_t device_id = props.properties.deviceID; + + switch (device_id) { + case 0x56A6: // A310 + return 6; + case 0x5693: // A370M + case 0x56A5: // A380 + case 0x56B1: // Pro A40/A50 + return 8; + case 0x5697: // A530M + return 12; + case 0x5692: // A550M + case 0x56B3: // Pro A60 + return 16; + case 0x56A2: // A580 + return 24; + case 0x5691: // A730M + case 0x56A1: // A750 + return 28; + case 0x56A0: // A770 + case 0x5690: // A770M + return 32; + case 0xE212: // Pro B50 + return 16; + case 0xE20C: // B570 + return 18; + case 0xE20B: // B580 + return 20; + default: + return 0; + } +} + // checks #ifdef GGML_VULKAN_CHECK_RESULTS @@ -16094,7 +16347,7 @@ static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * ggml_vk_print_graph_origin(tensor, done); } - if (avg_err > 0.5 || std::isnan(avg_err)) { + if (avg_err > 0.01 || std::isnan(avg_err)) { std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl; std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl; if (src0 != nullptr) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 0735f67854..ec48f5b115 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -3,9 +3,13 @@ #extension GL_EXT_control_flow_attributes : enable #extension GL_EXT_shader_16bit_storage : require -#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#ifdef FLOAT16 +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_subgroup_extended_types_float16 : require +#endif + #extension GL_KHR_shader_subgroup_shuffle : enable #extension GL_KHR_shader_subgroup_vote : enable @@ -15,8 +19,10 @@ const uint32_t HSK_per_thread = HSK / D_split; const uint32_t HSV_per_thread = HSV / D_split; -const uint32_t cols_per_iter = WorkGroupSize / D_split; +const uint32_t rows_per_thread = Br / row_split; +const uint32_t cols_per_iter = WorkGroupSize / D_split / row_split; const uint32_t cols_per_thread = Bc / cols_per_iter; +const uint32_t num_subgroups = SubGroupSize == 0 ? 0 : WorkGroupSize / SubGroupSize; layout (binding = 0) readonly buffer Q {float data_q[];}; @@ -27,20 +33,22 @@ layout (binding = 2) readonly buffer V {float16_t data_v[];}; layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; layout (binding = 3) readonly buffer M {float16_t data_m[];}; -// Store the output when doing grouped query attention. -// Rows index by Q's dimension 2, and the first N rows are valid. -D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) -{ - uint32_t offset = (iq2 + r) * HSV + c; - data_o[o_offset + offset] = D_TYPE(elem); - return elem; -} +// If SubGroupSize is set to 0 then only use shmem reductions +const uint32_t tmpsh_size = (SubGroupSize > 0) ? (row_split == 1 ? num_subgroups * D_split : num_subgroups) : WorkGroupSize; +shared float tmpsh[tmpsh_size]; +shared FLOAT_TYPEV4 tmpshv4[tmpsh_size]; -shared FLOAT_TYPE tmpsh[WorkGroupSize]; -shared vec4 tmpshv4[WorkGroupSize]; +const uint32_t masksh_stride = Br + 1; +shared FLOAT_TYPE masksh[Bc * masksh_stride]; -shared float masksh[Bc][Br]; -shared vec4 Qf[Br][HSK / 4]; +const uint32_t qf_stride = HSK / 4 + 1; +shared FLOAT_TYPEV4 Qf[Br * qf_stride]; + +const uint32_t D = HSK > HSV ? HSK : HSV; +const uint32_t kvsh_stride = D / 4 + 1; +shared FLOAT_TYPEV4 kvsh[SHMEM_STAGING != 0 ? Bc * kvsh_stride : 1]; + +shared vec4 occupancy_limiter[LIMIT_OCCUPANCY_SHMEM > 0 ? LIMIT_OCCUPANCY_SHMEM : 1]; void main() { #ifdef NEEDS_INIT_IQ_SHMEM @@ -50,8 +58,24 @@ void main() { init_indices(); const uint32_t tid = gl_LocalInvocationIndex; + const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split; + const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup; + const uint32_t rowgroup_tid = gl_LocalInvocationIndex % threads_per_rowgroup; const uint32_t d_tid = gl_LocalInvocationIndex % D_split; - const uint32_t col_tid = gl_LocalInvocationIndex / D_split; + const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split; + + if (LIMIT_OCCUPANCY_SHMEM > 0) { + // This just exists to avoid the occupancy_limiter array getting optimized out + occupancy_limiter[tid] = vec4(tid); + + barrier(); + + if (occupancy_limiter[tid] == vec4(99999.0)) { + data_ov4[0] = D_TYPEV4(occupancy_limiter[tid]); + } + } + +#define tile_row(r) (row_tid * rows_per_thread + (r)) uint32_t q_offset = gqa_iq1*p.nb01 + (iq2*p.nb02 + iq3*p.nb03) / 4; @@ -60,37 +84,37 @@ void main() { uint32_t r = (idx + tid) / (HSK / 4); if (r < Br && d < HSK / 4 && i * Br + r < N) { - Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale; + Qf[r * qf_stride + d] = FLOAT_TYPEV4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); } } barrier(); - vec4 Of[Br][HSV_per_thread / 4]; + FLOAT_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4]; [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - Of[r][d] = vec4(0.0); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] = FLOAT_TYPEV4(0.0); } } - float Lf[Br], Mf[Br]; + float Lf[rows_per_thread], Mf[rows_per_thread]; // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M. const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Lf[r] = 0; Mf[r] = NEG_FLT_MAX_OVER_2; } - float slope[Br]; - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - slope[r] = 1.0; + ACC_TYPE slope[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + slope[r] = ACC_TYPE(1.0); } // ALiBi if (p.max_bias > 0.0f) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + slope[r] = perElemOpComputeSlope(tile_row(r), col_tid, ACC_TYPE(0), iq2); } } @@ -113,75 +137,141 @@ void main() { uint32_t mask_opt = 0; uint32_t mask_opt_idx = ~0; + uint32_t mask_opt_bits = 0; [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { + if (MASK_ENABLE) { + if (USE_MASK_OPT && mask_opt_idx != j / 16) { + mask_opt_idx = j / 16; + mask_opt = data_mask_opt[mo_offset + mask_opt_idx]; + } + mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3; + if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) { + // skip this block + continue; + } + // Only load if the block is not all zeros + if (mask_opt_bits != MASK_OPT_ALL_ZERO) { + bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; - if (USE_MASK_OPT && mask_opt_idx != j / 16) { - mask_opt_idx = j / 16; - mask_opt = data_mask_opt[mo_offset + mask_opt_idx]; + float max_mask = NEG_FLT_MAX_OVER_2; + barrier(); + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) % Bc; + uint32_t r = (idx + tid) / Bc; + if (idx + tid < Bc * Br) { + if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { + FLOAT_TYPE m = FLOAT_TYPE(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); + masksh[c * masksh_stride + r] = m; + max_mask = max(max_mask, float(m)); + } else { + masksh[c * masksh_stride + r] = FLOAT_TYPE(0); + } + } + } + // skip the block if the mask is entirely -inf + bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); + barrier(); + if (gl_SubgroupInvocationID == 0) { + tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; + } + barrier(); + [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { + max_mask = max(max_mask, tmpsh[s]); + } + if (max_mask <= NEG_FLT_MAX_OVER_2) { + continue; + } + } } - uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3; - if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) { - // skip this block - continue; - } - // Only load if the block is not all zeros - if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) { - bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; - float max_mask = NEG_FLT_MAX_OVER_2; - [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { - uint32_t c = (idx + tid) % Bc; - uint32_t r = (idx + tid) / Bc; - if (idx + tid < Bc * Br) { - if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { - float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); - masksh[c][r] = m; - max_mask = max(max_mask, m); + ACC_TYPE Sf[rows_per_thread][cols_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + Sf[r][c] = ACC_TYPE(0.0); + } + } + + if (SHMEM_STAGING != 0) { + barrier(); + [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSK / 4); + uint32_t c = (idx + tid) / (HSK / 4); + if (idx + gl_WorkGroupSize.x <= Bc * HSK / 4 || c < Bc) { + FLOAT_TYPEV4 K_Tf = FLOAT_TYPEV4(0); + if (!KV_bounds_check || j * Bc + c < KV) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); +#else + K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); +#endif + } + + kvsh[c * kvsh_stride + d] = K_Tf; + } + } + barrier(); + } + + // More d iterations means Q register caching becomes relevant + // Few iterations means the additional registers needed are worse than the speed-up from caching + if (HSK_per_thread / 4 > 4) { + [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) { + FLOAT_TYPEV4 Q_cache[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Q_cache[r] = Qf[tile_row(r) * qf_stride + d * D_split + d_tid]; + } + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } + + FLOAT_TYPEV4 K_Tf; + if (SHMEM_STAGING != 0) { + K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; } else { - masksh[c][r] = float(0); +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); +#else + K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); +#endif + } + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Sf[r][c] += ACC_TYPE(dot(Q_cache[r], K_Tf)); } } } - // skip the block if the mask is entirely -inf - bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); - barrier(); - if (gl_SubgroupInvocationID == 0) { - tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; - } - barrier(); - [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { - max_mask = max(max_mask, tmpsh[s]); - } - if (max_mask <= NEG_FLT_MAX_OVER_2) { - continue; - } - } - - float Sf[Br][cols_per_thread]; - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + } else { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - Sf[r][c] = 0.0; - } - } + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } - - [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { - continue; - } - [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) { + [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) { + FLOAT_TYPEV4 K_Tf; + if (SHMEM_STAGING != 0) { + K_Tf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; + } else { #if BLOCK_SIZE > 1 - uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - vec4 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); + uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); #else - vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); + K_Tf = FLOAT_TYPEV4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); #endif - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf); + } + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Sf[r][c] += ACC_TYPE(dot(Qf[tile_row(r) * qf_stride + d * D_split + d_tid], K_Tf)); + } } } } @@ -189,89 +279,109 @@ void main() { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { // Compute sum across the D_split [[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Sf[r][c] += subgroupShuffleXor(Sf[r][c], s); } } } if (LOGIT_SOFTCAP) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]); + Sf[r][c] = ACC_TYPE(p.logit_softcap * tanh(Sf[r][c])); } } } if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) { [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - float mvf = masksh[c * cols_per_iter + col_tid][r]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + FLOAT_TYPE mvf = masksh[(c * cols_per_iter + col_tid) * masksh_stride + tile_row(r)]; Sf[r][c] += slope[r]*mvf; } } - barrier(); } - float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br]; - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - rowmaxf[r] = NEG_FLT_MAX_OVER_2; + float eMf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + float rowmaxf = NEG_FLT_MAX_OVER_2; [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { continue; } - rowmaxf[r] = max(rowmaxf[r], Sf[r][c]); + rowmaxf = max(rowmaxf, float(Sf[r][c])); } - Moldf[r] = Mf[r]; + float Moldf = Mf[r]; // M = max(rowmax, Mold) // P = e^(S - M) // eM = e^(Mold - M) - Mf[r] = max(rowmaxf[r], Moldf[r]); - [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - Pf[r][c] = exp(Sf[r][c] - Mf[r]); - } - eMf[r] = exp(Moldf[r] - Mf[r]); - - // Compute sum across row of P - rowsumf[r] = 0.0; - [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { - if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { - continue; - } - rowsumf[r] += Pf[r][c]; - } - - Lf[r] = eMf[r]*Lf[r] + rowsumf[r]; + Mf[r] = max(rowmaxf, Moldf); + eMf[r] = exp(Moldf - Mf[r]); + Lf[r] = eMf[r]*Lf[r]; } [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - Of[r][d] = eMf[r] * Of[r][d]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] = FLOAT_TYPE(eMf[r]) * Of[r][d]; } } + if (SHMEM_STAGING != 0) { + barrier(); + [[unroll]] for (uint32_t idx = 0; idx < Bc * HSV / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSV / 4); + uint32_t c = (idx + tid) / (HSV / 4); + if (idx + gl_WorkGroupSize.x <= Bc * HSV / 4 || c < Bc) { + FLOAT_TYPEV4 V_Tf = FLOAT_TYPEV4(0); + if (!KV_bounds_check || j * Bc + c < KV) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d; + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); +#else + V_Tf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); +#endif + } + + kvsh[c * kvsh_stride + d] = V_Tf; + } + } + barrier(); + } + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { continue; } + + FLOAT_TYPE Pf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Pf[r] = FLOAT_TYPE(exp(float(Sf[r][c]) - Mf[r])); + Lf[r] += Pf[r]; + } + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + FLOAT_TYPEV4 Vf; + if (SHMEM_STAGING != 0) { + Vf = kvsh[(c * cols_per_iter + col_tid) * kvsh_stride + (d * D_split + d_tid)]; + } else { #if BLOCK_SIZE > 1 - uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); - uint ib = coord / BLOCK_SIZE; - uint iqs = (coord % BLOCK_SIZE); - vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); + uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); #else - vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); + Vf = FLOAT_TYPEV4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); #endif - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - Of[r][d] += Pf[r][c] * Vf; + } + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] += FLOAT_TYPEV4(Pf[r] * Vf); } } } - - barrier(); } // prevent race on tmpsh @@ -279,58 +389,115 @@ void main() { // reduce across threads - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - float rowmaxf, eMf; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + float rowmaxf = Mf[r]; - tmpsh[tid] = Mf[r]; // Compute max across the row - barrier(); - [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { - if (tid < s) { - tmpsh[tid] = max(tmpsh[tid], tmpsh[tid + s]); + if (SubGroupSize > 0) { + [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) { + rowmaxf = max(rowmaxf, subgroupShuffleXor(rowmaxf, s)); } + if (row_split == 1) { + // Reduce inside workgroup with shmem + barrier(); + if (gl_SubgroupInvocationID == d_tid) { + tmpsh[gl_SubgroupID * D_split + d_tid] = rowmaxf; + } + barrier(); + rowmaxf = tmpsh[d_tid]; + [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) { + rowmaxf = max(rowmaxf, tmpsh[s * D_split + d_tid]); + } + } + } else { barrier(); + tmpsh[tid] = rowmaxf; + barrier(); + [[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) { + if (rowgroup_tid < s) { + tmpsh[tid] = max(tmpsh[tid], tmpsh[tid ^ s]); + } + barrier(); + } + rowmaxf = tmpsh[row_tid * threads_per_rowgroup + d_tid]; } - rowmaxf = tmpsh[d_tid]; - barrier(); float Moldf = Mf[r]; // M = max(rowmax, Mold) // eM = e^(Mold - M) Mf[r] = max(rowmaxf, Moldf); - eMf = exp(Moldf - Mf[r]); + float eMf = exp(Moldf - Mf[r]); Lf[r] = eMf*Lf[r]; - tmpsh[tid] = Lf[r]; - // Compute sum across the row - barrier(); - [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { - if (tid < s) { - tmpsh[tid] = tmpsh[tid] + tmpsh[tid + s]; + if (SubGroupSize > 0) { + [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) { + Lf[r] += subgroupShuffleXor(Lf[r], s); } + if (row_split == 1) { + barrier(); + if (gl_SubgroupInvocationID == d_tid) { + tmpsh[gl_SubgroupID * D_split + d_tid] = Lf[r]; + } + barrier(); + Lf[r] = tmpsh[d_tid]; + [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) { + Lf[r] += tmpsh[s * D_split + d_tid]; + } + } + } else { barrier(); - } - Lf[r] = tmpsh[d_tid]; - barrier(); - - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - - Of[r][d] = eMf * Of[r][d]; - tmpshv4[tid] = Of[r][d]; - + tmpsh[tid] = Lf[r]; barrier(); - [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { - if (tid < s) { - Of[r][d] += tmpshv4[tid + s]; - tmpshv4[tid] = Of[r][d]; + [[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) { + if (rowgroup_tid < s) { + tmpsh[tid] = tmpsh[tid] + tmpsh[tid ^ s]; } barrier(); } - Of[r][d] = tmpshv4[d_tid]; - barrier(); + Lf[r] = tmpsh[row_tid * threads_per_rowgroup + d_tid]; + } + + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + Of[r][d] = FLOAT_TYPE(eMf) * Of[r][d]; + + if (SubGroupSize > 0) { + [[unroll]] for (uint s = D_split; s < SubGroupSize; s *= 2) { + if (!OLD_AMD_WINDOWS) { + Of[r][d] += subgroupShuffleXor(Of[r][d], s); + } else { + // Something about f16vec4 subgroupShuffleXor is broken on AMD Windows RDNA2 and below. + // Shuffle full vec4 as workaround. + // See https://github.com/ggml-org/llama.cpp/issues/19881#issuecomment-3958643697 + Of[r][d] += FLOAT_TYPEV4(subgroupShuffleXor(vec4(Of[r][d]), s)); + } + } + if (row_split == 1) { + barrier(); + if (gl_SubgroupInvocationID == d_tid) { + tmpshv4[gl_SubgroupID * D_split + d_tid] = Of[r][d]; + } + barrier(); + Of[r][d] = tmpshv4[d_tid]; + [[unroll]] for (uint32_t s = 1; s < num_subgroups; ++s) { + Of[r][d] += tmpshv4[s * D_split + d_tid]; + } + } + } else { + barrier(); + tmpshv4[tid] = Of[r][d]; + barrier(); + [[unroll]] for (int s = int(threads_per_rowgroup) / 2; s >= D_split; s >>= 1) { + if (rowgroup_tid < s) { + Of[r][d] += tmpshv4[tid ^ s]; + tmpshv4[tid] = Of[r][d]; + } + barrier(); + } + Of[r][d] = tmpshv4[row_tid * threads_per_rowgroup + d_tid]; + } } } @@ -338,33 +505,53 @@ void main() { // If there is split_k, then the split_k resolve shader does the final // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { - // note: O and Q have swapped coord 1,2. - uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); + if (p.gqa_ratio > 1) { + // note: O and Q have swapped coord 1,2. + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4; - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - if (r < N) { - [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + if (row < N) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + gqaStore(row, d * D_split + d_tid, Of[r][d], o_offset, iq2, N); } } } - } - o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - if (r < N) { - perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); - perElemOpStoreCol0(r, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + if (row < N) { + perElemOpStoreCol0(row, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); + perElemOpStoreCol0(row, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + } + } + } else { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + const uint global_row = i * Br + row; + + if (global_row < N) { + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)) / 4; + + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + data_ov4[o_offset + iq2 * HSV/4 + d * D_split + d_tid] = D_TYPEV4(Of[r][d]); + } + } + + if (global_row < N && d_tid == 0 && col_tid == 0) { + uint32_t lm_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)); + data_o[lm_offset + iq2] = D_TYPE(Lf[r]); + data_o[lm_offset + p.ne1 + iq2] = D_TYPE(Mf[r]); + } } } - return; } if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2); float ms = 1.0f; float vs = 1.0f; @@ -373,7 +560,7 @@ void main() { ms = exp(Mf[r] - sink); [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - Of[r][d] *= ms; + Of[r][d] *= FLOAT_TYPE(ms); } } else { vs = exp(sink - Mf[r]); @@ -383,39 +570,37 @@ void main() { } } - float Lfrcp[Br]; - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + float Lfrcp[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { Lfrcp[r] = (Lf[r] == 0.0) ? 0.0 : (1.0 / Lf[r]); } [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - Of[r][d] *= Lfrcp[r]; -#if defined(ACC_TYPE_MAX) - Of[r][d] = clamp(Of[r][d], -vec4(ACC_TYPE_MAX), vec4(ACC_TYPE_MAX)); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] *= FLOAT_TYPE(Lfrcp[r]); +#if defined(FLOAT_TYPE_MAX) + Of[r][d] = clamp(Of[r][d], -FLOAT_TYPE_MAX, FLOAT_TYPE_MAX); #endif } } - uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV; + uint32_t o_offset = (gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV) / 4; if (p.gqa_ratio > 1) { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - if (r < N) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + if (row < N) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); - } + gqaStore(row, d * D_split + d_tid, Of[r][d], o_offset, iq2, N); } } } } else { - [[unroll]] for (uint32_t r = 0; r < Br; ++r) { - if (i * Br + r < N) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + if (i * Br + row < N) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { - [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); - } + data_ov4[o_offset + (iq2 * HSV + (i * Br + row) * p.ne1 * HSV) / 4 + d * D_split + d_tid] = D_TYPEV4(Of[r][d]); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl index 4142c1e6ea..172d38f034 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -1,20 +1,23 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; -layout (constant_id = 0) const uint32_t WorkGroupSize = 128; -layout (constant_id = 1) const uint32_t Br = 1; -layout (constant_id = 2) const uint32_t Bc = 32; -layout (constant_id = 3) const uint32_t HSK = 32; -layout (constant_id = 4) const uint32_t HSV = 32; -layout (constant_id = 5) const uint32_t Clamp = 0; -layout (constant_id = 6) const uint32_t D_split = 16; -layout (constant_id = 7) const uint32_t SubGroupSize = 32; -layout (constant_id = 8) const uint32_t K_LOAD_SHMEM = 0; -layout (constant_id = 9) const uint32_t Flags = 0; +layout (constant_id = 0) const uint32_t WorkGroupSize = 128; +layout (constant_id = 1) const uint32_t Br = 1; +layout (constant_id = 2) const uint32_t Bc = 32; +layout (constant_id = 3) const uint32_t HSK = 32; +layout (constant_id = 4) const uint32_t HSV = 32; +layout (constant_id = 5) const uint32_t Clamp = 0; +layout (constant_id = 6) const uint32_t D_split = 16; +layout (constant_id = 7) const uint32_t row_split = 1; +layout (constant_id = 8) const uint32_t SubGroupSize = 32; +layout (constant_id = 9) const uint32_t SHMEM_STAGING = 0; +layout (constant_id = 10) const uint32_t Flags = 0; +layout (constant_id = 11) const uint32_t LIMIT_OCCUPANCY_SHMEM = 0; -const bool USE_MASK_OPT = (Flags & 1) != 0; -const bool MASK_ENABLE = (Flags & 2) != 0; -const bool LOGIT_SOFTCAP = (Flags & 4) != 0; +const bool USE_MASK_OPT = (Flags & 1) != 0; +const bool MASK_ENABLE = (Flags & 2) != 0; +const bool LOGIT_SOFTCAP = (Flags & 4) != 0; +const bool OLD_AMD_WINDOWS = (Flags & 8) != 0; // Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths const uint32_t HSK_pad = (HSK + 15) & ~15; @@ -69,6 +72,7 @@ layout (push_constant) uniform parameter { layout (binding = 4) readonly buffer S {float data_s[];}; layout (binding = 5) writeonly buffer O {D_TYPE data_o[];}; +layout (binding = 5) writeonly buffer OV4 {D_TYPEV4 data_ov4[];}; layout (binding = 6) readonly buffer MO {uint32_t data_mask_opt[];}; @@ -94,12 +98,12 @@ layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16 #define BLOCK_SIZE 4 #define BLOCK_BYTE_SIZE 16 -vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { +FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { // iqs is currently always zero in the flash attention shaders if (binding_idx == BINDING_IDX_K) { - return k_packed.k_data_packed[a_offset + ib]; + return FLOAT_TYPEV4(k_packed.k_data_packed[a_offset + ib]); } else { - return v_packed.v_data_packed[a_offset + ib]; + return FLOAT_TYPEV4(v_packed.v_data_packed[a_offset + ib]); } } #endif @@ -107,7 +111,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { #if defined(DATA_A_Q4_0) #define BLOCK_BYTE_SIZE 18 -vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { +FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { if (binding_idx == BINDING_IDX_K) { uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); @@ -115,7 +119,7 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { vui_lo >>= shift; vui_hi >>= shift; - return float(k_packed.k_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); + return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f)); } else { uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); @@ -123,24 +127,24 @@ vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { vui_lo >>= shift; vui_hi >>= shift; - return float(v_packed.v_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); + return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * (FLOAT_TYPEV4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - FLOAT_TYPE(8.0f)); } } #endif #if defined(DATA_A_Q8_0) #define BLOCK_BYTE_SIZE 34 -vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { +FLOAT_TYPEV4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { if (binding_idx == BINDING_IDX_K) { const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; - return float(k_packed.k_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); + return FLOAT_TYPE(k_packed.k_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y); } else { const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; - return float(v_packed.v_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); + return FLOAT_TYPE(v_packed.v_data_packed16[a_offset + ib].d) * FLOAT_TYPEV4(v0.x, v0.y, v1.x, v1.y); } } #endif @@ -189,10 +193,16 @@ void init_indices() KV = p.KV; if (p.k_num > 1) { - i = 0; - // batch and split_k share gl_WorkGroupID.x - gqa_iq1 = gl_WorkGroupID.x / p.k_num; - split_k_index = gl_WorkGroupID.x % p.k_num; + if (p.gqa_ratio > 1) { + i = 0; + // batch and split_k share gl_WorkGroupID.x + gqa_iq1 = gl_WorkGroupID.x / p.k_num; + split_k_index = gl_WorkGroupID.x % p.k_num; + } else { + gqa_iq1 = 0; + split_k_index = gl_WorkGroupID.x % p.k_num; + i = gl_WorkGroupID.x / p.k_num; + } } else if (p.gqa_ratio > 1) { i = 0; gqa_iq1 = gl_WorkGroupID.x; @@ -244,3 +254,11 @@ void init_indices() // Bias applied to softmax to stay in fp16 range. // Based on ggml-cuda issue https://github.com/ggml-org/llama.cpp/issues/18606 const float FATTN_KQ_MAX_OFFSET = 3.0f*0.6931f; + +// Store the output when doing grouped query attention. +// Rows index by Q's dimension 2, and the first N rows are valid. +void gqaStore(const in uint32_t r, const in uint32_t c, const in FLOAT_TYPEV4 elems, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + uint32_t offset = (iq2 + r) * HSV / 4 + c; + data_ov4[o_offset + offset] = D_TYPEV4(elems); +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index 19630972da..526e8da384 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -19,7 +19,6 @@ const uint32_t MatBr = 16; const uint32_t MatBc = 16; -const uint32_t row_split = Bc / MatBc; const uint32_t rows_per_thread = Br / row_split; const uint32_t cols_per_iter = gl_WorkGroupSize.x / row_split; const uint32_t cols_per_thread = Bc / cols_per_iter; @@ -33,15 +32,6 @@ layout (binding = 2) readonly buffer V {float16_t data_v[];}; layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; layout (binding = 3) readonly buffer M {float16_t data_m[];}; -// Store the output when doing grouped query attention. -// Rows index by Q's dimension 2, and the first N rows are valid. -D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) -{ - uint32_t offset = (iq2 + r) * HSV + c; - data_o[o_offset + offset] = D_TYPE(elem); - return elem; -} - shared float tmpsh[row_split]; const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4 @@ -54,10 +44,14 @@ shared f16vec4 Psh[Bc * psh_stride]; const uint32_t sfshstride = (HSK <= 128) ? (Br / 4 + 2) : Br / 4; shared ACC_TYPEV4 sfsh[Bc * sfshstride]; -const uint32_t kshstride = (K_LOAD_SHMEM != 0 ? HSK_pad : MatBr) / 4 + 2; // in units of f16vec4 +const uint32_t D_pad = HSK_pad > HSV_pad ? HSK_pad : HSV_pad; +const uint32_t kvsh_stride = (SHMEM_STAGING != 0 ? D_pad : MatBr) / 4 + 2; // in units of f16vec4 const uint v_cols = MatBc / 4 * row_split; // total cols, 4 vec4s per MatBc * number of subgroups const uint vsh_stride = v_cols; -shared f16vec4 ksh[(kshstride >= vsh_stride) ? (Bc * kshstride) : (Bc * vsh_stride)]; +shared f16vec4 kvsh[(kvsh_stride >= vsh_stride) ? (Bc * kvsh_stride) : (Bc * vsh_stride)]; + +const uint32_t osh_stride = row_split * MatBr / 4; +shared f16vec4 pvsh[MatBc * osh_stride]; shared ACC_TYPE slope[Br]; @@ -84,11 +78,6 @@ void main() { Qf[i + tid] = f16vec4(0); } } - [[unroll]] for (uint i = 0; i < Bc * kshstride; i += gl_WorkGroupSize.x) { - if (i + tid < Bc * kshstride) { - ksh[i + tid] = f16vec4(0); - } - } barrier(); } @@ -104,10 +93,10 @@ void main() { } barrier(); - ACC_TYPEV4 Of[rows_per_thread][d_per_thread]; + f16vec4 Of[rows_per_thread][d_per_thread]; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t d = 0; d < d_per_thread; ++d) { - Of[r][d] = ACC_TYPEV4(0.0); + Of[r][d] = f16vec4(0.0); } } @@ -153,22 +142,22 @@ void main() { uint32_t mask_opt = 0; uint32_t mask_opt_idx = ~0; + uint32_t mask_opt_bits = 0; + f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize]; [[dont_unroll]] for (uint32_t j = start_j; j < end_j; ++j) { - f16vec4 mask_cache[Bc * Br / 4 / WorkGroupSize]; [[unroll]] for (uint32_t idx = 0; idx < mask_cache.length(); ++idx) { mask_cache[idx] = f16vec4(0); } if (MASK_ENABLE) { - if (USE_MASK_OPT && mask_opt_idx != j / 16) { mask_opt_idx = j / 16; mask_opt = data_mask_opt[mo_offset + mask_opt_idx]; } - uint32_t mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3; + mask_opt_bits = (mask_opt >> ((j % 16) * 2)) & 0x3; if (mask_opt_bits == MASK_OPT_ALL_NEG_INF) { // skip this block continue; @@ -231,24 +220,24 @@ void main() { } } - if (K_LOAD_SHMEM != 0) { - [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) { - uint32_t d = (idx + tid) % (HSK / 4); - uint32_t c = (idx + tid) / (HSK / 4); - if (c < Bc && d < HSK / 4) { + if (SHMEM_STAGING != 0) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK_pad / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSK_pad / 4); + uint32_t c = (idx + tid) / (HSK_pad / 4); + if (idx + gl_WorkGroupSize.x <= Bc * HSK_pad / 4 || c < Bc) { f16vec4 K_Tf = f16vec4(0); - if (!KV_bounds_check || j * Bc + c < KV) { + if ((!KV_bounds_check || j * Bc + c < KV) && (HSK == HSK_pad || d < HSK / 4)) { #if BLOCK_SIZE > 1 uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; uint ib = coord / BLOCK_SIZE; uint iqs = (coord % BLOCK_SIZE); - K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); #else K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); #endif } - ksh[c * kshstride + d] = K_Tf; + kvsh[c * kvsh_stride + d] = K_Tf; } } barrier(); @@ -262,7 +251,11 @@ void main() { coopmat QMat; [[unroll]] for (uint32_t d = 0; d < HSK_pad / 16; ++d) { - if (K_LOAD_SHMEM == 0) { + // If SHMEM_STAGING is set, a Bc * HSK_pad size tile of K is loaded to shmem + // If not, f16 K is loaded directly from global memory if aligned, otherwise + // staged through a Bc * MatBr size staging buffer. + // If K is not type f16, then it is always staged for dequantization. + if (SHMEM_STAGING == 0) { #if BLOCK_SIZE == 1 if (KV_bounds_check || d * 16 + 16 > HSK) { #endif @@ -277,13 +270,13 @@ void main() { uint coord = (j * Bc + row) * k_stride * BLOCK_SIZE + d * 16 + col_vec * 4; uint ib = coord / BLOCK_SIZE; uint iqs = (coord % BLOCK_SIZE); - K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); + K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); #else K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + row) * k_stride / 4 + d * 16 / 4 + col_vec]); #endif } - ksh[row * kshstride + col_vec] = K_Tf; + kvsh[row * kvsh_stride + col_vec] = K_Tf; } } barrier(); @@ -295,8 +288,8 @@ void main() { if (KV_bounds_check || d * 16 + 16 > HSK) #endif { - uint coord = (gl_SubgroupID * MatBc) * kshstride; - coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor); + uint coord = (gl_SubgroupID * MatBc) * kvsh_stride; + coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor); } #if BLOCK_SIZE == 1 else { @@ -305,8 +298,8 @@ void main() { } #endif } else { - uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4; - coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor); + uint coord = (gl_SubgroupID * MatBc) * kvsh_stride + d * 16 / 4; + coopMatLoad(KMat, kvsh, coord, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor); } coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor); @@ -329,7 +322,7 @@ void main() { barrier(); } - if (MASK_ENABLE) { + if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) { [[unroll]] for (uint32_t idx = 0; idx < Bc * Br / 4; idx += gl_WorkGroupSize.x) { uint32_t c = (idx + tid) / (Br / 4); uint32_t r = (idx + tid) % (Br / 4); @@ -374,7 +367,7 @@ void main() { [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { const uint d_local = d0 / threads_per_rowgroup; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d_local] = ACC_TYPE(eMf[r]) * Of[r][d_local]; + Of[r][d_local] = float16_t(eMf[r]) * Of[r][d_local]; } } @@ -397,19 +390,47 @@ void main() { } } + if (SHMEM_STAGING != 0) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * HSV_pad / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSV_pad / 4); + uint32_t c = (idx + tid) / (HSV_pad / 4); + if (idx + gl_WorkGroupSize.x <= Bc * HSV_pad / 4 || c < Bc) { + f16vec4 V_Tf = f16vec4(0); + if ((!KV_bounds_check || j * Bc + c < KV) && (HSV == HSV_pad || d < HSV / 4)) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c) * v_stride * BLOCK_SIZE + 4 * d; + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + V_Tf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); +#else + V_Tf = f16vec4(data_vv4[v_offset / 4 + (j * Bc + c) * v_stride / 4 + d]); +#endif + } + + kvsh[c * kvsh_stride + d] = V_Tf; + } + } + } + barrier(); + const uint num_hsv_tiles = (HSV + MatBc * row_split - 1) / (MatBc * row_split); // round up // Each subgroup handles HSV/4 columns [[unroll]] for (uint32_t hsv_tile = 0; hsv_tile < num_hsv_tiles; ++hsv_tile) { const uint hsv_offset = (hsv_tile * row_split + gl_SubgroupID) * 16; - SfMat = coopmat(0); + coopmat PVMat = coopmat(0); // Preload V tiles for [Bc, 16 * num subgroups] const uint v_rows = Bc; const uint v_total = v_rows * v_cols; const uint v_loads_per_thread = v_total / gl_WorkGroupSize.x; + // If SHMEM_STAGING is set, a Bc * HSV_pad size tile of V is loaded to shmem. + // If not, f16 V is loaded directly from global memory if aligned, otherwise + // staged through a Bc * MatBr size staging buffer. + // If V is not type f16, then it is always staged for dequantization. + if (SHMEM_STAGING == 0) { #if BLOCK_SIZE == 1 // For f16, only preload if not aligned if (KV_bounds_check) { @@ -428,44 +449,52 @@ void main() { if (!KV_bounds_check || (v_row < KV && v_col < HSV)) { #if BLOCK_SIZE > 1 - ksh[row * vsh_stride + col] = f16vec4(dequantize4(ib, iqs, v_offset, BINDING_IDX_V)); + kvsh[row * vsh_stride + col] = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); #else - ksh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4]; + kvsh[row * vsh_stride + col] = data_vv4[(v_offset + v_row * v_stride + v_col) / 4]; #endif } else { - ksh[row * vsh_stride + col] = f16vec4(0.0f); + kvsh[row * vsh_stride + col] = f16vec4(0.0f); } } + #if BLOCK_SIZE == 1 } #endif - + } barrier(); - [[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) { - coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor); + const uint o_offset = gl_SubgroupID * MatBr / 4; + if (hsv_offset < HSV_pad) { + [[unroll]] for (uint32_t bc_chunk = 0; bc_chunk < Bc / MatBc; ++bc_chunk) { + coopMatLoad(KMat, Psh, bc_chunk * MatBc * psh_stride, psh_stride, gl_CooperativeMatrixLayoutColumnMajor); + + if (SHMEM_STAGING == 0) { #if BLOCK_SIZE == 1 - if (!KV_bounds_check) { - // F16 values can be loaded directly from global memory - const uint v_tile_row = j * Bc + bc_chunk * MatBc; - const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4; - coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor); - } else + if (!KV_bounds_check) { + // F16 values can be loaded directly from global memory + const uint v_tile_row = j * Bc + bc_chunk * MatBc; + const uint v_tile_offset = v_offset / 4 + v_tile_row * v_stride / 4 + hsv_offset / 4; + coopMatLoad(QMat, data_vv4, v_tile_offset, v_stride / 4, gl_CooperativeMatrixLayoutRowMajor); + } else #endif - { - const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4); - coopMatLoad(QMat, ksh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor); + { + const uint v_tile_offset = bc_chunk * MatBr * v_cols + gl_SubgroupID * (MatBc / 4); + coopMatLoad(QMat, kvsh, v_tile_offset, vsh_stride, gl_CooperativeMatrixLayoutRowMajor); + } + } else { + const uint v_tile_offset = bc_chunk * MatBc * kvsh_stride + (hsv_tile * row_split + gl_SubgroupID) * (MatBc / 4); + coopMatLoad(QMat, kvsh, v_tile_offset, kvsh_stride, gl_CooperativeMatrixLayoutRowMajor); + } + + PVMat = coopMatMulAdd(KMat, QMat, PVMat); } - SfMat = coopMatMulAdd(KMat, QMat, SfMat); + // Store PVMat to pvsh and load into Of + coopMatStore(PVMat, pvsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor); } - // Store SfMat to sfsh and load into Of - const uint osh_stride = row_split * MatBc / 4; - const uint o_offset = gl_SubgroupID * MatBc / 4; - coopMatStore(SfMat, sfsh, o_offset, osh_stride, gl_CooperativeMatrixLayoutRowMajor); - barrier(); const uint hsv_per_tile = row_split * MatBc; @@ -484,7 +513,7 @@ void main() { if (hsv_col >= hsv_base && hsv_col < hsv_base + hsv_per_tile && hsv_col < HSV) { const uint local_hsv = (hsv_col - hsv_base) / 4; - Of[r][d_local] += ACC_TYPEV4(sfsh[row * osh_stride + local_hsv]); + Of[r][d_local] += pvsh[row * osh_stride + local_hsv]; } } } @@ -500,27 +529,48 @@ void main() { // If there is split_k, then the split_k resolve shader does the final // division by L. Store the intermediate O value and per-row m and L values. if (p.k_num > 1) { - // note: O and Q have swapped coord 1,2. - uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); + if (p.gqa_ratio > 1) { + // note: O and Q have swapped coord 1,2. + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)) / 4; - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - if (tile_row(r) < N) { - [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { - const uint d = d0 + col_tid; - if (d >= HSV/4) break; - const uint d_local = d0 / threads_per_rowgroup; - [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - perElemOpGqaStore(tile_row(r), 4 * d + comp, float(Of[r][d_local][comp]), o_offset, iq2, N); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { + const uint d = d0 + col_tid; + if (d >= HSV/4) break; + const uint d_local = d0 / threads_per_rowgroup; + gqaStore(tile_row(r), d, Of[r][d_local], o_offset, iq2, N); } } } - } - o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); - [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - if (tile_row(r) < N) { - perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); - perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); + perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + } + } + } else { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + const uint row = tile_row(r); + const uint global_row = i * Br + row; + + if (global_row < N) { + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)) / 4; + + [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { + const uint d = d0 + col_tid; + if (d >= HSV/4) break; + data_ov4[o_offset + iq2 * HSV/4 + d] = D_TYPEV4(Of[r][d/threads_per_rowgroup]); + } + } + + if (global_row < N && col_tid == 0) { + uint32_t lm_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)); + data_o[lm_offset + iq2] = D_TYPE(Lf[r]); + data_o[lm_offset + p.ne1 + iq2] = D_TYPE(Mf[r]); + } } } @@ -539,7 +589,7 @@ void main() { [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { const uint d_local = d0 / threads_per_rowgroup; - Of[r][d_local] *= ACC_TYPE(ms); + Of[r][d_local] *= float16_t(ms); } } else { vs = exp(sink - Mf[r]); @@ -557,14 +607,14 @@ void main() { [[unroll]] for (uint32_t d0 = 0; d0 < HSV / 4; d0 += threads_per_rowgroup) { const uint d_local = d0 / threads_per_rowgroup; [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { - Of[r][d_local] *= ACC_TYPE(Lfrcp[r]); -#if defined(ACC_TYPE_MAX) - Of[r][d_local] = clamp(Of[r][d_local], -ACC_TYPE_MAX, ACC_TYPE_MAX); + Of[r][d_local] *= float16_t(Lfrcp[r]); +#if defined(FLOAT_TYPE_MAX) + Of[r][d_local] = clamp(Of[r][d_local], -FLOAT_TYPE_MAX, FLOAT_TYPE_MAX); #endif } } - uint32_t o_offset = gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV; + uint32_t o_offset = (gqa_iq1*p.ne1*HSV + iq3*p.ne2*p.ne1*HSV) / 4; if (p.gqa_ratio > 1) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { @@ -573,9 +623,7 @@ void main() { const uint d = d0 + col_tid; if (d >= HSV / 4) break; const uint d_local = d0 / threads_per_rowgroup; - [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - perElemOpGqaStore(tile_row(r), 4 * d + comp, float(Of[r][d_local][comp]), o_offset, iq2, N); - } + gqaStore(tile_row(r), d, Of[r][d_local], o_offset, iq2, N); } } } @@ -586,9 +634,7 @@ void main() { const uint d = d0 + col_tid; if (d >= HSV / 4) break; const uint d_local = d0 / threads_per_rowgroup; - [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { - data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4 * d + comp] = D_TYPE(Of[r][d_local][comp]); - } + data_ov4[o_offset + (iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV) / 4 + d] = D_TYPEV4(Of[r][d_local]); } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 853f17fa16..0ea181342c 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -72,6 +72,28 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY return elem; } +// Store O values for non-GQA split_k. Rows are tokens, not heads. +D_TYPE perElemOpNonGqaSplitKStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t unused, const in uint32_t iq2, const in uint32_t N) { + uint32_t global_row = i * Br + r; + if (global_row < N && c < HSV) { + uint32_t o_off = HSV * p.ne1 + * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)); + data_o[o_off + iq2 * HSV + c] = D_TYPE(elem); + } + return elem; +} + +// Store L/M values for non-GQA split_k. +ACC_TYPE perElemOpNonGqaSplitKStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t lm_base, const in uint32_t iq2, const in uint32_t N) { + uint32_t global_row = i * Br + r; + if (global_row < N && c == 0) { + uint32_t lm_off = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + + p.ne1 * 2 * (split_k_index + p.k_num * (global_row + p.ne2 * iq3)); + data_o[lm_off + lm_base + iq2] = D_TYPE(elem); + } + return elem; +} + void main() { #ifdef NEEDS_INIT_IQ_SHMEM init_iq_shmem(gl_WorkGroupSize); @@ -290,13 +312,19 @@ void main() { if (p.k_num > 1) { coopmat O_D = coopmat(O); - // note: O and Q have swapped coord 1,2. - uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); - coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); + if (p.gqa_ratio > 1) { + // note: O and Q have swapped coord 1,2. + uint32_t o_offset = HSV * p.ne1 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); + coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); - o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); - coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N); - coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N); + o_offset = HSV * p.ne1 * p.k_num * p.ne2 * p.ne3 + p.ne1 * 2 * (split_k_index + p.k_num * (gqa_iq1 + p.ne2 * iq3)); + coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N); + coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N); + } else { + coopMatPerElementNV(O_D, O_D, perElemOpNonGqaSplitKStore, 0u, iq2, N); + coopMatPerElementNV(L, L, perElemOpNonGqaSplitKStoreCol0, 0u, iq2, N); + coopMatPerElementNV(M, M, perElemOpNonGqaSplitKStoreCol0, p.ne1, iq2, N); + } return; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp index 717d124e01..497a18ff8a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -167,7 +167,9 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) { uint id = ids[iter++]; uvec4 ballot = subgroupBallot(in_range && id == expert_idx); - ballots_sh[gl_SubgroupID] = ballot; + if (gl_SubgroupInvocationID == 0) { + ballots_sh[gl_SubgroupID] = ballot; + } barrier(); uint subgroup_base = 0; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl index 743004ff8a..26c5c12a49 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_id_funcs.glsl @@ -43,7 +43,9 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) { uint id = ids[iter++]; uvec4 ballot = subgroupBallot(in_range && id == expert_idx); - ballots_sh[gl_SubgroupID] = ballot; + if (gl_SubgroupInvocationID == 0) { + ballots_sh[gl_SubgroupID] = ballot; + } barrier(); uint subgroup_base = 0; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 42ebc21e2a..85455988c5 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -595,8 +595,6 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c } void process_shaders() { - std::map base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}}; - // matmul for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) { // No coopmats @@ -622,49 +620,63 @@ void process_shaders() { } } - // flash attention - for (const auto& f16acc : {false, true}) { - std::map fa_base_dict = base_dict; - fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; - fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4"; - if (f16acc) { - fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)"; + for (const bool& fp16 : {false, true}) { + std::map base_dict; + if (fp16) { + base_dict = {{"FLOAT_TYPE", "float16_t"}, {"FLOAT_TYPEV4", "f16vec4"}, {"FLOAT16", "1"}, {"FLOAT_TYPE_MAX", "float16_t(65504.0)"}}; + } else { + base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPEV4", "vec4"}}; } - for (const auto& tname : type_names) { - if (tname == "bf16") continue; - -#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) - if (tname == "f16") { - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", - merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, true, f16acc); - } else { - std::string data_a_key = "DATA_A_" + to_uppercase(tname); - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", - merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); + // flash attention + for (const bool& f16acc : {false, true}) { + std::map fa_base_dict = base_dict; + fa_base_dict["ACC_TYPE"] = fp16 && f16acc ? "float16_t" : "float"; + fa_base_dict["ACC_TYPEV4"] = fp16 && f16acc ? "f16vec4" : "vec4"; + if (fp16 && f16acc) { + fa_base_dict["ACC_TYPE_MAX"] = "float16_t(65504.0)"; } + + for (const auto& tname : type_names) { + if (tname == "bf16") continue; + + if (fp16) { +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (tname == "f16") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, true, f16acc); + } else { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", + merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, true, f16acc); + } #endif #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) - if (tname == "f16") { - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", - merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc); - } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") { - std::string data_a_key = "DATA_A_" + to_uppercase(tname); - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", - merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc); - } + if (tname == "f16") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"COOPMAT", "1"}}), fp16, true, false, f16acc); + } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", + merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), fp16, true, false, f16acc); + } #endif - if (tname == "f16") { - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", - merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc); - } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") { - std::string data_a_key = "DATA_A_" + to_uppercase(tname); - string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", - merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc); + } + + if (tname == "f16") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}}), fp16, false, false, f16acc); + } else if (tname == "q4_0" || tname == "q8_0" || tname == "f32") { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", + merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"D_TYPEV4", "vec4"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), fp16, false, false, f16acc); + } } } } + std::map base_dict = {{"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}}; + for (const auto& tname : type_names) { // mul mat vec std::string data_a_key = "DATA_A_" + to_uppercase(tname); diff --git a/ggml/src/ggml-zendnn/CMakeLists.txt b/ggml/src/ggml-zendnn/CMakeLists.txt index f5cf6eedd3..9bdb4e836d 100644 --- a/ggml/src/ggml-zendnn/CMakeLists.txt +++ b/ggml/src/ggml-zendnn/CMakeLists.txt @@ -1,12 +1,19 @@ ggml_add_backend_library(ggml-zendnn ggml-zendnn.cpp) -# Get ZenDNN path if (NOT DEFINED ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "") set(ZENDNN_ROOT "$ENV{ZENDNN_ROOT}") endif() -# Check if path is still empty or OFF +if (BUILD_SHARED_LIBS) + set(ZENDNN_SHARED_LIB ON) + set(ZENDNN_ARCHIVE_LIB OFF) +else() + set(ZENDNN_SHARED_LIB OFF) + set(ZENDNN_ARCHIVE_LIB ON) +endif() + +# Download and build ZenDNN if not provided if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF") message(STATUS "ZENDNN_ROOT not set. Automatically downloading and building ZenDNN...") message(STATUS "This will take several minutes on first build...") @@ -21,7 +28,7 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF") ExternalProject_Add( zendnn GIT_REPOSITORY https://github.com/amd/ZenDNN.git - GIT_TAG 21ce8f7879c86bf3637f707fae6f29e0951db5fe + GIT_TAG a18adf8c605fb5f5e52cefd7eda08a7b18febbaf # ZenDNN-2026-WW08 PREFIX ${ZENDNN_PREFIX} SOURCE_DIR ${ZENDNN_SOURCE_DIR} BINARY_DIR ${ZENDNN_BUILD_DIR} @@ -32,7 +39,9 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF") -DZENDNNL_BUILD_DOXYGEN=OFF -DZENDNNL_BUILD_GTEST=OFF -DZENDNNL_BUILD_BENCHDNN=OFF - # Enable ALL matmul algorithm backends + -DZENDNNL_DEPENDS_FBGEMM=OFF + -DZENDNNL_LIB_BUILD_ARCHIVE=${ZENDNN_ARCHIVE_LIB} + -DZENDNNL_LIB_BUILD_SHARED=${ZENDNN_SHARED_LIB} -DZENDNNL_DEPENDS_AOCLDLP=ON -DZENDNNL_DEPENDS_ONEDNN=ON -DZENDNNL_DEPENDS_LIBXSMM=ON @@ -45,47 +54,37 @@ if (NOT ZENDNN_ROOT OR ZENDNN_ROOT STREQUAL "" OR ZENDNN_ROOT STREQUAL "OFF") LOG_INSTALL ON ) - # Add dependency so ZenDNN builds before our library add_dependencies(ggml-zendnn zendnn) - - # Set ZENDNN_ROOT to the installation directory set(ZENDNN_ROOT ${ZENDNN_INSTALL_DIR}) - message(STATUS "ZenDNN will be built to: ${ZENDNN_ROOT}") else() message(STATUS "Using custom ZenDNN installation at: ${ZENDNN_ROOT}") endif() -# ZenDNN headers + libs target_include_directories(ggml-zendnn PRIVATE ${ZENDNN_ROOT}/zendnnl/include - ${ZENDNN_ROOT}/deps/aocldlp/include - ${ZENDNN_ROOT}/deps/aoclutils/include ${ZENDNN_ROOT}/deps/json/include - ${ZENDNN_ROOT}/deps/libxsmm/include + ${ZENDNN_ROOT}/deps/aoclutils/include + ${ZENDNN_ROOT}/deps/aocldlp/include ${ZENDNN_ROOT}/deps/onednn/include -) + ${ZENDNN_ROOT}/deps/libxsmm/include) -target_link_directories(ggml-zendnn PRIVATE - ${ZENDNN_ROOT}/zendnnl/lib - ${ZENDNN_ROOT}/deps/aocldlp/lib - ${ZENDNN_ROOT}/deps/aoclutils/lib - ${ZENDNN_ROOT}/deps/libxsmm/lib - ${ZENDNN_ROOT}/deps/onednn/lib -) +if (ZENDNN_SHARED_LIB) + target_link_directories(ggml-zendnn PRIVATE ${ZENDNN_ROOT}/zendnnl/lib) + target_link_libraries(ggml-zendnn PRIVATE zendnnl) +elseif (ZENDNN_ARCHIVE_LIB) + target_link_libraries(ggml-zendnn PRIVATE + ${ZENDNN_ROOT}/zendnnl/lib/libzendnnl_archive.a + ${ZENDNN_ROOT}/deps/aoclutils/${CMAKE_INSTALL_LIBDIR}/libaoclutils.a + ${ZENDNN_ROOT}/deps/aoclutils/${CMAKE_INSTALL_LIBDIR}/libau_cpuid.a + ${ZENDNN_ROOT}/deps/aocldlp/lib/libaocl-dlp.a + ${ZENDNN_ROOT}/deps/onednn/${CMAKE_INSTALL_LIBDIR}/libdnnl.a + ${ZENDNN_ROOT}/deps/libxsmm/lib/libxsmm.a + ${ZENDNN_ROOT}/deps/libxsmm/lib/libxsmmext.a + ${ZENDNN_ROOT}/deps/libxsmm/lib/libxsmmnoblas.a) +endif() -target_link_libraries(ggml-zendnn PRIVATE - zendnnl_archive # ZenDNN main - aocl-dlp # AOCL libraries - aoclutils - au_cpuid - dnnl # OneDNN - xsmm # libxsmm small matrix math - xsmmext - xsmmnoblas - m - pthread -) +target_link_libraries(ggml-zendnn PRIVATE m pthread) if (GGML_OPENMP) target_link_libraries(ggml-zendnn PRIVATE OpenMP::OpenMP_CXX) diff --git a/ggml/src/ggml-zendnn/ggml-zendnn.cpp b/ggml/src/ggml-zendnn/ggml-zendnn.cpp index 551c15bb4a..c876030400 100644 --- a/ggml/src/ggml-zendnn/ggml-zendnn.cpp +++ b/ggml/src/ggml-zendnn/ggml-zendnn.cpp @@ -41,13 +41,13 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int const TA * A, int64_t lda, const TB * B, int64_t ldb, TC * C, int64_t ldc) { - zendnnl::lowoha::lowoha_params params; + zendnnl::lowoha::matmul::matmul_params params; params.dtypes.src = ggml_to_zendnn_type(); params.dtypes.wei = ggml_to_zendnn_type(); params.dtypes.dst = ggml_to_zendnn_type(); params.num_threads = ctx->n_threads; - zendnnl::lowoha::status_t status = zendnnl::lowoha::matmul_direct( + zendnnl::error_handling::status_t status = zendnnl::lowoha::matmul::matmul_direct( 'r', false, true, // row-major, don't transpose B, transpose A (because it's column-major) n, // M: rows of B and C m, // N: cols of A^T and C @@ -63,7 +63,7 @@ static bool ggml_zendnn_matmul(ggml_backend_zendnn_context * ctx, int64_t m, int params // params ); - if (status != zendnnl::lowoha::status_t::success) { + if (status != zendnnl::error_handling::status_t::success) { GGML_LOG_ERROR("%s, ZenDNN matmul failed: status=%d\n", __func__, static_cast(status)); return false; } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index ed819eaa4c..e9529fbb66 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -899,7 +899,8 @@ static const struct ggml_type_traits type_traits[GGML_TYPE_COUNT] = { }; const struct ggml_type_traits * ggml_get_type_traits(enum ggml_type type) { - GGML_ASSERT(type < GGML_TYPE_COUNT); + assert(type >= 0); + assert(type < GGML_TYPE_COUNT); return &type_traits[type]; } @@ -1265,27 +1266,33 @@ size_t ggml_nbytes_pad(const struct ggml_tensor * tensor) { } int64_t ggml_blck_size(enum ggml_type type) { + assert(type >= 0); + assert(type < GGML_TYPE_COUNT); return type_traits[type].blck_size; } size_t ggml_type_size(enum ggml_type type) { + assert(type >= 0); + assert(type < GGML_TYPE_COUNT); return type_traits[type].type_size; } size_t ggml_row_size(enum ggml_type type, int64_t ne) { + assert(type >= 0); + assert(type < GGML_TYPE_COUNT); assert(ne % ggml_blck_size(type) == 0); return ggml_type_size(type)*ne/ggml_blck_size(type); } -double ggml_type_sizef(enum ggml_type type) { - return ((double)(type_traits[type].type_size))/type_traits[type].blck_size; -} - const char * ggml_type_name(enum ggml_type type) { - return type < GGML_TYPE_COUNT ? type_traits[type].type_name : "NONE"; + assert(type >= 0); + assert(type < GGML_TYPE_COUNT); + return type_traits[type].type_name; } bool ggml_is_quantized(enum ggml_type type) { + assert(type >= 0); + assert(type < GGML_TYPE_COUNT); return type_traits[type].is_quantized; } @@ -1629,11 +1636,23 @@ static struct ggml_object * ggml_new_object(struct ggml_context * ctx, enum ggml const size_t cur_end = cur_offs + cur_size; // align to GGML_MEM_ALIGN + GGML_ASSERT(size <= SIZE_MAX - (GGML_MEM_ALIGN - 1)); size_t size_needed = GGML_PAD(size, GGML_MEM_ALIGN); char * const mem_buffer = ctx->mem_buffer; struct ggml_object * const obj_new = (struct ggml_object *)(mem_buffer + cur_end); + // integer overflow checks + if (cur_end > SIZE_MAX - size_needed) { + GGML_LOG_WARN("%s: overflow detected in cur_end (%zu) + size_needed (%zu)\n", __func__, cur_end, size_needed); + return NULL; + } + if (cur_end + size_needed > SIZE_MAX - GGML_OBJECT_SIZE) { + GGML_LOG_WARN("%s: overflow detected in cur_end (%zu) + size_needed (%zu) + GGML_OBJECT_SIZE (%zu)\n", __func__, + cur_end, size_needed, (size_t) GGML_OBJECT_SIZE); + return NULL; + } + if (cur_end + size_needed + GGML_OBJECT_SIZE > ctx->mem_size) { GGML_LOG_WARN("%s: not enough space in the context's memory pool (needed %zu, available %zu)\n", __func__, cur_end + size_needed + GGML_OBJECT_SIZE, ctx->mem_size); @@ -1702,6 +1721,8 @@ static struct ggml_tensor * ggml_new_tensor_impl( obj_alloc_size = data_size; } + GGML_ASSERT(GGML_TENSOR_SIZE <= SIZE_MAX - obj_alloc_size); + struct ggml_object * const obj_new = ggml_new_object(ctx, GGML_OBJECT_TYPE_TENSOR, GGML_TENSOR_SIZE + obj_alloc_size); GGML_ASSERT(obj_new); diff --git a/ggml/src/gguf.cpp b/ggml/src/gguf.cpp index ed0d7f2cae..cbeedf6c4b 100644 --- a/ggml/src/gguf.cpp +++ b/ggml/src/gguf.cpp @@ -15,6 +15,17 @@ #include #include +#define GGUF_MAX_STRING_LENGTH (1024*1024*1024) +#define GGUF_MAX_ARRAY_ELEMENTS (1024*1024*1024) + +#ifdef _WIN32 +# define gguf_ftell _ftelli64 +# define gguf_fseek _fseeki64 +#else +# define gguf_ftell ftello +# define gguf_fseek fseeko +#endif + template struct type_to_gguf_type; @@ -217,17 +228,64 @@ struct gguf_context { }; struct gguf_reader { - FILE * file; + gguf_reader(FILE * file) : file(file) { + // read the remaining bytes once and update on each read + nbytes_remain = file_remain(file); + } - gguf_reader(FILE * file) : file(file) {} + // helper for remaining bytes in a file + static uint64_t file_remain(FILE * file) { + const int64_t cur = gguf_ftell(file); + if (cur < 0) { + return 0; + } + if (gguf_fseek(file, 0, SEEK_END) != 0) { + gguf_fseek(file, cur, SEEK_SET); + + return 0; + } + const int64_t end = gguf_ftell(file); + if (end < 0) { + gguf_fseek(file, cur, SEEK_SET); + + return 0; + } + gguf_fseek(file, cur, SEEK_SET); + return static_cast(end - cur); + } template bool read(T & dst) const { - return fread(&dst, 1, sizeof(dst), file) == sizeof(dst); + const size_t size = sizeof(dst); + if (nbytes_remain < size) { + return false; + } + const size_t nread = fread(&dst, 1, size, file); + nbytes_remain -= nread; + return nread == size; } template bool read(std::vector & dst, const size_t n) const { + if (n > GGUF_MAX_ARRAY_ELEMENTS) { + return false; + } + if constexpr (std::is_same::value) { + // strings are prefixed with their length, so we need to account for that + if (n > SIZE_MAX / sizeof(uint64_t)) { + return false; + } + if (nbytes_remain < n * sizeof(uint64_t)) { + return false; + } + } else { + if (n > SIZE_MAX / sizeof(T)) { + return false; + } + if (nbytes_remain < n * sizeof(T)) { + return false; + } + } dst.resize(n); for (size_t i = 0; i < dst.size(); ++i) { if constexpr (std::is_same::value) { @@ -277,13 +335,33 @@ struct gguf_reader { if (!read(size)) { return false; } - dst.resize(size); - return fread(dst.data(), 1, dst.length(), file) == dst.length(); + if (size > GGUF_MAX_STRING_LENGTH) { + GGML_LOG_ERROR("%s: string length %" PRIu64 " exceeds maximum %" PRIu64 "\n", __func__, size, (uint64_t) GGUF_MAX_STRING_LENGTH); + return false; + } + if (size > nbytes_remain) { + GGML_LOG_ERROR("%s: string length %" PRIu64 " exceeds remaining file size %" PRIu64 " bytes\n", __func__, size, nbytes_remain); + return false; + } + dst.resize(static_cast(size)); + const size_t nread = fread(dst.data(), 1, size, file); + nbytes_remain -= nread; + return nread == size; } bool read(void * dst, const size_t size) const { - return fread(dst, 1, size, file) == size; + if (size > nbytes_remain) { + return false; + } + const size_t nread = fread(dst, 1, size, file); + nbytes_remain -= nread; + return nread == size; } + +private: + FILE * file; + + mutable uint64_t nbytes_remain; }; struct gguf_context * gguf_init_empty(void) { @@ -568,8 +646,8 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par // check that tensor type is within defined range if (info.t.type < 0 || info.t.type >= GGML_TYPE_COUNT) { - GGML_LOG_ERROR("%s: tensor '%s' has invalid ggml type %d (%s)\n", - __func__, info.t.name, info.t.type, ggml_type_name(info.t.type)); + GGML_LOG_ERROR("%s: tensor '%s' has invalid ggml type %d. should be in [0, %d)\n", + __func__, info.t.name, info.t.type, GGML_TYPE_COUNT); ok = false; break; } @@ -618,14 +696,14 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par GGML_ASSERT(int64_t(ctx->info.size()) == n_tensors); // we require the data section to be aligned, so take into account any padding - if (fseek(file, GGML_PAD(ftell(file), ctx->alignment), SEEK_SET) != 0) { + if (gguf_fseek(file, GGML_PAD(gguf_ftell(file), ctx->alignment), SEEK_SET) != 0) { GGML_LOG_ERROR("%s: failed to seek to beginning of data section\n", __func__); gguf_free(ctx); return nullptr; } // store the current file offset - this is where the data section starts - ctx->offset = ftell(file); + ctx->offset = gguf_ftell(file); // compute the total size of the data section, taking into account the alignment { @@ -657,10 +735,34 @@ struct gguf_context * gguf_init_from_file_impl(FILE * file, struct gguf_init_par // the ggml_tensor structs to the appropriate locations in the binary blob // compute the exact size needed for the new ggml_context - const size_t mem_size = - params.no_alloc ? - (n_tensors )*ggml_tensor_overhead() : - (n_tensors + 1)*ggml_tensor_overhead() + ctx->size; + size_t mem_size = 0; + if (params.no_alloc) { + if (n_tensors != 0 && SIZE_MAX / n_tensors < ggml_tensor_overhead()) { + GGML_LOG_ERROR("%s: memory size overflow while allocating ggml context\n", __func__); + gguf_free(ctx); + return nullptr; + } + + const size_t overhead = n_tensors * ggml_tensor_overhead(); + + mem_size = overhead; + } else { + if ((n_tensors + 1) != 0 && SIZE_MAX / (n_tensors + 1) < ggml_tensor_overhead()) { + GGML_LOG_ERROR("%s: memory size overflow while allocating ggml context\n", __func__); + gguf_free(ctx); + return nullptr; + } + + const size_t overhead = (n_tensors + 1) * ggml_tensor_overhead(); + + if (SIZE_MAX - overhead < ctx->size) { + GGML_LOG_ERROR("%s: memory size overflow while allocating ggml context\n", __func__); + gguf_free(ctx); + return nullptr; + } + + mem_size = overhead + ctx->size; + } struct ggml_init_params pdata = { /*mem_size =*/ mem_size, diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index 689acdc65d..839c6e787f 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -379,6 +379,7 @@ class MODEL_ARCH(IntEnum): NEO_BERT = auto() JINA_BERT_V2 = auto() JINA_BERT_V3 = auto() + EUROBERT = auto() BLOOM = auto() STABLELM = auto() QWEN = auto() @@ -531,6 +532,7 @@ class MODEL_TENSOR(IntEnum): FFN_GATE_EXP = auto() FFN_DOWN_EXP = auto() FFN_UP_EXP = auto() + FFN_GATE_UP_EXP = auto() FFN_GATE_SHEXP = auto() FFN_DOWN_SHEXP = auto() FFN_UP_SHEXP = auto() @@ -820,6 +822,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.NEO_BERT: "neo-bert", MODEL_ARCH.JINA_BERT_V2: "jina-bert-v2", MODEL_ARCH.JINA_BERT_V3: "jina-bert-v3", + MODEL_ARCH.EUROBERT: "eurobert", MODEL_ARCH.BLOOM: "bloom", MODEL_ARCH.STABLELM: "stablelm", MODEL_ARCH.QWEN: "qwen", @@ -978,6 +981,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.FFN_GATE_EXP: "blk.{bid}.ffn_gate_exps", MODEL_TENSOR.FFN_DOWN_EXP: "blk.{bid}.ffn_down_exps", MODEL_TENSOR.FFN_UP_EXP: "blk.{bid}.ffn_up_exps", + MODEL_TENSOR.FFN_GATE_UP_EXP: "blk.{bid}.ffn_gate_up_exps", MODEL_TENSOR.FFN_EXP_PROBS_B: "blk.{bid}.exp_probs_b", MODEL_TENSOR.LAYER_OUT_NORM: "blk.{bid}.layer_output_norm", MODEL_TENSOR.PER_LAYER_TOKEN_EMBD: "per_layer_token_embd", # gemma3n @@ -1587,6 +1591,19 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_UP, MODEL_TENSOR.LAYER_OUT_NORM, ], + MODEL_ARCH.EUROBERT: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_DOWN, + ], MODEL_ARCH.MPT: [ MODEL_TENSOR.TOKEN_EMBD, MODEL_TENSOR.OUTPUT_NORM, @@ -1805,6 +1822,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_GATE_UP_EXP, MODEL_TENSOR.SSM_A, MODEL_TENSOR.SSM_CONV1D, MODEL_TENSOR.SSM_DT, @@ -1894,6 +1912,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_GATE_UP_EXP, MODEL_TENSOR.SSM_A, MODEL_TENSOR.SSM_CONV1D, MODEL_TENSOR.SSM_DT, @@ -2595,6 +2614,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_GATE_EXP, MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_UP_EXP, MODEL_TENSOR.FFN_GATE_SHEXP, MODEL_TENSOR.FFN_DOWN_SHEXP, MODEL_TENSOR.FFN_UP_SHEXP, diff --git a/gguf-py/gguf/gguf_reader.py b/gguf-py/gguf/gguf_reader.py index d87e8f7232..0a1b85f506 100644 --- a/gguf-py/gguf/gguf_reader.py +++ b/gguf-py/gguf/gguf_reader.py @@ -175,6 +175,9 @@ class GGUFReader: if new_align.types != [GGUFValueType.UINT32]: raise ValueError('Bad type for general.alignment field') self.alignment = new_align.parts[-1][0] + # Ensure alignment is a non-zero power of two + if self.alignment == 0 or (self.alignment & (self.alignment - 1)) != 0: + raise ValueError('Invalid alignment: must be a non-zero power of two') padding = offs % self.alignment if padding != 0: offs += self.alignment - padding @@ -202,11 +205,11 @@ class GGUFReader: def _push_field(self, field: ReaderField, skip_sum: bool = False) -> int: if field.name in self.fields: - # TODO: add option to generate error on duplicate keys - # raise KeyError(f'Duplicate {field.name} already in list at offset {field.offset}') + # TODO: add option to make this a warning and accept duplicate keys like below + raise KeyError(f'Duplicate {field.name} already in list at offset {field.offset}') - logger.warning(f'Duplicate key {field.name} at offset {field.offset}') - self.fields[field.name + '_{}'.format(field.offset)] = field + # logger.warning(f'Duplicate key {field.name} at offset {field.offset}') + # self.fields[field.name + '_{}'.format(field.offset)] = field else: self.fields[field.name] = field return 0 if skip_sum else sum(int(part.nbytes) for part in field.parts) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 4245d18bc4..9ee3ac9e8f 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -501,6 +501,8 @@ class GGUFWriter: self.add_uint32(Keys.General.QUANTIZATION_VERSION, quantization_version) def add_custom_alignment(self, alignment: int) -> None: + if alignment <= 0 or (alignment & (alignment - 1)) != 0: + raise ValueError('Invalid alignment: must be a non-zero power of two') self.data_alignment = alignment self.add_uint32(Keys.General.ALIGNMENT, alignment) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index fc468d0774..e575610900 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -567,6 +567,10 @@ class TensorNameMap: "model.layers.{bid}.mlp.chunk_experts.gate_proj", # grovemoe ), + MODEL_TENSOR.FFN_GATE_UP_EXP: ( + "model.layers.{bid}.mlp.experts.gate_up_proj", + ), + # Feed-forward down MODEL_TENSOR.FFN_DOWN: ( "gpt_neox.layers.{bid}.mlp.dense_4h_to_h", # gptneox diff --git a/gguf-py/pyproject.toml b/gguf-py/pyproject.toml index 48693ae3e3..5fb2755f1a 100644 --- a/gguf-py/pyproject.toml +++ b/gguf-py/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "gguf" -version = "0.17.1" +version = "0.18.0" description = "Read and write ML models in GGUF for GGML" authors = ["GGML "] packages = [ diff --git a/scripts/compare-logprobs.py b/scripts/compare-logprobs.py index 63861dd9a4..ac10085b78 100644 --- a/scripts/compare-logprobs.py +++ b/scripts/compare-logprobs.py @@ -25,16 +25,12 @@ Example usage: """ -def generate_input_prompt(length: int) -> list[str]: - CORPUS = """ - You are an advanced AI assistant capable of using tools to gather information, perform calculations, or execute tasks. Always think step by step before responding. If a user's query requires external data, computation, or actions beyond your internal knowledge, use the appropriate tools via function calls. - - ### Tool Call Format: - When you need to use a tool, output the call in this exact XML format. Include the opening and closing tags. Do not escape arguments; they will be parsed as plain text. - - You can make multiple calls in one go by placing them one after another. - """ - words = [w.strip() for w in CORPUS.strip().split(" ")] +def get_remote_corpus(url: str, length: int) -> list[str]: + response = requests.get(url) + response.raise_for_status() + corpus = response.text + words = [w.strip() for w in corpus.strip().split(" ")] + words = [w for w in words if "<" not in w] # make sure nothing looks like special tokens words = [w for w in words if len(w) > 0] # filter out empty strings while len(words) < length: words += words @@ -226,9 +222,9 @@ def parse_args() -> argparse.Namespace: ) parser_dump.add_argument( "--file", - type=Path, - default=None, - help="File containing prompt to use instead of the default", + type=str, + default="https://raw.githubusercontent.com/ggml-org/llama.cpp/eaba92c3dcc980ebe753348855d4a5d75c069997/tools/server/README.md", + help="File containing prompt to use instead of the default (can also be an URL)", ) parser_dump.add_argument( "--pattern", @@ -259,17 +255,19 @@ def main(): if args.verb == "dump": pattern = parse_pattern(args.pattern) - input_length = sum(n for _, n in pattern) - input_words = generate_input_prompt(input_length) - if args.file is not None: - with args.file.open("r") as f: + required_words = sum(n for _, n in pattern) + if args.file.startswith("http"): + input_words = get_remote_corpus(args.file, required_words) + logger.info(f"Fetched {len(input_words)} words from remote {args.file}") + else: + with open(args.file, "r") as f: input_words = f.read().strip().split(" ") - if input_length < sum(n for _, n in pattern): + input_words = [w for w in input_words if len(w) > 0] # filter out empty strings + if len(input_words) < required_words: raise ValueError( - f"Input file has only {input_length} words, but pattern requires at least {input_length} words." + f"Input file has only {len(input_words)} words, but pattern requires at least {required_words} words." ) - input_length = len(input_words) - logger.info(f"Using {input_length} words") + logger.info(f"Using {len(input_words)} words") dump_logits(args.endpoint, args.output, input_words, pattern, args.api_key) elif args.verb == "compare": compare_logits(args.input1, args.input2, args.output) diff --git a/scripts/snapdragon/adb/run-cli.sh b/scripts/snapdragon/adb/run-cli.sh index d19d4e920e..dfc051b28b 100755 --- a/scripts/snapdragon/adb/run-cli.sh +++ b/scripts/snapdragon/adb/run-cli.sh @@ -54,6 +54,6 @@ adb $adbserial $adbhost shell " \ $verbose $experimental $sched $opmask $profile $nhvx $ndev $hb \ ./$branch/bin/llama-cli --no-mmap -m $basedir/../gguf/$model \ --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \ - --ctx-size 8192 --batch-size 128 -fa on \ - -ngl 99 --device $device $cli_opts $@ \ + --ctx-size 8192 --ubatch-size 256 -fa on \ + -ngl 99 --device $device $cli_opts $@ \ " diff --git a/scripts/snapdragon/adb/run-completion.sh b/scripts/snapdragon/adb/run-completion.sh index da9df110a0..d53b588739 100755 --- a/scripts/snapdragon/adb/run-completion.sh +++ b/scripts/snapdragon/adb/run-completion.sh @@ -54,6 +54,6 @@ adb $adbserial $adbhost shell " \ $verbose $experimental $sched $opmask $profile $nhvx $ndev $hb \ ./$branch/bin/llama-completion --no-mmap -m $basedir/../gguf/$model \ --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \ - --ctx-size 8192 --batch-size 128 -fa on \ - -ngl 99 -no-cnv --device $device $cli_opts $@ \ + --ctx-size 8192 --ubatch-size 256 -fa on \ + -ngl 99 -no-cnv --device $device $cli_opts $@ \ " diff --git a/scripts/snapdragon/adb/run-mtmd.sh b/scripts/snapdragon/adb/run-mtmd.sh index fc018e7269..41d7cd44f8 100755 --- a/scripts/snapdragon/adb/run-mtmd.sh +++ b/scripts/snapdragon/adb/run-mtmd.sh @@ -58,11 +58,11 @@ adb $adbserial $adbhost shell " \ cd $basedir; ulimit -c unlimited; \ LD_LIBRARY_PATH=$basedir/$branch/lib \ ADSP_LIBRARY_PATH=$basedir/$branch/lib \ - $verbose $experimental $sched $opmask $profile $nhvx $ndev $mtmd_backend \ - ./$branch/bin/llama-mtmd-cli --no-mmap -m $basedir/../gguf/$model \ - --mmproj $basedir/../gguf/$mmproj \ - --image $basedir/../gguf/$image \ - --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \ - --ctx-size 8192 --batch-size 128 -ctk q8_0 -ctv q8_0 -fa on \ - -ngl 99 --device $device -v $cli_opts $@ \ + $verbose $experimental $sched $opmask $profile $nhvx $ndev $mtmd_backend \ + ./$branch/bin/llama-mtmd-cli --no-mmap -m $basedir/../gguf/$model \ + --mmproj $basedir/../gguf/$mmproj \ + --image $basedir/../gguf/$image \ + --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \ + --ctx-size 8192 --ubatch-size 256 -fa on \ + -ngl 99 --device $device -v $cli_opts $@ \ " diff --git a/scripts/snapdragon/windows/run-cli.ps1 b/scripts/snapdragon/windows/run-cli.ps1 index b13161aa63..40c7acc430 100644 --- a/scripts/snapdragon/windows/run-cli.ps1 +++ b/scripts/snapdragon/windows/run-cli.ps1 @@ -49,5 +49,5 @@ $env:ADSP_LIBRARY_PATH="$basedir\lib" & "$basedir\bin\llama-completion.exe" ` --no-mmap -no-cnv -m $basedir\..\..\gguf\$model ` --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 ` - --ctx-size 8192 --batch-size 128 -ctk q8_0 -ctv q8_0 -fa on ` + --ctx-size 8192 --ubatch-size 128 -fa on ` -ngl 99 --device $device $cli_opts diff --git a/scripts/sync_vendor.py b/scripts/sync_vendor.py index fe1286d009..a26cb26c9b 100755 --- a/scripts/sync_vendor.py +++ b/scripts/sync_vendor.py @@ -5,7 +5,7 @@ import os import sys import subprocess -HTTPLIB_VERSION = "d4180e923f846b44a3d30acd938438d6e64fc9f6" +HTTPLIB_VERSION = "refs/tags/v0.35.0" vendor = { "https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp", @@ -14,8 +14,8 @@ vendor = { "https://raw.githubusercontent.com/nothings/stb/refs/heads/master/stb_image.h": "vendor/stb/stb_image.h", # not using latest tag to avoid this issue: https://github.com/ggml-org/llama.cpp/pull/17179#discussion_r2515877926 - # "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.23/miniaudio.h": "vendor/miniaudio/miniaudio.h", - "https://github.com/mackron/miniaudio/raw/669ed3e844524fcd883231b13095baee9f6de304/miniaudio.h": "vendor/miniaudio/miniaudio.h", + # "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.24/miniaudio.h": "vendor/miniaudio/miniaudio.h", + "https://github.com/mackron/miniaudio/raw/13d161bc8d856ad61ae46b798bbeffc0f49808e8/miniaudio.h": "vendor/miniaudio/miniaudio.h", f"https://raw.githubusercontent.com/yhirose/cpp-httplib/{HTTPLIB_VERSION}/httplib.h": "httplib.h", f"https://raw.githubusercontent.com/yhirose/cpp-httplib/{HTTPLIB_VERSION}/split.py": "split.py", diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 2a661a1fe8..283823fa9c 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -62,6 +62,7 @@ add_library(llama models/dream.cpp models/ernie4-5-moe.cpp models/ernie4-5.cpp + models/eurobert.cpp models/exaone-moe.cpp models/exaone.cpp models/exaone4.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 39ebb9db02..47e8d5278a 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -26,6 +26,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_NEO_BERT, "neo-bert" }, { LLM_ARCH_JINA_BERT_V2, "jina-bert-v2" }, { LLM_ARCH_JINA_BERT_V3, "jina-bert-v3" }, + { LLM_ARCH_EUROBERT, "eurobert" }, { LLM_ARCH_BLOOM, "bloom" }, { LLM_ARCH_STABLELM, "stablelm" }, { LLM_ARCH_QWEN, "qwen" }, @@ -348,6 +349,7 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_FFN_DOWN_EXP, "blk.%d.ffn_down.%d" }, { LLM_TENSOR_FFN_UP_EXP, "blk.%d.ffn_up.%d" }, { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_GATE_UP_EXPS, "blk.%d.ffn_gate_up_exps" }, { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, { LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" }, @@ -819,6 +821,20 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_CLS, LLM_TENSOR_CLS_OUT, }; + case LLM_ARCH_EUROBERT: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_DOWN, + }; case LLM_ARCH_MODERN_BERT: return { LLM_TENSOR_TOKEN_EMBD, @@ -989,6 +1005,7 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_GATE_EXPS, LLM_TENSOR_FFN_DOWN_EXPS, LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_UP_EXPS, LLM_TENSOR_FFN_GATE_INP_SHEXP, LLM_TENSOR_FFN_GATE_SHEXP, LLM_TENSOR_FFN_DOWN_SHEXP, @@ -1046,6 +1063,7 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_GATE_EXPS, LLM_TENSOR_FFN_DOWN_EXPS, LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_UP_EXPS, LLM_TENSOR_FFN_GATE_INP_SHEXP, LLM_TENSOR_FFN_GATE_SHEXP, LLM_TENSOR_FFN_DOWN_SHEXP, @@ -1586,6 +1604,7 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_FFN_GATE_EXPS, LLM_TENSOR_FFN_DOWN_EXPS, LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_UP_EXPS, LLM_TENSOR_FFN_GATE_INP_SHEXP, LLM_TENSOR_FFN_GATE_SHEXP, LLM_TENSOR_FFN_DOWN_SHEXP, @@ -2670,6 +2689,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_FFN_DOWN_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_GATE_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, + {LLM_TENSOR_FFN_GATE_UP_EXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_DOWN_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_GATE_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, {LLM_TENSOR_FFN_UP_CHEXPS, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT_ID}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index 11daa14133..6d1b1df31c 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -30,6 +30,7 @@ enum llm_arch { LLM_ARCH_NEO_BERT, LLM_ARCH_JINA_BERT_V2, LLM_ARCH_JINA_BERT_V3, + LLM_ARCH_EUROBERT, LLM_ARCH_BLOOM, LLM_ARCH_STABLELM, LLM_ARCH_QWEN, @@ -372,6 +373,7 @@ enum llm_tensor { LLM_TENSOR_FFN_DOWN_EXPS, // merged experts LLM_TENSOR_FFN_GATE_EXPS, LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_UP_EXPS, LLM_TENSOR_FFN_DOWN_SHEXP, LLM_TENSOR_FFN_GATE_SHEXP, LLM_TENSOR_FFN_UP_SHEXP, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 7cd0bfc0d2..98d055d34e 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -2440,64 +2440,6 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { // TODO: add more model-specific info which should prevent loading the session file if not identical } - // write output ids - { - LLAMA_LOG_DEBUG("%s: - writing output ids\n", __func__); - - const auto n_outputs = this->n_outputs; - const auto & output_ids = this->output_ids; - - std::vector w_output_pos; - - w_output_pos.resize(n_outputs); - - // build a more compact representation of the output ids - for (size_t i = 0; i < n_batch(); ++i) { - // map an output id to a position in the batch - int64_t pos = output_ids[i]; - if (pos >= 0) { - GGML_ASSERT(pos < n_outputs); - w_output_pos[pos] = i; - } - } - - io.write(&n_outputs, sizeof(n_outputs)); - - if (n_outputs) { - io.write(w_output_pos.data(), n_outputs * sizeof(int32_t)); - } - } - - // [TAG_CONTEXT_STATE_LOGITS] - // write logits - { - LLAMA_LOG_DEBUG("%s: - writing logits\n", __func__); - - const uint64_t logits_size = std::min((uint64_t) this->logits.size, (uint64_t) n_outputs * model.vocab.n_tokens()); - - io.write(&logits_size, sizeof(logits_size)); - - if (logits_size) { - io.write(logits.data, logits_size * sizeof(float)); - } - } - - // write embeddings - { - LLAMA_LOG_DEBUG("%s: - writing embeddings\n", __func__); - - const uint64_t embd_size = std::min((uint64_t) this->embd.size, (uint64_t) n_outputs * model.hparams.n_embd); - - io.write(&embd_size, sizeof(embd_size)); - - if (embd_size) { - io.write(embd.data, embd_size * sizeof(float)); - } - } - - // TODO: handle sampling buffers and samplers state ? - // https://github.com/ggml-org/llama.cpp/pull/17004 - if (memory != nullptr) { LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__); memory->state_write(io); @@ -2523,70 +2465,6 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { // TODO: add more info which needs to be identical but which is not verified otherwise } - // read output ids - { - LLAMA_LOG_DEBUG("%s: - reading output ids\n", __func__); - - auto n_outputs = this->n_outputs; - io.read_to(&n_outputs, sizeof(n_outputs)); - - if (n_outputs > output_reserve(n_outputs)) { - throw std::runtime_error("could not reserve outputs"); - } - - std::vector output_pos; - - if (n_outputs) { - output_pos.resize(n_outputs); - io.read_to(output_pos.data(), n_outputs * sizeof(int32_t)); - - for (int32_t i = 0; i < (int32_t) output_pos.size(); ++i) { - int32_t id = output_pos[i]; - if ((uint32_t) id >= n_batch()) { - throw std::runtime_error(format("invalid output id, %d does not fit in batch size of %u", id, n_batch())); - } - this->output_ids[id] = i; - } - - this->n_outputs = n_outputs; - } - } - - // read logits - { - LLAMA_LOG_DEBUG("%s: - reading logits\n", __func__); - - uint64_t logits_size; - io.read_to(&logits_size, sizeof(logits_size)); - - if (this->logits.size < logits_size) { - throw std::runtime_error("logits buffer too small"); - } - - if (logits_size) { - io.read_to(this->logits.data, logits_size * sizeof(float)); - } - } - - // read embeddings - { - LLAMA_LOG_DEBUG("%s: - reading embeddings\n", __func__); - - uint64_t embd_size; - io.read_to(&embd_size, sizeof(embd_size)); - - if (this->embd.size < embd_size) { - throw std::runtime_error("embeddings buffer too small"); - } - - if (embd_size) { - io.read_to(this->embd.data, embd_size * sizeof(float)); - } - } - - // TODO: handle sampling buffers and samplers state ? - // https://github.com/ggml-org/llama.cpp/pull/17004 - if (memory) { LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__); diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index dc58c0826a..23a86ea290 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -1165,7 +1165,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn( float w_scale, llama_expert_gating_func_type gating_op, int il, - ggml_tensor * probs_in) const { + ggml_tensor * probs_in, + ggml_tensor * gate_up_exps) const { return build_moe_ffn( cur, gate_inp, /* gate_inp_b */ nullptr, @@ -1181,7 +1182,8 @@ ggml_tensor * llm_graph_context::build_moe_ffn( w_scale, gating_op, il, - probs_in + probs_in, + gate_up_exps ); } @@ -1204,7 +1206,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn( float w_scale, llama_expert_gating_func_type gating_op, int il, - ggml_tensor * probs_in) const { + ggml_tensor * probs_in, + ggml_tensor * gate_up_exps, + ggml_tensor * gate_up_exps_b) const { const int64_t n_embd = cur->ne[0]; const int64_t n_tokens = cur->ne[1]; const bool weight_before_ffn = arch == LLM_ARCH_LLAMA4; // for llama4, we apply the sigmoid-ed weights before the FFN @@ -1343,26 +1347,48 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(cur, "ffn_moe_weighted", il); } - ggml_tensor * up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] - cb(up, "ffn_moe_up", il); - - if (up_exps_b) { - up = ggml_add_id(ctx0, up, up_exps_b, selected_experts); - cb(up, "ffn_moe_up_biased", il); - } - + ggml_tensor * up = nullptr; ggml_tensor * experts = nullptr; - if (gate_exps) { - cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + + if (gate_up_exps) { + // merged gate_up path: one mul_mat_id, then split into gate and up views + ggml_tensor * gate_up = build_lora_mm_id(gate_up_exps, cur, selected_experts); // [n_ff*2, n_expert_used, n_tokens] + cb(gate_up, "ffn_moe_gate_up", il); + + if (gate_up_exps_b) { + gate_up = ggml_add_id(ctx0, gate_up, gate_up_exps_b, selected_experts); + cb(gate_up, "ffn_moe_gate_up_biased", il); + } + + const int64_t n_ff = gate_up->ne[0] / 2; + cur = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], 0); cb(cur, "ffn_moe_gate", il); + up = ggml_view_3d(ctx0, gate_up, n_ff, gate_up->ne[1], gate_up->ne[2], gate_up->nb[1], gate_up->nb[2], n_ff * gate_up->nb[0]); + cb(up, "ffn_moe_up", il); } else { - cur = up; + // separate gate and up path + up = build_lora_mm_id(up_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(up, "ffn_moe_up", il); + + if (up_exps_b) { + up = ggml_add_id(ctx0, up, up_exps_b, selected_experts); + cb(up, "ffn_moe_up_biased", il); + } + + if (gate_exps) { + cur = build_lora_mm_id(gate_exps, cur, selected_experts); // [n_ff, n_expert_used, n_tokens] + cb(cur, "ffn_moe_gate", il); + } else { + cur = up; + } + + if (gate_exps_b) { + cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts); + cb(cur, "ffn_moe_gate_biased", il); + } } - if (gate_exps_b) { - cur = ggml_add_id(ctx0, cur, gate_exps_b, selected_experts); - cb(cur, "ffn_moe_gate_biased", il); - } + const bool has_gate = gate_exps || gate_up_exps; switch (type_op) { case LLM_FFN_SILU: @@ -1385,7 +1411,9 @@ ggml_tensor * llm_graph_context::build_moe_ffn( break; } } + } + if (has_gate) { cur = ggml_swiglu_split(ctx0, cur, up); cb(cur, "ffn_moe_swiglu", il); } else { @@ -1393,7 +1421,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(cur, "ffn_moe_silu", il); } break; case LLM_FFN_GELU: - if (gate_exps) { + if (has_gate) { cur = ggml_geglu_split(ctx0, cur, up); cb(cur, "ffn_moe_geglu", il); } else { @@ -1409,7 +1437,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(cur, "ffn_moe_swiglu_oai", il); } break; case LLM_FFN_RELU: - if (gate_exps) { + if (has_gate) { cur = ggml_reglu_split(ctx0, cur, up); cb(cur, "ffn_moe_reglu", il); } else { @@ -1417,7 +1445,7 @@ ggml_tensor * llm_graph_context::build_moe_ffn( cb(cur, "ffn_moe_relu", il); } break; case LLM_FFN_RELU_SQR: - if (gate_exps) { + if (has_gate) { // TODO: add support for gated squared relu GGML_ABORT("fatal error: gated squared relu not implemented"); } else { diff --git a/src/llama-graph.h b/src/llama-graph.h index 22d11a8385..e8f006977d 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -814,7 +814,8 @@ struct llm_graph_context { float w_scale, llama_expert_gating_func_type gating_op, int il, - ggml_tensor * probs_in = nullptr) const; + ggml_tensor * probs_in = nullptr, + ggml_tensor * gate_up_exps = nullptr) const; ggml_tensor * build_moe_ffn( ggml_tensor * cur, @@ -835,7 +836,9 @@ struct llm_graph_context { float w_scale, llama_expert_gating_func_type gating_op, int il, - ggml_tensor * probs_in = nullptr) const; + ggml_tensor * probs_in = nullptr, + ggml_tensor * gate_up_exps = nullptr, + ggml_tensor * gate_up_exps_b = nullptr) const; // // inputs diff --git a/src/llama-kv-cache.cpp b/src/llama-kv-cache.cpp index cb702b2a59..6b668ee9ab 100644 --- a/src/llama-kv-cache.cpp +++ b/src/llama-kv-cache.cpp @@ -978,6 +978,9 @@ bool llama_kv_cache::get_can_shift() const { if (model.arch == LLM_ARCH_STEP35) { return false; } + if (hparams.n_pos_per_embd() > 1) { + return false; + } return true; } diff --git a/src/llama-memory-recurrent.cpp b/src/llama-memory-recurrent.cpp index f0038036dc..6e8413f493 100644 --- a/src/llama-memory-recurrent.cpp +++ b/src/llama-memory-recurrent.cpp @@ -163,7 +163,7 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos const auto & cell = cells[tail_id]; // partial intersection is invalid if it includes the final pos if (0 < p0 && p0 <= cell.pos && p1 > cell.pos) { - //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n"); + //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false, p0 = %d, cell.pos = %d, p1 = %d\n", p0, cell.pos, p1); return false; } // invalidate tails which will be cleared diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 764839b9bc..dabf3b3086 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -123,6 +123,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_8B_A1B: return "8B.A1B"; case LLM_TYPE_16B_A1B: return "16B.A1B"; case LLM_TYPE_21B_A3B: return "21B.A3B"; + case LLM_TYPE_24B_A2B: return "24B.A2B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; case LLM_TYPE_31B_A3_5B: return "31B.A3.5B"; case LLM_TYPE_35B_A3B: return "35B.A3B"; @@ -978,6 +979,16 @@ void llama_model::load_hparams(llama_model_loader & ml) { type = LLM_TYPE_250M; } } break; + case LLM_ARCH_EUROBERT: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_ATTENTION_CAUSAL, hparams.causal_attn); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + + if (hparams.n_layer == 12) { + type = LLM_TYPE_SMALL; // 0.2B + } + } break; case LLM_ARCH_BLOOM: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -1703,8 +1714,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; case LLM_ARCH_DEEPSEEK2: { - // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B - const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26); + // lite variants include DeepSeek-V2-Lite, GigaChat3-10B-A1.8B, Kanana-2-30B-A3B + const bool is_lite = (hparams.n_layer == 27 || hparams.n_layer == 26 || (hparams.n_layer == 48 && n_vocab == 128256)); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); @@ -2381,7 +2392,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; } - type = LLM_TYPE_8B_A1B; + switch (hparams.n_layer) { + case 24: type = LLM_TYPE_8B_A1B; break; + case 40: type = LLM_TYPE_24B_A2B; break; + default: type = LLM_TYPE_UNKNOWN; + } } break; case LLM_ARCH_SMALLTHINKER: { @@ -2965,6 +2980,15 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // TODO: move to a separate function const auto tn = LLM_TN(arch); + + // helper: try merged gate_up_exps first, fall back to separate gate and up + auto create_tensor_gate_up_exps = [&](llama_layer & layer, int bid, int64_t n_embd_, int64_t n_ff_, int64_t n_expert_, int flags) { + layer.ffn_gate_up_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_UP_EXPS, "weight", bid), {n_embd_, n_ff_ * 2, n_expert_}, TENSOR_NOT_REQUIRED); + if (layer.ffn_gate_up_exps == nullptr) { + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", bid), {n_embd_, n_ff_, n_expert_}, flags); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", bid), {n_embd_, n_ff_, n_expert_}, flags); + } + }; switch (arch) { case LLM_ARCH_LLAMA: case LLM_ARCH_REFACT: @@ -3565,6 +3589,29 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); } } break; + case LLM_ARCH_EUROBERT: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd, n_embd}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); + } + } break; case LLM_ARCH_JINA_BERT_V2: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // word_embeddings @@ -5183,9 +5230,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } // MoE branch - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, 0); + create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); // Shared expert branch layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, 0); @@ -7387,9 +7433,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); // Shared experts layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0); @@ -7453,9 +7498,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0); - layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0); - layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0); + create_tensor_gate_up_exps(layer, i, n_embd, n_ff_exp, n_expert, 0); // Shared experts const int64_t n_ff_shexp = hparams.n_ff_shexp ? hparams.n_ff_shexp : n_ff; @@ -8176,6 +8220,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: case LLM_ARCH_NEO_BERT: + case LLM_ARCH_EUROBERT: case LLM_ARCH_WAVTOKENIZER_DEC: case LLM_ARCH_MODERN_BERT: case LLM_ARCH_GEMMA_EMBEDDING: @@ -8373,6 +8418,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_EUROBERT: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_BLOOM: { llm = std::make_unique(*this, params); @@ -8999,6 +9048,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_MODERN_BERT: case LLM_ARCH_NOMIC_BERT: case LLM_ARCH_NOMIC_BERT_MOE: + case LLM_ARCH_EUROBERT: case LLM_ARCH_STABLELM: case LLM_ARCH_BITNET: case LLM_ARCH_QWEN: diff --git a/src/llama-model.h b/src/llama-model.h index 422ed45699..d7c3e7d1c1 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -116,6 +116,7 @@ enum llm_type { LLM_TYPE_8B_A1B, // lfm2moe LLM_TYPE_16B_A1B, LLM_TYPE_21B_A3B, // Ernie MoE small + LLM_TYPE_24B_A2B, // lfm2moe LLM_TYPE_30B_A3B, LLM_TYPE_31B_A3_5B, LLM_TYPE_35B_A3B, // Qwen3.5 @@ -279,14 +280,16 @@ struct llama_layer { struct ggml_tensor * ffn_up_enc = nullptr; // ff MoE - struct ggml_tensor * ffn_gate_inp = nullptr; - struct ggml_tensor * ffn_gate_exps = nullptr; - struct ggml_tensor * ffn_down_exps = nullptr; - struct ggml_tensor * ffn_up_exps = nullptr; - struct ggml_tensor * ffn_gate_inp_b = nullptr; - struct ggml_tensor * ffn_gate_exps_b = nullptr; - struct ggml_tensor * ffn_down_exps_b = nullptr; - struct ggml_tensor * ffn_up_exps_b = nullptr; + struct ggml_tensor * ffn_gate_inp = nullptr; + struct ggml_tensor * ffn_gate_exps = nullptr; + struct ggml_tensor * ffn_down_exps = nullptr; + struct ggml_tensor * ffn_up_exps = nullptr; + struct ggml_tensor * ffn_gate_up_exps = nullptr; + struct ggml_tensor * ffn_gate_inp_b = nullptr; + struct ggml_tensor * ffn_gate_exps_b = nullptr; + struct ggml_tensor * ffn_down_exps_b = nullptr; + struct ggml_tensor * ffn_up_exps_b = nullptr; + struct ggml_tensor * ffn_gate_up_exps_b = nullptr; // ff shared expert (shexp) struct ggml_tensor * ffn_gate_inp_shexp = nullptr; diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index 69b25a1bf9..194eed238e 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -1890,7 +1890,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "falcon-h1" || tokenizer_pre == "pixtral" || tokenizer_pre == "midm-2.0" || - tokenizer_pre == "lfm2") { + tokenizer_pre == "lfm2" || + tokenizer_pre == "jina-v5-nano") { pre_type = LLAMA_VOCAB_PRE_TYPE_LLAMA3; ignore_merges = true; add_bos = true; @@ -2027,7 +2028,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { pre_type = LLAMA_VOCAB_PRE_TYPE_QWEN2; } else if ( tokenizer_pre == "gpt-4o" || - tokenizer_pre == "llama4") { + tokenizer_pre == "llama4" || + tokenizer_pre == "kanana2") { pre_type = LLAMA_VOCAB_PRE_TYPE_GPT4O; clean_spaces = false; } else if ( diff --git a/src/models/deepseek2.cpp b/src/models/deepseek2.cpp index b2c1f16060..b608396e50 100644 --- a/src/models/deepseek2.cpp +++ b/src/models/deepseek2.cpp @@ -218,7 +218,9 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr LLM_FFN_SILU, hparams.expert_weights_norm, hparams.expert_weights_scale, hparams.expert_weights_scale, (llama_expert_gating_func_type) hparams.expert_gating_func, - il); + il, + nullptr, + model.layers[il].ffn_gate_up_exps); cb(moe_out, "ffn_moe_out", il); // FFN shared expert diff --git a/src/models/eurobert.cpp b/src/models/eurobert.cpp new file mode 100644 index 0000000000..86e3176edc --- /dev/null +++ b/src/models/eurobert.cpp @@ -0,0 +1,97 @@ +#include "models.h" + +llm_build_eurobert::llm_build_eurobert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + + ggml_tensor * cur; + ggml_tensor * inpL; + ggml_tensor * inp_pos = build_inp_pos(); + + inpL = build_inp_embd(model.tok_embd); + cb(inpL, "inp_embd", -1); + + auto * inp_attn = build_attn_inp_no_cache(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * cur = inpL; + + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + + { + ggml_tensor * Qcur; + ggml_tensor * Kcur; + ggml_tensor * Vcur; + + Qcur = build_lora_mm(model.layers[il].wq, cur); + Kcur = build_lora_mm(model.layers[il].wk, cur); + Vcur = build_lora_mm(model.layers[il].wv, cur); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, nullptr, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + cb(cur, "kqv_out", il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpL = ggml_get_rows(ctx0, inpL, inp_out_ids); + } + + cur = ggml_add(ctx0, cur, inpL); + + ggml_tensor * ffn_inp = cur; + cb(ffn_inp, "ffn_inp", il); + + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_embd", -1); + res->t_embd = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/src/models/kimi-linear.cpp b/src/models/kimi-linear.cpp index 4d6bb83c14..83d11241f8 100644 --- a/src/models/kimi-linear.cpp +++ b/src/models/kimi-linear.cpp @@ -116,6 +116,8 @@ llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const ll cur = build_norm(inpL, layer.attn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); + ggml_build_forward_expand(gf, cur); + // Check layer type by checking which tensors exist // KDA layers have ssm_a_log tensor, MLA layers have wkv_a_mqa tensor bool is_kda = (layer.ssm_a != nullptr); diff --git a/src/models/models.h b/src/models/models.h index 10f8b58921..0712d03d8d 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -424,6 +424,10 @@ struct llm_build_neo_bert : public llm_graph_context { llm_build_neo_bert(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_eurobert : public llm_graph_context { + llm_build_eurobert(const llama_model & model, const llm_graph_params & params); +}; + template struct llm_build_olmo2 : public llm_graph_context { llm_build_olmo2(const llama_model & model, const llm_graph_params & params); diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index 56eefd7de2..bacf7a4c2e 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -29,6 +29,8 @@ llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_pa cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); + ggml_build_forward_expand(gf, cur); + // Determine layer type and build appropriate attention mechanism if (hparams.is_recurrent(il)) { // Linear attention layer (gated delta net) @@ -269,7 +271,6 @@ ggml_tensor * llm_build_qwen35::build_layer_attn_linear( cb(state_update_target, "state_update_target", il); ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); - cb(conv_states_all, "conv_states_updated", il); ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp index c7295e3364..22d708f206 100644 --- a/src/models/qwen35moe.cpp +++ b/src/models/qwen35moe.cpp @@ -29,6 +29,8 @@ llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_gr cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); + ggml_build_forward_expand(gf, cur); + // Determine layer type and build appropriate attention mechanism if (hparams.is_recurrent(il)) { // Linear attention layer (gated delta net) @@ -269,7 +271,6 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_attn_linear( cb(state_update_target, "state_update_target", il); ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); - cb(conv_states_all, "conv_states_updated", il); ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); @@ -379,7 +380,8 @@ ggml_tensor * llm_build_qwen35moe ::build_layer_ffn(ggml_tensor * cur, const int model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, nullptr, n_expert, n_expert_used, LLM_FFN_SILU, - true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); + true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il, + nullptr, model.layers[il].ffn_gate_up_exps); cb(moe_out, "ffn_moe_out", il); // Add shared experts if present - following Qwen3Next reference implementation diff --git a/src/models/qwen3next.cpp b/src/models/qwen3next.cpp index 974120ea6f..f2621200f2 100644 --- a/src/models/qwen3next.cpp +++ b/src/models/qwen3next.cpp @@ -21,6 +21,8 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr cur = build_norm(inpL, model.layers[il].attn_norm, nullptr, LLM_NORM_RMS, il); cb(cur, "attn_norm", il); + ggml_build_forward_expand(gf, cur); + // Determine layer type and build appropriate attention mechanism if (hparams.is_recurrent(il)) { // Linear attention layer (gated delta net) @@ -354,7 +356,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( cb(state_update_target, "state_update_target", il); ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); - cb(conv_states_all, "conv_states_updated", il); ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); @@ -478,7 +479,8 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int model.layers[il].ffn_gate_exps, model.layers[il].ffn_down_exps, nullptr, n_expert, n_expert_used, LLM_FFN_SILU, - true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il); + true, false, 0.0, LLAMA_EXPERT_GATING_FUNC_TYPE_SOFTMAX, il, + nullptr, model.layers[il].ffn_gate_up_exps); cb(moe_out, "ffn_moe_out", il); // Add shared experts if present - following Qwen3Next reference implementation diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 350bffc315..7e0b17a7c1 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -152,7 +152,7 @@ if (NOT WIN32 OR NOT BUILD_SHARED_LIBS) llama_build_and_test(test-grammar-parser.cpp) llama_build_and_test(test-grammar-integration.cpp) llama_build_and_test(test-llama-grammar.cpp) - llama_build_and_test(test-chat.cpp) + llama_build_and_test(test-chat.cpp WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}) # TODO: disabled on loongarch64 because the ggml-ci node lacks Python 3.8 if (NOT ${CMAKE_SYSTEM_PROCESSOR} MATCHES "loongarch64") llama_build_and_test(test-json-schema-to-grammar.cpp WORKING_DIRECTORY ${PROJECT_SOURCE_DIR}) @@ -257,6 +257,21 @@ set(LLAMA_TEST_NAME test-mtmd-c-api) llama_build_and_test(test-mtmd-c-api.c) target_link_libraries(${LLAMA_TEST_NAME} PRIVATE mtmd) +# GGUF model data fetcher library for tests that need real model metadata +# Only compile when cpp-httplib has SSL support (CPPHTTPLIB_OPENSSL_SUPPORT) +if (TARGET cpp-httplib) + get_target_property(_cpp_httplib_defs cpp-httplib INTERFACE_COMPILE_DEFINITIONS) + if (_cpp_httplib_defs MATCHES "CPPHTTPLIB_OPENSSL_SUPPORT") + add_library(gguf-model-data STATIC gguf-model-data.cpp) + target_link_libraries(gguf-model-data PRIVATE common cpp-httplib) + target_include_directories(gguf-model-data PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) + + add_executable(test-gguf-model-data test-gguf-model-data.cpp) + target_link_libraries(test-gguf-model-data PRIVATE gguf-model-data common) + llama_test(test-gguf-model-data LABEL "model") + endif() +endif() + # dummy executable - not installed get_filename_component(TEST_TARGET test-c.c NAME_WE) add_executable(${TEST_TARGET} test-c.c) diff --git a/tests/gguf-model-data.cpp b/tests/gguf-model-data.cpp new file mode 100644 index 0000000000..3bc82c88da --- /dev/null +++ b/tests/gguf-model-data.cpp @@ -0,0 +1,613 @@ +// GGUF binary parser adapted from the huggingface/gguf package. +// Reference: https://github.com/huggingface/huggingface.js + +#include "gguf-model-data.h" + +#include "common.h" +#include "gguf.h" + +#include +#include +#include +#include +#include + +#include "http.h" +#define JSON_ASSERT GGML_ASSERT +#include + +// Equivalent of RangeView +struct gguf_buf_reader { + const char * data; + size_t size; + size_t pos; + + gguf_buf_reader(const std::vector & buf) : data(buf.data()), size(buf.size()), pos(0) {} + + bool has_n_bytes(size_t n) const { + return pos + n <= size; + } + + template + bool read_val(T & out) { + if (!has_n_bytes(sizeof(T))) { + return false; + } + memcpy(&out, data + pos, sizeof(T)); + pos += sizeof(T); + return true; + } + + bool read_str(std::string & out) { + uint64_t len; + if (!read_val(len)) { + return false; + } + if (!has_n_bytes((size_t)len)) { + return false; + } + out.assign(data + pos, (size_t)len); + pos += (size_t)len; + return true; + } + + bool skip(size_t n) { + if (!has_n_bytes(n)) { + return false; + } + pos += n; + return true; + } +}; + +static size_t gguf_val_type_size(int32_t vtype) { + switch (vtype) { + case GGUF_TYPE_UINT8: return 1; + case GGUF_TYPE_INT8: return 1; + case GGUF_TYPE_UINT16: return 2; + case GGUF_TYPE_INT16: return 2; + case GGUF_TYPE_UINT32: return 4; + case GGUF_TYPE_INT32: return 4; + case GGUF_TYPE_FLOAT32: return 4; + case GGUF_TYPE_BOOL: return 1; + case GGUF_TYPE_UINT64: return 8; + case GGUF_TYPE_INT64: return 8; + case GGUF_TYPE_FLOAT64: return 8; + default: return 0; // string/array handled separately + } +} + +// Equivalent of readMetadataValue(), skips unused values rather than storing +static bool gguf_skip_value(gguf_buf_reader & r, int32_t vtype) { + if (vtype == GGUF_TYPE_STRING) { + std::string tmp; + return r.read_str(tmp); + } + if (vtype == GGUF_TYPE_ARRAY) { + int32_t elem_type; + uint64_t count; + if (!r.read_val(elem_type)) { + return false; + } + if (!r.read_val(count)) { + return false; + } + if (elem_type == GGUF_TYPE_STRING) { + for (uint64_t i = 0; i < count; i++) { + std::string tmp; + if (!r.read_str(tmp)) { + return false; + } + } + return true; + } + if (elem_type == GGUF_TYPE_ARRAY) { + // nested arrays - recurse + for (uint64_t i = 0; i < count; i++) { + if (!gguf_skip_value(r, GGUF_TYPE_ARRAY)) { + return false; + } + } + return true; + } + size_t elem_sz = gguf_val_type_size(elem_type); + if (elem_sz == 0) { + return false; + } + return r.skip((size_t)count * elem_sz); + } + size_t sz = gguf_val_type_size(vtype); + if (sz == 0) { + return false; + } + return r.skip(sz); +} + +static bool gguf_read_uint32_val(gguf_buf_reader & r, int32_t vtype, uint32_t & out) { + if (vtype == GGUF_TYPE_UINT8) { + uint8_t v; + if (!r.read_val(v)) { + return false; + } + out = v; + return true; + } + if (vtype == GGUF_TYPE_INT8) { + int8_t v; + if (!r.read_val(v)) { + return false; + } + out = (uint32_t)v; + return true; + } + if (vtype == GGUF_TYPE_UINT16) { + uint16_t v; + if (!r.read_val(v)) { + return false; + } + out = v; + return true; + } + if (vtype == GGUF_TYPE_INT16) { + int16_t v; + if (!r.read_val(v)) { + return false; + } + out = (uint32_t)v; + return true; + } + if (vtype == GGUF_TYPE_UINT32) { + uint32_t v; + if (!r.read_val(v)) { + return false; + } + out = v; + return true; + } + if (vtype == GGUF_TYPE_INT32) { + int32_t v; + if (!r.read_val(v)) { + return false; + } + out = (uint32_t)v; + return true; + } + if (vtype == GGUF_TYPE_UINT64) { + uint64_t v; + if (!r.read_val(v)) { + return false; + } + out = (uint32_t)v; + return true; + } + if (vtype == GGUF_TYPE_INT64) { + int64_t v; + if (!r.read_val(v)) { + return false; + } + out = (uint32_t)v; + return true; + } + return false; +} + +// Follows the same header -> KV -> tensor parsing sequence as gguf() huggingface/gguf +static std::optional gguf_parse_meta(const std::vector & buf) { + gguf_buf_reader r(buf); + + // Header: magic(4) + version(4) + tensor_count(8) + kv_count(8) = 24 bytes minimum + uint32_t magic_raw; + if (!r.read_val(magic_raw)) { + return std::nullopt; + } + if (memcmp(&magic_raw, "GGUF", 4) != 0) { + fprintf(stderr, "gguf_parse_meta: invalid magic\n"); + return std::nullopt; + } + + uint32_t version; + if (!r.read_val(version)) { + return std::nullopt; + } + if (version < 2 || version > 3) { + fprintf(stderr, "gguf_parse_meta: unsupported version %u\n", version); + return std::nullopt; + } + + int64_t tensor_count_raw; + int64_t kv_count_raw; + if (!r.read_val(tensor_count_raw)) { + return std::nullopt; + } + if (!r.read_val(kv_count_raw)) { + return std::nullopt; + } + + uint64_t tensor_count = (uint64_t)tensor_count_raw; + uint64_t kv_count = (uint64_t)kv_count_raw; + + gguf_remote_model model; + + std::string arch_prefix; + + // Parse KV pairs + for (uint64_t i = 0; i < kv_count; i++) { + std::string key; + if (!r.read_str(key)) { + return std::nullopt; + } + + int32_t vtype; + if (!r.read_val(vtype)) { + return std::nullopt; + } + + if (key == "general.architecture" && vtype == GGUF_TYPE_STRING) { + if (!r.read_str(model.architecture)) { + return std::nullopt; + } + arch_prefix = model.architecture + "."; + continue; + } + + // Extract split.count for proper handling of split files + if (key == "split.count") { + uint32_t v; + if (!gguf_read_uint32_val(r, vtype, v)) { + return std::nullopt; + } + model.n_split = (uint16_t)v; + continue; + } + + // Extract split.tensors.count so we can verify we have all tensors + if (key == "split.tensors.count") { + uint32_t v; + if (!gguf_read_uint32_val(r, vtype, v)) { + return std::nullopt; + } + model.n_split_tensors = v; + continue; + } + + if (!arch_prefix.empty()) { + uint32_t * target = nullptr; + + if (key == arch_prefix + "embedding_length") { target = &model.n_embd; } + else if (key == arch_prefix + "feed_forward_length") { target = &model.n_ff; } + else if (key == arch_prefix + "block_count") { target = &model.n_layer; } + else if (key == arch_prefix + "attention.head_count") { target = &model.n_head; } + else if (key == arch_prefix + "attention.head_count_kv") { target = &model.n_head_kv; } + else if (key == arch_prefix + "expert_count") { target = &model.n_expert; } + else if (key == arch_prefix + "attention.key_length") { target = &model.n_embd_head_k; } + else if (key == arch_prefix + "attention.value_length") { target = &model.n_embd_head_v; } + + if (target) { + if (!gguf_read_uint32_val(r, vtype, *target)) { + return std::nullopt; + } + continue; + } + } + + if (!gguf_skip_value(r, vtype)) { + return std::nullopt; + } + } + + // Parse tensor info entries + model.tensors.reserve((size_t)tensor_count); + for (uint64_t i = 0; i < tensor_count; i++) { + gguf_remote_tensor t; + + if (!r.read_str(t.name)) { + return std::nullopt; + } + if (!r.read_val(t.n_dims)) { + return std::nullopt; + } + + if (t.n_dims > 4) { + fprintf(stderr, "gguf_parse_meta: tensor '%s' has %u dims (max 4)\n", t.name.c_str(), t.n_dims); + return std::nullopt; + } + + for (uint32_t d = 0; d < t.n_dims; d++) { + if (!r.read_val(t.ne[d])) { + return std::nullopt; + } + } + + int32_t type_raw; + if (!r.read_val(type_raw)) { + return std::nullopt; + } + t.type = (ggml_type)type_raw; + + uint64_t offset; + if (!r.read_val(offset)) { + return std::nullopt; + } + + // Infer n_vocab from token_embd.weight + if (t.name == "token_embd.weight") { + model.n_vocab = (uint32_t)t.ne[1]; + } + + model.tensors.push_back(std::move(t)); + } + + return model; +} + +// cache handling for local download +static std::string get_default_cache_dir() { + return fs_get_cache_directory() + "gguf-headers/"; +} + +static std::string sanitize_for_path(const std::string & s) { + std::string out = s; + for (char & c : out) { + if (c == '/' || c == '\\' || c == ':') { + c = '_'; + } + } + return out; +} + +static bool read_file(const std::string & path, std::vector & out) { + std::ifstream f(path, std::ios::binary | std::ios::ate); + if (!f.good()) { + return false; + } + auto sz = f.tellg(); + if (sz <= 0) { + return false; + } + out.resize((size_t)sz); + f.seekg(0); + f.read(out.data(), sz); + return f.good(); +} + +static bool write_file(const std::string & path, const std::vector & data) { + std::ofstream f(path, std::ios::binary | std::ios::trunc); + if (!f.good()) { + return false; + } + f.write(data.data(), (std::streamsize)data.size()); + return f.good(); +} + +// HuggingFace file auto-detection and HTTP download +static std::pair> gguf_http_get( + const std::string & url, + const httplib::Headers & headers = {}, + int timeout_sec = 60) { + try { + auto [cli, parts] = common_http_client(url); + + if (timeout_sec > 0) { + cli.set_read_timeout(timeout_sec, 0); + cli.set_write_timeout(timeout_sec, 0); + } + cli.set_connection_timeout(30, 0); + + std::vector body; + auto res = cli.Get(parts.path, headers, + [&](const char * data, size_t len) { + body.insert(body.end(), data, data + len); + return true; + }, nullptr); + + if (!res) { + fprintf(stderr, "gguf_fetch: HTTP request failed for %s (error %d)\n", + url.c_str(), (int)res.error()); + return {-1, {}}; + } + return {res->status, std::move(body)}; + } catch (const std::exception & e) { + fprintf(stderr, "gguf_fetch: HTTP error: %s\n", e.what()); + return {-1, {}}; + } +} + +// Find the filename for given repo/quant. +// For split models, returns the first shard (the one containing "00001-of-") +// split_prefix is set to the portion before "-00001-of-XXXXX.gguf" when a split file is found +static std::string detect_gguf_filename(const std::string & repo, const std::string & quant, + std::string & split_prefix) { + split_prefix.clear(); + std::string api_url = "https://huggingface.co/api/models/" + repo; + + auto [code, body] = gguf_http_get(api_url, {}, 30); + if (code != 200 || body.empty()) { + fprintf(stderr, "gguf_fetch: failed to query HF API for %s (HTTP %ld)\n", repo.c_str(), code); + return ""; + } + + nlohmann::json j; + try { + j = nlohmann::json::parse(body.begin(), body.end()); + } catch (...) { + fprintf(stderr, "gguf_fetch: failed to parse HF API response\n"); + return ""; + } + + if (!j.contains("siblings") || !j["siblings"].is_array()) { + fprintf(stderr, "gguf_fetch: unexpected HF API response format\n"); + return ""; + } + + std::vector matches; + std::string quant_upper = quant; + for (char & c : quant_upper) { c = (char)toupper(c); } + + for (const auto & sibling : j["siblings"]) { + if (!sibling.contains("rfilename")) { continue; } + std::string fname = sibling["rfilename"].get(); + if (fname.size() < 5 || fname.substr(fname.size() - 5) != ".gguf") { + continue; + } + + std::string fname_upper = fname; + for (char & c : fname_upper) { c = (char)toupper(c); } + if (fname_upper.find(quant_upper) != std::string::npos) { + matches.push_back(fname); + } + } + + if (matches.empty()) { + fprintf(stderr, "gguf_fetch: no .gguf files matching '%s' in %s\n", quant.c_str(), repo.c_str()); + return ""; + } + + std::sort(matches.begin(), matches.end()); + + // Prefer non-split, non-supplementary file + for (const auto & m : matches) { + if (m.find("-of-") == std::string::npos && m.find("mmproj") == std::string::npos) { + return m; + } + } + + // Return the first shard (00001-of-) and extract the prefix + for (const auto & m : matches) { + auto pos = m.find("-00001-of-"); + if (pos != std::string::npos) { + split_prefix = m.substr(0, pos); + return m; + } + } + + return matches[0]; +} + +static std::optional fetch_and_parse( + const std::string & repo, + const std::string & filename, + const std::string & cache_path) { + std::string url = "https://huggingface.co/" + repo + "/resolve/main/" + filename; + + // Progressive download inspired by RangeView.fetchChunk() + // Start at 2MB, double each time, cap at 64MB + size_t chunk_size = 2 * 1024 * 1024; + const size_t max_chunk = 64 * 1024 * 1024; + + while (chunk_size <= max_chunk) { + fprintf(stderr, "gguf_fetch: downloading %zu bytes from %s\n", chunk_size, filename.c_str()); + + char range_buf[64]; + snprintf(range_buf, sizeof(range_buf), "bytes=0-%zu", chunk_size - 1); + httplib::Headers headers = {{"Range", range_buf}}; + + auto [code, body] = gguf_http_get(url, headers, 120); + if (code != 200 && code != 206) { + fprintf(stderr, "gguf_fetch: HTTP %ld fetching %s\n", code, url.c_str()); + return std::nullopt; + } + + if (body.empty()) { + fprintf(stderr, "gguf_fetch: empty response\n"); + return std::nullopt; + } + + auto result = gguf_parse_meta(body); + if (result.has_value()) { + write_file(cache_path, body); + return result; + } + + if (code == 200) { + fprintf(stderr, "gguf_fetch: server returned full response but metadata parse failed\n"); + return std::nullopt; + } + + // Parse failed, try larger chunk + chunk_size *= 2; + } + + fprintf(stderr, "gguf_fetch: metadata exceeds 64MB, giving up\n"); + return std::nullopt; +} + +// Try cache first, then fetch and parse a single GGUF shard. +static std::optional fetch_or_cached( + const std::string & repo, + const std::string & filename, + const std::string & cdir, + const std::string & repo_part) { + std::string fname_part = sanitize_for_path(filename); + std::string cache_path = cdir + "/" + repo_part + "--" + fname_part + ".partial"; + + { + std::vector cached; + if (std::filesystem::exists(cache_path) && read_file(cache_path, cached)) { + auto result = gguf_parse_meta(cached); + if (result.has_value()) { + fprintf(stderr, "gguf_fetch: loaded from cache: %s\n", cache_path.c_str()); + return result; + } + } + } + + fs_create_directory_with_parents(cdir); + return fetch_and_parse(repo, filename, cache_path); +} + +std::optional gguf_fetch_model_meta( + const std::string & repo, + const std::string & quant, + const std::string & cache_dir) { + std::string cdir = cache_dir.empty() ? get_default_cache_dir() : cache_dir; + std::string repo_part = sanitize_for_path(repo); + + std::string split_prefix; + std::string filename = detect_gguf_filename(repo, quant, split_prefix); + if (filename.empty()) { + return std::nullopt; + } + + auto model_opt = fetch_or_cached(repo, filename, cdir, repo_part); + if (!model_opt.has_value()) { + fprintf(stderr, "gguf_fetch: failed to fetch %s\n", filename.c_str()); + return std::nullopt; + } + + auto & model = model_opt.value(); + + // If the model is split across multiple files we need to fetch the remaining shards metadata + if (model.n_split > 1) { + if (split_prefix.empty()) { + fprintf(stderr, "gguf_fetch: model reports %u splits but filename has no split pattern\n", model.n_split); + return std::nullopt; + } + + fprintf(stderr, "gguf_fetch: split model with %u shards, fetching remaining %u...\n", + model.n_split, model.n_split - 1); + + for (int i = 2; i <= model.n_split; i++) { + char num_buf[6], total_buf[6]; + snprintf(num_buf, sizeof(num_buf), "%05d", i); + snprintf(total_buf, sizeof(total_buf), "%05d", (int)model.n_split); + std::string shard_name = split_prefix + "-" + num_buf + "-of-" + total_buf + ".gguf"; + + auto shard = fetch_or_cached(repo, shard_name, cdir, repo_part); + if (!shard.has_value()) { + fprintf(stderr, "gguf_fetch: failed to fetch shard %d: %s\n", i, shard_name.c_str()); + return std::nullopt; + } + + model.tensors.insert(model.tensors.end(), + std::make_move_iterator(shard->tensors.begin()), + std::make_move_iterator(shard->tensors.end())); + } + + if (model.n_split_tensors > 0 && model.tensors.size() != model.n_split_tensors) { + fprintf(stderr, "gguf_fetch: WARNING: expected %u tensors from split.tensors.count, got %zu\n", + model.n_split_tensors, model.tensors.size()); + } + } + + return model_opt; +} diff --git a/tests/gguf-model-data.h b/tests/gguf-model-data.h new file mode 100644 index 0000000000..ed433791ad --- /dev/null +++ b/tests/gguf-model-data.h @@ -0,0 +1,42 @@ +#pragma once + +#include "ggml.h" + +#include +#include +#include +#include + +struct gguf_remote_tensor { + std::string name; + ggml_type type = GGML_TYPE_F32; + int64_t ne[4] = {1, 1, 1, 1}; // dimensions, unused dims = 1 + uint32_t n_dims = 0; +}; + +struct gguf_remote_model { + // Selected KV metadata + std::string architecture; // general.architecture + uint32_t n_embd = 0; // .embedding_length + uint32_t n_ff = 0; // .feed_forward_length + uint32_t n_vocab = 0; // inferred from token_embd.weight ne[1] + uint32_t n_layer = 0; // .block_count + uint32_t n_head = 0; // .attention.head_count + uint32_t n_head_kv = 0; // .attention.head_count_kv + uint32_t n_expert = 0; // .expert_count (0 if absent) + uint32_t n_embd_head_k = 0; // .attention.key_length + uint32_t n_embd_head_v = 0; // .attention.value_length + uint16_t n_split = 0; // split.count (0 = not split) + uint32_t n_split_tensors = 0; // split.tensors.count (0 if not split) + + std::vector tensors; +}; + +// Fetch model metadata from HuggingFace with local caching. +// repo: e.g., "ggml-org/Qwen3-32B-GGUF" +// quant: e.g., "Q8_0" -- auto-detects filename (including first shard of split models) +// Returns nullopt if download fails or network is unavailable. +std::optional gguf_fetch_model_meta( + const std::string & repo, + const std::string & quant = "Q8_0", + const std::string & cache_dir = ""); // empty = default diff --git a/tests/test-backend-sampler.cpp b/tests/test-backend-sampler.cpp index c10bde91b6..d4cd62c71e 100644 --- a/tests/test-backend-sampler.cpp +++ b/tests/test-backend-sampler.cpp @@ -361,7 +361,7 @@ static void test_backend_temp_sampling(const test_params & params) { GGML_ASSERT(false && "Failed to decode token"); } - // Verfify sequence 0 + // Verify sequence 0 { int32_t batch_idx = test_ctx.idx_for_seq(0); int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx); @@ -379,7 +379,7 @@ static void test_backend_temp_sampling(const test_params & params) { } - // Verfify sequence 1 + // Verify sequence 1 { int32_t batch_idx = test_ctx.idx_for_seq(1); @@ -395,7 +395,7 @@ static void test_backend_temp_sampling(const test_params & params) { } } - // lambda to testing non-positive temperature values. + // lambda for testing non-positive temperature values. auto test_argmax_temp = [&](float temp) { printf("\nTesting temperature = %.1f\n", temp); @@ -454,7 +454,7 @@ static void test_backend_temp_ext_sampling(const test_params & params) { } } - // lambda to testing non-positive temp/delta/exponent values. + // lambda for testing non-positive temp/delta/exponent values. auto test_argmax_temp = [&](float temp, float delta, float exponent) { printf("\nTesting temperature = %.1f, delta = %1.f, exponent = %1.f\n", temp, delta, exponent); @@ -530,7 +530,7 @@ static void test_backend_min_p_sampling(const test_params & params) { printf("min-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str()); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); - // Decode and sampler 10 more tokens + // Decode and sample 10 more tokens for (int i = 0; i < 10; i++) { int32_t loop_idx = test_ctx.idx_for_seq(seq_id); llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), loop_idx); @@ -582,7 +582,7 @@ static void test_backend_top_p_sampling(const test_params & params) { printf("top-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str()); GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); - // Decode and sampler 10 more tokens + // Decode and sample 10 more tokens for (int i = 0; i < 10; i++) { int32_t loop_idx = test_ctx.idx_for_seq(seq_id); llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), loop_idx); @@ -619,7 +619,7 @@ static void test_backend_multi_sequence_sampling(const test_params & params) { GGML_ASSERT(false && "Failed to decode token"); } - // Verfiy sequence 0 + // Verify sequence 0 { int32_t batch_idx = test_ctx.idx_for_seq(0); llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx); @@ -763,7 +763,7 @@ static void test_backend_logit_bias_sampling(const test_params & params) { printf("backend logit bias sampling test PASSED\n"); } -// This test verifies that it is possible to have two different backend sampler, +// This test verifies that it is possible to have two different backend samplers, // one that uses the backend dist sampler, and another that uses CPU dist sampler. static void test_backend_mixed_sampling(const test_params & params) { struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params(); @@ -791,7 +791,7 @@ static void test_backend_mixed_sampling(const test_params & params) { GGML_ASSERT(false && "Failed to decode token"); } - // Verfiy sequence 0 that used the dist backend sampler. + // Verify sequence 0 that used the dist backend sampler. { int32_t batch_idx = test_ctx.idx_for_seq(0); llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx); @@ -802,7 +802,7 @@ static void test_backend_mixed_sampling(const test_params & params) { //GGML_ASSERT(llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx) == 0); } - // Verfiy sequence 1 that used the top-k backend sampler. + // Verify sequence 1 that used the top-k backend sampler. { int32_t batch_idx = test_ctx.idx_for_seq(1); float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx); @@ -934,7 +934,7 @@ static void test_backend_cpu_mixed_batch(const test_params & params) { // samplers. llama_set_sampler(test_ctx.ctx.get(), 0, nullptr); - // Create a CPU sampler and verify we can sampler from it. + // Create a CPU sampler and verify we can sample from it. struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); llama_sampler_chain_add(chain.get(), llama_sampler_init_greedy()); diff --git a/tests/test-chat.cpp b/tests/test-chat.cpp index 1bef5b9f44..f3d19118b5 100644 --- a/tests/test-chat.cpp +++ b/tests/test-chat.cpp @@ -229,6 +229,20 @@ common_chat_tool python_tool { "required": ["code"] })", }; +common_chat_tool todo_list_tool { + /* .name = */ "todo_list", + /* .description = */ "Create or update the todo list", + /* .parameters = */ R"({ + "type": "object", + "properties": { + "todos": { + "type": "array", + "description": "List of TODO list items" + } + }, + "required": ["todos"] + })", +}; common_chat_tool code_interpreter_tool { /* .name = */ "code_interpreter", /* .description = */ "an ipython interpreter", @@ -3018,542 +3032,6 @@ Hey there!<|im_end|> ); } - // Test Qwen3-Coder XML format - { - // Basic XML tool call parsing - assert_msg_equals( - message_assist_call, - test_chat_parse( - "\n" - " \n" - " \n" - " 1\n" - " \n" - " \n" - "", - /* is_partial= */ false, - {COMMON_CHAT_FORMAT_QWEN3_CODER_XML})); - - // Multiple parameters with different types - common_chat_msg expected_multi_param; - expected_multi_param.role = "assistant"; - expected_multi_param.tool_calls = { - { "complex_function", "{\"name\":\"John Doe\",\"age\":30,\"active\":true,\"score\":95.5}", "" } - }; - - test_parser_with_streaming(expected_multi_param, - "\n" - " \n" - " \n" - " John Doe\n" - " \n" - " \n" - " 30\n" - " \n" - " \n" - " true\n" - " \n" - " \n" - " 95.5\n" - " \n" - " \n" - "", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); - - // Special characters and Unicode - common_chat_msg expected_special_chars; - expected_special_chars.role = "assistant"; - expected_special_chars.tool_calls = { - { "unicode_function", "{\"message\":\"Hello 世界! 🌍 Special chars: @#$%^&*()\"}", "" } - }; - - test_parser_with_streaming(expected_special_chars, - "\n" - " \n" - " \n" - " Hello 世界! 🌍 Special chars: @#$%^&*()\n" - " \n" - " \n" - "", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); - - // Multiline content with newlines and indentation - common_chat_msg expected_multiline; - expected_multiline.role = "assistant"; - expected_multiline.tool_calls = { - { "code_function", "{\"code\":\"def hello():\\n print(\\\"Hello, World!\\\")\\n return True\"}", "" } - }; - - test_parser_with_streaming(expected_multiline, - "\n" - " \n" - " \n" - "def hello():\n" - " print(\"Hello, World!\")\n" - " return True\n" - " \n" - " \n" - "", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); - - // JSON object as parameter value - common_chat_msg expected_json_param; - expected_json_param.role = "assistant"; - expected_json_param.tool_calls = { - { "json_function", "{\"config\":{\"host\":\"localhost\",\"port\":8080,\"ssl\":false}}", "" } - }; - - test_parser_with_streaming( - expected_json_param, - "\n" - " \n" - " \n" - " {\"host\": \"localhost\", \"port\": 8080, \"ssl\": false}\n" - " \n" - " \n" - "", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); - - // Array as parameter value - common_chat_msg expected_array_param; - expected_array_param.role = "assistant"; - expected_array_param.tool_calls = { - { "array_function", "{\"items\":[\"apple\",\"banana\",\"cherry\"]}", "" } - }; - - test_parser_with_streaming( - expected_array_param, - "\n" - " \n" - " \n" - " [\"apple\", \"banana\", \"cherry\"]\n" - " \n" - " \n" - "", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); - - // Empty parameter - common_chat_msg expected_empty_param; - expected_empty_param.role = "assistant"; - expected_empty_param.tool_calls = { - { "empty_function", "{\"empty_param\":\"\"}", "" } - }; - - test_parser_with_streaming( - expected_empty_param, - "\n" - " \n" - " \n" - " \n" - " \n" - "", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); - - // Boolean values (true/false) - common_chat_msg expected_boolean; - expected_boolean.role = "assistant"; - expected_boolean.tool_calls = { - { "boolean_function", "{\"enabled\":true,\"debug\":false}", "" } - }; - - test_parser_with_streaming( - expected_boolean, - "\n" - " \n" - " \n" - " true\n" - " \n" - " \n" - " false\n" - " \n" - " \n" - "", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); - - // Null value - common_chat_msg expected_null; - expected_null.role = "assistant"; - expected_null.tool_calls = { - { "null_function", "{\"optional_param\":null}", "" } - }; - - test_parser_with_streaming( - expected_null, - "\n" - " \n" - " \n" - " null\n" - " \n" - " \n" - "", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); - - // Negative numbers and scientific notation - common_chat_msg expected_numbers; - expected_numbers.role = "assistant"; - expected_numbers.tool_calls = { - { "math_function", "{\"negative\":-42,\"decimal\":-3.14,\"scientific\":1.23e-4}", "" } - }; - - test_parser_with_streaming( - expected_numbers, - "\n" - " \n" - " \n" - " -42\n" - " \n" - " \n" - " -3.14\n" - " \n" - " \n" - " 1.23e-4\n" - " \n" - " \n" - "", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); - - // XML-like content in parameters (should be escaped) - common_chat_msg expected_xml_content; - expected_xml_content.role = "assistant"; - expected_xml_content.tool_calls = { - { "xml_function", "{\"xml_content\":\"value\"}", "" } - }; - - test_parser_with_streaming( - expected_xml_content, - "\n" - " \n" - " \n" - " value\n" - " \n" - " \n" - "", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); - - // Quotes and escape characters - common_chat_msg expected_quotes; - expected_quotes.role = "assistant"; - expected_quotes.tool_calls = { - { "quote_function", "{\"message\":\"She said \\\"Hello!\\\" and left.\"}", "" } - }; - - test_parser_with_streaming( - expected_quotes, - "\n" - " \n" - " \n" - " She said \"Hello!\" and left.\n" - " \n" - " \n" - "", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); - - // Long parameter value (simplified) - std::string long_text = "This is a long text parameter that should test the parser's ability to handle larger amounts of text data."; - - common_chat_msg expected_long_text; - expected_long_text.role = "assistant"; - expected_long_text.tool_calls = { - { "long_function", "{\"long_text\":\"" + long_text + "\"}", "" } - }; - - test_parser_with_streaming( - expected_long_text, - "\n" - " \n" - " \n" - " " + long_text + "\n" - " \n" - " \n" - "", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); - - // Mixed content with text before and after tool call - common_chat_msg expected_mixed_content; - expected_mixed_content.role = "assistant"; - expected_mixed_content.content = "I'll help you search for products. "; - expected_mixed_content.tool_calls = { - { "search_function", "{\"query\":\"laptops\"}", "" } - }; - - test_parser_with_streaming( - expected_mixed_content, - "I'll help you search for products. \n" - " \n" - " \n" - " laptops\n" - " \n" - " \n" - "", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); - - // Compact format (no extra whitespace) - common_chat_msg expected_compact; - expected_compact.role = "assistant"; - expected_compact.tool_calls = { - { "compact_function", "{\"param\":\"value\"}", "" } - }; - - test_parser_with_streaming( - expected_compact, - "value", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); - - // Function name with underscores and numbers - common_chat_msg expected_complex_name; - expected_complex_name.role = "assistant"; - expected_complex_name.tool_calls = { - { "get_user_data_v2", "{\"user_id\":12345}", "" } - }; - - test_parser_with_streaming( - expected_complex_name, - "\n" - " \n" - " \n" - " 12345\n" - " \n" - " \n" - "", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); - - // Parameter names with underscores and numbers - common_chat_msg expected_complex_params; - expected_complex_params.role = "assistant"; - expected_complex_params.tool_calls = { - { "test_function", "{\"param_1\":\"value1\",\"param_2_name\":\"value2\",\"param3\":123}", "" } - }; - - test_parser_with_streaming( - expected_complex_params, - "\n" - " \n" - " \n" - " value1\n" - " \n" - " \n" - " value2\n" - " \n" - " \n" - " 123\n" - " \n" - " \n" - "", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); - - // Very deeply nested XML content in parameter - common_chat_msg expected_deep_xml; - expected_deep_xml.role = "assistant"; - expected_deep_xml.tool_calls = { - { "xml_parser", "{\"xml\":\"deep content\"}", "" } - }; - - test_parser_with_streaming( - expected_deep_xml, - "\n" - " \n" - " \n" - " deep content\n" - " \n" - " \n" - "", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); - - // Parameter with only whitespace - common_chat_msg expected_whitespace_param; - expected_whitespace_param.role = "assistant"; - expected_whitespace_param.tool_calls = { - { "whitespace_function", "{\"spaces\":\"\"}", "" } - }; - - test_parser_with_streaming( - expected_whitespace_param, - "\n" - " \n" - " \n" - " \n" - " \n" - " \n" - "", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); - - // Parameter with tabs and mixed whitespace - common_chat_msg expected_mixed_whitespace; - expected_mixed_whitespace.role = "assistant"; - expected_mixed_whitespace.tool_calls = { - { "tab_function", "{\"content\":\"line1\\n\\tindented line\\n spaces\"}", "" } - }; - - test_parser_with_streaming( - expected_mixed_whitespace, - "\n" - " \n" - " \n" - "line1\n" - "\tindented line\n" - " spaces\n" - " \n" - " \n" - "", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); - - // Control characters and special Unicode - common_chat_msg expected_control_chars; - expected_control_chars.role = "assistant"; - expected_control_chars.tool_calls = { - { "control_function", "{\"text\":\"Line1\\nLine2\\tTabbed\\rCarriage return\"}", "" } - }; - - test_parser_with_streaming( - expected_control_chars, - "\n" - " \n" - " \n" - "Line1\nLine2\tTabbed\rCarriage return\n" - " \n" - " \n" - "", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); - - // Emoji and extended Unicode characters - common_chat_msg expected_emoji; - expected_emoji.role = "assistant"; - expected_emoji.tool_calls = { - { "emoji_function", "{\"message\":\"Hello! 👋 🌟 🚀 Testing emojis: 😀😃😄😁 and symbols: ∑∏∆∇\"}", "" } - }; - - test_parser_with_streaming( - expected_emoji, - "\n" - " \n" - " \n" - " Hello! 👋 🌟 🚀 Testing emojis: 😀😃😄😁 and symbols: ∑∏∆∇\n" - " \n" - " \n" - "", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); - - // Mathematical expressions and formulas - common_chat_msg expected_math; - expected_math.role = "assistant"; - expected_math.tool_calls = { - { "math_function", "{\"formula\":\"E = mc² and ∫f(x)dx = F(x) + C\"}", "" } - }; - - test_parser_with_streaming( - expected_math, - "\n" - " \n" - " \n" - " E = mc² and ∫f(x)dx = F(x) + C\n" - " \n" - " \n" - "", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); - - // SQL injection-like content (should be safely escaped) - common_chat_msg expected_sql; - expected_sql.role = "assistant"; - expected_sql.tool_calls = { - { "sql_function", "{\"query\":\"SELECT * FROM users WHERE id = 1; DROP TABLE users; --\"}", "" } - }; - - test_parser_with_streaming( - expected_sql, - "\n" - " \n" - " \n" - " SELECT * FROM users WHERE id = 1; DROP TABLE users; --\n" - " \n" - " \n" - "", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); - - // HTML/XML injection content - common_chat_msg expected_html; - expected_html.role = "assistant"; - expected_html.tool_calls = { - { "html_function", "{\"content\":\"\"}", "" } - }; - - test_parser_with_streaming( - expected_html, - "\n" - " \n" - " \n" - " \n" - " \n" - " \n" - "", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); - - // Binary-like content (base64) - common_chat_msg expected_binary; - expected_binary.role = "assistant"; - expected_binary.tool_calls = { - { "binary_function", "{\"data\":\"SGVsbG8gV29ybGQhIFRoaXMgaXMgYmFzZTY0IGVuY29kZWQgdGV4dC4=\"}", "" } - }; - - test_parser_with_streaming( - expected_binary, - "\n" - " \n" - " \n" - " SGVsbG8gV29ybGQhIFRoaXMgaXMgYmFzZTY0IGVuY29kZWQgdGV4dC4=\n" - " \n" - " \n" - "", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); - - // Very large numbers (should be parsed as scientific notation) - common_chat_msg expected_large_numbers; - expected_large_numbers.role = "assistant"; - expected_large_numbers.tool_calls = { - { "number_function", "{\"big_int\":1e+60}", "" } // Large number becomes scientific notation - }; - - test_parser_with_streaming( - expected_large_numbers, - "\n" - " \n" - " \n" - " 999999999999999999999999999999999999999999999999999999999999\n" - " \n" - " \n" - "", - [&](const std::string &msg) { return test_chat_parse(msg, /* is_partial= */ true, {COMMON_CHAT_FORMAT_QWEN3_CODER_XML}); }); - } - - { - // Qwen3-Coder template - auto tmpls = read_templates("models/templates/Qwen3-Coder.jinja"); - common_chat_templates_inputs inputs; - inputs.messages = { message_user }; - - common_chat_tool qwen_union_tool { - /* .name = */ "qwen_union", - /* .description = */ "Test tool for union/anyOf handling", - /* .parameters = */ R"({ - "type": "object", - "properties": { - "priority": { "type": ["number", "null"] }, - "maybe_text": { "anyOf": [ { "type": "string" } ] }, - "config": { "anyOf": [ { "type": "object" }, { "type": "null" } ] } - }, - "required": [] - })", - }; - inputs.tools = { qwen_union_tool }; - - auto params = common_chat_templates_apply(tmpls.get(), inputs); - assert_equals(COMMON_CHAT_FORMAT_QWEN3_CODER_XML, params.format); - assert_equals(false, params.grammar.empty()); - - // Grammar should compile successfully - auto grammar = build_grammar(params.grammar); - GGML_ASSERT(grammar && "Failed to build Qwen3-Coder grammar with union types"); - } - { // Step-3.5-Flash template: uses same XML output format as Qwen3-Coder and Nemotron v3, // but with support. Routes to the Nemotron v3 PEG parser for streaming and @@ -3665,6 +3143,135 @@ static void test_template_output_peg_parsers() { }); } + { + // Qwen3-Coder + auto tmpls = read_templates("models/templates/Qwen3-Coder.jinja"); + + // Test basic message + test_peg_parser(tmpls.get(), [&](auto & t) { + t.input = "Hello, world!\nWhat's up?"; + t.expect = message_assist; + }); + + // Test tool call + test_peg_parser(tmpls.get(), [&](auto & t) { + t.input = + "\n" + "\n" + "\n" + "1\n" + "\n" + "\n" + ""; + t.params.tools = {special_function_tool}; + t.expect = message_assist_call; + }); + + // Test parallel tool calls + test_peg_parser(tmpls.get(), [&](auto & t) { + t.input = + "\n" + "\n" + "\n" + "1\n" + "\n" + "\n" + "\n" + "\n" + "\n" + "\n" + "1\n" + "\n" + "\n" + "2\n" + "\n" + "\n" + ""; + t.params.parallel_tool_calls = true; + t.params.tools = {special_function_tool, special_function_tool_with_optional_param}; + + t.expect.tool_calls = {{ + /* .name = */ "special_function", + /* .arguments = */ R"({"arg1": 1})", + /* .id = */ {}, + }, { + /* .name = */ "special_function_with_opt", + /* .arguments = */ R"({"arg1": 1, "arg2": 2})", + /* .id = */ {}, + }}; + }); + + // Test tool call with string parameter + test_peg_parser(tmpls.get(), [&](auto & t) { + t.input = + "\n" + "\n" + "\n" + "def hello():\n" + " print(\"Hello, world!\")\n" + "\n" + "hello()\n" + "\n" + "\n" + ""; + t.params.tools = {python_tool}; + + t.expect.tool_calls = {{ + /* .name = */ "python", + /* .arguments = */ "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}", + /* .id = */ {}, + }}; + }); + + // Test tool call with JSON parameter + test_peg_parser(tmpls.get(), [&](auto & t) { + t.input = + "\n" + "\n" + "\n" + "[{\"item\": \"Check stuff\", \"selected\": false}, {\"item\": \"Prepare stuff\", \"selected\": true}]\n" + "\n" + "\n" + ""; + t.params.tools = {todo_list_tool}; + + t.expect.tool_calls = {{ + /* .name = */ "todo_list", + /* .arguments = */ "{\"todos\": [{\"item\": \"Check stuff\", \"selected\": false}, {\"item\": \"Prepare stuff\", \"selected\": true}]}", + /* .id = */ {}, + }}; + }); + + // Test tool call with string parameter and no closing tag + test_peg_parser(tmpls.get(), [&](auto & t) { + t.input = + "\n" + "\n" + "\n" + "def hello():\n" + " print(\"Hello, world!\")\n" + "\n" + "hello()\n" + "\n" + ""; + t.params.tools = {python_tool}; + + t.expect.tool_calls = {{ + /* .name = */ "python", + /* .arguments = */ "{\"code\": \"def hello():\\n print(\\\"Hello, world!\\\")\\n\\nhello()\"}", + /* .id = */ {}, + }}; + }); + + // Test response format + test_peg_parser(tmpls.get(), [&](auto & t) { + t.input = R"({"amount": 123.45, "date": "2025-12-03"})"; + t.params.json_schema = invoice_schema; + + t.expect.content = R"({"amount": 123.45, "date": "2025-12-03"})"; + }); + } + { // NVIDIA Nemotron-3 Nano auto tmpls = read_templates("models/templates/NVIDIA-Nemotron-3-Nano-30B-A3B-BF16.jinja"); diff --git a/tests/test-gguf-model-data.cpp b/tests/test-gguf-model-data.cpp new file mode 100644 index 0000000000..cc0174961d --- /dev/null +++ b/tests/test-gguf-model-data.cpp @@ -0,0 +1,121 @@ +#include "gguf-model-data.h" + +#include + +#define TEST_ASSERT(cond, msg) \ + do { \ + if (!(cond)) { \ + fprintf(stderr, "FAIL: %s (line %d): %s\n", #cond, __LINE__, msg); \ + return 1; \ + } \ + } while (0) + +int main() { + fprintf(stderr, "=== test-gguf-model-data ===\n"); + + // Fetch Qwen3-0.6B Q8_0 metadata + auto result = gguf_fetch_model_meta("ggml-org/Qwen3-0.6B-GGUF", "Q8_0"); + + if (!result.has_value()) { + fprintf(stderr, "SKIP: could not fetch model metadata (no network or HTTP disabled)\n"); + return 0; + } + + const auto & model = result.value(); + + fprintf(stderr, "Architecture: %s\n", model.architecture.c_str()); + fprintf(stderr, "n_embd: %u\n", model.n_embd); + fprintf(stderr, "n_ff: %u\n", model.n_ff); + fprintf(stderr, "n_vocab: %u\n", model.n_vocab); + fprintf(stderr, "n_layer: %u\n", model.n_layer); + fprintf(stderr, "n_head: %u\n", model.n_head); + fprintf(stderr, "n_head_kv: %u\n", model.n_head_kv); + fprintf(stderr, "n_expert: %u\n", model.n_expert); + fprintf(stderr, "n_embd_head_k: %u\n", model.n_embd_head_k); + fprintf(stderr, "n_embd_head_v: %u\n", model.n_embd_head_v); + fprintf(stderr, "tensors: %zu\n", model.tensors.size()); + + // Verify architecture + TEST_ASSERT(model.architecture == "qwen3", "expected architecture 'qwen3'"); + + // Verify key dimensions (Qwen3-0.6B) + TEST_ASSERT(model.n_layer == 28, "expected n_layer == 28"); + TEST_ASSERT(model.n_embd == 1024, "expected n_embd == 1024"); + TEST_ASSERT(model.n_head == 16, "expected n_head == 16"); + TEST_ASSERT(model.n_head_kv == 8, "expected n_head_kv == 8"); + TEST_ASSERT(model.n_expert == 0, "expected n_expert == 0 (not MoE)"); + TEST_ASSERT(model.n_vocab == 151936, "expected n_vocab == 151936"); + + // Verify tensor count + TEST_ASSERT(model.tensors.size() == 311, "expected tensor count == 311"); + + // Verify known tensor names exist + bool found_attn_q = false; + bool found_token_embd = false; + bool found_output_norm = false; + for (const auto & t : model.tensors) { + if (t.name == "blk.0.attn_q.weight") { + found_attn_q = true; + } + if (t.name == "token_embd.weight") { + found_token_embd = true; + } + if (t.name == "output_norm.weight") { + found_output_norm = true; + } + } + TEST_ASSERT(found_attn_q, "expected tensor 'blk.0.attn_q.weight'"); + TEST_ASSERT(found_token_embd, "expected tensor 'token_embd.weight'"); + TEST_ASSERT(found_output_norm, "expected tensor 'output_norm.weight'"); + + // Verify token_embd.weight shape + for (const auto & t : model.tensors) { + if (t.name == "token_embd.weight") { + TEST_ASSERT(t.ne[0] == 1024, "expected token_embd.weight ne[0] == 1024"); + TEST_ASSERT(t.n_dims == 2, "expected token_embd.weight to be 2D"); + break; + } + } + + // Test that second call uses cache (just call again, it should work) + auto result2 = gguf_fetch_model_meta("ggml-org/Qwen3-0.6B-GGUF", "Q8_0"); + TEST_ASSERT(result2.has_value(), "cached fetch should succeed"); + TEST_ASSERT(result2->tensors.size() == model.tensors.size(), "cached result should match"); + + // Test a split MoE model without specifying quant (should default to Q8_0) + auto result3 = gguf_fetch_model_meta("ggml-org/GLM-4.6V-GGUF"); + if (!result3.has_value()) { + fprintf(stderr, "SKIP: could not fetch GLM-4.6V metadata (no network?)\n"); + return 0; + } + const auto & model3 = result3.value(); + + fprintf(stderr, "Architecture: %s\n", model3.architecture.c_str()); + fprintf(stderr, "n_embd: %u\n", model3.n_embd); + fprintf(stderr, "n_ff: %u\n", model3.n_ff); + fprintf(stderr, "n_vocab: %u\n", model3.n_vocab); + fprintf(stderr, "n_layer: %u\n", model3.n_layer); + fprintf(stderr, "n_head: %u\n", model3.n_head); + fprintf(stderr, "n_head_kv: %u\n", model3.n_head_kv); + fprintf(stderr, "n_expert: %u\n", model3.n_expert); + fprintf(stderr, "n_embd_head_k: %u\n", model3.n_embd_head_k); + fprintf(stderr, "n_embd_head_v: %u\n", model3.n_embd_head_v); + fprintf(stderr, "tensors: %zu\n", model3.tensors.size()); + + // Verify architecture + TEST_ASSERT(model3.architecture == "glm4moe", "expected architecture 'glm4moe'"); + + // Verify key dimensions (GLM-4.6V) + TEST_ASSERT(model3.n_layer == 46, "expected n_layer == 46"); + TEST_ASSERT(model3.n_embd == 4096, "expected n_embd == 4096"); + TEST_ASSERT(model3.n_head == 96, "expected n_head == 96"); + TEST_ASSERT(model3.n_head_kv == 8, "expected n_head_kv == 8"); + TEST_ASSERT(model3.n_expert == 128, "expected n_expert == 128 (MoE)"); + TEST_ASSERT(model3.n_vocab == 151552, "expected n_vocab == 151552"); + + // Verify tensor count + TEST_ASSERT(model3.tensors.size() == 780, "expected tensor count == 780"); + + fprintf(stderr, "=== ALL TESTS PASSED ===\n"); + return 0; +} diff --git a/tests/test-gguf.cpp b/tests/test-gguf.cpp index 84b7f3bc49..8ebd16ba82 100644 --- a/tests/test-gguf.cpp +++ b/tests/test-gguf.cpp @@ -48,6 +48,7 @@ enum handcrafted_file_type { HANDCRAFTED_DATA_NOT_ENOUGH_DATA = 10 + offset_has_data, HANDCRAFTED_DATA_BAD_ALIGN = 15 + offset_has_data, HANDCRAFTED_DATA_INCONSISTENT_ALIGN = 20 + offset_has_data, + HANDCRAFTED_DATA_MEM_SIZE_OVERFLOW = 30 + offset_has_data, HANDCRAFTED_DATA_SUCCESS = 800 + offset_has_data, HANDCRAFTED_DATA_CUSTOM_ALIGN = 810 + offset_has_data, }; @@ -84,6 +85,7 @@ static std::string handcrafted_file_type_name(const enum handcrafted_file_type h case HANDCRAFTED_DATA_NOT_ENOUGH_DATA: return "DATA_NOT_ENOUGH_DATA"; case HANDCRAFTED_DATA_BAD_ALIGN: return "DATA_BAD_ALIGN"; case HANDCRAFTED_DATA_INCONSISTENT_ALIGN: return "DATA_INCONSISTENT_ALIGN"; + case HANDCRAFTED_DATA_MEM_SIZE_OVERFLOW: return "DATA_MEM_SIZE_OVERFLOW"; case HANDCRAFTED_DATA_SUCCESS: return "DATA_SUCCESS"; case HANDCRAFTED_DATA_CUSTOM_ALIGN: return "DATA_CUSTOM_ALIGN"; } @@ -196,6 +198,13 @@ static FILE * get_handcrafted_file(const unsigned int seed, const enum handcraft tensor_configs = get_tensor_configs(rng); } + if (hft == HANDCRAFTED_DATA_MEM_SIZE_OVERFLOW) { + tensor_configs.resize(2); + + tensor_configs[0] = { GGML_TYPE_I8, { 0x7FFFFFFFFFFFFFC0, 1, 1, 1 } }; + tensor_configs[1] = { GGML_TYPE_I8, { 0x7FFFFFFFFFFFFFC0, 1, 1, 1 } }; + } + if (hft == HANDCRAFTED_HEADER_BAD_N_TENSORS) { const uint64_t n_tensors = -1; helper_write(file, n_tensors); @@ -397,7 +406,8 @@ static FILE * get_handcrafted_file(const unsigned int seed, const enum handcraft for (uint32_t i = 1; i < n_dims; ++i) { ne *= shape[i]; } - offset += GGML_PAD(ggml_row_size(type, ne), alignment); + + offset += GGML_PAD(ggml_row_size(type, ne), (uint64_t) alignment); } while (ftell(file) % alignment != 0) { @@ -411,6 +421,9 @@ static FILE * get_handcrafted_file(const unsigned int seed, const enum handcraft if (hft == HANDCRAFTED_DATA_NOT_ENOUGH_DATA) { nbytes -= 1; } + if (hft == HANDCRAFTED_DATA_MEM_SIZE_OVERFLOW) { + nbytes = 32; + } for (uint64_t i = 0; i < nbytes; ++i) { const uint8_t random_byte = i % 256; helper_write(file, random_byte); @@ -704,6 +717,7 @@ static std::pair test_handcrafted_file(const unsigned int seed) { HANDCRAFTED_DATA_NOT_ENOUGH_DATA, HANDCRAFTED_DATA_BAD_ALIGN, HANDCRAFTED_DATA_INCONSISTENT_ALIGN, + HANDCRAFTED_DATA_MEM_SIZE_OVERFLOW, HANDCRAFTED_DATA_SUCCESS, HANDCRAFTED_DATA_CUSTOM_ALIGN, }; diff --git a/tests/test-jinja.cpp b/tests/test-jinja.cpp index f5197bd33f..05ea8ca9e9 100644 --- a/tests/test-jinja.cpp +++ b/tests/test-jinja.cpp @@ -32,6 +32,7 @@ static void test_string_methods(testing & t); static void test_array_methods(testing & t); static void test_object_methods(testing & t); static void test_hasher(testing & t); +static void test_stats(testing & t); static void test_fuzzing(testing & t); static bool g_python_mode = false; @@ -70,6 +71,7 @@ int main(int argc, char *argv[]) { t.test("object methods", test_object_methods); if (!g_python_mode) { t.test("hasher", test_hasher); + t.test("stats", test_stats); t.test("fuzzing", test_fuzzing); } @@ -1795,6 +1797,63 @@ static void test_hasher(testing & t) { }); } +static void test_stats(testing & t) { + static auto get_stats = [](const std::string & tmpl, const json & vars) -> jinja::value { + jinja::lexer lexer; + auto lexer_res = lexer.tokenize(tmpl); + + jinja::program prog = jinja::parse_from_tokens(lexer_res); + + jinja::context ctx(tmpl); + jinja::global_from_json(ctx, json{{ "val", vars }}, true); + ctx.is_get_stats = true; + + jinja::runtime runtime(ctx); + runtime.execute(prog); + + return ctx.get_val("val"); + }; + + t.test("stats", [](testing & t) { + jinja::value val = get_stats( + "{{val.num}} " + "{{val.str}} " + "{{val.arr[0]}} " + "{{val.obj.key1}} " + "{{val.nested | tojson}}", + // Note: the json below will be wrapped inside "val" in the context + json{ + {"num", 1}, + {"str", "abc"}, + {"arr", json::array({1, 2, 3})}, + {"obj", json::object({{"key1", 1}, {"key2", 2}, {"key3", 3}})}, + {"nested", json::object({ + {"inner_key1", json::array({1, 2})}, + {"inner_key2", json::object({{"a", "x"}, {"b", "y"}})} + })}, + {"mixed", json::object({ + {"used", 1}, + {"unused", 2}, + })}, + } + ); + + t.assert_true("num is used", val->at("num")->stats.used); + t.assert_true("str is used", val->at("str")->stats.used); + + t.assert_true("arr is used", val->at("arr")->stats.used); + t.assert_true("arr[0] is used", val->at("arr")->at(0)->stats.used); + t.assert_true("arr[1] is not used", !val->at("arr")->at(1)->stats.used); + + t.assert_true("obj is used", val->at("obj")->stats.used); + t.assert_true("obj.key1 is used", val->at("obj")->at("key1")->stats.used); + t.assert_true("obj.key2 is not used", !val->at("obj")->at("key2")->stats.used); + + t.assert_true("inner_key1[0] is used", val->at("nested")->at("inner_key1")->at(0)->stats.used); + t.assert_true("inner_key2.a is used", val->at("nested")->at("inner_key2")->at("a")->stats.used); + }); +} + static void test_template_cpp(testing & t, const std::string & name, const std::string & tmpl, const json & vars, const std::string & expect) { t.test(name, [&tmpl, &vars, &expect](testing & t) { jinja::lexer lexer; diff --git a/tests/test-tokenizer-0.sh b/tests/test-tokenizer-0.sh index 7ef009dc90..7024b00afe 100755 --- a/tests/test-tokenizer-0.sh +++ b/tests/test-tokenizer-0.sh @@ -13,7 +13,12 @@ fi name=$1 input=$2 -make -j tests/test-tokenizer-0 +# Build using CMake if binary doesn't exist +if [ ! -f ./build/bin/test-tokenizer-0 ]; then + printf "Building test-tokenizer-0 with CMake...\n" + cmake -B build -DLLAMA_BUILD_TESTS=ON + cmake --build build --target test-tokenizer-0 -j +fi printf "Testing %s on %s ...\n" $name $input @@ -23,7 +28,7 @@ printf "Tokenizing using (py) Python AutoTokenizer ...\n" python3 ./tests/test-tokenizer-0.py ./models/tokenizers/$name --fname-tok $input > /tmp/test-tokenizer-0-$name-py.log 2>&1 printf "Tokenizing using (cpp) llama.cpp ...\n" -./tests/test-tokenizer-0 ./models/ggml-vocab-$name.gguf $input > /tmp/test-tokenizer-0-$name-cpp.log 2>&1 +./build/bin/test-tokenizer-0 ./models/ggml-vocab-$name.gguf $input > /tmp/test-tokenizer-0-$name-cpp.log 2>&1 cat /tmp/test-tokenizer-0-$name-py.log | grep "tokenized in" cat /tmp/test-tokenizer-0-$name-cpp.log | grep "tokenized in" diff --git a/tools/cli/README.md b/tools/cli/README.md index 4a15cbad9d..22d3fc87e9 100644 --- a/tools/cli/README.md +++ b/tools/cli/README.md @@ -57,8 +57,8 @@ | `-dt, --defrag-thold N` | KV cache defragmentation threshold (DEPRECATED)
(env: LLAMA_ARG_DEFRAG_THOLD) | | `-np, --parallel N` | number of parallel sequences to decode (default: 1)
(env: LLAMA_ARG_N_PARALLEL) | | `--mlock` | force system to keep model in RAM rather than swapping or compressing
(env: LLAMA_ARG_MLOCK) | -| `--mmap, --no-mmap` | whether to memory-map model. Explicitly enabling mmap disables direct-io. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: enabled)
(env: LLAMA_ARG_MMAP) | -| `-dio, --direct-io, -ndio, --no-direct-io` | use DirectIO if available. Takes precedence over --mmap (default: enabled)
(env: LLAMA_ARG_DIO) | +| `--mmap, --no-mmap` | whether to memory-map model. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: enabled)
(env: LLAMA_ARG_MMAP) | +| `-dio, --direct-io, -ndio, --no-direct-io` | use DirectIO if available. (default: disabled)
(env: LLAMA_ARG_DIO) | | `--numa TYPE` | attempt optimizations that help on some NUMA systems
- distribute: spread execution evenly over all nodes
- isolate: only spawn threads on CPUs on the node that execution started on
- numactl: use the CPU map provided by numactl
if run without this previously, it is recommended to drop the system page cache before using this
see https://github.com/ggml-org/llama.cpp/issues/1437
(env: LLAMA_ARG_NUMA) | | `-dev, --device ` | comma-separated list of devices to use for offloading (none = don't offload)
use --list-devices to see a list of available devices
(env: LLAMA_ARG_DEVICE) | | `--list-devices` | print list of available devices and exit | @@ -109,14 +109,14 @@ | `-s, --seed SEED` | RNG seed (default: -1, use random seed for -1) | | `--sampler-seq, --sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: edskypmxt) | | `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) | -| `--temp N` | temperature (default: 0.80) | +| `--temp, --temperature N` | temperature (default: 0.80) | | `--top-k N` | top-k sampling (default: 40, 0 = disabled)
(env: LLAMA_ARG_TOP_K) | | `--top-p N` | top-p sampling (default: 0.95, 1.0 = disabled) | | `--min-p N` | min-p sampling (default: 0.05, 0.0 = disabled) | -| `--top-nsigma N` | top-n-sigma sampling (default: -1.00, -1.0 = disabled) | +| `--top-nsigma, --top-n-sigma N` | top-n-sigma sampling (default: -1.00, -1.0 = disabled) | | `--xtc-probability N` | xtc probability (default: 0.00, 0.0 = disabled) | | `--xtc-threshold N` | xtc threshold (default: 0.10, 1.0 = disabled) | -| `--typical N` | locally typical sampling, parameter p (default: 1.00, 1.0 = disabled) | +| `--typical, --typical-p N` | locally typical sampling, parameter p (default: 1.00, 1.0 = disabled) | | `--repeat-last-n N` | last n tokens to consider for penalize (default: 64, 0 = disabled, -1 = ctx_size) | | `--repeat-penalty N` | penalize repeat sequence of tokens (default: 1.00, 1.0 = disabled) | | `--presence-penalty N` | repeat alpha presence penalty (default: 0.00, 0.0 = disabled) | diff --git a/tools/cli/cli.cpp b/tools/cli/cli.cpp index ad421e6326..e57bf52e36 100644 --- a/tools/cli/cli.cpp +++ b/tools/cli/cli.cpp @@ -380,6 +380,15 @@ int main(int argc, char ** argv) { console::error("file does not exist or cannot be opened: '%s'\n", fname.c_str()); continue; } + if (inf.fim_sep_token != LLAMA_TOKEN_NULL) { + cur_msg += common_token_to_piece(ctx_cli.ctx_server.get_llama_context(), inf.fim_sep_token, true); + cur_msg += fname; + cur_msg.push_back('\n'); + } else { + cur_msg += "--- File: "; + cur_msg += fname; + cur_msg += " ---\n"; + } cur_msg += marker; console::log("Loaded text from '%s'\n", fname.c_str()); continue; diff --git a/tools/completion/README.md b/tools/completion/README.md index 3ca3e68454..bcc0887659 100644 --- a/tools/completion/README.md +++ b/tools/completion/README.md @@ -140,8 +140,8 @@ llama-completion.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1 | `-dt, --defrag-thold N` | KV cache defragmentation threshold (DEPRECATED)
(env: LLAMA_ARG_DEFRAG_THOLD) | | `-np, --parallel N` | number of parallel sequences to decode (default: 1)
(env: LLAMA_ARG_N_PARALLEL) | | `--mlock` | force system to keep model in RAM rather than swapping or compressing
(env: LLAMA_ARG_MLOCK) | -| `--mmap, --no-mmap` | whether to memory-map model. Explicitly enabling mmap disables direct-io. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: enabled)
(env: LLAMA_ARG_MMAP) | -| `-dio, --direct-io, -ndio, --no-direct-io` | use DirectIO if available. Takes precedence over --mmap (default: enabled)
(env: LLAMA_ARG_DIO) | +| `--mmap, --no-mmap` | whether to memory-map model. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: enabled)
(env: LLAMA_ARG_MMAP) | +| `-dio, --direct-io, -ndio, --no-direct-io` | use DirectIO if available. (default: disabled)
(env: LLAMA_ARG_DIO) | | `--numa TYPE` | attempt optimizations that help on some NUMA systems
- distribute: spread execution evenly over all nodes
- isolate: only spawn threads on CPUs on the node that execution started on
- numactl: use the CPU map provided by numactl
if run without this previously, it is recommended to drop the system page cache before using this
see https://github.com/ggml-org/llama.cpp/issues/1437
(env: LLAMA_ARG_NUMA) | | `-dev, --device ` | comma-separated list of devices to use for offloading (none = don't offload)
use --list-devices to see a list of available devices
(env: LLAMA_ARG_DEVICE) | | `--list-devices` | print list of available devices and exit | @@ -192,14 +192,14 @@ llama-completion.exe -m models\gemma-1.1-7b-it.Q4_K_M.gguf --ignore-eos -n -1 | `-s, --seed SEED` | RNG seed (default: -1, use random seed for -1) | | `--sampler-seq, --sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: edskypmxt) | | `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) | -| `--temp N` | temperature (default: 0.80) | +| `--temp, --temperature N` | temperature (default: 0.80) | | `--top-k N` | top-k sampling (default: 40, 0 = disabled)
(env: LLAMA_ARG_TOP_K) | | `--top-p N` | top-p sampling (default: 0.95, 1.0 = disabled) | | `--min-p N` | min-p sampling (default: 0.05, 0.0 = disabled) | -| `--top-nsigma N` | top-n-sigma sampling (default: -1.00, -1.0 = disabled) | +| `--top-nsigma, --top-n-sigma N` | top-n-sigma sampling (default: -1.00, -1.0 = disabled) | | `--xtc-probability N` | xtc probability (default: 0.00, 0.0 = disabled) | | `--xtc-threshold N` | xtc threshold (default: 0.10, 1.0 = disabled) | -| `--typical N` | locally typical sampling, parameter p (default: 1.00, 1.0 = disabled) | +| `--typical, --typical-p N` | locally typical sampling, parameter p (default: 1.00, 1.0 = disabled) | | `--repeat-last-n N` | last n tokens to consider for penalize (default: 64, 0 = disabled, -1 = ctx_size) | | `--repeat-penalty N` | penalize repeat sequence of tokens (default: 1.00, 1.0 = disabled) | | `--presence-penalty N` | repeat alpha presence penalty (default: 0.00, 0.0 = disabled) | diff --git a/tools/completion/completion.cpp b/tools/completion/completion.cpp index 977132756f..aed2c0e38f 100644 --- a/tools/completion/completion.cpp +++ b/tools/completion/completion.cpp @@ -387,6 +387,17 @@ int main(int argc, char ** argv) { } session_do_save = !path_session.empty() && n_match < embd_inp.size() && !params.prompt_cache_ro; + + // Logits are not stored as part of the session state so we need to + // "replay" the last token to get logits for sampling. + if (!session_tokens.empty() && n_match > 0 && n_match == session_tokens.size()) { + if (!common_replay_last_token(ctx, session_tokens.back(), n_match)) { + return 1; + } + + session_do_save = false; + LOG_INF("%s: replayed last token from session\n", __func__); + } } // number of tokens to keep when resetting context @@ -675,40 +686,27 @@ int main(int argc, char ** argv) { } if (!embd.empty()) { - int n_eval = (int) embd.size(); - LOG_DBG("eval: %s\n", string_from(ctx, embd).c_str()); - - GGML_ASSERT(n_eval <= params.n_batch); - if (llama_decode(ctx, llama_batch_get_one(embd.data(), n_eval))) { - LOG_ERR("%s : failed to eval\n", __func__); + const bool is_last_batch = (n_consumed >= (int) embd_inp.size()); + const bool save_now = session_do_save && is_last_batch; + if (!common_prompt_batch_decode(ctx, embd, n_past, params.n_batch, path_session, save_now)) { return 1; } - - n_past += n_eval; + session_tokens.insert(session_tokens.end(), embd.begin(), embd.begin()); + n_session_consumed = session_tokens.size(); + session_do_save = false; LOG_DBG("n_past = %d\n", n_past); + // Display total tokens alongside total time if (params.n_print > 0 && n_past % params.n_print == 0) { LOG_DBG("\n\033[31mTokens consumed so far = %d / %d \033[0m\n", n_past, n_ctx); } } - - if (!embd.empty() && !path_session.empty()) { - session_tokens.insert(session_tokens.end(), embd.begin(), embd.end()); - n_session_consumed = session_tokens.size(); - } } embd.clear(); if ((int) embd_inp.size() <= n_consumed && !is_interacting) { - // optionally save the session on first sample (for faster prompt loading next time) - if (session_do_save) { - session_do_save = false; - llama_state_save_file(ctx, path_session.c_str(), session_tokens.data(), session_tokens.size()); - - LOG_DBG("saved session to %s\n", path_session.c_str()); - } const llama_token id = common_sampler_sample(smpl, ctx, -1); diff --git a/tools/imatrix/imatrix.cpp b/tools/imatrix/imatrix.cpp index 669de55ddb..e025c114b4 100644 --- a/tools/imatrix/imatrix.cpp +++ b/tools/imatrix/imatrix.cpp @@ -912,7 +912,9 @@ static bool compute_imatrix(llama_context * ctx, const common_params & params, c const bool add_bos = llama_vocab_get_add_bos(vocab); - GGML_ASSERT(!llama_vocab_get_add_eos(vocab)); + if (llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_LAST) { + GGML_ASSERT(!llama_vocab_get_add_eos(vocab)); + } auto tim1 = std::chrono::high_resolution_clock::now(); LOG_INF("%s: tokenizing the input ..\n", __func__); diff --git a/tools/mtmd/mtmd-helper.cpp b/tools/mtmd/mtmd-helper.cpp index 902a4b456d..c75f90730f 100644 --- a/tools/mtmd/mtmd-helper.cpp +++ b/tools/mtmd/mtmd-helper.cpp @@ -248,7 +248,7 @@ int32_t mtmd_helper_decode_image_chunk( int32_t n_tokens = mtmd_input_chunk_get_n_tokens(chunk); int32_t i_batch = 0; - int32_t n_img_batches = GGML_PAD(n_tokens, n_batch) / n_batch; + int32_t n_img_batches = (n_tokens + n_batch - 1) / n_batch; decode_embd_batch batch_embd(encoded_embd, n_tokens, n_pos_per_embd, n_mmproj_embd); if (mtmd_decode_use_mrope(ctx)) { diff --git a/tools/server/README.md b/tools/server/README.md index 0b56ca1e27..da16ddc756 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -74,8 +74,8 @@ For the full list of features, please refer to [server's changelog](https://gith | `-ctv, --cache-type-v TYPE` | KV cache data type for V
allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1
(default: f16)
(env: LLAMA_ARG_CACHE_TYPE_V) | | `-dt, --defrag-thold N` | KV cache defragmentation threshold (DEPRECATED)
(env: LLAMA_ARG_DEFRAG_THOLD) | | `--mlock` | force system to keep model in RAM rather than swapping or compressing
(env: LLAMA_ARG_MLOCK) | -| `--mmap, --no-mmap` | whether to memory-map model. Explicitly enabling mmap disables direct-io. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: enabled)
(env: LLAMA_ARG_MMAP) | -| `-dio, --direct-io, -ndio, --no-direct-io` | use DirectIO if available. Takes precedence over --mmap (default: enabled)
(env: LLAMA_ARG_DIO) | +| `--mmap, --no-mmap` | whether to memory-map model. (if mmap disabled, slower load but may reduce pageouts if not using mlock) (default: enabled)
(env: LLAMA_ARG_MMAP) | +| `-dio, --direct-io, -ndio, --no-direct-io` | use DirectIO if available. (default: disabled)
(env: LLAMA_ARG_DIO) | | `--numa TYPE` | attempt optimizations that help on some NUMA systems
- distribute: spread execution evenly over all nodes
- isolate: only spawn threads on CPUs on the node that execution started on
- numactl: use the CPU map provided by numactl
if run without this previously, it is recommended to drop the system page cache before using this
see https://github.com/ggml-org/llama.cpp/issues/1437
(env: LLAMA_ARG_NUMA) | | `-dev, --device ` | comma-separated list of devices to use for offloading (none = don't offload)
use --list-devices to see a list of available devices
(env: LLAMA_ARG_DEVICE) | | `--list-devices` | print list of available devices and exit | @@ -126,14 +126,14 @@ For the full list of features, please refer to [server's changelog](https://gith | `-s, --seed SEED` | RNG seed (default: -1, use random seed for -1) | | `--sampler-seq, --sampling-seq SEQUENCE` | simplified sequence for samplers that will be used (default: edskypmxt) | | `--ignore-eos` | ignore end of stream token and continue generating (implies --logit-bias EOS-inf) | -| `--temp N` | temperature (default: 0.80) | +| `--temp, --temperature N` | temperature (default: 0.80) | | `--top-k N` | top-k sampling (default: 40, 0 = disabled)
(env: LLAMA_ARG_TOP_K) | | `--top-p N` | top-p sampling (default: 0.95, 1.0 = disabled) | | `--min-p N` | min-p sampling (default: 0.05, 0.0 = disabled) | -| `--top-nsigma N` | top-n-sigma sampling (default: -1.00, -1.0 = disabled) | +| `--top-nsigma, --top-n-sigma N` | top-n-sigma sampling (default: -1.00, -1.0 = disabled) | | `--xtc-probability N` | xtc probability (default: 0.00, 0.0 = disabled) | | `--xtc-threshold N` | xtc threshold (default: 0.10, 1.0 = disabled) | -| `--typical N` | locally typical sampling, parameter p (default: 1.00, 1.0 = disabled) | +| `--typical, --typical-p N` | locally typical sampling, parameter p (default: 1.00, 1.0 = disabled) | | `--repeat-last-n N` | last n tokens to consider for penalize (default: 64, 0 = disabled, -1 = ctx_size) | | `--repeat-penalty N` | penalize repeat sequence of tokens (default: 1.00, 1.0 = disabled) | | `--presence-penalty N` | repeat alpha presence penalty (default: 0.00, 0.0 = disabled) | @@ -162,9 +162,11 @@ For the full list of features, please refer to [server's changelog](https://gith | Argument | Explanation | | -------- | ----------- | +| `-lcs, --lookup-cache-static FNAME` | path to static lookup cache to use for lookup decoding (not updated by generation) | +| `-lcd, --lookup-cache-dynamic FNAME` | path to dynamic lookup cache to use for lookup decoding (updated by generation) | | `--ctx-checkpoints, --swa-checkpoints N` | max number of context checkpoints to create per slot (default: 8)[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)
(env: LLAMA_ARG_CTX_CHECKPOINTS) | | `-cram, --cache-ram N` | set the maximum cache size in MiB (default: 8192, -1 - no limit, 0 - disable)[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)
(env: LLAMA_ARG_CACHE_RAM) | -| `-kvu, --kv-unified` | use single unified KV buffer shared across all sequences (default: enabled if number of slots is auto)
(env: LLAMA_ARG_KV_UNIFIED) | +| `-kvu, --kv-unified, -no-kvu, --no-kv-unified` | use single unified KV buffer shared across all sequences (default: enabled if number of slots is auto)
(env: LLAMA_ARG_KV_UNIFIED) | | `--context-shift, --no-context-shift` | whether to use context shift on infinite text generation (default: disabled)
(env: LLAMA_ARG_CONTEXT_SHIFT) | | `-r, --reverse-prompt PROMPT` | halt generation at PROMPT, return control in interactive mode | | `-sp, --special` | special tokens output enabled (default: false) | @@ -182,7 +184,8 @@ For the full list of features, please refer to [server's changelog](https://gith | `-otd, --override-tensor-draft =,...` | override tensor buffer type for draft model | | `-cmoed, --cpu-moe-draft` | keep all Mixture of Experts (MoE) weights in the CPU for the draft model
(env: LLAMA_ARG_CPU_MOE_DRAFT) | | `-ncmoed, --n-cpu-moe-draft N` | keep the Mixture of Experts (MoE) weights of the first N layers in the CPU for the draft model
(env: LLAMA_ARG_N_CPU_MOE_DRAFT) | -| `-a, --alias STRING` | set alias for model name (to be used by REST API)
(env: LLAMA_ARG_ALIAS) | +| `-a, --alias STRING` | set model name aliases, comma-separated (to be used by API)
(env: LLAMA_ARG_ALIAS) | +| `--tags STRING` | set model tags, comma-separated (informational, not used for routing)
(env: LLAMA_ARG_TAGS) | | `--host HOST` | ip address to listen, or bind to an UNIX socket if the address ends with .sock (default: 127.0.0.1)
(env: LLAMA_ARG_HOST) | | `--port PORT` | port to listen (default: 8080)
(env: LLAMA_ARG_PORT) | | `--path PATH` | path to serve static files from (default: )
(env: LLAMA_ARG_STATIC_PATH) | @@ -229,6 +232,10 @@ For the full list of features, please refer to [server's changelog](https://gith | `-ngld, --gpu-layers-draft, --n-gpu-layers-draft N` | max. number of draft model layers to store in VRAM, either an exact number, 'auto', or 'all' (default: auto)
(env: LLAMA_ARG_N_GPU_LAYERS_DRAFT) | | `-md, --model-draft FNAME` | draft model for speculative decoding (default: unused)
(env: LLAMA_ARG_MODEL_DRAFT) | | `--spec-replace TARGET DRAFT` | translate the string in TARGET into DRAFT if the draft model and main model are not compatible | +| `--spec-type [none\|ngram-cache\|ngram-simple\|ngram-map-k\|ngram-map-k4v\|ngram-mod]` | type of speculative decoding to use when no draft model is provided (default: none) | +| `--spec-ngram-size-n N` | ngram size N for ngram-simple/ngram-map speculative decoding, length of lookup n-gram (default: 12) | +| `--spec-ngram-size-m N` | ngram size M for ngram-simple/ngram-map speculative decoding, length of draft m-gram (default: 48) | +| `--spec-ngram-min-hits N` | minimum hits for ngram-map speculative decoding (default: 1) | | `-mv, --model-vocoder FNAME` | vocoder model for audio generation (default: unused) | | `--tts-use-guide-tokens` | Use guide tokens to improve TTS word recall | | `--embd-gemma-default` | use default EmbeddingGemma model (note: can download weights from the internet) | @@ -1510,7 +1517,7 @@ version = 1 ; If the same key is defined in a specific preset, it will override the value in this global section. [*] c = 8192 -n-gpu-layer = 8 +n-gpu-layers = 8 ; If the key corresponds to an existing model on the server, ; this will be used as the default config for that model diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index c69481e798..a5465fcd13 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index d717fb6698..ff3c6d3c2b 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -231,19 +231,77 @@ server_tokens::server_tokens(mtmd::input_chunks & mtmd_chunks, bool has_mtmd) : server_tokens::server_tokens(const llama_tokens & tokens, bool has_mtmd) : has_mtmd(has_mtmd), tokens(tokens) { } -llama_pos server_tokens::pos_next() const { +llama_pos server_tokens::pos_next(int64_t n_tokens) const { if (!has_mtmd) { - return tokens.size(); + if (n_tokens < 0) { + return tokens.size(); + } + + return n_tokens; } - llama_pos res = tokens.size(); + if (n_tokens < 0) { + llama_pos res = tokens.size(); - for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) { - const auto & chunk = it->second; - res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get()); + for (auto it = map_idx_to_media.begin(); it != map_idx_to_media.end(); ++it) { + const auto & chunk = it->second; + res += mtmd_input_chunk_get_n_pos(chunk.get()) - mtmd_input_chunk_get_n_tokens(chunk.get()); + } + + return res; } - return res; + int64_t idx = 0; + llama_pos pos = 0; + + GGML_ASSERT(n_tokens <= (int64_t)tokens.size()); + + while (idx < n_tokens) { + const auto media_it = map_idx_to_media.find(idx); + if (media_it != map_idx_to_media.end()) { + const auto & chunk = media_it->second; + const llama_pos n_pos = mtmd_input_chunk_get_n_pos(chunk.get()); + const size_t n_tok = mtmd_input_chunk_get_n_tokens(chunk.get()); + + pos += n_pos; + idx += n_tok; + } else { + pos++; + idx++; + } + } + + return pos; +} + +size_t server_tokens::size_up_to_pos(llama_pos max_pos) const { + if (!has_mtmd) { + return std::min((size_t)(max_pos + 1), tokens.size()); + } + + size_t idx = 0; + llama_pos pos = 0; + + while (idx < tokens.size()) { + const auto media_it = map_idx_to_media.find(idx); + if (media_it != map_idx_to_media.end()) { + const auto & chunk = media_it->second; + const llama_pos n_pos = mtmd_input_chunk_get_n_pos(chunk.get()); + const size_t n_tok = mtmd_input_chunk_get_n_tokens(chunk.get()); + + pos += n_pos; + idx += n_tok; + } else { + pos++; + idx++; + } + + if (pos > max_pos) { + break; + } + } + + return idx; } std::string server_tokens::str() const { @@ -1105,6 +1163,8 @@ json convert_responses_to_chatcmpl(const json & response_body) { }; for (json item : input_value) { + bool merge_prev = !chatcmpl_messages.empty() && chatcmpl_messages.back().value("role", "") == "assistant"; + if (exists_and_is_string(item, "content")) { // #responses_create-input-input_item_list-input_message-content-text_input // Only "Input message" contains item["content"]::string @@ -1193,7 +1253,7 @@ json convert_responses_to_chatcmpl(const json & response_body) { item.at("type") == "message" ) { // #responses_create-input-input_item_list-item-output_message - std::vector chatcmpl_content; + auto chatcmpl_content = json::array(); for (const auto & output_text : item.at("content")) { const std::string type = json_value(output_text, "type", std::string()); @@ -1210,10 +1270,19 @@ json convert_responses_to_chatcmpl(const json & response_body) { }); } - item.erase("status"); - item.erase("type"); - item["content"] = chatcmpl_content; - chatcmpl_messages.push_back(item); + if (merge_prev) { + auto & prev_msg = chatcmpl_messages.back(); + if (!exists_and_is_array(prev_msg, "content")) { + prev_msg["content"] = json::array(); + } + auto & prev_content = prev_msg["content"]; + prev_content.insert(prev_content.end(), chatcmpl_content.begin(), chatcmpl_content.end()); + } else { + item.erase("status"); + item.erase("type"); + item["content"] = chatcmpl_content; + chatcmpl_messages.push_back(item); + } } else if (exists_and_is_string(item, "arguments") && exists_and_is_string(item, "call_id") && exists_and_is_string(item, "name") && @@ -1221,24 +1290,27 @@ json convert_responses_to_chatcmpl(const json & response_body) { item.at("type") == "function_call" ) { // #responses_create-input-input_item_list-item-function_tool_call - json msg = json { - {"role", "assistant"}, - {"tool_calls", json::array({ json { - {"function", json { - {"arguments", item.at("arguments")}, - {"name", item.at("name")}, - }}, - {"id", item.at("call_id")}, - {"type", "function"}, - }})}, + json tool_call = { + {"function", json { + {"arguments", item.at("arguments")}, + {"name", item.at("name")}, + }}, + {"id", item.at("call_id")}, + {"type", "function"}, }; - if (!chatcmpl_messages.empty() && chatcmpl_messages.back().contains("reasoning_content")) { - // Move reasoning content from dummy message to tool call message - msg["reasoning_content"] = chatcmpl_messages.back().at("reasoning_content"); - chatcmpl_messages.pop_back(); + if (merge_prev) { + auto & prev_msg = chatcmpl_messages.back(); + if (!exists_and_is_array(prev_msg, "tool_calls")) { + prev_msg["tool_calls"] = json::array(); + } + prev_msg["tool_calls"].push_back(tool_call); + } else { + chatcmpl_messages.push_back(json { + {"role", "assistant"}, + {"tool_calls", json::array({tool_call})} + }); } - chatcmpl_messages.push_back(msg); } else if (exists_and_is_string(item, "call_id") && (exists_and_is_string(item, "output") || exists_and_is_array(item, "output")) && exists_and_is_string(item, "type") && @@ -1282,12 +1354,16 @@ json convert_responses_to_chatcmpl(const json & response_body) { throw std::invalid_argument("item['content']['text'] is not a string"); } - // Pack reasoning content in dummy message - chatcmpl_messages.push_back(json { - {"role", "assistant"}, - {"content", json::array()}, - {"reasoning_content", item.at("content")[0].at("text")}, - }); + if (merge_prev) { + auto & prev_msg = chatcmpl_messages.back(); + prev_msg["reasoning_content"] = item.at("content")[0].at("text"); + } else { + chatcmpl_messages.push_back(json { + {"role", "assistant"}, + {"content", json::array()}, + {"reasoning_content", item.at("content")[0].at("text")}, + }); + } } else { throw std::invalid_argument("Cannot determine type of 'item'"); } @@ -1296,20 +1372,6 @@ json convert_responses_to_chatcmpl(const json & response_body) { throw std::invalid_argument("'input' must be a string or array of objects"); } - // Remove unused dummy message which contains - // reasoning content not followed by tool call - chatcmpl_messages.erase(std::remove_if( - chatcmpl_messages.begin(), - chatcmpl_messages.end(), - [](const json & x){ return x.contains("role") && - x.at("role") == "assistant" && - x.contains("content") && - x.at("content") == json::array() && - x.contains("reasoning_content"); - }), - chatcmpl_messages.end() - ); - chatcmpl_body["messages"] = chatcmpl_messages; if (response_body.contains("tools")) { diff --git a/tools/server/server-common.h b/tools/server/server-common.h index 2629a6bee9..4fb9e488df 100644 --- a/tools/server/server-common.h +++ b/tools/server/server-common.h @@ -167,7 +167,12 @@ public: // for debugging std::string str() const; - llama_pos pos_next() const; + // the next position after n_tokens. if n_tokens < 0, return the next position after all tokens. + llama_pos pos_next(int64_t n_tokens = -1) const; + + // number of tokens with position <= max_pos + size_t size_up_to_pos(llama_pos max_pos) const; + const mtmd::input_chunk_ptr & find_chunk(size_t idx) const; void push_back(llama_token tok); diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 8aab0d4c1b..aafed49502 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -580,6 +580,8 @@ private: float slot_prompt_similarity = 0.0f; std::string model_name; // name of the loaded model, to be used by API + std::set model_aliases; // additional names for the model + std::set model_tags; // informational tags bool sleeping = false; @@ -813,10 +815,9 @@ private: SRV_WRN("%s", "for more info see https://github.com/ggml-org/llama.cpp/pull/16391\n"); if (!params_base.model_alias.empty()) { - // user explicitly specified model name - model_name = params_base.model_alias; + // backward compat: use first alias as model name + model_name = *params_base.model_alias.begin(); } else if (!params_base.model.name.empty()) { - // use model name in registry format (for models in cache) model_name = params_base.model.name; } else { // fallback: derive model name from file name @@ -824,6 +825,9 @@ private: model_name = model_path.filename().string(); } + model_aliases = params_base.model_alias; + model_tags = params_base.model_tags; + if (!is_resume) { return init(); } @@ -995,9 +999,6 @@ private: // don't update the cache if the slot's context is empty update_cache = update_cache && tokens.size() > 0; - // TODO: mtmd does not support prompt cache - update_cache = update_cache && (ret->mctx == nullptr); - if (update_cache) { SRV_WRN("%s", "updating prompt cache\n"); @@ -1442,7 +1443,7 @@ private: res->id = slot.task->id; res->id_slot = slot.id; - res->index = slot.task->index; + res->index = slot.task->index; // keep copy of last generated text for debugging purposes if (slots_debug) { @@ -2282,15 +2283,15 @@ private: n_past = 0; } + llama_pos pos_next = slot.prompt.tokens.pos_next(n_past); + // note: when n_swa == 0, the model does not use SWA, which is equivalent to a window of 1 const auto n_swa = std::max(1, llama_model_n_swa(model)); // the largest pos_min required for a checkpoint to be useful - const auto pos_min_thold = std::max(0, n_past - n_swa); + const auto pos_min_thold = std::max(0, pos_next - n_swa); - // note: disallow with mtmd contexts for now - // https://github.com/ggml-org/llama.cpp/issues/17043 - if (!mctx && n_past > 0 && n_past < slot.prompt.n_tokens()) { + if (n_past > 0 && n_past < slot.prompt.n_tokens()) { const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); if (pos_min == -1) { SLT_ERR(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min); @@ -2341,9 +2342,6 @@ private: } if (pos_min > pos_min_thold) { - // TODO: support can be added in the future when corresponding vision models get released - GGML_ASSERT(!slot.prompt.tokens.has_mtmd); - SLT_WRN(slot, "n_past = %d, slot.prompt.tokens.size() = %d, seq_id = %d, pos_min = %d, n_swa = %d\n", n_past, (int) slot.prompt.tokens.size(), slot.id, pos_min, n_swa); // search for a context checkpoint @@ -2364,18 +2362,20 @@ private: const size_t n = llama_state_seq_set_data_ext(ctx, it->data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); if (n != checkpoint_size) { - SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024); + SLT_ERR(slot, "failed to restore context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float) checkpoint_size / 1024 / 1024); do_reset = true; //printf("[DEBUG] `do_reset` was set to `true` after failing to restore a checkpoint"); } else { - n_past = std::min(n_past, std::max(it->pos_min + 1, it->pos_max)); - SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", it->pos_min, it->pos_max, (float) checkpoint_size / 1024 / 1024); + pos_next = std::min(pos_next, std::max(it->pos_min + 1, it->pos_max)); + n_past = std::min(slot.prompt.tokens.size_up_to_pos(pos_next), (size_t) it->n_tokens); + SLT_WRN(slot, "restored context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", it->pos_min, it->pos_max, it->n_tokens, (float) checkpoint_size / 1024 / 1024); } } if (do_reset) { SLT_WRN(slot, "forcing full prompt re-processing due to lack of cache data (likely due to SWA or hybrid/recurrent memory, see %s)\n", "https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055"); + pos_next = 0; n_past = 0; } } @@ -2386,7 +2386,7 @@ private: for (auto it = slot.prompt.checkpoints.begin(); it != slot.prompt.checkpoints.end();) { const auto & cur = *it; if (cur.pos_min > pos_min_thold) { - SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, n_swa, (float) cur.data.size() / 1024 / 1024); + SLT_WRN(slot, "erased invalidated context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", n_swa = %d, size = %.3f MiB)\n", cur.pos_min, cur.pos_max, cur.n_tokens, n_swa, (float) cur.data.size() / 1024 / 1024); it = slot.prompt.checkpoints.erase(it); } else { ++it; @@ -2402,7 +2402,7 @@ private: SLT_WRN(slot, "n_past was set to %d\n", n_past); } - slot.n_prompt_tokens_cache = n_past; + slot.n_prompt_tokens_cache = n_past; slot.n_prompt_tokens_processed = 0; slot.prompt.tokens.keep_first(n_past); @@ -2520,10 +2520,6 @@ private: } } - // SLT_INF(slot, "new slot.prompt.tokens: %s\n", slot.slot.prompt.tokens.str().c_str()); - - SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens()); - // entire prompt has been processed if (slot.prompt.n_tokens() == slot.task->n_tokens()) { slot.state = SLOT_STATE_DONE_PROMPT; @@ -2536,8 +2532,6 @@ private: slot.n_decoded = 0; slot.i_batch = batch.n_tokens - 1; - SLT_INF(slot, "prompt done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens); - slot.init_sampler(); const auto pos_min = llama_memory_seq_pos_min(llama_get_memory(ctx), slot.id); @@ -2549,13 +2543,15 @@ private: // no need to create checkpoints that are too close together do_checkpoint = do_checkpoint && (slot.prompt.checkpoints.empty() || pos_max > slot.prompt.checkpoints.back().pos_max + 64); + // note: we create the checkpoint before calling llama_decode(), so the current batch is not + // yet processed and therefore it is not part of the checkpoint. if (do_checkpoint) { while (slot.prompt.checkpoints.size() >= (size_t) params_base.n_ctx_checkpoints) { // make room for the new checkpoint, if needed const auto & cur = slot.prompt.checkpoints.front(); - SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", - cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); + SLT_WRN(slot, "erasing old context checkpoint (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", + cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024); slot.prompt.checkpoints.erase(slot.prompt.checkpoints.begin()); } @@ -2563,16 +2559,21 @@ private: const size_t checkpoint_size = llama_state_seq_get_size_ext(ctx, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); auto & cur = slot.prompt.checkpoints.emplace_back(server_prompt_checkpoint{ - /*.pos_min = */ pos_min, - /*.pos_max = */ pos_max, - /*.data = */ std::vector(checkpoint_size), + /*.pos_min = */ pos_min, + /*.pos_max = */ pos_max, + /*.n_tokens = */ slot.prompt.n_tokens() - batch.n_tokens, + /*.data = */ std::vector(checkpoint_size), }); llama_state_seq_get_data_ext(ctx, cur.data.data(), checkpoint_size, slot.id, LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY); - SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, size = %.3f MiB)\n", - (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, (float) cur.data.size() / 1024 / 1024); + SLT_WRN(slot, "created context checkpoint %d of %d (pos_min = %d, pos_max = %d, n_tokens = %" PRId64 ", size = %.3f MiB)\n", + (int) slot.prompt.checkpoints.size(), params_base.n_ctx_checkpoints, cur.pos_min, cur.pos_max, cur.n_tokens, (float) cur.data.size() / 1024 / 1024); } + + SLT_INF(slot, "prompt processing done, n_tokens = %d, batch.n_tokens = %d\n", slot.prompt.n_tokens(), batch.n_tokens); + } else { + SLT_INF(slot, "prompt processing progress, n_tokens = %d, batch.n_tokens = %d, progress = %f\n", slot.prompt.n_tokens(), batch.n_tokens, (float) slot.prompt.n_tokens() / slot.task->n_tokens()); } } @@ -2895,6 +2896,8 @@ server_context_meta server_context::get_meta() const { return server_context_meta { /* build_info */ build_info, /* model_name */ impl->model_name, + /* model_aliases */ impl->model_aliases, + /* model_tags */ impl->model_tags, /* model_path */ impl->params_base.model.path, /* has_mtmd */ impl->mctx != nullptr, /* has_inp_image */ impl->chat_params.allow_image, @@ -2911,6 +2914,9 @@ server_context_meta server_context::get_meta() const { /* fim_pre_token */ llama_vocab_fim_pre(impl->vocab), /* fim_sub_token */ llama_vocab_fim_suf(impl->vocab), /* fim_mid_token */ llama_vocab_fim_mid(impl->vocab), + /* fim_pad_token */ llama_vocab_fim_pad(impl->vocab), + /* fim_rep_token */ llama_vocab_fim_rep(impl->vocab), + /* fim_sep_token */ llama_vocab_fim_sep(impl->vocab), /* model_vocab_type */ llama_vocab_type(impl->vocab), /* model_vocab_n_tokens */ llama_vocab_n_tokens(impl->vocab), @@ -3688,6 +3694,8 @@ void server_routes::init_routes() { {"data", { { {"id", meta->model_name}, + {"aliases", meta->model_aliases}, + {"tags", meta->model_tags}, {"object", "model"}, {"created", std::time(0)}, {"owned_by", "llamacpp"}, diff --git a/tools/server/server-context.h b/tools/server/server-context.h index c0b5d373ff..75f3d2de56 100644 --- a/tools/server/server-context.h +++ b/tools/server/server-context.h @@ -1,3 +1,5 @@ +#pragma once + #include "server-http.h" #include "server-task.h" #include "server-queue.h" @@ -6,12 +8,15 @@ #include #include +#include struct server_context_impl; // private implementation struct server_context_meta { std::string build_info; std::string model_name; + std::set model_aliases; + std::set model_tags; std::string model_path; bool has_mtmd; bool has_inp_image; @@ -30,6 +35,9 @@ struct server_context_meta { llama_token fim_pre_token; llama_token fim_sub_token; llama_token fim_mid_token; + llama_token fim_pad_token; + llama_token fim_rep_token; + llama_token fim_sep_token; // model meta enum llama_vocab_type model_vocab_type; diff --git a/tools/server/server-http.cpp b/tools/server/server-http.cpp index 00897eeea5..129022a711 100644 --- a/tools/server/server-http.cpp +++ b/tools/server/server-http.cpp @@ -339,6 +339,17 @@ static std::map get_headers(const httplib::Request & r return headers; } +static std::string build_query_string(const httplib::Request & req) { + std::string qs; + for (const auto & [key, value] : req.params) { + if (!qs.empty()) { + qs += '&'; + } + qs += httplib::encode_query_component(key) + "=" + httplib::encode_query_component(value); + } + return qs; +} + // using unique_ptr for request to allow safe capturing in lambdas using server_http_req_ptr = std::unique_ptr; @@ -382,6 +393,7 @@ void server_http_context::get(const std::string & path, const server_http_contex get_params(req), get_headers(req), req.path, + build_query_string(req), req.body, req.is_connection_closed }); @@ -396,6 +408,7 @@ void server_http_context::post(const std::string & path, const server_http_conte get_params(req), get_headers(req), req.path, + build_query_string(req), req.body, req.is_connection_closed }); diff --git a/tools/server/server-http.h b/tools/server/server-http.h index 24c0b40117..3621064cdf 100644 --- a/tools/server/server-http.h +++ b/tools/server/server-http.h @@ -36,7 +36,8 @@ using server_http_res_ptr = std::unique_ptr; struct server_http_req { std::map params; // path_params + query_params std::map headers; // reserved for future use - std::string path; // reserved for future use + std::string path; + std::string query_string; // query parameters string (e.g. "action=save") std::string body; const std::function & should_stop; diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index 57655476af..bc601237b7 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -184,6 +184,51 @@ void server_models::add_model(server_model_meta && meta) { if (mapping.find(meta.name) != mapping.end()) { throw std::runtime_error(string_format("model '%s' appears multiple times", meta.name.c_str())); } + + // check model name does not conflict with existing aliases + for (const auto & [key, inst] : mapping) { + if (inst.meta.aliases.count(meta.name)) { + throw std::runtime_error(string_format("model name '%s' conflicts with alias of model '%s'", + meta.name.c_str(), key.c_str())); + } + } + + // parse aliases from preset's --alias option (comma-separated) + std::string alias_str; + if (meta.preset.get_option("LLAMA_ARG_ALIAS", alias_str) && !alias_str.empty()) { + for (auto & alias : string_split(alias_str, ',')) { + alias = string_strip(alias); + if (!alias.empty()) { + meta.aliases.insert(alias); + } + } + } + + // parse tags from preset's --tags option (comma-separated) + std::string tags_str; + if (meta.preset.get_option("LLAMA_ARG_TAGS", tags_str) && !tags_str.empty()) { + for (auto & tag : string_split(tags_str, ',')) { + tag = string_strip(tag); + if (!tag.empty()) { + meta.tags.insert(tag); + } + } + } + + // validate aliases do not conflict with existing names or aliases + for (const auto & alias : meta.aliases) { + if (mapping.find(alias) != mapping.end()) { + throw std::runtime_error(string_format("alias '%s' for model '%s' conflicts with existing model name", + alias.c_str(), meta.name.c_str())); + } + for (const auto & [key, inst] : mapping) { + if (inst.meta.aliases.count(alias)) { + throw std::runtime_error(string_format("alias '%s' for model '%s' conflicts with alias of model '%s'", + alias.c_str(), meta.name.c_str(), key.c_str())); + } + } + } + meta.update_args(ctx_preset, bin_path); // render args std::string name = meta.name; mapping[name] = instance_t{ @@ -249,6 +294,8 @@ void server_models::load_models() { server_model_meta meta{ /* preset */ preset.second, /* name */ preset.first, + /* aliases */ {}, + /* tags */ {}, /* port */ 0, /* status */ SERVER_MODEL_STATUS_UNLOADED, /* last_used */ 0, @@ -265,10 +312,28 @@ void server_models::load_models() { for (const auto & [name, preset] : custom_presets) { custom_names.insert(name); } + auto join_set = [](const std::set & s) { + std::string result; + for (const auto & v : s) { + if (!result.empty()) { + result += ", "; + } + result += v; + } + return result; + }; + SRV_INF("Available models (%zu) (*: custom preset)\n", mapping.size()); for (const auto & [name, inst] : mapping) { bool has_custom = custom_names.find(name) != custom_names.end(); - SRV_INF(" %c %s\n", has_custom ? '*' : ' ', name.c_str()); + std::string info; + if (!inst.meta.aliases.empty()) { + info += " (aliases: " + join_set(inst.meta.aliases) + ")"; + } + if (!inst.meta.tags.empty()) { + info += " [tags: " + join_set(inst.meta.tags) + "]"; + } + SRV_INF(" %c %s%s\n", has_custom ? '*' : ' ', name.c_str(), info.c_str()); } } @@ -291,7 +356,9 @@ void server_models::load_models() { for (const auto & [name, inst] : mapping) { std::string val; if (inst.meta.preset.get_option(COMMON_ARG_PRESET_LOAD_ON_STARTUP, val)) { - models_to_load.push_back(name); + if (common_arg_utils::is_truthy(val)) { + models_to_load.push_back(name); + } } } if ((int)models_to_load.size() > base_params.models_max) { @@ -318,7 +385,15 @@ void server_models::update_meta(const std::string & name, const server_model_met bool server_models::has_model(const std::string & name) { std::lock_guard lk(mutex); - return mapping.find(name) != mapping.end(); + if (mapping.find(name) != mapping.end()) { + return true; + } + for (const auto & [key, inst] : mapping) { + if (inst.meta.aliases.count(name)) { + return true; + } + } + return false; } std::optional server_models::get_meta(const std::string & name) { @@ -327,6 +402,11 @@ std::optional server_models::get_meta(const std::string & nam if (it != mapping.end()) { return it->second.meta; } + for (const auto & [key, inst] : mapping) { + if (inst.meta.aliases.count(name)) { + return inst.meta; + } + } return std::nullopt; } @@ -697,11 +777,15 @@ server_http_res_ptr server_models::proxy_request(const server_http_req & req, co mapping[name].meta.last_used = ggml_time_ms(); } SRV_INF("proxying request to model %s on port %d\n", name.c_str(), meta->port); + std::string proxy_path = req.path; + if (!req.query_string.empty()) { + proxy_path += '?' + req.query_string; + } auto proxy = std::make_unique( method, CHILD_ADDR, meta->port, - req.path, + proxy_path, req.headers, req.body, req.should_stop, @@ -760,7 +844,7 @@ static void res_err(std::unique_ptr & res, const json & error_d res->data = safe_json_to_str({{ "error", error_data }}); } -static bool router_validate_model(const std::string & name, server_models & models, bool models_autoload, std::unique_ptr & res) { +static bool router_validate_model(std::string & name, server_models & models, bool models_autoload, std::unique_ptr & res) { if (name.empty()) { res_err(res, format_error_response("model name is missing from the request", ERROR_TYPE_INVALID_REQUEST)); return false; @@ -770,6 +854,8 @@ static bool router_validate_model(const std::string & name, server_models & mode res_err(res, format_error_response(string_format("model '%s' not found", name.c_str()), ERROR_TYPE_INVALID_REQUEST)); return false; } + // resolve alias to canonical model name + name = meta->name; if (models_autoload) { models.ensure_model_loaded(name); } else { @@ -841,16 +927,16 @@ void server_models_routes::init_routes() { auto res = std::make_unique(); json body = json::parse(req.body); std::string name = json_value(body, "model", std::string()); - auto model = models.get_meta(name); - if (!model.has_value()) { + auto meta = models.get_meta(name); + if (!meta.has_value()) { res_err(res, format_error_response("model is not found", ERROR_TYPE_NOT_FOUND)); return res; } - if (model->status == SERVER_MODEL_STATUS_LOADED) { + if (meta->status == SERVER_MODEL_STATUS_LOADED) { res_err(res, format_error_response("model is already loaded", ERROR_TYPE_INVALID_REQUEST)); return res; } - models.load(name); + models.load(meta->name); res_ok(res, {{"success", true}}); return res; }; @@ -871,6 +957,7 @@ void server_models_routes::init_routes() { preset_copy.unset_option("LLAMA_ARG_HOST"); preset_copy.unset_option("LLAMA_ARG_PORT"); preset_copy.unset_option("LLAMA_ARG_ALIAS"); + preset_copy.unset_option("LLAMA_ARG_TAGS"); status["preset"] = preset_copy.to_ini(); } if (meta.is_failed()) { @@ -879,6 +966,8 @@ void server_models_routes::init_routes() { } models_json.push_back(json { {"id", meta.name}, + {"aliases", meta.aliases}, + {"tags", meta.tags}, {"object", "model"}, // for OAI-compat {"owned_by", "llamacpp"}, // for OAI-compat {"created", t}, // for OAI-compat @@ -906,7 +995,7 @@ void server_models_routes::init_routes() { res_err(res, format_error_response("model is not loaded", ERROR_TYPE_INVALID_REQUEST)); return res; } - models.unload(name); + models.unload(model->name); res_ok(res, {{"success", true}}); return res; }; diff --git a/tools/server/server-models.h b/tools/server/server-models.h index a397abda4a..78abc8d72a 100644 --- a/tools/server/server-models.h +++ b/tools/server/server-models.h @@ -52,6 +52,8 @@ static std::string server_model_status_to_string(server_model_status status) { struct server_model_meta { common_preset preset; std::string name; + std::set aliases; // additional names that resolve to this model + std::set tags; // informational tags, not used for routing int port = 0; server_model_status status = SERVER_MODEL_STATUS_UNLOADED; int64_t last_used = 0; // for LRU unloading diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index a137427c69..d3aba18489 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -204,7 +204,8 @@ task_params server_task::params_from_json_cmpl( params.cache_prompt = json_value(data, "cache_prompt", defaults.cache_prompt); params.return_tokens = json_value(data, "return_tokens", false); params.return_progress = json_value(data, "return_progress", false); - params.n_predict = json_value(data, "n_predict", json_value(data, "max_tokens", defaults.n_predict)); + auto max_tokens = json_value(data, "max_tokens", defaults.n_predict); + params.n_predict = json_value(data, "n_predict", json_value(data, "max_completion_tokens", max_tokens)); params.n_indent = json_value(data, "n_indent", defaults.n_indent); params.n_keep = json_value(data, "n_keep", defaults.n_keep); params.n_discard = json_value(data, "n_discard", defaults.n_discard); @@ -1899,10 +1900,9 @@ server_prompt * server_prompt_cache::alloc(const server_prompt & prompt, size_t return nullptr; } - // TODO: for some reason we can't copy server_tokens, so we have to do this workaround auto & cur = states.emplace_back(); cur = { - /*.tokens =*/ server_tokens(prompt.tokens.get_text_tokens(), false), + /*.tokens =*/ prompt.tokens.clone(), /*.data =*/ std::move(state_data), /*.checkpoints =*/ prompt.checkpoints, }; diff --git a/tools/server/server-task.h b/tools/server/server-task.h index a69e8f1a3d..e2e3e5a582 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -557,6 +557,8 @@ struct server_prompt_checkpoint { llama_pos pos_min; llama_pos pos_max; + int64_t n_tokens; + std::vector data; size_t size() const { diff --git a/tools/server/server.cpp b/tools/server/server.cpp index d3d4316026..f353dcdde7 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -92,7 +92,7 @@ int main(int argc, char ** argv) { // for consistency between server router mode and single-model mode, we set the same model name as alias if (params.model_alias.empty() && !params.model.name.empty()) { - params.model_alias = params.model.name; + params.model_alias.insert(params.model.name); } common_init(); @@ -178,6 +178,7 @@ int main(int argc, char ** argv) { ctx_http.post("/v1/chat/completions", ex_wrapper(routes.post_chat_completions)); ctx_http.post("/api/chat", ex_wrapper(routes.post_chat_completions)); // ollama specific endpoint ctx_http.post("/v1/responses", ex_wrapper(routes.post_responses_oai)); + ctx_http.post("/responses", ex_wrapper(routes.post_responses_oai)); ctx_http.post("/v1/messages", ex_wrapper(routes.post_anthropic_messages)); // anthropic messages API ctx_http.post("/v1/messages/count_tokens", ex_wrapper(routes.post_anthropic_count_tokens)); // anthropic token counting ctx_http.post("/infill", ex_wrapper(routes.post_infill)); diff --git a/tools/server/tests/unit/test_basic.py b/tools/server/tests/unit/test_basic.py index 3405be3e25..d1b89cf1a9 100644 --- a/tools/server/tests/unit/test_basic.py +++ b/tools/server/tests/unit/test_basic.py @@ -94,3 +94,20 @@ def test_no_webui(): server.start() res = requests.get(url) assert res.status_code == 404 + + +def test_server_model_aliases_and_tags(): + global server + server.model_alias = "tinyllama-2,fim,code" + server.model_tags = "chat,fim,small" + server.start() + res = server.make_request("GET", "/models") + assert res.status_code == 200 + assert len(res.body["data"]) == 1 + model = res.body["data"][0] + # aliases field must contain all aliases + assert set(model["aliases"]) == {"tinyllama-2", "fim", "code"} + # tags field must contain all tags + assert set(model["tags"]) == {"chat", "fim", "small"} + # id is derived from first alias (alphabetical order from std::set) + assert model["id"] == "code" diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index f76bb1a911..5002999d9b 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -56,6 +56,7 @@ class ServerProcess: # custom options model_alias: str | None = None + model_tags: str | None = None model_url: str | None = None model_file: str | None = None model_draft: str | None = None @@ -180,6 +181,8 @@ class ServerProcess: server_args.extend(["--pooling", self.pooling]) if self.model_alias: server_args.extend(["--alias", self.model_alias]) + if self.model_tags: + server_args.extend(["--tags", self.model_tags]) if self.n_ctx: server_args.extend(["--ctx-size", self.n_ctx]) if self.n_slots: diff --git a/tools/server/webui/README.md b/tools/server/webui/README.md index 98b01fdcd7..6fc908e274 100644 --- a/tools/server/webui/README.md +++ b/tools/server/webui/README.md @@ -101,7 +101,7 @@ In a separate terminal, start the backend server: ./llama-server -m model.gguf # Multi-model (ROUTER mode) -./llama-server --model-store /path/to/models +./llama-server --models-dir /path/to/models ``` ### 3. Start Development Servers diff --git a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte index c3cb8343fc..2130658dda 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte @@ -114,6 +114,11 @@ label: 'Render user content as Markdown', type: SettingsFieldType.CHECKBOX }, + { + key: SETTINGS_KEYS.FULL_HEIGHT_CODE_BLOCKS, + label: 'Use full height code blocks', + type: SettingsFieldType.CHECKBOX + }, { key: SETTINGS_KEYS.DISABLE_AUTO_SCROLL, label: 'Disable automatic scroll', diff --git a/tools/server/webui/src/lib/components/app/content/MarkdownContent.svelte b/tools/server/webui/src/lib/components/app/content/MarkdownContent.svelte index 0bc69a739f..a0944e18a0 100644 --- a/tools/server/webui/src/lib/components/app/content/MarkdownContent.svelte +++ b/tools/server/webui/src/lib/components/app/content/MarkdownContent.svelte @@ -38,6 +38,8 @@ import { ActionIconsCodeBlock, DialogCodePreview } from '$lib/components/app'; import { createAutoScrollController } from '$lib/hooks/use-auto-scroll.svelte'; import type { DatabaseMessageExtra } from '$lib/types/database'; + import { config } from '$lib/stores/settings.svelte'; + import { SETTINGS_KEYS } from '$lib/constants/settings-keys'; interface Props { attachments?: DatabaseMessageExtra[]; @@ -593,7 +595,12 @@ }); -
+
{#each renderedBlocks as block (block.id)}
@@ -914,6 +921,16 @@ line-height: 1.3; } + .full-height-code-blocks :global(.code-block-wrapper) { + max-height: none; + } + + .full-height-code-blocks :global(.code-block-scroll-container), + .full-height-code-blocks .streaming-code-scroll-container { + max-height: none; + overflow-y: visible; + } + div :global(.code-block-header) { display: flex; justify-content: space-between; diff --git a/tools/server/webui/src/lib/components/app/models/ModelsSelector.svelte b/tools/server/webui/src/lib/components/app/models/ModelsSelector.svelte index e011fa6ec1..ebffae1212 100644 --- a/tools/server/webui/src/lib/components/app/models/ModelsSelector.svelte +++ b/tools/server/webui/src/lib/components/app/models/ModelsSelector.svelte @@ -251,9 +251,6 @@ return options.find((option) => option.id === activeId); } - if (options.length === 1) { - return options[0]; - } // No selection - return undefined to show "Select model" return undefined; } diff --git a/tools/server/webui/src/lib/constants/settings-config.ts b/tools/server/webui/src/lib/constants/settings-config.ts index 6f6dbea2ec..00dac3d6e9 100644 --- a/tools/server/webui/src/lib/constants/settings-config.ts +++ b/tools/server/webui/src/lib/constants/settings-config.ts @@ -22,6 +22,7 @@ export const SETTING_CONFIG_DEFAULT: Record = alwaysShowSidebarOnDesktop: false, autoShowSidebarOnNewChat: true, autoMicOnEmpty: false, + fullHeightCodeBlocks: false, // make sure these default values are in sync with `common.h` samplers: 'top_k;typ_p;top_p;min_p;temperature', backend_sampling: false, @@ -113,6 +114,8 @@ export const SETTING_CONFIG_INFO: Record = { 'Automatically show sidebar when starting a new chat. Disable to keep the sidebar hidden until you click on it.', autoMicOnEmpty: 'Automatically show microphone button instead of send button when textarea is empty for models with audio modality support.', + fullHeightCodeBlocks: + 'Always display code blocks at their full natural height, overriding any height limits.', pyInterpreterEnabled: 'Enable Python interpreter using Pyodide. Allows running Python code in markdown code blocks.', enableContinueGeneration: diff --git a/tools/server/webui/src/lib/constants/settings-keys.ts b/tools/server/webui/src/lib/constants/settings-keys.ts index 63960d4d56..38de41ffee 100644 --- a/tools/server/webui/src/lib/constants/settings-keys.ts +++ b/tools/server/webui/src/lib/constants/settings-keys.ts @@ -23,6 +23,7 @@ export const SETTINGS_KEYS = { DISABLE_AUTO_SCROLL: 'disableAutoScroll', ALWAYS_SHOW_SIDEBAR_ON_DESKTOP: 'alwaysShowSidebarOnDesktop', AUTO_SHOW_SIDEBAR_ON_NEW_CHAT: 'autoShowSidebarOnNewChat', + FULL_HEIGHT_CODE_BLOCKS: 'fullHeightCodeBlocks', // Sampling TEMPERATURE: 'temperature', DYNATEMP_RANGE: 'dynatemp_range', diff --git a/tools/server/webui/src/lib/services/parameter-sync.service.ts b/tools/server/webui/src/lib/services/parameter-sync.service.ts index 1d7666e955..1acb5ce453 100644 --- a/tools/server/webui/src/lib/services/parameter-sync.service.ts +++ b/tools/server/webui/src/lib/services/parameter-sync.service.ts @@ -153,6 +153,12 @@ export const SYNCABLE_PARAMETERS: SyncableParameter[] = [ serverKey: 'enableContinueGeneration', type: SyncableParameterType.BOOLEAN, canSync: true + }, + { + key: 'fullHeightCodeBlocks', + serverKey: 'fullHeightCodeBlocks', + type: SyncableParameterType.BOOLEAN, + canSync: true } ]; diff --git a/tools/server/webui/src/lib/stores/models.svelte.ts b/tools/server/webui/src/lib/stores/models.svelte.ts index 4cb6167220..c4cc3d3860 100644 --- a/tools/server/webui/src/lib/stores/models.svelte.ts +++ b/tools/server/webui/src/lib/stores/models.svelte.ts @@ -306,6 +306,16 @@ class ModelsStore { const response = await ModelsService.listRouter(); this.routerModels = response.data; await this.fetchModalitiesForLoadedModels(); + + const o = this.models.filter((option) => { + const modelProps = this.getModelProps(option.model); + + return modelProps?.webui !== false; + }); + + if (o.length === 1 && this.isModelLoaded(o[0].model)) { + this.selectModelById(o[0].id); + } } catch (error) { console.warn('Failed to fetch router models:', error); this.routerModels = []; diff --git a/vendor/cpp-httplib/CMakeLists.txt b/vendor/cpp-httplib/CMakeLists.txt index f2d3f98005..4960f9c861 100644 --- a/vendor/cpp-httplib/CMakeLists.txt +++ b/vendor/cpp-httplib/CMakeLists.txt @@ -171,7 +171,6 @@ endif() if (CPPHTTPLIB_OPENSSL_SUPPORT) target_compile_definitions(${TARGET} PUBLIC CPPHTTPLIB_OPENSSL_SUPPORT) # used in server.cpp if (APPLE AND CMAKE_SYSTEM_NAME STREQUAL "Darwin") - target_compile_definitions(${TARGET} PRIVATE CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) find_library(CORE_FOUNDATION_FRAMEWORK CoreFoundation REQUIRED) find_library(SECURITY_FRAMEWORK Security REQUIRED) target_link_libraries(${TARGET} PUBLIC ${CORE_FOUNDATION_FRAMEWORK} ${SECURITY_FRAMEWORK}) diff --git a/vendor/cpp-httplib/httplib.cpp b/vendor/cpp-httplib/httplib.cpp index 9d24594f98..7f76978fd8 100644 --- a/vendor/cpp-httplib/httplib.cpp +++ b/vendor/cpp-httplib/httplib.cpp @@ -1,5 +1,7 @@ #include "httplib.h" namespace httplib { +// httplib::any — type-erased value container (C++11 compatible) +// On C++17+ builds, thin wrappers around std::any are provided. /* * Implementation that will be part of the .cc file if split into .h + .cc. @@ -630,6 +632,56 @@ size_t to_utf8(int code, char *buff) { return 0; } +} // namespace detail + +namespace ws { +namespace impl { + +bool is_valid_utf8(const std::string &s) { + size_t i = 0; + auto n = s.size(); + while (i < n) { + auto c = static_cast(s[i]); + size_t len; + uint32_t cp; + if (c < 0x80) { + i++; + continue; + } else if ((c & 0xE0) == 0xC0) { + len = 2; + cp = c & 0x1F; + } else if ((c & 0xF0) == 0xE0) { + len = 3; + cp = c & 0x0F; + } else if ((c & 0xF8) == 0xF0) { + len = 4; + cp = c & 0x07; + } else { + return false; + } + if (i + len > n) { return false; } + for (size_t j = 1; j < len; j++) { + auto b = static_cast(s[i + j]); + if ((b & 0xC0) != 0x80) { return false; } + cp = (cp << 6) | (b & 0x3F); + } + // Overlong encoding check + if (len == 2 && cp < 0x80) { return false; } + if (len == 3 && cp < 0x800) { return false; } + if (len == 4 && cp < 0x10000) { return false; } + // Surrogate halves (U+D800..U+DFFF) and beyond U+10FFFF are invalid + if (cp >= 0xD800 && cp <= 0xDFFF) { return false; } + if (cp > 0x10FFFF) { return false; } + i += len; + } + return true; +} + +} // namespace impl +} // namespace ws + +namespace detail { + // NOTE: This code came up with the following stackoverflow post: // https://stackoverflow.com/questions/180947/base64-decode-snippet-in-c std::string base64_encode(const std::string &in) { @@ -660,6 +712,281 @@ std::string base64_encode(const std::string &in) { return out; } +std::string sha1(const std::string &input) { + // RFC 3174 SHA-1 implementation + auto left_rotate = [](uint32_t x, uint32_t n) -> uint32_t { + return (x << n) | (x >> (32 - n)); + }; + + uint32_t h0 = 0x67452301; + uint32_t h1 = 0xEFCDAB89; + uint32_t h2 = 0x98BADCFE; + uint32_t h3 = 0x10325476; + uint32_t h4 = 0xC3D2E1F0; + + // Pre-processing: adding padding bits + std::string msg = input; + uint64_t original_bit_len = static_cast(msg.size()) * 8; + msg.push_back(static_cast(0x80)); + while (msg.size() % 64 != 56) { + msg.push_back(0); + } + + // Append original length in bits as 64-bit big-endian + for (int i = 56; i >= 0; i -= 8) { + msg.push_back(static_cast((original_bit_len >> i) & 0xFF)); + } + + // Process each 512-bit chunk + for (size_t offset = 0; offset < msg.size(); offset += 64) { + uint32_t w[80]; + + for (size_t i = 0; i < 16; i++) { + w[i] = + (static_cast(static_cast(msg[offset + i * 4])) + << 24) | + (static_cast(static_cast(msg[offset + i * 4 + 1])) + << 16) | + (static_cast(static_cast(msg[offset + i * 4 + 2])) + << 8) | + (static_cast( + static_cast(msg[offset + i * 4 + 3]))); + } + + for (int i = 16; i < 80; i++) { + w[i] = left_rotate(w[i - 3] ^ w[i - 8] ^ w[i - 14] ^ w[i - 16], 1); + } + + uint32_t a = h0, b = h1, c = h2, d = h3, e = h4; + + for (int i = 0; i < 80; i++) { + uint32_t f, k; + if (i < 20) { + f = (b & c) | ((~b) & d); + k = 0x5A827999; + } else if (i < 40) { + f = b ^ c ^ d; + k = 0x6ED9EBA1; + } else if (i < 60) { + f = (b & c) | (b & d) | (c & d); + k = 0x8F1BBCDC; + } else { + f = b ^ c ^ d; + k = 0xCA62C1D6; + } + + uint32_t temp = left_rotate(a, 5) + f + e + k + w[i]; + e = d; + d = c; + c = left_rotate(b, 30); + b = a; + a = temp; + } + + h0 += a; + h1 += b; + h2 += c; + h3 += d; + h4 += e; + } + + // Produce the final hash as a 20-byte binary string + std::string hash(20, '\0'); + for (size_t i = 0; i < 4; i++) { + hash[i] = static_cast((h0 >> (24 - i * 8)) & 0xFF); + hash[4 + i] = static_cast((h1 >> (24 - i * 8)) & 0xFF); + hash[8 + i] = static_cast((h2 >> (24 - i * 8)) & 0xFF); + hash[12 + i] = static_cast((h3 >> (24 - i * 8)) & 0xFF); + hash[16 + i] = static_cast((h4 >> (24 - i * 8)) & 0xFF); + } + return hash; +} + +std::string websocket_accept_key(const std::string &client_key) { + const std::string magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + return base64_encode(sha1(client_key + magic)); +} + +bool is_websocket_upgrade(const Request &req) { + if (req.method != "GET") { return false; } + + // Check Upgrade: websocket (case-insensitive) + auto upgrade_it = req.headers.find("Upgrade"); + if (upgrade_it == req.headers.end()) { return false; } + auto upgrade_val = upgrade_it->second; + std::transform(upgrade_val.begin(), upgrade_val.end(), upgrade_val.begin(), + ::tolower); + if (upgrade_val != "websocket") { return false; } + + // Check Connection header contains "Upgrade" + auto connection_it = req.headers.find("Connection"); + if (connection_it == req.headers.end()) { return false; } + auto connection_val = connection_it->second; + std::transform(connection_val.begin(), connection_val.end(), + connection_val.begin(), ::tolower); + if (connection_val.find("upgrade") == std::string::npos) { return false; } + + // Check Sec-WebSocket-Key is a valid base64-encoded 16-byte value (24 chars) + // RFC 6455 Section 4.2.1 + auto ws_key = req.get_header_value("Sec-WebSocket-Key"); + if (ws_key.size() != 24 || ws_key[22] != '=' || ws_key[23] != '=') { + return false; + } + static const std::string b64chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + for (size_t i = 0; i < 22; i++) { + if (b64chars.find(ws_key[i]) == std::string::npos) { return false; } + } + + // Check Sec-WebSocket-Version: 13 + auto version = req.get_header_value("Sec-WebSocket-Version"); + if (version != "13") { return false; } + + return true; +} + +bool write_websocket_frame(Stream &strm, ws::Opcode opcode, + const char *data, size_t len, bool fin, + bool mask) { + // First byte: FIN + opcode + uint8_t header[2]; + header[0] = static_cast((fin ? 0x80 : 0x00) | + (static_cast(opcode) & 0x0F)); + + // Second byte: MASK + payload length + if (len < 126) { + header[1] = static_cast(len); + if (mask) { header[1] |= 0x80; } + if (strm.write(reinterpret_cast(header), 2) < 0) { return false; } + } else if (len <= 0xFFFF) { + header[1] = 126; + if (mask) { header[1] |= 0x80; } + if (strm.write(reinterpret_cast(header), 2) < 0) { return false; } + uint8_t ext[2]; + ext[0] = static_cast((len >> 8) & 0xFF); + ext[1] = static_cast(len & 0xFF); + if (strm.write(reinterpret_cast(ext), 2) < 0) { return false; } + } else { + header[1] = 127; + if (mask) { header[1] |= 0x80; } + if (strm.write(reinterpret_cast(header), 2) < 0) { return false; } + uint8_t ext[8]; + for (int i = 7; i >= 0; i--) { + ext[7 - i] = static_cast((len >> (i * 8)) & 0xFF); + } + if (strm.write(reinterpret_cast(ext), 8) < 0) { return false; } + } + + if (mask) { + // Generate random mask key + thread_local std::mt19937 rng(std::random_device{}()); + uint8_t mask_key[4]; + auto r = rng(); + std::memcpy(mask_key, &r, 4); + if (strm.write(reinterpret_cast(mask_key), 4) < 0) { return false; } + + // Write masked payload in chunks + const size_t chunk_size = 4096; + std::vector buf((std::min)(len, chunk_size)); + for (size_t offset = 0; offset < len; offset += chunk_size) { + size_t n = (std::min)(chunk_size, len - offset); + for (size_t i = 0; i < n; i++) { + buf[i] = + data[offset + i] ^ static_cast(mask_key[(offset + i) % 4]); + } + if (strm.write(buf.data(), n) < 0) { return false; } + } + } else { + if (len > 0) { + if (strm.write(data, len) < 0) { return false; } + } + } + + return true; +} + +} // namespace detail + +namespace ws { +namespace impl { + +bool read_websocket_frame(Stream &strm, Opcode &opcode, + std::string &payload, bool &fin, + bool expect_masked, size_t max_len) { + // Read first 2 bytes + uint8_t header[2]; + if (strm.read(reinterpret_cast(header), 2) != 2) { return false; } + + fin = (header[0] & 0x80) != 0; + + // RSV1, RSV2, RSV3 must be 0 when no extension is negotiated + if (header[0] & 0x70) { return false; } + + opcode = static_cast(header[0] & 0x0F); + bool masked = (header[1] & 0x80) != 0; + uint64_t payload_len = header[1] & 0x7F; + + // RFC 6455 Section 5.5: control frames MUST NOT be fragmented and + // MUST have a payload length of 125 bytes or less + bool is_control = (static_cast(opcode) & 0x08) != 0; + if (is_control) { + if (!fin) { return false; } + if (payload_len > 125) { return false; } + } + + if (masked != expect_masked) { return false; } + + // Extended payload length + if (payload_len == 126) { + uint8_t ext[2]; + if (strm.read(reinterpret_cast(ext), 2) != 2) { return false; } + payload_len = (static_cast(ext[0]) << 8) | ext[1]; + } else if (payload_len == 127) { + uint8_t ext[8]; + if (strm.read(reinterpret_cast(ext), 8) != 8) { return false; } + // RFC 6455 Section 5.2: the most significant bit MUST be 0 + if (ext[0] & 0x80) { return false; } + payload_len = 0; + for (int i = 0; i < 8; i++) { + payload_len = (payload_len << 8) | ext[i]; + } + } + + if (payload_len > max_len) { return false; } + + // Read mask key if present + uint8_t mask_key[4] = {0}; + if (masked) { + if (strm.read(reinterpret_cast(mask_key), 4) != 4) { return false; } + } + + // Read payload + payload.resize(static_cast(payload_len)); + if (payload_len > 0) { + size_t total_read = 0; + while (total_read < payload_len) { + auto n = strm.read(&payload[total_read], + static_cast(payload_len - total_read)); + if (n <= 0) { return false; } + total_read += static_cast(n); + } + } + + // Unmask if needed + if (masked) { + for (size_t i = 0; i < payload.size(); i++) { + payload[i] ^= static_cast(mask_key[i % 4]); + } + } + + return true; +} + +} // namespace impl +} // namespace ws + +namespace detail { + bool is_valid_path(const std::string &path) { size_t level = 0; size_t i = 0; @@ -1333,12 +1660,14 @@ public: bool is_readable() const override; bool wait_readable() const override; bool wait_writable() const override; + bool is_peer_alive() const override; ssize_t read(char *ptr, size_t size) override; ssize_t write(const char *ptr, size_t size) override; void get_remote_ip_and_port(std::string &ip, int &port) const override; void get_local_ip_and_port(std::string &ip, int &port) const override; socket_t socket() const override; time_t duration() const override; + void set_read_timeout(time_t sec, time_t usec = 0) override; private: socket_t sock_; @@ -2242,10 +2571,46 @@ find_content_type(const std::string &path, } } +std::string +extract_media_type(const std::string &content_type, + std::map *params = nullptr) { + // Extract type/subtype from Content-Type value (RFC 2045) + // e.g. "application/json; charset=utf-8" -> "application/json" + auto media_type = content_type; + auto semicolon_pos = media_type.find(';'); + if (semicolon_pos != std::string::npos) { + auto param_str = media_type.substr(semicolon_pos + 1); + media_type = media_type.substr(0, semicolon_pos); + + if (params) { + // Parse parameters: key=value pairs separated by ';' + split(param_str.data(), param_str.data() + param_str.size(), ';', + [&](const char *b, const char *e) { + std::string key; + std::string val; + split(b, e, '=', [&](const char *b2, const char *e2) { + if (key.empty()) { + key.assign(b2, e2); + } else { + val.assign(b2, e2); + } + }); + if (!key.empty()) { + params->emplace(trim_copy(key), trim_double_quotes_copy(val)); + } + }); + } + } + + // Trim whitespace from media type + return trim_copy(media_type); +} + bool can_compress_content_type(const std::string &content_type) { using udl::operator""_t; - auto tag = str2tag(content_type); + auto mime_type = extract_media_type(content_type); + auto tag = str2tag(mime_type); switch (tag) { case "image/svg+xml"_t: @@ -2257,7 +2622,7 @@ bool can_compress_content_type(const std::string &content_type) { case "text/event-stream"_t: return false; - default: return !content_type.rfind("text/", 0); + default: return !mime_type.rfind("text/", 0); } } @@ -2653,6 +3018,50 @@ bool read_headers(Stream &strm, Headers &headers) { return true; } +bool read_websocket_upgrade_response(Stream &strm, + const std::string &expected_accept, + std::string &selected_subprotocol) { + // Read status line + const auto bufsiz = 2048; + char buf[bufsiz]; + stream_line_reader line_reader(strm, buf, bufsiz); + if (!line_reader.getline()) { return false; } + + // Check for "HTTP/1.1 101" + auto line = std::string(line_reader.ptr(), line_reader.size()); + if (line.find("HTTP/1.1 101") == std::string::npos) { return false; } + + // Parse headers using existing read_headers + Headers headers; + if (!read_headers(strm, headers)) { return false; } + + // Verify Upgrade: websocket (case-insensitive) + auto upgrade_it = headers.find("Upgrade"); + if (upgrade_it == headers.end()) { return false; } + auto upgrade_val = upgrade_it->second; + std::transform(upgrade_val.begin(), upgrade_val.end(), upgrade_val.begin(), + ::tolower); + if (upgrade_val != "websocket") { return false; } + + // Verify Connection header contains "Upgrade" (case-insensitive) + auto connection_it = headers.find("Connection"); + if (connection_it == headers.end()) { return false; } + auto connection_val = connection_it->second; + std::transform(connection_val.begin(), connection_val.end(), + connection_val.begin(), ::tolower); + if (connection_val.find("upgrade") == std::string::npos) { return false; } + + // Verify Sec-WebSocket-Accept header value + auto it = headers.find("Sec-WebSocket-Accept"); + if (it == headers.end() || it->second != expected_accept) { return false; } + + // Extract negotiated subprotocol + auto proto_it = headers.find("Sec-WebSocket-Protocol"); + if (proto_it != headers.end()) { selected_subprotocol = proto_it->second; } + + return true; +} + enum class ReadContentResult { Success, // Successfully read the content PayloadTooLarge, // The content exceeds the specified payload limit @@ -2768,7 +3177,8 @@ bool is_chunked_transfer_encoding(const Headers &headers) { template bool prepare_content_receiver(T &x, int &status, ContentReceiverWithProgress receiver, - bool decompress, U callback) { + bool decompress, size_t payload_max_length, + bool &exceed_payload_max_length, U callback) { if (decompress) { std::string encoding = x.get_header_value("Content-Encoding"); std::unique_ptr decompressor; @@ -2784,12 +3194,22 @@ bool prepare_content_receiver(T &x, int &status, if (decompressor) { if (decompressor->is_valid()) { + size_t decompressed_size = 0; ContentReceiverWithProgress out = [&](const char *buf, size_t n, size_t off, size_t len) { - return decompressor->decompress(buf, n, - [&](const char *buf2, size_t n2) { - return receiver(buf2, n2, off, len); - }); + return decompressor->decompress( + buf, n, [&](const char *buf2, size_t n2) { + // Guard against zip-bomb: check + // decompressed size against limit. + if (payload_max_length > 0 && + (decompressed_size >= payload_max_length || + n2 > payload_max_length - decompressed_size)) { + exceed_payload_max_length = true; + return false; + } + decompressed_size += n2; + return receiver(buf2, n2, off, len); + }); }; return callback(std::move(out)); } else { @@ -2810,11 +3230,14 @@ template bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, DownloadProgress progress, ContentReceiverWithProgress receiver, bool decompress) { + bool exceed_payload_max_length = false; return prepare_content_receiver( - x, status, std::move(receiver), decompress, - [&](const ContentReceiverWithProgress &out) { + x, status, std::move(receiver), decompress, payload_max_length, + exceed_payload_max_length, [&](const ContentReceiverWithProgress &out) { auto ret = true; - auto exceed_payload_max_length = false; + // Note: exceed_payload_max_length may also be set by the decompressor + // wrapper in prepare_content_receiver when the decompressed payload + // size exceeds the limit. if (is_chunked_transfer_encoding(x.headers)) { auto result = read_content_chunked(strm, x, payload_max_length, out); @@ -2941,10 +3364,10 @@ bool write_content_with_progress(Stream &strm, return ok; }; - data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); }; + data_sink.is_writable = [&]() -> bool { return strm.is_peer_alive(); }; while (offset < end_offset && !is_shutting_down()) { - if (!strm.wait_writable()) { + if (!strm.wait_writable() || !strm.is_peer_alive()) { error = Error::Write; return false; } else if (!content_provider(offset, end_offset - offset, data_sink)) { @@ -2956,6 +3379,11 @@ bool write_content_with_progress(Stream &strm, } } + if (offset < end_offset) { // exited due to is_shutting_down(), not completion + error = Error::Write; + return false; + } + error = Error::Success; return true; } @@ -2995,12 +3423,12 @@ write_content_without_length(Stream &strm, return ok; }; - data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); }; + data_sink.is_writable = [&]() -> bool { return strm.is_peer_alive(); }; data_sink.done = [&](void) { data_available = false; }; while (data_available && !is_shutting_down()) { - if (!strm.wait_writable()) { + if (!strm.wait_writable() || !strm.is_peer_alive()) { return false; } else if (!content_provider(offset, 0, data_sink)) { return false; @@ -3008,7 +3436,8 @@ write_content_without_length(Stream &strm, return false; } } - return true; + return !data_available; // true only if done() was called, false if shutting + // down } template @@ -3044,7 +3473,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider, return ok; }; - data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); }; + data_sink.is_writable = [&]() -> bool { return strm.is_peer_alive(); }; auto done_with_trailer = [&](const Headers *trailer) { if (!ok) { return; } @@ -3094,7 +3523,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider, }; while (data_available && !is_shutting_down()) { - if (!strm.wait_writable()) { + if (!strm.wait_writable() || !strm.is_peer_alive()) { error = Error::Write; return false; } else if (!content_provider(offset, 0, data_sink)) { @@ -3106,6 +3535,11 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider, } } + if (data_available) { // exited due to is_shutting_down(), not done() + error = Error::Write; + return false; + } + error = Error::Success; return true; } @@ -3219,12 +3653,11 @@ std::string normalize_query_string(const std::string &query) { bool parse_multipart_boundary(const std::string &content_type, std::string &boundary) { - auto boundary_keyword = "boundary="; - auto pos = content_type.find(boundary_keyword); - if (pos == std::string::npos) { return false; } - auto end = content_type.find(';', pos); - auto beg = pos + strlen(boundary_keyword); - boundary = trim_double_quotes_copy(content_type.substr(beg, end - beg)); + std::map params; + extract_media_type(content_type, ¶ms); + auto it = params.find("boundary"); + if (it == params.end()) { return false; } + boundary = it->second; return !boundary.empty(); } @@ -3392,11 +3825,7 @@ bool parse_accept_header(const std::string &s, } // Remove additional parameters from media type - auto param_pos = accept_entry.media_type.find(';'); - if (param_pos != std::string::npos) { - accept_entry.media_type = - trim_copy(accept_entry.media_type.substr(0, param_pos)); - } + accept_entry.media_type = extract_media_type(accept_entry.media_type); // Basic validation of media type format if (accept_entry.media_type.empty()) { @@ -3772,6 +4201,73 @@ serialize_multipart_formdata(const UploadFormDataItems &items, return body; } +size_t get_multipart_content_length(const UploadFormDataItems &items, + const std::string &boundary) { + size_t total = 0; + for (const auto &item : items) { + total += serialize_multipart_formdata_item_begin(item, boundary).size(); + total += item.content.size(); + total += serialize_multipart_formdata_item_end().size(); + } + total += serialize_multipart_formdata_finish(boundary).size(); + return total; +} + +struct MultipartSegment { + const char *data; + size_t size; +}; + +// NOTE: items must outlive the returned ContentProvider +// (safe for synchronous use inside Post/Put/Patch) +ContentProvider +make_multipart_content_provider(const UploadFormDataItems &items, + const std::string &boundary) { + // Own the per-item header strings and the finish string + std::vector owned; + owned.reserve(items.size() + 1); + for (const auto &item : items) + owned.push_back(serialize_multipart_formdata_item_begin(item, boundary)); + owned.push_back(serialize_multipart_formdata_finish(boundary)); + + // Flat segment list: [header, content, "\r\n"] * N + [finish] + std::vector segs; + segs.reserve(items.size() * 3 + 1); + static const char crlf[] = "\r\n"; + for (size_t i = 0; i < items.size(); i++) { + segs.push_back({owned[i].data(), owned[i].size()}); + segs.push_back({items[i].content.data(), items[i].content.size()}); + segs.push_back({crlf, 2}); + } + segs.push_back({owned.back().data(), owned.back().size()}); + + struct MultipartState { + std::vector owned; + std::vector segs; + }; + auto state = std::make_shared(); + state->owned = std::move(owned); + // `segs` holds raw pointers into owned strings; std::string move preserves + // the data pointer, so these pointers remain valid after the move above. + state->segs = std::move(segs); + + return [state](size_t offset, size_t length, DataSink &sink) -> bool { + size_t pos = 0; + for (const auto &seg : state->segs) { + // Loop invariant: pos <= offset (proven by advancing pos only when + // offset - pos >= seg.size, i.e., the segment doesn't contain offset) + if (seg.size > 0 && offset - pos < seg.size) { + size_t seg_offset = offset - pos; + size_t available = seg.size - seg_offset; + size_t to_write = (std::min)(available, length); + return sink.write(seg.data + seg_offset, to_write); + } + pos += seg.size; + } + return true; // past end (shouldn't be reached when content_length is exact) + }; +} + void coalesce_ranges(Ranges &ranges, size_t content_length) { if (ranges.size() <= 1) return; @@ -4020,15 +4516,6 @@ bool expect_content(const Request &req) { return false; } -bool has_crlf(const std::string &s) { - auto p = s.c_str(); - while (*p) { - if (*p == '\r' || *p == '\n') { return true; } - p++; - } - return false; -} - #ifdef _WIN32 class WSInit { public: @@ -4148,6 +4635,52 @@ bool is_field_content(const std::string &s) { bool is_field_value(const std::string &s) { return is_field_content(s); } } // namespace fields + +bool perform_websocket_handshake(Stream &strm, const std::string &host, + int port, const std::string &path, + const Headers &headers, + std::string &selected_subprotocol) { + // Validate path and host + if (!fields::is_field_value(path) || !fields::is_field_value(host)) { + return false; + } + + // Validate user-provided headers + for (const auto &h : headers) { + if (!fields::is_field_name(h.first) || !fields::is_field_value(h.second)) { + return false; + } + } + + // Generate random Sec-WebSocket-Key + thread_local std::mt19937 rng(std::random_device{}()); + std::string key_bytes(16, '\0'); + for (size_t i = 0; i < 16; i += 4) { + auto r = rng(); + std::memcpy(&key_bytes[i], &r, (std::min)(size_t(4), size_t(16 - i))); + } + auto client_key = base64_encode(key_bytes); + + // Build upgrade request + std::string req_str = "GET " + path + " HTTP/1.1\r\n"; + req_str += "Host: " + host + ":" + std::to_string(port) + "\r\n"; + req_str += "Upgrade: websocket\r\n"; + req_str += "Connection: Upgrade\r\n"; + req_str += "Sec-WebSocket-Key: " + client_key + "\r\n"; + req_str += "Sec-WebSocket-Version: 13\r\n"; + for (const auto &h : headers) { + req_str += h.first + ": " + h.second + "\r\n"; + } + req_str += "\r\n"; + + if (strm.write(req_str.data(), req_str.size()) < 0) { return false; } + + // Verify 101 response and Sec-WebSocket-Accept header + auto expected_accept = websocket_accept_key(client_key); + return read_websocket_upgrade_response(strm, expected_accept, + selected_subprotocol); +} + } // namespace detail /* @@ -4170,12 +4703,14 @@ public: bool is_readable() const override; bool wait_readable() const override; bool wait_writable() const override; + bool is_peer_alive() const override; ssize_t read(char *ptr, size_t size) override; ssize_t write(const char *ptr, size_t size) override; void get_remote_ip_and_port(std::string &ip, int &port) const override; void get_local_ip_and_port(std::string &ip, int &port) const override; socket_t socket() const override; time_t duration() const override; + void set_read_timeout(time_t sec, time_t usec = 0) override; private: socket_t sock_; @@ -4268,6 +4803,39 @@ std::string SHA_512(const std::string &s) { #endif return hash_to_hex(hash); } +#elif defined(CPPHTTPLIB_WOLFSSL_SUPPORT) +namespace { +template +std::string hash_to_hex(const unsigned char (&hash)[N]) { + std::stringstream ss; + for (size_t i = 0; i < N; ++i) { + ss << std::hex << std::setw(2) << std::setfill('0') + << static_cast(hash[i]); + } + return ss.str(); +} +} // namespace + +std::string MD5(const std::string &s) { + unsigned char hash[WC_MD5_DIGEST_SIZE]; + wc_Md5Hash(reinterpret_cast(s.c_str()), + static_cast(s.size()), hash); + return hash_to_hex(hash); +} + +std::string SHA_256(const std::string &s) { + unsigned char hash[WC_SHA256_DIGEST_SIZE]; + wc_Sha256Hash(reinterpret_cast(s.c_str()), + static_cast(s.size()), hash); + return hash_to_hex(hash); +} + +std::string SHA_512(const std::string &s) { + unsigned char hash[WC_SHA512_DIGEST_SIZE]; + wc_Sha512Hash(reinterpret_cast(s.c_str()), + static_cast(s.size()), hash); + return hash_to_hex(hash); +} #endif bool is_ip_address(const std::string &host) { @@ -4510,6 +5078,53 @@ bool verify_cert_with_windows_schannel( } #endif // _WIN32 +bool setup_client_tls_session(const std::string &host, tls::ctx_t &ctx, + tls::session_t &session, socket_t sock, + bool server_certificate_verification, + const std::string &ca_cert_file_path, + tls::ca_store_t ca_cert_store, + time_t timeout_sec, time_t timeout_usec) { + using namespace tls; + + ctx = create_client_context(); + if (!ctx) { return false; } + + if (server_certificate_verification) { + if (!ca_cert_file_path.empty()) { + load_ca_file(ctx, ca_cert_file_path.c_str()); + } + if (ca_cert_store) { set_ca_store(ctx, ca_cert_store); } + load_system_certs(ctx); + } + + bool is_ip = is_ip_address(host); + +#ifdef CPPHTTPLIB_MBEDTLS_SUPPORT + if (is_ip && server_certificate_verification) { + set_verify_client(ctx, false); + } else { + set_verify_client(ctx, server_certificate_verification); + } +#endif + + session = create_session(ctx, sock); + if (!session) { return false; } + + // RFC 6066: SNI must not be set for IP addresses + if (!is_ip) { set_sni(session, host.c_str()); } + if (server_certificate_verification) { set_hostname(session, host.c_str()); } + + if (!connect_nonblocking(session, sock, timeout_sec, timeout_usec, nullptr)) { + return false; + } + + if (server_certificate_verification) { + if (get_verify_result(session) != 0) { return false; } + } + + return true; +} + } // namespace detail #endif // CPPHTTPLIB_SSL_ENABLED @@ -5040,7 +5655,7 @@ size_t Request::get_param_value_count(const std::string &key) const { bool Request::is_multipart_form_data() const { const auto &content_type = get_header_value("Content-Type"); - return !content_type.rfind("multipart/form-data", 0); + return detail::extract_media_type(content_type) == "multipart/form-data"; } // Multipart FormData implementation @@ -5327,22 +5942,37 @@ ssize_t detail::BodyReader::read(char *buf, size_t len) { } // ThreadPool implementation -ThreadPool::ThreadPool(size_t n, size_t mqr) - : shutdown_(false), max_queued_requests_(mqr) { - threads_.reserve(n); - while (n) { - threads_.emplace_back(worker(*this)); - n--; +ThreadPool::ThreadPool(size_t n, size_t max_n, size_t mqr) + : base_thread_count_(n), max_queued_requests_(mqr), idle_thread_count_(0), + shutdown_(false) { +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + if (max_n != 0 && max_n < n) { + std::string msg = "max_threads must be >= base_threads"; + throw std::invalid_argument(msg); + } +#endif + max_thread_count_ = max_n == 0 ? n : max_n; + threads_.reserve(base_thread_count_); + for (size_t i = 0; i < base_thread_count_; i++) { + threads_.emplace_back(std::thread([this]() { worker(false); })); } } bool ThreadPool::enqueue(std::function fn) { { std::unique_lock lock(mutex_); + if (shutdown_) { return false; } if (max_queued_requests_ > 0 && jobs_.size() >= max_queued_requests_) { return false; } jobs_.push_back(std::move(fn)); + + // Spawn a dynamic thread if no idle threads and under max + if (idle_thread_count_ == 0 && + threads_.size() + dynamic_threads_.size() < max_thread_count_) { + cleanup_finished_threads(); + dynamic_threads_.emplace_back(std::thread([this]() { worker(true); })); + } } cond_.notify_one(); @@ -5350,7 +5980,6 @@ bool ThreadPool::enqueue(std::function fn) { } void ThreadPool::shutdown() { - // Stop all worker threads... { std::unique_lock lock(mutex_); shutdown_ = true; @@ -5358,31 +5987,84 @@ void ThreadPool::shutdown() { cond_.notify_all(); - // Join... for (auto &t : threads_) { - t.join(); + if (t.joinable()) { t.join(); } + } + + // Move dynamic_threads_ to a local list under the lock to avoid racing + // with worker threads that call move_to_finished() concurrently. + std::list remaining_dynamic; + { + std::unique_lock lock(mutex_); + remaining_dynamic = std::move(dynamic_threads_); + } + for (auto &t : remaining_dynamic) { + if (t.joinable()) { t.join(); } + } + + std::unique_lock lock(mutex_); + cleanup_finished_threads(); +} + +void ThreadPool::move_to_finished(std::thread::id id) { + // Must be called with mutex_ held + for (auto it = dynamic_threads_.begin(); it != dynamic_threads_.end(); ++it) { + if (it->get_id() == id) { + finished_threads_.push_back(std::move(*it)); + dynamic_threads_.erase(it); + return; + } } } -ThreadPool::worker::worker(ThreadPool &pool) : pool_(pool) {} +void ThreadPool::cleanup_finished_threads() { + // Must be called with mutex_ held + for (auto &t : finished_threads_) { + if (t.joinable()) { t.join(); } + } + finished_threads_.clear(); +} -void ThreadPool::worker::operator()() { +void ThreadPool::worker(bool is_dynamic) { for (;;) { std::function fn; { - std::unique_lock lock(pool_.mutex_); + std::unique_lock lock(mutex_); + idle_thread_count_++; - pool_.cond_.wait(lock, - [&] { return !pool_.jobs_.empty() || pool_.shutdown_; }); + if (is_dynamic) { + auto has_work = cond_.wait_for( + lock, std::chrono::seconds(CPPHTTPLIB_THREAD_POOL_IDLE_TIMEOUT), + [&] { return !jobs_.empty() || shutdown_; }); + if (!has_work) { + // Timed out with no work - exit this dynamic thread + idle_thread_count_--; + move_to_finished(std::this_thread::get_id()); + break; + } + } else { + cond_.wait(lock, [&] { return !jobs_.empty() || shutdown_; }); + } - if (pool_.shutdown_ && pool_.jobs_.empty()) { break; } + idle_thread_count_--; - fn = pool_.jobs_.front(); - pool_.jobs_.pop_front(); + if (shutdown_ && jobs_.empty()) { break; } + + fn = std::move(jobs_.front()); + jobs_.pop_front(); } assert(true == static_cast(fn)); fn(); + + // Dynamic thread: exit if queue is empty after task completion + if (is_dynamic) { + std::unique_lock lock(mutex_); + if (jobs_.empty()) { + move_to_finished(std::this_thread::get_id()); + break; + } + } } #if defined(CPPHTTPLIB_OPENSSL_SUPPORT) && !defined(OPENSSL_IS_BORINGSSL) && \ @@ -5445,8 +6127,11 @@ bool SocketStream::wait_readable() const { } bool SocketStream::wait_writable() const { - return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && - is_socket_alive(sock_); + return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0; +} + +bool SocketStream::is_peer_alive() const { + return detail::is_socket_alive(sock_); } ssize_t SocketStream::read(char *ptr, size_t size) { @@ -5540,6 +6225,11 @@ time_t SocketStream::duration() const { .count(); } +void SocketStream::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; +} + // Buffer stream implementation bool BufferStream::is_readable() const { return true; } @@ -5772,7 +6462,11 @@ bool SSLSocketStream::wait_readable() const { bool SSLSocketStream::wait_writable() const { return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && - is_socket_alive(sock_) && !tls::is_peer_closed(session_, sock_); + !tls::is_peer_closed(session_, sock_); +} + +bool SSLSocketStream::is_peer_alive() const { + return !tls::is_peer_closed(session_, sock_); } ssize_t SSLSocketStream::read(char *ptr, size_t size) { @@ -5865,6 +6559,11 @@ time_t SSLSocketStream::duration() const { .count(); } +void SSLSocketStream::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; +} + } // namespace detail #endif // CPPHTTPLIB_SSL_ENABLED @@ -5874,8 +6573,10 @@ time_t SSLSocketStream::duration() const { // HTTP server implementation Server::Server() - : new_task_queue( - [] { return new ThreadPool(CPPHTTPLIB_THREAD_POOL_COUNT); }) { + : new_task_queue([] { + return new ThreadPool(CPPHTTPLIB_THREAD_POOL_COUNT, + CPPHTTPLIB_THREAD_POOL_MAX_COUNT); + }) { #ifndef _WIN32 signal(SIGPIPE, SIG_IGN); #endif @@ -5950,6 +6651,21 @@ Server &Server::Options(const std::string &pattern, Handler handler) { return *this; } +Server &Server::WebSocket(const std::string &pattern, + WebSocketHandler handler) { + websocket_handlers_.push_back( + {make_matcher(pattern), std::move(handler), nullptr}); + return *this; +} + +Server &Server::WebSocket(const std::string &pattern, + WebSocketHandler handler, + SubProtocolSelector sub_protocol_selector) { + websocket_handlers_.push_back({make_matcher(pattern), std::move(handler), + std::move(sub_protocol_selector)}); + return *this; +} + bool Server::set_base_dir(const std::string &dir, const std::string &mount_point) { return set_mount_point(mount_point, dir); @@ -6274,35 +6990,33 @@ bool Server::write_response_core(Stream &strm, bool close_connection, if (post_routing_handler_) { post_routing_handler_(req, res); } // Response line and headers - { - detail::BufferStream bstrm; - if (!detail::write_response_line(bstrm, res.status)) { return false; } - if (header_writer_(bstrm, res.headers) <= 0) { return false; } + detail::BufferStream bstrm; + if (!detail::write_response_line(bstrm, res.status)) { return false; } + if (header_writer_(bstrm, res.headers) <= 0) { return false; } - // Flush buffer - auto &data = bstrm.get_buffer(); - detail::write_data(strm, data.data(), data.size()); + // Combine small body with headers to reduce write syscalls + if (req.method != "HEAD" && !res.body.empty() && !res.content_provider_) { + bstrm.write(res.body.data(), res.body.size()); } - // Body + // Log before writing to avoid race condition with client-side code that + // accesses logger-captured data immediately after receiving the response. + output_log(req, res); + + // Flush buffer + auto &data = bstrm.get_buffer(); + if (!detail::write_data(strm, data.data(), data.size())) { return false; } + + // Streaming body auto ret = true; - if (req.method != "HEAD") { - if (!res.body.empty()) { - if (!detail::write_data(strm, res.body.data(), res.body.size())) { - ret = false; - } - } else if (res.content_provider_) { - if (write_content_with_provider(strm, req, res, boundary, content_type)) { - res.content_provider_success_ = true; - } else { - ret = false; - } + if (req.method != "HEAD" && res.content_provider_) { + if (write_content_with_provider(strm, req, res, boundary, content_type)) { + res.content_provider_success_ = true; + } else { + ret = false; } } - // Log - output_log(req, res); - return ret; } @@ -6423,7 +7137,8 @@ bool Server::read_content(Stream &strm, Request &req, Response &res) { return true; })) { const auto &content_type = req.get_header_value("Content-Type"); - if (!content_type.find("application/x-www-form-urlencoded")) { + if (detail::extract_media_type(content_type) == + "application/x-www-form-urlencoded") { if (req.body.size() > CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH) { res.status = StatusCode::PayloadTooLarge_413; // NOTE: should be 414? output_error_log(Error::ExceedMaxPayloadSize, &req); @@ -6810,45 +7525,63 @@ bool Server::routing(Request &req, Response &res, Stream &strm) { if (detail::expect_content(req)) { // Content reader handler { + // Track whether the ContentReader was aborted due to the decompressed + // payload exceeding `payload_max_length_`. + // The user handler runs after the lambda returns, so we must restore the + // 413 status if the handler overwrites it. + bool content_reader_payload_too_large = false; + ContentReader reader( [&](ContentReceiver receiver) { auto result = read_content_with_content_receiver( strm, req, res, std::move(receiver), nullptr, nullptr); - if (!result) { output_error_log(Error::Read, &req); } + if (!result) { + output_error_log(Error::Read, &req); + if (res.status == StatusCode::PayloadTooLarge_413) { + content_reader_payload_too_large = true; + } + } return result; }, [&](FormDataHeader header, ContentReceiver receiver) { auto result = read_content_with_content_receiver( strm, req, res, nullptr, std::move(header), std::move(receiver)); - if (!result) { output_error_log(Error::Read, &req); } + if (!result) { + output_error_log(Error::Read, &req); + if (res.status == StatusCode::PayloadTooLarge_413) { + content_reader_payload_too_large = true; + } + } return result; }); + bool dispatched = false; if (req.method == "POST") { - if (dispatch_request_for_content_reader( - req, res, std::move(reader), - post_handlers_for_content_reader_)) { - return true; - } + dispatched = dispatch_request_for_content_reader( + req, res, std::move(reader), post_handlers_for_content_reader_); } else if (req.method == "PUT") { - if (dispatch_request_for_content_reader( - req, res, std::move(reader), - put_handlers_for_content_reader_)) { - return true; - } + dispatched = dispatch_request_for_content_reader( + req, res, std::move(reader), put_handlers_for_content_reader_); } else if (req.method == "PATCH") { - if (dispatch_request_for_content_reader( - req, res, std::move(reader), - patch_handlers_for_content_reader_)) { - return true; - } + dispatched = dispatch_request_for_content_reader( + req, res, std::move(reader), patch_handlers_for_content_reader_); } else if (req.method == "DELETE") { - if (dispatch_request_for_content_reader( - req, res, std::move(reader), - delete_handlers_for_content_reader_)) { - return true; + dispatched = dispatch_request_for_content_reader( + req, res, std::move(reader), delete_handlers_for_content_reader_); + } + + if (dispatched) { + if (content_reader_payload_too_large) { + // Enforce the limit: override any status the handler may have set + // and return false so the error path sends a plain 413 response. + res.status = StatusCode::PayloadTooLarge_413; + res.body.clear(); + res.content_length_ = 0; + res.content_provider_ = nullptr; + return false; } + return true; } } @@ -7072,7 +7805,8 @@ Server::process_request(Stream &strm, const std::string &remote_addr, int remote_port, const std::string &local_addr, int local_port, bool close_connection, bool &connection_closed, - const std::function &setup_request) { + const std::function &setup_request, + bool *websocket_upgraded) { std::array buf{}; detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); @@ -7175,6 +7909,77 @@ Server::process_request(Stream &strm, const std::string &remote_addr, return !detail::is_socket_alive(sock); }; + // WebSocket upgrade + // Check pre_routing_handler_ before upgrading so that authentication + // and other middleware can reject the request with an HTTP response + // (e.g., 401) before the protocol switches. + if (detail::is_websocket_upgrade(req)) { + if (pre_routing_handler_ && + pre_routing_handler_(req, res) == HandlerResponse::Handled) { + if (res.status == -1) { res.status = StatusCode::OK_200; } + return write_response(strm, close_connection, req, res); + } + // Find matching WebSocket handler + for (const auto &entry : websocket_handlers_) { + if (entry.matcher->match(req)) { + // Compute accept key + auto client_key = req.get_header_value("Sec-WebSocket-Key"); + auto accept_key = detail::websocket_accept_key(client_key); + + // Negotiate subprotocol + std::string selected_subprotocol; + if (entry.sub_protocol_selector) { + auto protocol_header = req.get_header_value("Sec-WebSocket-Protocol"); + if (!protocol_header.empty()) { + std::vector protocols; + std::istringstream iss(protocol_header); + std::string token; + while (std::getline(iss, token, ',')) { + // Trim whitespace + auto start = token.find_first_not_of(' '); + auto end = token.find_last_not_of(' '); + if (start != std::string::npos) { + protocols.push_back(token.substr(start, end - start + 1)); + } + } + selected_subprotocol = entry.sub_protocol_selector(protocols); + } + } + + // Send 101 Switching Protocols + std::string handshake_response = "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: " + + accept_key + "\r\n"; + if (!selected_subprotocol.empty()) { + if (!detail::fields::is_field_value(selected_subprotocol)) { + return false; + } + handshake_response += + "Sec-WebSocket-Protocol: " + selected_subprotocol + "\r\n"; + } + handshake_response += "\r\n"; + if (strm.write(handshake_response.data(), handshake_response.size()) < + 0) { + return false; + } + + connection_closed = true; + if (websocket_upgraded) { *websocket_upgraded = true; } + + { + // Use WebSocket-specific read timeout instead of HTTP timeout + strm.set_read_timeout(CPPHTTPLIB_WEBSOCKET_READ_TIMEOUT_SECOND, 0); + ws::WebSocket ws(strm, req, true); + entry.handler(req, ws); + } + return true; + } + } + // No matching handler - fall through to 404 + } + // Routing auto routed = false; #ifdef CPPHTTPLIB_NO_EXCEPTIONS @@ -7189,16 +7994,6 @@ Server::process_request(Stream &strm, const std::string &remote_addr, routed = true; } else { res.status = StatusCode::InternalServerError_500; - std::string val; - auto s = e.what(); - for (size_t i = 0; s[i]; i++) { - switch (s[i]) { - case '\r': val += "\\r"; break; - case '\n': val += "\\n"; break; - default: val += s[i]; break; - } - } - res.set_header("EXCEPTION_WHAT", val); } } catch (...) { if (exception_handler_) { @@ -7207,7 +8002,6 @@ Server::process_request(Stream &strm, const std::string &remote_addr, routed = true; } else { res.status = StatusCode::InternalServerError_500; - res.set_header("EXCEPTION_WHAT", "UNKNOWN"); } } #endif @@ -7271,6 +8065,7 @@ bool Server::process_and_close_socket(socket_t sock) { int local_port = 0; detail::get_local_ip_and_port(sock, local_addr, local_port); + bool websocket_upgraded = false; auto ret = detail::process_server_socket( svr_sock_, sock, keep_alive_max_count_, keep_alive_timeout_sec_, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, @@ -7278,7 +8073,7 @@ bool Server::process_and_close_socket(socket_t sock) { [&](Stream &strm, bool close_connection, bool &connection_closed) { return process_request(strm, remote_addr, remote_port, local_addr, local_port, close_connection, connection_closed, - nullptr); + nullptr, &websocket_upgraded); }); detail::shutdown_socket(sock); @@ -9019,8 +9814,10 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers, const auto &boundary = detail::make_multipart_data_boundary(); const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); - const auto &body = detail::serialize_multipart_formdata(items, boundary); - return Post(path, headers, body, content_type, progress); + auto content_length = detail::get_multipart_content_length(items, boundary); + return Post(path, headers, content_length, + detail::make_multipart_content_provider(items, boundary), + content_type, progress); } Result ClientImpl::Post(const std::string &path, const Headers &headers, @@ -9033,8 +9830,10 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers, const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); - const auto &body = detail::serialize_multipart_formdata(items, boundary); - return Post(path, headers, body, content_type, progress); + auto content_length = detail::get_multipart_content_length(items, boundary); + return Post(path, headers, content_length, + detail::make_multipart_content_provider(items, boundary), + content_type, progress); } Result ClientImpl::Post(const std::string &path, const Headers &headers, @@ -9212,8 +10011,10 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers, const auto &boundary = detail::make_multipart_data_boundary(); const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); - const auto &body = detail::serialize_multipart_formdata(items, boundary); - return Put(path, headers, body, content_type, progress); + auto content_length = detail::get_multipart_content_length(items, boundary); + return Put(path, headers, content_length, + detail::make_multipart_content_provider(items, boundary), + content_type, progress); } Result ClientImpl::Put(const std::string &path, const Headers &headers, @@ -9226,8 +10027,10 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers, const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); - const auto &body = detail::serialize_multipart_formdata(items, boundary); - return Put(path, headers, body, content_type, progress); + auto content_length = detail::get_multipart_content_length(items, boundary); + return Put(path, headers, content_length, + detail::make_multipart_content_provider(items, boundary), + content_type, progress); } Result ClientImpl::Put(const std::string &path, const Headers &headers, @@ -9407,8 +10210,10 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers, const auto &boundary = detail::make_multipart_data_boundary(); const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); - const auto &body = detail::serialize_multipart_formdata(items, boundary); - return Patch(path, headers, body, content_type, progress); + auto content_length = detail::get_multipart_content_length(items, boundary); + return Patch(path, headers, content_length, + detail::make_multipart_content_provider(items, boundary), + content_type, progress); } Result ClientImpl::Patch(const std::string &path, const Headers &headers, @@ -9421,8 +10226,10 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers, const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); - const auto &body = detail::serialize_multipart_formdata(items, boundary); - return Patch(path, headers, body, content_type, progress); + auto content_length = detail::get_multipart_content_length(items, boundary); + return Patch(path, headers, content_length, + detail::make_multipart_content_provider(items, boundary), + content_type, progress); } Result ClientImpl::Patch(const std::string &path, const Headers &headers, @@ -10579,9 +11386,9 @@ bool SSLServer::process_and_close_socket(socket_t sock) { // Use scope_exit to ensure cleanup on all paths (including exceptions) bool handshake_done = false; bool ret = false; + bool websocket_upgraded = false; auto cleanup = detail::scope_exit([&] { - // Shutdown gracefully if handshake succeeded and processing was successful - if (handshake_done) { shutdown(session, ret); } + if (handshake_done) { shutdown(session, !websocket_upgraded && ret); } free_session(session); detail::shutdown_socket(sock); detail::close_socket(sock); @@ -10621,9 +11428,10 @@ bool SSLServer::process_and_close_socket(socket_t sock) { read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, write_timeout_usec_, [&](Stream &strm, bool close_connection, bool &connection_closed) { - return process_request(strm, remote_addr, remote_port, local_addr, - local_port, close_connection, connection_closed, - [&](Request &req) { req.ssl = session; }); + return process_request( + strm, remote_addr, remote_port, local_addr, local_port, + close_connection, connection_closed, + [&](Request &req) { req.ssl = session; }, &websocket_upgraded); }); return ret; @@ -10874,8 +11682,7 @@ void SSLClient::set_session_verifier( session_verifier_ = std::move(verifier); } -#if defined(_WIN32) && \ - !defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE) +#ifdef CPPHTTPLIB_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE void SSLClient::enable_windows_certificate_verification(bool enabled) { enable_windows_cert_verification_ = enabled; } @@ -10929,11 +11736,11 @@ bool SSLClient::initialize_ssl(Socket &socket, Error &error) { bool is_ip = detail::is_ip_address(host_); -#ifdef CPPHTTPLIB_MBEDTLS_SUPPORT - // MbedTLS needs explicit verification mode (OpenSSL uses SSL_VERIFY_NONE - // by default and performs all verification post-handshake). +#if defined(CPPHTTPLIB_MBEDTLS_SUPPORT) || defined(CPPHTTPLIB_WOLFSSL_SUPPORT) + // MbedTLS/wolfSSL need explicit verification mode (OpenSSL uses + // SSL_VERIFY_NONE by default and performs all verification post-handshake). // For IP addresses with verification enabled, use OPTIONAL mode since - // MbedTLS requires hostname for VERIFY_REQUIRED. + // these backends require hostname for strict verification. if (is_ip && server_certificate_verification_) { set_verify_client(ctx_, false); } else { @@ -11033,8 +11840,7 @@ bool SSLClient::initialize_ssl(Socket &socket, Error &error) { } } -#if defined(_WIN32) && \ - !defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE) +#ifdef CPPHTTPLIB_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE // Additional Windows Schannel verification. // This provides real-time certificate validation with Windows Update // integration, working with both OpenSSL and MbedTLS backends. @@ -11080,8 +11886,7 @@ void Client::enable_server_hostname_verification(bool enabled) { cli_->enable_server_hostname_verification(enabled); } -#if defined(_WIN32) && \ - !defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE) +#ifdef CPPHTTPLIB_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE void Client::enable_windows_certificate_verification(bool enabled) { if (is_ssl_) { static_cast(*cli_).enable_windows_certificate_verification( @@ -11154,6 +11959,107 @@ VerifyCallback &get_mbedtls_verify_callback() { return callback; } +// Check if a string is an IPv4 address +bool is_ipv4_address(const std::string &str) { + int dots = 0; + for (char c : str) { + if (c == '.') { + dots++; + } else if (!isdigit(static_cast(c))) { + return false; + } + } + return dots == 3; +} + +// Parse IPv4 address string to bytes +bool parse_ipv4(const std::string &str, unsigned char *out) { + int parts[4]; + if (sscanf(str.c_str(), "%d.%d.%d.%d", &parts[0], &parts[1], &parts[2], + &parts[3]) != 4) { + return false; + } + for (int i = 0; i < 4; i++) { + if (parts[i] < 0 || parts[i] > 255) return false; + out[i] = static_cast(parts[i]); + } + return true; +} + +#ifdef _WIN32 +// Enumerate Windows system certificates and call callback with DER data +template +bool enumerate_windows_system_certs(Callback cb) { + bool loaded = false; + static const wchar_t *store_names[] = {L"ROOT", L"CA"}; + for (auto store_name : store_names) { + HCERTSTORE hStore = CertOpenSystemStoreW(0, store_name); + if (hStore) { + PCCERT_CONTEXT pContext = nullptr; + while ((pContext = CertEnumCertificatesInStore(hStore, pContext)) != + nullptr) { + if (cb(pContext->pbCertEncoded, pContext->cbCertEncoded)) { + loaded = true; + } + } + CertCloseStore(hStore, 0); + } + } + return loaded; +} +#endif + +#ifdef CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN +// Enumerate macOS Keychain certificates and call callback with DER data +template +bool enumerate_macos_keychain_certs(Callback cb) { + bool loaded = false; + CFArrayRef certs = nullptr; + OSStatus status = SecTrustCopyAnchorCertificates(&certs); + if (status == errSecSuccess && certs) { + CFIndex count = CFArrayGetCount(certs); + for (CFIndex i = 0; i < count; i++) { + SecCertificateRef cert = + (SecCertificateRef)CFArrayGetValueAtIndex(certs, i); + CFDataRef data = SecCertificateCopyData(cert); + if (data) { + if (cb(CFDataGetBytePtr(data), + static_cast(CFDataGetLength(data)))) { + loaded = true; + } + CFRelease(data); + } + } + CFRelease(certs); + } + return loaded; +} +#endif + +#if !defined(_WIN32) && !(defined(__APPLE__) && \ + defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN)) +// Common CA certificate file paths on Linux/Unix +const char **system_ca_paths() { + static const char *paths[] = { + "/etc/ssl/certs/ca-certificates.crt", // Debian/Ubuntu + "/etc/pki/tls/certs/ca-bundle.crt", // RHEL/CentOS + "/etc/ssl/ca-bundle.pem", // OpenSUSE + "/etc/pki/tls/cacert.pem", // OpenELEC + "/etc/ssl/cert.pem", // Alpine, FreeBSD + nullptr}; + return paths; +} + +// Common CA certificate directory paths on Linux/Unix +const char **system_ca_dirs() { + static const char *dirs[] = {"/etc/ssl/certs", // Debian/Ubuntu + "/etc/pki/tls/certs", // RHEL/CentOS + "/usr/share/ca-certificates", // Other + nullptr}; + return dirs; +} +#endif + } // namespace impl bool set_client_ca_file(ctx_t ctx, const char *ca_file, @@ -12730,33 +13636,6 @@ int mbedtls_sni_callback(void *p_ctx, mbedtls_ssl_context *ssl, int mbedtls_verify_callback(void *data, mbedtls_x509_crt *crt, int cert_depth, uint32_t *flags); -// Check if a string is an IPv4 address -bool is_ipv4_address(const std::string &str) { - int dots = 0; - for (char c : str) { - if (c == '.') { - dots++; - } else if (!isdigit(static_cast(c))) { - return false; - } - } - return dots == 3; -} - -// Parse IPv4 address string to bytes -bool parse_ipv4(const std::string &str, unsigned char *out) { - int parts[4]; - if (sscanf(str.c_str(), "%d.%d.%d.%d", &parts[0], &parts[1], &parts[2], - &parts[3]) != 4) { - return false; - } - for (int i = 0; i < 4; i++) { - if (parts[i] < 0 || parts[i] > 255) return false; - out[i] = static_cast(parts[i]); - } - return true; -} - // MbedTLS verify callback wrapper int mbedtls_verify_callback(void *data, mbedtls_x509_crt *crt, int cert_depth, uint32_t *flags) { @@ -12971,68 +13850,26 @@ bool load_system_certs(ctx_t ctx) { bool loaded = false; #ifdef _WIN32 - // Load from Windows certificate store (ROOT and CA) - static const wchar_t *store_names[] = {L"ROOT", L"CA"}; - for (auto store_name : store_names) { - HCERTSTORE hStore = CertOpenSystemStoreW(0, store_name); - if (hStore) { - PCCERT_CONTEXT pContext = nullptr; - while ((pContext = CertEnumCertificatesInStore(hStore, pContext)) != - nullptr) { - int ret = mbedtls_x509_crt_parse_der( - &mctx->ca_chain, pContext->pbCertEncoded, pContext->cbCertEncoded); - if (ret == 0) { loaded = true; } - } - CertCloseStore(hStore, 0); - } - } + loaded = impl::enumerate_windows_system_certs( + [&](const unsigned char *data, size_t len) { + return mbedtls_x509_crt_parse_der(&mctx->ca_chain, data, len) == 0; + }); #elif defined(__APPLE__) && defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) - // Load from macOS Keychain - CFArrayRef certs = nullptr; - OSStatus status = SecTrustCopyAnchorCertificates(&certs); - if (status == errSecSuccess && certs) { - CFIndex count = CFArrayGetCount(certs); - for (CFIndex i = 0; i < count; i++) { - SecCertificateRef cert = - (SecCertificateRef)CFArrayGetValueAtIndex(certs, i); - CFDataRef data = SecCertificateCopyData(cert); - if (data) { - int ret = mbedtls_x509_crt_parse_der( - &mctx->ca_chain, CFDataGetBytePtr(data), - static_cast(CFDataGetLength(data))); - if (ret == 0) { loaded = true; } - CFRelease(data); - } - } - CFRelease(certs); - } + loaded = impl::enumerate_macos_keychain_certs( + [&](const unsigned char *data, size_t len) { + return mbedtls_x509_crt_parse_der(&mctx->ca_chain, data, len) == 0; + }); #else - // Try common CA certificate locations on Linux/Unix - static const char *ca_paths[] = { - "/etc/ssl/certs/ca-certificates.crt", // Debian/Ubuntu - "/etc/pki/tls/certs/ca-bundle.crt", // RHEL/CentOS - "/etc/ssl/ca-bundle.pem", // OpenSUSE - "/etc/pki/tls/cacert.pem", // OpenELEC - "/etc/ssl/cert.pem", // Alpine, FreeBSD - nullptr}; - - for (const char **path = ca_paths; *path; ++path) { - int ret = mbedtls_x509_crt_parse_file(&mctx->ca_chain, *path); - if (ret >= 0) { + for (auto path = impl::system_ca_paths(); *path; ++path) { + if (mbedtls_x509_crt_parse_file(&mctx->ca_chain, *path) >= 0) { loaded = true; break; } } - // Also try the CA directory if (!loaded) { - static const char *ca_dirs[] = {"/etc/ssl/certs", // Debian/Ubuntu - "/etc/pki/tls/certs", // RHEL/CentOS - "/usr/share/ca-certificates", nullptr}; - - for (const char **dir = ca_dirs; *dir; ++dir) { - int ret = mbedtls_x509_crt_parse_path(&mctx->ca_chain, *dir); - if (ret >= 0) { + for (auto dir = impl::system_ca_dirs(); *dir; ++dir) { + if (mbedtls_x509_crt_parse_path(&mctx->ca_chain, *dir) >= 0) { loaded = true; break; } @@ -13083,6 +13920,18 @@ bool set_client_cert_pem(ctx_t ctx, const char *cert, const char *key, return false; } + // Verify that the certificate and private key match +#ifdef CPPHTTPLIB_MBEDTLS_V3 + ret = mbedtls_pk_check_pair(&mctx->own_cert.pk, &mctx->own_key, + mbedtls_ctr_drbg_random, &mctx->ctr_drbg); +#else + ret = mbedtls_pk_check_pair(&mctx->own_cert.pk, &mctx->own_key); +#endif + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + ret = mbedtls_ssl_conf_own_cert(&mctx->conf, &mctx->own_cert, &mctx->own_key); if (ret != 0) { impl::mbedtls_last_error() = ret; @@ -13116,6 +13965,18 @@ bool set_client_cert_file(ctx_t ctx, const char *cert_path, return false; } + // Verify that the certificate and private key match +#ifdef CPPHTTPLIB_MBEDTLS_V3 + ret = mbedtls_pk_check_pair(&mctx->own_cert.pk, &mctx->own_key, + mbedtls_ctr_drbg_random, &mctx->ctr_drbg); +#else + ret = mbedtls_pk_check_pair(&mctx->own_cert.pk, &mctx->own_key); +#endif + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + ret = mbedtls_ssl_conf_own_cert(&mctx->conf, &mctx->own_cert, &mctx->own_key); if (ret != 0) { impl::mbedtls_last_error() = ret; @@ -13877,4 +14738,1477 @@ std::string verify_error_string(long error_code) { #endif // CPPHTTPLIB_MBEDTLS_SUPPORT +/* + * Group 10: TLS abstraction layer - wolfSSL backend + */ + +/* + * wolfSSL Backend Implementation + */ + +#ifdef CPPHTTPLIB_WOLFSSL_SUPPORT +namespace tls { + +namespace impl { + +// wolfSSL session wrapper +struct WolfSSLSession { + WOLFSSL *ssl = nullptr; + socket_t sock = INVALID_SOCKET; + std::string hostname; // For client: set via set_sni + std::string sni_hostname; // For server: received from client via SNI callback + + WolfSSLSession() = default; + + ~WolfSSLSession() { + if (ssl) { wolfSSL_free(ssl); } + } + + WolfSSLSession(const WolfSSLSession &) = delete; + WolfSSLSession &operator=(const WolfSSLSession &) = delete; +}; + +// Thread-local error code accessor for wolfSSL +uint64_t &wolfssl_last_error() { + static thread_local uint64_t err = 0; + return err; +} + +// Helper to map wolfSSL error to ErrorCode. +// ssl_error is the value from wolfSSL_get_error(). +// raw_ret is the raw return value from the wolfSSL call (for low-level error). +ErrorCode map_wolfssl_error(WOLFSSL *ssl, int ssl_error, + int &out_errno) { + switch (ssl_error) { + case SSL_ERROR_NONE: return ErrorCode::Success; + case SSL_ERROR_WANT_READ: return ErrorCode::WantRead; + case SSL_ERROR_WANT_WRITE: return ErrorCode::WantWrite; + case SSL_ERROR_ZERO_RETURN: return ErrorCode::PeerClosed; + case SSL_ERROR_SYSCALL: out_errno = errno; return ErrorCode::SyscallError; + default: + if (ssl) { + // wolfSSL stores the low-level error code as a negative value. + // DOMAIN_NAME_MISMATCH (-322) indicates hostname verification failure. + int low_err = ssl_error; // wolfSSL_get_error returns the low-level code + if (low_err == DOMAIN_NAME_MISMATCH) { + return ErrorCode::HostnameMismatch; + } + // Check verify result to distinguish cert verification from generic SSL + // errors. + long vr = wolfSSL_get_verify_result(ssl); + if (vr != 0) { return ErrorCode::CertVerifyFailed; } + } + return ErrorCode::Fatal; + } +} + +// WolfSSLContext constructor/destructor implementations +WolfSSLContext::WolfSSLContext() { wolfSSL_Init(); } + +WolfSSLContext::~WolfSSLContext() { + if (ctx) { wolfSSL_CTX_free(ctx); } +} + +// Thread-local storage for SNI captured during handshake +std::string &wolfssl_pending_sni() { + static thread_local std::string sni; + return sni; +} + +// SNI callback for wolfSSL server to capture client's SNI hostname +int wolfssl_sni_callback(WOLFSSL *ssl, int *ret, void *exArg) { + (void)ret; + (void)exArg; + + void *name_data = nullptr; + unsigned short name_len = + wolfSSL_SNI_GetRequest(ssl, WOLFSSL_SNI_HOST_NAME, &name_data); + + if (name_data && name_len > 0) { + wolfssl_pending_sni().assign(static_cast(name_data), + name_len); + } else { + wolfssl_pending_sni().clear(); + } + return 0; // Continue regardless +} + +// wolfSSL verify callback wrapper +int wolfssl_verify_callback(int preverify_ok, + WOLFSSL_X509_STORE_CTX *x509_ctx) { + auto &callback = get_verify_callback(); + if (!callback) { return preverify_ok; } + + WOLFSSL_X509 *cert = wolfSSL_X509_STORE_CTX_get_current_cert(x509_ctx); + int depth = wolfSSL_X509_STORE_CTX_get_error_depth(x509_ctx); + int err = wolfSSL_X509_STORE_CTX_get_error(x509_ctx); + + // Get the WOLFSSL object from the X509_STORE_CTX + WOLFSSL *ssl = static_cast(wolfSSL_X509_STORE_CTX_get_ex_data( + x509_ctx, wolfSSL_get_ex_data_X509_STORE_CTX_idx())); + + VerifyContext verify_ctx; + verify_ctx.session = static_cast(ssl); + verify_ctx.cert = static_cast(cert); + verify_ctx.depth = depth; + verify_ctx.preverify_ok = (preverify_ok != 0); + verify_ctx.error_code = static_cast(err); + + if (err != 0) { + verify_ctx.error_string = wolfSSL_X509_verify_cert_error_string(err); + } else { + verify_ctx.error_string = nullptr; + } + + bool accepted = callback(verify_ctx); + return accepted ? 1 : 0; +} + +void set_wolfssl_password_cb(WOLFSSL_CTX *ctx, const char *password) { + wolfSSL_CTX_set_default_passwd_cb_userdata(ctx, const_cast(password)); + wolfSSL_CTX_set_default_passwd_cb( + ctx, [](char *buf, int size, int /*rwflag*/, void *userdata) -> int { + auto *pwd = static_cast(userdata); + if (!pwd) return 0; + auto len = static_cast(strlen(pwd)); + if (len > size) len = size; + memcpy(buf, pwd, static_cast(len)); + return len; + }); +} + +} // namespace impl + +ctx_t create_client_context() { + auto ctx = new (std::nothrow) impl::WolfSSLContext(); + if (!ctx) { return nullptr; } + + ctx->is_server = false; + + WOLFSSL_METHOD *method = wolfTLSv1_2_client_method(); + if (!method) { + delete ctx; + return nullptr; + } + + ctx->ctx = wolfSSL_CTX_new(method); + if (!ctx->ctx) { + delete ctx; + return nullptr; + } + + // Default: verify peer certificate + wolfSSL_CTX_set_verify(ctx->ctx, SSL_VERIFY_PEER, nullptr); + + return static_cast(ctx); +} + +ctx_t create_server_context() { + auto ctx = new (std::nothrow) impl::WolfSSLContext(); + if (!ctx) { return nullptr; } + + ctx->is_server = true; + + WOLFSSL_METHOD *method = wolfTLSv1_2_server_method(); + if (!method) { + delete ctx; + return nullptr; + } + + ctx->ctx = wolfSSL_CTX_new(method); + if (!ctx->ctx) { + delete ctx; + return nullptr; + } + + // Default: don't verify client + wolfSSL_CTX_set_verify(ctx->ctx, SSL_VERIFY_NONE, nullptr); + + // Enable SNI on server + wolfSSL_CTX_SNI_SetOptions(ctx->ctx, WOLFSSL_SNI_HOST_NAME, + WOLFSSL_SNI_CONTINUE_ON_MISMATCH); + wolfSSL_CTX_set_servername_callback(ctx->ctx, impl::wolfssl_sni_callback); + + return static_cast(ctx); +} + +void free_context(ctx_t ctx) { + if (ctx) { delete static_cast(ctx); } +} + +bool set_min_version(ctx_t ctx, Version version) { + if (!ctx) { return false; } + auto wctx = static_cast(ctx); + + int min_ver = WOLFSSL_TLSV1_2; + if (version >= Version::TLS1_3) { min_ver = WOLFSSL_TLSV1_3; } + + return wolfSSL_CTX_SetMinVersion(wctx->ctx, min_ver) == WOLFSSL_SUCCESS; +} + +bool load_ca_pem(ctx_t ctx, const char *pem, size_t len) { + if (!ctx || !pem) { return false; } + auto wctx = static_cast(ctx); + + int ret = wolfSSL_CTX_load_verify_buffer( + wctx->ctx, reinterpret_cast(pem), + static_cast(len), SSL_FILETYPE_PEM); + if (ret != SSL_SUCCESS) { + impl::wolfssl_last_error() = + static_cast(wolfSSL_ERR_peek_last_error()); + return false; + } + wctx->ca_pem_data_.append(pem, len); + return true; +} + +bool load_ca_file(ctx_t ctx, const char *file_path) { + if (!ctx || !file_path) { return false; } + auto wctx = static_cast(ctx); + + int ret = wolfSSL_CTX_load_verify_locations(wctx->ctx, file_path, nullptr); + if (ret != SSL_SUCCESS) { + impl::wolfssl_last_error() = + static_cast(wolfSSL_ERR_peek_last_error()); + return false; + } + return true; +} + +bool load_ca_dir(ctx_t ctx, const char *dir_path) { + if (!ctx || !dir_path) { return false; } + auto wctx = static_cast(ctx); + + int ret = wolfSSL_CTX_load_verify_locations(wctx->ctx, nullptr, dir_path); + // wolfSSL may fail if the directory doesn't contain properly hashed certs. + // Unlike OpenSSL which lazily loads certs from directories, wolfSSL scans + // immediately. Return true even on failure since the CA file may have + // already been loaded, matching OpenSSL's lenient behavior. + (void)ret; + return true; +} + +bool load_system_certs(ctx_t ctx) { + if (!ctx) { return false; } + auto wctx = static_cast(ctx); + bool loaded = false; + +#ifdef _WIN32 + loaded = impl::enumerate_windows_system_certs( + [&](const unsigned char *data, size_t len) { + return wolfSSL_CTX_load_verify_buffer(wctx->ctx, data, + static_cast(len), + SSL_FILETYPE_ASN1) == SSL_SUCCESS; + }); +#elif defined(__APPLE__) && defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) + loaded = impl::enumerate_macos_keychain_certs( + [&](const unsigned char *data, size_t len) { + return wolfSSL_CTX_load_verify_buffer(wctx->ctx, data, + static_cast(len), + SSL_FILETYPE_ASN1) == SSL_SUCCESS; + }); +#else + for (auto path = impl::system_ca_paths(); *path; ++path) { + if (wolfSSL_CTX_load_verify_locations(wctx->ctx, *path, nullptr) == + SSL_SUCCESS) { + loaded = true; + break; + } + } + + if (!loaded) { + for (auto dir = impl::system_ca_dirs(); *dir; ++dir) { + if (wolfSSL_CTX_load_verify_locations(wctx->ctx, nullptr, *dir) == + SSL_SUCCESS) { + loaded = true; + break; + } + } + } +#endif + + return loaded; +} + +bool set_client_cert_pem(ctx_t ctx, const char *cert, const char *key, + const char *password) { + if (!ctx || !cert || !key) { return false; } + auto wctx = static_cast(ctx); + + // Load certificate + int ret = wolfSSL_CTX_use_certificate_buffer( + wctx->ctx, reinterpret_cast(cert), + static_cast(strlen(cert)), SSL_FILETYPE_PEM); + if (ret != SSL_SUCCESS) { + impl::wolfssl_last_error() = + static_cast(wolfSSL_ERR_peek_last_error()); + return false; + } + + // Set password callback if password is provided + if (password) { impl::set_wolfssl_password_cb(wctx->ctx, password); } + + // Load private key + ret = wolfSSL_CTX_use_PrivateKey_buffer( + wctx->ctx, reinterpret_cast(key), + static_cast(strlen(key)), SSL_FILETYPE_PEM); + if (ret != SSL_SUCCESS) { + impl::wolfssl_last_error() = + static_cast(wolfSSL_ERR_peek_last_error()); + return false; + } + + // Verify that the certificate and private key match + return wolfSSL_CTX_check_private_key(wctx->ctx) == SSL_SUCCESS; +} + +bool set_client_cert_file(ctx_t ctx, const char *cert_path, + const char *key_path, const char *password) { + if (!ctx || !cert_path || !key_path) { return false; } + auto wctx = static_cast(ctx); + + // Load certificate file + int ret = + wolfSSL_CTX_use_certificate_file(wctx->ctx, cert_path, SSL_FILETYPE_PEM); + if (ret != SSL_SUCCESS) { + impl::wolfssl_last_error() = + static_cast(wolfSSL_ERR_peek_last_error()); + return false; + } + + // Set password callback if password is provided + if (password) { impl::set_wolfssl_password_cb(wctx->ctx, password); } + + // Load private key file + ret = wolfSSL_CTX_use_PrivateKey_file(wctx->ctx, key_path, SSL_FILETYPE_PEM); + if (ret != SSL_SUCCESS) { + impl::wolfssl_last_error() = + static_cast(wolfSSL_ERR_peek_last_error()); + return false; + } + + // Verify that the certificate and private key match + return wolfSSL_CTX_check_private_key(wctx->ctx) == SSL_SUCCESS; +} + +void set_verify_client(ctx_t ctx, bool require) { + if (!ctx) { return; } + auto wctx = static_cast(ctx); + wctx->verify_client = require; + if (require) { + wolfSSL_CTX_set_verify( + wctx->ctx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, + wctx->has_verify_callback ? impl::wolfssl_verify_callback : nullptr); + } else { + if (wctx->has_verify_callback) { + wolfSSL_CTX_set_verify(wctx->ctx, SSL_VERIFY_PEER, + impl::wolfssl_verify_callback); + } else { + wolfSSL_CTX_set_verify(wctx->ctx, SSL_VERIFY_NONE, nullptr); + } + } +} + +session_t create_session(ctx_t ctx, socket_t sock) { + if (!ctx || sock == INVALID_SOCKET) { return nullptr; } + auto wctx = static_cast(ctx); + + auto session = new (std::nothrow) impl::WolfSSLSession(); + if (!session) { return nullptr; } + + session->sock = sock; + session->ssl = wolfSSL_new(wctx->ctx); + if (!session->ssl) { + impl::wolfssl_last_error() = + static_cast(wolfSSL_ERR_peek_last_error()); + delete session; + return nullptr; + } + + wolfSSL_set_fd(session->ssl, static_cast(sock)); + + return static_cast(session); +} + +void free_session(session_t session) { + if (session) { delete static_cast(session); } +} + +bool set_sni(session_t session, const char *hostname) { + if (!session || !hostname) { return false; } + auto wsession = static_cast(session); + + int ret = wolfSSL_UseSNI(wsession->ssl, WOLFSSL_SNI_HOST_NAME, hostname, + static_cast(strlen(hostname))); + if (ret != WOLFSSL_SUCCESS) { + impl::wolfssl_last_error() = + static_cast(wolfSSL_ERR_peek_last_error()); + return false; + } + + // Also set hostname for verification + wolfSSL_check_domain_name(wsession->ssl, hostname); + + wsession->hostname = hostname; + return true; +} + +bool set_hostname(session_t session, const char *hostname) { + // In wolfSSL, set_hostname also sets up hostname verification + return set_sni(session, hostname); +} + +TlsError connect(session_t session) { + TlsError err; + if (!session) { + err.code = ErrorCode::Fatal; + return err; + } + + auto wsession = static_cast(session); + int ret = wolfSSL_connect(wsession->ssl); + + if (ret == SSL_SUCCESS) { + err.code = ErrorCode::Success; + } else { + int ssl_error = wolfSSL_get_error(wsession->ssl, ret); + err.code = impl::map_wolfssl_error(wsession->ssl, ssl_error, err.sys_errno); + err.backend_code = static_cast(ssl_error); + impl::wolfssl_last_error() = err.backend_code; + } + + return err; +} + +TlsError accept(session_t session) { + TlsError err; + if (!session) { + err.code = ErrorCode::Fatal; + return err; + } + + auto wsession = static_cast(session); + int ret = wolfSSL_accept(wsession->ssl); + + if (ret == SSL_SUCCESS) { + err.code = ErrorCode::Success; + // Capture SNI from thread-local storage after successful handshake + wsession->sni_hostname = std::move(impl::wolfssl_pending_sni()); + impl::wolfssl_pending_sni().clear(); + } else { + int ssl_error = wolfSSL_get_error(wsession->ssl, ret); + err.code = impl::map_wolfssl_error(wsession->ssl, ssl_error, err.sys_errno); + err.backend_code = static_cast(ssl_error); + impl::wolfssl_last_error() = err.backend_code; + } + + return err; +} + +bool connect_nonblocking(session_t session, socket_t sock, + time_t timeout_sec, time_t timeout_usec, + TlsError *err) { + if (!session) { + if (err) { err->code = ErrorCode::Fatal; } + return false; + } + + auto wsession = static_cast(session); + + // Set socket to non-blocking mode + detail::set_nonblocking(sock, true); + auto cleanup = + detail::scope_exit([&]() { detail::set_nonblocking(sock, false); }); + + int ret; + while ((ret = wolfSSL_connect(wsession->ssl)) != SSL_SUCCESS) { + int ssl_error = wolfSSL_get_error(wsession->ssl, ret); + if (ssl_error == SSL_ERROR_WANT_READ) { + if (detail::select_read(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + } else if (ssl_error == SSL_ERROR_WANT_WRITE) { + if (detail::select_write(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + } + + // Error or timeout + if (err) { + err->code = + impl::map_wolfssl_error(wsession->ssl, ssl_error, err->sys_errno); + err->backend_code = static_cast(ssl_error); + } + impl::wolfssl_last_error() = static_cast(ssl_error); + return false; + } + + if (err) { err->code = ErrorCode::Success; } + return true; +} + +bool accept_nonblocking(session_t session, socket_t sock, + time_t timeout_sec, time_t timeout_usec, + TlsError *err) { + if (!session) { + if (err) { err->code = ErrorCode::Fatal; } + return false; + } + + auto wsession = static_cast(session); + + // Set socket to non-blocking mode + detail::set_nonblocking(sock, true); + auto cleanup = + detail::scope_exit([&]() { detail::set_nonblocking(sock, false); }); + + int ret; + while ((ret = wolfSSL_accept(wsession->ssl)) != SSL_SUCCESS) { + int ssl_error = wolfSSL_get_error(wsession->ssl, ret); + if (ssl_error == SSL_ERROR_WANT_READ) { + if (detail::select_read(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + } else if (ssl_error == SSL_ERROR_WANT_WRITE) { + if (detail::select_write(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + } + + // Error or timeout + if (err) { + err->code = + impl::map_wolfssl_error(wsession->ssl, ssl_error, err->sys_errno); + err->backend_code = static_cast(ssl_error); + } + impl::wolfssl_last_error() = static_cast(ssl_error); + return false; + } + + if (err) { err->code = ErrorCode::Success; } + + // Capture SNI from thread-local storage after successful handshake + wsession->sni_hostname = std::move(impl::wolfssl_pending_sni()); + impl::wolfssl_pending_sni().clear(); + + return true; +} + +ssize_t read(session_t session, void *buf, size_t len, TlsError &err) { + if (!session || !buf) { + err.code = ErrorCode::Fatal; + return -1; + } + + auto wsession = static_cast(session); + int ret = wolfSSL_read(wsession->ssl, buf, static_cast(len)); + + if (ret > 0) { + err.code = ErrorCode::Success; + return static_cast(ret); + } + + if (ret == 0) { + err.code = ErrorCode::PeerClosed; + return 0; + } + + int ssl_error = wolfSSL_get_error(wsession->ssl, ret); + err.code = impl::map_wolfssl_error(wsession->ssl, ssl_error, err.sys_errno); + err.backend_code = static_cast(ssl_error); + impl::wolfssl_last_error() = err.backend_code; + return -1; +} + +ssize_t write(session_t session, const void *buf, size_t len, + TlsError &err) { + if (!session || !buf) { + err.code = ErrorCode::Fatal; + return -1; + } + + auto wsession = static_cast(session); + int ret = wolfSSL_write(wsession->ssl, buf, static_cast(len)); + + if (ret > 0) { + err.code = ErrorCode::Success; + return static_cast(ret); + } + + // wolfSSL_write returns 0 when the peer has sent a close_notify. + // Treat this as an error (return -1) so callers don't spin in a + // write loop adding zero to the offset. + if (ret == 0) { + err.code = ErrorCode::PeerClosed; + return -1; + } + + int ssl_error = wolfSSL_get_error(wsession->ssl, ret); + err.code = impl::map_wolfssl_error(wsession->ssl, ssl_error, err.sys_errno); + err.backend_code = static_cast(ssl_error); + impl::wolfssl_last_error() = err.backend_code; + return -1; +} + +int pending(const_session_t session) { + if (!session) { return 0; } + auto wsession = + static_cast(const_cast(session)); + return wolfSSL_pending(wsession->ssl); +} + +void shutdown(session_t session, bool graceful) { + if (!session) { return; } + auto wsession = static_cast(session); + + if (graceful) { + int ret; + int attempts = 0; + while ((ret = wolfSSL_shutdown(wsession->ssl)) != SSL_SUCCESS && + attempts < 3) { + int ssl_error = wolfSSL_get_error(wsession->ssl, ret); + if (ssl_error != SSL_ERROR_WANT_READ && + ssl_error != SSL_ERROR_WANT_WRITE) { + break; + } + attempts++; + } + } else { + wolfSSL_shutdown(wsession->ssl); + } +} + +bool is_peer_closed(session_t session, socket_t sock) { + if (!session || sock == INVALID_SOCKET) { return true; } + auto wsession = static_cast(session); + + // Check if there's already decrypted data available + if (wolfSSL_pending(wsession->ssl) > 0) { return false; } + + // Set socket to non-blocking to avoid blocking on read + detail::set_nonblocking(sock, true); + auto cleanup = + detail::scope_exit([&]() { detail::set_nonblocking(sock, false); }); + + // Peek 1 byte to check connection status without consuming data + unsigned char buf; + int ret = wolfSSL_peek(wsession->ssl, &buf, 1); + + // If we got data or WANT_READ (would block), connection is alive + if (ret > 0) { return false; } + + int ssl_error = wolfSSL_get_error(wsession->ssl, ret); + if (ssl_error == SSL_ERROR_WANT_READ) { return false; } + + return ssl_error == SSL_ERROR_ZERO_RETURN || ssl_error == SSL_ERROR_SYSCALL || + ret == 0; +} + +cert_t get_peer_cert(const_session_t session) { + if (!session) { return nullptr; } + auto wsession = + static_cast(const_cast(session)); + + WOLFSSL_X509 *cert = wolfSSL_get_peer_certificate(wsession->ssl); + return static_cast(cert); +} + +void free_cert(cert_t cert) { + if (cert) { wolfSSL_X509_free(static_cast(cert)); } +} + +bool verify_hostname(cert_t cert, const char *hostname) { + if (!cert || !hostname) { return false; } + auto x509 = static_cast(cert); + std::string host_str(hostname); + + // Check if hostname is an IP address + bool is_ip = impl::is_ipv4_address(host_str); + unsigned char ip_bytes[4]; + if (is_ip) { impl::parse_ipv4(host_str, ip_bytes); } + + // Check Subject Alternative Names + auto *san_names = static_cast( + wolfSSL_X509_get_ext_d2i(x509, NID_subject_alt_name, nullptr, nullptr)); + + if (san_names) { + int san_count = wolfSSL_sk_num(san_names); + for (int i = 0; i < san_count; i++) { + auto *names = + static_cast(wolfSSL_sk_value(san_names, i)); + if (!names) continue; + + if (!is_ip && names->type == WOLFSSL_GEN_DNS) { + // DNS name + unsigned char *dns_name = nullptr; + int dns_len = wolfSSL_ASN1_STRING_to_UTF8(&dns_name, names->d.dNSName); + if (dns_name && dns_len > 0) { + std::string san_name(reinterpret_cast(dns_name), + static_cast(dns_len)); + XFREE(dns_name, nullptr, DYNAMIC_TYPE_OPENSSL); + if (detail::match_hostname(san_name, host_str)) { + wolfSSL_sk_free(san_names); + return true; + } + } + } else if (is_ip && names->type == WOLFSSL_GEN_IPADD) { + // IP address + unsigned char *ip_data = wolfSSL_ASN1_STRING_data(names->d.iPAddress); + int ip_len = wolfSSL_ASN1_STRING_length(names->d.iPAddress); + if (ip_data && ip_len == 4 && memcmp(ip_data, ip_bytes, 4) == 0) { + wolfSSL_sk_free(san_names); + return true; + } + } + } + wolfSSL_sk_free(san_names); + } + + // Fallback: Check Common Name (CN) in subject + WOLFSSL_X509_NAME *subject = wolfSSL_X509_get_subject_name(x509); + if (subject) { + char cn[256] = {}; + int cn_len = wolfSSL_X509_NAME_get_text_by_NID(subject, NID_commonName, cn, + sizeof(cn)); + if (cn_len > 0) { + std::string cn_str(cn, static_cast(cn_len)); + if (detail::match_hostname(cn_str, host_str)) { return true; } + } + } + + return false; +} + +uint64_t hostname_mismatch_code() { + return static_cast(DOMAIN_NAME_MISMATCH); +} + +long get_verify_result(const_session_t session) { + if (!session) { return -1; } + auto wsession = + static_cast(const_cast(session)); + long result = wolfSSL_get_verify_result(wsession->ssl); + return result; +} + +std::string get_cert_subject_cn(cert_t cert) { + if (!cert) return ""; + auto x509 = static_cast(cert); + + WOLFSSL_X509_NAME *subject = wolfSSL_X509_get_subject_name(x509); + if (!subject) return ""; + + char cn[256] = {}; + int cn_len = wolfSSL_X509_NAME_get_text_by_NID(subject, NID_commonName, cn, + sizeof(cn)); + if (cn_len <= 0) return ""; + return std::string(cn, static_cast(cn_len)); +} + +std::string get_cert_issuer_name(cert_t cert) { + if (!cert) return ""; + auto x509 = static_cast(cert); + + WOLFSSL_X509_NAME *issuer = wolfSSL_X509_get_issuer_name(x509); + if (!issuer) return ""; + + char *name_str = wolfSSL_X509_NAME_oneline(issuer, nullptr, 0); + if (!name_str) return ""; + + std::string result(name_str); + XFREE(name_str, nullptr, DYNAMIC_TYPE_OPENSSL); + return result; +} + +bool get_cert_sans(cert_t cert, std::vector &sans) { + sans.clear(); + if (!cert) return false; + auto x509 = static_cast(cert); + + auto *san_names = static_cast( + wolfSSL_X509_get_ext_d2i(x509, NID_subject_alt_name, nullptr, nullptr)); + if (!san_names) return true; // No SANs is not an error + + int count = wolfSSL_sk_num(san_names); + for (int i = 0; i < count; i++) { + auto *name = + static_cast(wolfSSL_sk_value(san_names, i)); + if (!name) continue; + + SanEntry entry; + switch (name->type) { + case WOLFSSL_GEN_DNS: { + entry.type = SanType::DNS; + unsigned char *dns_name = nullptr; + int dns_len = wolfSSL_ASN1_STRING_to_UTF8(&dns_name, name->d.dNSName); + if (dns_name && dns_len > 0) { + entry.value = std::string(reinterpret_cast(dns_name), + static_cast(dns_len)); + XFREE(dns_name, nullptr, DYNAMIC_TYPE_OPENSSL); + } + break; + } + case WOLFSSL_GEN_IPADD: { + entry.type = SanType::IP; + unsigned char *ip_data = wolfSSL_ASN1_STRING_data(name->d.iPAddress); + int ip_len = wolfSSL_ASN1_STRING_length(name->d.iPAddress); + if (ip_data && ip_len == 4) { + char buf[16]; + snprintf(buf, sizeof(buf), "%d.%d.%d.%d", ip_data[0], ip_data[1], + ip_data[2], ip_data[3]); + entry.value = buf; + } else if (ip_data && ip_len == 16) { + char buf[64]; + snprintf(buf, sizeof(buf), + "%02x%02x:%02x%02x:%02x%02x:%02x%02x:" + "%02x%02x:%02x%02x:%02x%02x:%02x%02x", + ip_data[0], ip_data[1], ip_data[2], ip_data[3], ip_data[4], + ip_data[5], ip_data[6], ip_data[7], ip_data[8], ip_data[9], + ip_data[10], ip_data[11], ip_data[12], ip_data[13], + ip_data[14], ip_data[15]); + entry.value = buf; + } + break; + } + case WOLFSSL_GEN_EMAIL: + entry.type = SanType::EMAIL; + { + unsigned char *email = nullptr; + int email_len = wolfSSL_ASN1_STRING_to_UTF8(&email, name->d.rfc822Name); + if (email && email_len > 0) { + entry.value = std::string(reinterpret_cast(email), + static_cast(email_len)); + XFREE(email, nullptr, DYNAMIC_TYPE_OPENSSL); + } + } + break; + case WOLFSSL_GEN_URI: + entry.type = SanType::URI; + { + unsigned char *uri = nullptr; + int uri_len = wolfSSL_ASN1_STRING_to_UTF8( + &uri, name->d.uniformResourceIdentifier); + if (uri && uri_len > 0) { + entry.value = std::string(reinterpret_cast(uri), + static_cast(uri_len)); + XFREE(uri, nullptr, DYNAMIC_TYPE_OPENSSL); + } + } + break; + default: entry.type = SanType::OTHER; break; + } + + if (!entry.value.empty()) { sans.push_back(std::move(entry)); } + } + wolfSSL_sk_free(san_names); + return true; +} + +bool get_cert_validity(cert_t cert, time_t ¬_before, + time_t ¬_after) { + if (!cert) return false; + auto x509 = static_cast(cert); + + const WOLFSSL_ASN1_TIME *nb = wolfSSL_X509_get_notBefore(x509); + const WOLFSSL_ASN1_TIME *na = wolfSSL_X509_get_notAfter(x509); + + if (!nb || !na) return false; + + // wolfSSL_ASN1_TIME_to_tm is available + struct tm tm_nb = {}, tm_na = {}; + if (wolfSSL_ASN1_TIME_to_tm(nb, &tm_nb) != WOLFSSL_SUCCESS) return false; + if (wolfSSL_ASN1_TIME_to_tm(na, &tm_na) != WOLFSSL_SUCCESS) return false; + +#ifdef _WIN32 + not_before = _mkgmtime(&tm_nb); + not_after = _mkgmtime(&tm_na); +#else + not_before = timegm(&tm_nb); + not_after = timegm(&tm_na); +#endif + return true; +} + +std::string get_cert_serial(cert_t cert) { + if (!cert) return ""; + auto x509 = static_cast(cert); + + WOLFSSL_ASN1_INTEGER *serial_asn1 = wolfSSL_X509_get_serialNumber(x509); + if (!serial_asn1) return ""; + + // Get the serial number data + int len = serial_asn1->length; + unsigned char *data = serial_asn1->data; + if (!data || len <= 0) return ""; + + std::string result; + result.reserve(static_cast(len) * 2); + for (int i = 0; i < len; i++) { + char hex[3]; + snprintf(hex, sizeof(hex), "%02X", data[i]); + result += hex; + } + return result; +} + +bool get_cert_der(cert_t cert, std::vector &der) { + if (!cert) return false; + auto x509 = static_cast(cert); + + int der_len = 0; + const unsigned char *der_data = wolfSSL_X509_get_der(x509, &der_len); + if (!der_data || der_len <= 0) return false; + + der.assign(der_data, der_data + der_len); + return true; +} + +const char *get_sni(const_session_t session) { + if (!session) return nullptr; + auto wsession = static_cast(session); + + // For server: return SNI received from client during handshake + if (!wsession->sni_hostname.empty()) { + return wsession->sni_hostname.c_str(); + } + + // For client: return the hostname set via set_sni + if (!wsession->hostname.empty()) { return wsession->hostname.c_str(); } + + return nullptr; +} + +uint64_t peek_error() { + return static_cast(wolfSSL_ERR_peek_last_error()); +} + +uint64_t get_error() { + uint64_t err = impl::wolfssl_last_error(); + impl::wolfssl_last_error() = 0; + return err; +} + +std::string error_string(uint64_t code) { + char buf[256]; + wolfSSL_ERR_error_string(static_cast(code), buf); + return std::string(buf); +} + +ca_store_t create_ca_store(const char *pem, size_t len) { + if (!pem || len == 0) { return nullptr; } + // Validate by attempting to load into a temporary ctx + WOLFSSL_CTX *tmp_ctx = wolfSSL_CTX_new(wolfTLSv1_2_client_method()); + if (!tmp_ctx) { return nullptr; } + int ret = wolfSSL_CTX_load_verify_buffer( + tmp_ctx, reinterpret_cast(pem), + static_cast(len), SSL_FILETYPE_PEM); + wolfSSL_CTX_free(tmp_ctx); + if (ret != SSL_SUCCESS) { return nullptr; } + return static_cast( + new impl::WolfSSLCAStore{std::string(pem, len)}); +} + +void free_ca_store(ca_store_t store) { + delete static_cast(store); +} + +bool set_ca_store(ctx_t ctx, ca_store_t store) { + if (!ctx || !store) { return false; } + auto *wctx = static_cast(ctx); + auto *ca = static_cast(store); + int ret = wolfSSL_CTX_load_verify_buffer( + wctx->ctx, reinterpret_cast(ca->pem_data.data()), + static_cast(ca->pem_data.size()), SSL_FILETYPE_PEM); + if (ret == SSL_SUCCESS) { wctx->ca_pem_data_ += ca->pem_data; } + return ret == SSL_SUCCESS; +} + +size_t get_ca_certs(ctx_t ctx, std::vector &certs) { + certs.clear(); + if (!ctx) { return 0; } + auto *wctx = static_cast(ctx); + if (wctx->ca_pem_data_.empty()) { return 0; } + + const std::string &pem = wctx->ca_pem_data_; + const std::string begin_marker = "-----BEGIN CERTIFICATE-----"; + const std::string end_marker = "-----END CERTIFICATE-----"; + size_t pos = 0; + while ((pos = pem.find(begin_marker, pos)) != std::string::npos) { + size_t end_pos = pem.find(end_marker, pos); + if (end_pos == std::string::npos) { break; } + end_pos += end_marker.size(); + std::string cert_pem = pem.substr(pos, end_pos - pos); + WOLFSSL_X509 *x509 = wolfSSL_X509_load_certificate_buffer( + reinterpret_cast(cert_pem.data()), + static_cast(cert_pem.size()), WOLFSSL_FILETYPE_PEM); + if (x509) { certs.push_back(static_cast(x509)); } + pos = end_pos; + } + return certs.size(); +} + +std::vector get_ca_names(ctx_t ctx) { + std::vector names; + if (!ctx) { return names; } + auto *wctx = static_cast(ctx); + if (wctx->ca_pem_data_.empty()) { return names; } + + const std::string &pem = wctx->ca_pem_data_; + const std::string begin_marker = "-----BEGIN CERTIFICATE-----"; + const std::string end_marker = "-----END CERTIFICATE-----"; + size_t pos = 0; + while ((pos = pem.find(begin_marker, pos)) != std::string::npos) { + size_t end_pos = pem.find(end_marker, pos); + if (end_pos == std::string::npos) { break; } + end_pos += end_marker.size(); + std::string cert_pem = pem.substr(pos, end_pos - pos); + WOLFSSL_X509 *x509 = wolfSSL_X509_load_certificate_buffer( + reinterpret_cast(cert_pem.data()), + static_cast(cert_pem.size()), WOLFSSL_FILETYPE_PEM); + if (x509) { + WOLFSSL_X509_NAME *subject = wolfSSL_X509_get_subject_name(x509); + if (subject) { + char *name_str = wolfSSL_X509_NAME_oneline(subject, nullptr, 0); + if (name_str) { + names.push_back(name_str); + XFREE(name_str, nullptr, DYNAMIC_TYPE_OPENSSL); + } + } + wolfSSL_X509_free(x509); + } + pos = end_pos; + } + return names; +} + +bool update_server_cert(ctx_t ctx, const char *cert_pem, + const char *key_pem, const char *password) { + if (!ctx || !cert_pem || !key_pem) { return false; } + auto *wctx = static_cast(ctx); + + // Load new certificate + int ret = wolfSSL_CTX_use_certificate_buffer( + wctx->ctx, reinterpret_cast(cert_pem), + static_cast(strlen(cert_pem)), SSL_FILETYPE_PEM); + if (ret != SSL_SUCCESS) { + impl::wolfssl_last_error() = + static_cast(wolfSSL_ERR_peek_last_error()); + return false; + } + + // Set password if provided + if (password) { impl::set_wolfssl_password_cb(wctx->ctx, password); } + + // Load new private key + ret = wolfSSL_CTX_use_PrivateKey_buffer( + wctx->ctx, reinterpret_cast(key_pem), + static_cast(strlen(key_pem)), SSL_FILETYPE_PEM); + if (ret != SSL_SUCCESS) { + impl::wolfssl_last_error() = + static_cast(wolfSSL_ERR_peek_last_error()); + return false; + } + + return true; +} + +bool update_server_client_ca(ctx_t ctx, const char *ca_pem) { + if (!ctx || !ca_pem) { return false; } + auto *wctx = static_cast(ctx); + + int ret = wolfSSL_CTX_load_verify_buffer( + wctx->ctx, reinterpret_cast(ca_pem), + static_cast(strlen(ca_pem)), SSL_FILETYPE_PEM); + if (ret != SSL_SUCCESS) { + impl::wolfssl_last_error() = + static_cast(wolfSSL_ERR_peek_last_error()); + return false; + } + return true; +} + +bool set_verify_callback(ctx_t ctx, VerifyCallback callback) { + if (!ctx) { return false; } + auto *wctx = static_cast(ctx); + + impl::get_verify_callback() = std::move(callback); + wctx->has_verify_callback = static_cast(impl::get_verify_callback()); + + if (wctx->has_verify_callback) { + wolfSSL_CTX_set_verify(wctx->ctx, SSL_VERIFY_PEER, + impl::wolfssl_verify_callback); + } else { + wolfSSL_CTX_set_verify( + wctx->ctx, + wctx->verify_client + ? (SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT) + : SSL_VERIFY_NONE, + nullptr); + } + return true; +} + +long get_verify_error(const_session_t session) { + if (!session) { return -1; } + auto *wsession = + static_cast(const_cast(session)); + return wolfSSL_get_verify_result(wsession->ssl); +} + +std::string verify_error_string(long error_code) { + if (error_code == 0) { return ""; } + const char *str = + wolfSSL_X509_verify_cert_error_string(static_cast(error_code)); + return str ? std::string(str) : std::string(); +} + +} // namespace tls + +#endif // CPPHTTPLIB_WOLFSSL_SUPPORT + +// WebSocket implementation +namespace ws { + +bool WebSocket::send_frame(Opcode op, const char *data, size_t len, + bool fin) { + std::lock_guard lock(write_mutex_); + if (closed_) { return false; } + return detail::write_websocket_frame(strm_, op, data, len, fin, !is_server_); +} + +ReadResult WebSocket::read(std::string &msg) { + while (!closed_) { + Opcode opcode; + std::string payload; + bool fin; + + if (!impl::read_websocket_frame(strm_, opcode, payload, fin, is_server_, + CPPHTTPLIB_WEBSOCKET_MAX_PAYLOAD_LENGTH)) { + closed_ = true; + return Fail; + } + + switch (opcode) { + case Opcode::Ping: { + std::lock_guard lock(write_mutex_); + detail::write_websocket_frame(strm_, Opcode::Pong, payload.data(), + payload.size(), true, !is_server_); + continue; + } + case Opcode::Pong: continue; + case Opcode::Close: { + if (!closed_.exchange(true)) { + // Echo close frame back + std::lock_guard lock(write_mutex_); + detail::write_websocket_frame(strm_, Opcode::Close, payload.data(), + payload.size(), true, !is_server_); + } + return Fail; + } + case Opcode::Text: + case Opcode::Binary: { + auto result = opcode == Opcode::Text ? Text : Binary; + msg = std::move(payload); + + // Handle fragmentation + if (!fin) { + while (true) { + Opcode cont_opcode; + std::string cont_payload; + bool cont_fin; + if (!impl::read_websocket_frame( + strm_, cont_opcode, cont_payload, cont_fin, is_server_, + CPPHTTPLIB_WEBSOCKET_MAX_PAYLOAD_LENGTH)) { + closed_ = true; + return Fail; + } + if (cont_opcode == Opcode::Ping) { + std::lock_guard lock(write_mutex_); + detail::write_websocket_frame( + strm_, Opcode::Pong, cont_payload.data(), cont_payload.size(), + true, !is_server_); + continue; + } + if (cont_opcode == Opcode::Pong) { continue; } + if (cont_opcode == Opcode::Close) { + if (!closed_.exchange(true)) { + std::lock_guard lock(write_mutex_); + detail::write_websocket_frame( + strm_, Opcode::Close, cont_payload.data(), + cont_payload.size(), true, !is_server_); + } + return Fail; + } + // RFC 6455: continuation frames must use opcode 0x0 + if (cont_opcode != Opcode::Continuation) { + closed_ = true; + return Fail; + } + msg += cont_payload; + if (msg.size() > CPPHTTPLIB_WEBSOCKET_MAX_PAYLOAD_LENGTH) { + closed_ = true; + return Fail; + } + if (cont_fin) { break; } + } + } + // RFC 6455 Section 5.6: text frames must contain valid UTF-8 + if (result == Text && !impl::is_valid_utf8(msg)) { + close(CloseStatus::InvalidPayload, "invalid UTF-8"); + return Fail; + } + return result; + } + default: closed_ = true; return Fail; + } + } + return Fail; +} + +bool WebSocket::send(const std::string &data) { + return send_frame(Opcode::Text, data.data(), data.size()); +} + +bool WebSocket::send(const char *data, size_t len) { + return send_frame(Opcode::Binary, data, len); +} + +void WebSocket::close(CloseStatus status, const std::string &reason) { + if (closed_.exchange(true)) { return; } + ping_cv_.notify_all(); + std::string payload; + auto code = static_cast(status); + payload.push_back(static_cast((code >> 8) & 0xFF)); + payload.push_back(static_cast(code & 0xFF)); + // RFC 6455 Section 5.5: control frame payload must not exceed 125 bytes + // Close frame has 2-byte status code, so reason is limited to 123 bytes + payload += reason.substr(0, 123); + { + std::lock_guard lock(write_mutex_); + detail::write_websocket_frame(strm_, Opcode::Close, payload.data(), + payload.size(), true, !is_server_); + } + + // RFC 6455 Section 7.1.1: after sending a Close frame, wait for the peer's + // Close response before closing the TCP connection. Use a short timeout to + // avoid hanging if the peer doesn't respond. + strm_.set_read_timeout(CPPHTTPLIB_WEBSOCKET_CLOSE_TIMEOUT_SECOND, 0); + Opcode op; + std::string resp; + bool fin; + while (impl::read_websocket_frame(strm_, op, resp, fin, is_server_, 125)) { + if (op == Opcode::Close) { break; } + } +} + +WebSocket::~WebSocket() { + { + std::lock_guard lock(ping_mutex_); + closed_ = true; + } + ping_cv_.notify_all(); + if (ping_thread_.joinable()) { ping_thread_.join(); } +} + +void WebSocket::start_heartbeat() { + ping_thread_ = std::thread([this]() { + std::unique_lock lock(ping_mutex_); + while (!closed_) { + ping_cv_.wait_for(lock, std::chrono::seconds( + CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND)); + if (closed_) { break; } + lock.unlock(); + if (!send_frame(Opcode::Ping, nullptr, 0)) { + closed_ = true; + break; + } + lock.lock(); + } + }); +} + +const Request &WebSocket::request() const { return req_; } + +bool WebSocket::is_open() const { return !closed_; } + +// WebSocketClient implementation +WebSocketClient::WebSocketClient( + const std::string &scheme_host_port_path, const Headers &headers) + : headers_(headers) { + const static std::regex re( + R"(([a-z]+):\/\/(?:\[([a-fA-F\d:]+)\]|([^:/?#]+))(?::(\d+))?(\/.*))"); + + std::smatch m; + if (std::regex_match(scheme_host_port_path, m, re)) { + auto scheme = m[1].str(); + +#ifdef CPPHTTPLIB_SSL_ENABLED + if (scheme != "ws" && scheme != "wss") { +#else + if (scheme != "ws") { +#endif +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + std::string msg = "'" + scheme + "' scheme is not supported."; + throw std::invalid_argument(msg); +#endif + return; + } + + auto is_ssl = scheme == "wss"; + + host_ = m[2].str(); + if (host_.empty()) { host_ = m[3].str(); } + + auto port_str = m[4].str(); + port_ = !port_str.empty() ? std::stoi(port_str) : (is_ssl ? 443 : 80); + + path_ = m[5].str(); + +#ifdef CPPHTTPLIB_SSL_ENABLED + is_ssl_ = is_ssl; +#else + if (is_ssl) { return; } +#endif + + is_valid_ = true; + } +} + +WebSocketClient::~WebSocketClient() { shutdown_and_close(); } + +bool WebSocketClient::is_valid() const { return is_valid_; } + +void WebSocketClient::shutdown_and_close() { +#ifdef CPPHTTPLIB_SSL_ENABLED + if (is_ssl_) { + if (tls_session_) { + tls::shutdown(tls_session_, true); + tls::free_session(tls_session_); + tls_session_ = nullptr; + } + if (tls_ctx_) { + tls::free_context(tls_ctx_); + tls_ctx_ = nullptr; + } + } +#endif + if (ws_ && ws_->is_open()) { ws_->close(); } + ws_.reset(); + if (sock_ != INVALID_SOCKET) { + detail::shutdown_socket(sock_); + detail::close_socket(sock_); + sock_ = INVALID_SOCKET; + } +} + +bool WebSocketClient::create_stream(std::unique_ptr &strm) { +#ifdef CPPHTTPLIB_SSL_ENABLED + if (is_ssl_) { + if (!detail::setup_client_tls_session( + host_, tls_ctx_, tls_session_, sock_, + server_certificate_verification_, ca_cert_file_path_, + ca_cert_store_, read_timeout_sec_, read_timeout_usec_)) { + return false; + } + + strm = std::unique_ptr(new detail::SSLSocketStream( + sock_, tls_session_, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_)); + return true; + } +#endif + strm = std::unique_ptr( + new detail::SocketStream(sock_, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_)); + return true; +} + +bool WebSocketClient::connect() { + if (!is_valid_) { return false; } + shutdown_and_close(); + + Error error; + sock_ = detail::create_client_socket( + host_, std::string(), port_, AF_UNSPEC, false, false, nullptr, 5, 0, + read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, std::string(), error); + + if (sock_ == INVALID_SOCKET) { return false; } + + std::unique_ptr strm; + if (!create_stream(strm)) { + shutdown_and_close(); + return false; + } + + std::string selected_subprotocol; + if (!detail::perform_websocket_handshake(*strm, host_, port_, path_, headers_, + selected_subprotocol)) { + shutdown_and_close(); + return false; + } + subprotocol_ = std::move(selected_subprotocol); + + Request req; + req.method = "GET"; + req.path = path_; + ws_ = std::unique_ptr(new WebSocket(std::move(strm), req, false)); + return true; +} + +ReadResult WebSocketClient::read(std::string &msg) { + if (!ws_) { return Fail; } + return ws_->read(msg); +} + +bool WebSocketClient::send(const std::string &data) { + if (!ws_) { return false; } + return ws_->send(data); +} + +bool WebSocketClient::send(const char *data, size_t len) { + if (!ws_) { return false; } + return ws_->send(data, len); +} + +void WebSocketClient::close(CloseStatus status, + const std::string &reason) { + if (ws_) { ws_->close(status, reason); } +} + +bool WebSocketClient::is_open() const { return ws_ && ws_->is_open(); } + +const std::string &WebSocketClient::subprotocol() const { + return subprotocol_; +} + +void WebSocketClient::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; +} + +void WebSocketClient::set_write_timeout(time_t sec, time_t usec) { + write_timeout_sec_ = sec; + write_timeout_usec_ = usec; +} + +#ifdef CPPHTTPLIB_SSL_ENABLED + +void WebSocketClient::set_ca_cert_path(const std::string &path) { + ca_cert_file_path_ = path; +} + +void WebSocketClient::set_ca_cert_store(tls::ca_store_t store) { + ca_cert_store_ = store; +} + +void +WebSocketClient::enable_server_certificate_verification(bool enabled) { + server_certificate_verification_ = enabled; +} + +#endif // CPPHTTPLIB_SSL_ENABLED + +} // namespace ws + } // namespace httplib diff --git a/vendor/cpp-httplib/httplib.h b/vendor/cpp-httplib/httplib.h index f7563283ee..aea6fd308b 100644 --- a/vendor/cpp-httplib/httplib.h +++ b/vendor/cpp-httplib/httplib.h @@ -8,8 +8,8 @@ #ifndef CPPHTTPLIB_HTTPLIB_H #define CPPHTTPLIB_HTTPLIB_H -#define CPPHTTPLIB_VERSION "0.32.0" -#define CPPHTTPLIB_VERSION_NUM "0x002000" +#define CPPHTTPLIB_VERSION "0.35.0" +#define CPPHTTPLIB_VERSION_NUM "0x002300" /* * Platform compatibility check @@ -185,6 +185,14 @@ : 0)) #endif +#ifndef CPPHTTPLIB_THREAD_POOL_MAX_COUNT +#define CPPHTTPLIB_THREAD_POOL_MAX_COUNT (CPPHTTPLIB_THREAD_POOL_COUNT * 4) +#endif + +#ifndef CPPHTTPLIB_THREAD_POOL_IDLE_TIMEOUT +#define CPPHTTPLIB_THREAD_POOL_IDLE_TIMEOUT 3 // seconds +#endif + #ifndef CPPHTTPLIB_RECV_FLAGS #define CPPHTTPLIB_RECV_FLAGS 0 #endif @@ -201,6 +209,22 @@ #define CPPHTTPLIB_MAX_LINE_LENGTH 32768 #endif +#ifndef CPPHTTPLIB_WEBSOCKET_MAX_PAYLOAD_LENGTH +#define CPPHTTPLIB_WEBSOCKET_MAX_PAYLOAD_LENGTH 16777216 +#endif + +#ifndef CPPHTTPLIB_WEBSOCKET_READ_TIMEOUT_SECOND +#define CPPHTTPLIB_WEBSOCKET_READ_TIMEOUT_SECOND 300 +#endif + +#ifndef CPPHTTPLIB_WEBSOCKET_CLOSE_TIMEOUT_SECOND +#define CPPHTTPLIB_WEBSOCKET_CLOSE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND +#define CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND 30 +#endif + /* * Headers */ @@ -310,6 +334,7 @@ using socket_t = int; #include #include #include +#include #include #include #include @@ -328,6 +353,28 @@ using socket_t = int; #include #include #include +#if __cplusplus >= 201703L +#include +#endif + +// On macOS with a TLS backend, enable Keychain root certificates by default +// unless the user explicitly opts out. +#if defined(__APPLE__) && \ + !defined(CPPHTTPLIB_DISABLE_MACOSX_AUTOMATIC_ROOT_CERTIFICATES) && \ + (defined(CPPHTTPLIB_OPENSSL_SUPPORT) || \ + defined(CPPHTTPLIB_MBEDTLS_SUPPORT) || \ + defined(CPPHTTPLIB_WOLFSSL_SUPPORT)) +#ifndef CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN +#define CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN +#endif +#endif + +// On Windows, enable Schannel certificate verification by default +// unless the user explicitly opts out. +#if defined(_WIN32) && \ + !defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE) +#define CPPHTTPLIB_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE +#endif #if defined(CPPHTTPLIB_USE_NON_BLOCKING_GETADDRINFO) || \ defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) @@ -335,8 +382,7 @@ using socket_t = int; #include #include #endif -#endif // CPPHTTPLIB_USE_NON_BLOCKING_GETADDRINFO or - // CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN +#endif #ifdef CPPHTTPLIB_OPENSSL_SUPPORT #ifdef _WIN32 @@ -354,11 +400,11 @@ using socket_t = int; #endif #endif // _WIN32 -#if defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) +#ifdef CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN #if TARGET_OS_MAC #include #endif -#endif // CPPHTTPLIB_USE_NON_BLOCKING_GETADDRINFO +#endif #include #include @@ -402,11 +448,11 @@ using socket_t = int; #pragma comment(lib, "crypt32.lib") #endif #endif // _WIN32 -#if defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) +#ifdef CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN #if TARGET_OS_MAC #include #endif -#endif // CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN +#endif // Mbed TLS 3.x API compatibility #if MBEDTLS_VERSION_MAJOR >= 3 @@ -415,10 +461,46 @@ using socket_t = int; #endif // CPPHTTPLIB_MBEDTLS_SUPPORT +#ifdef CPPHTTPLIB_WOLFSSL_SUPPORT +#include + +#include + +// Fallback definitions for older wolfSSL versions (e.g., 5.6.6) +#ifndef WOLFSSL_GEN_EMAIL +#define WOLFSSL_GEN_EMAIL 1 +#endif +#ifndef WOLFSSL_GEN_DNS +#define WOLFSSL_GEN_DNS 2 +#endif +#ifndef WOLFSSL_GEN_URI +#define WOLFSSL_GEN_URI 6 +#endif +#ifndef WOLFSSL_GEN_IPADD +#define WOLFSSL_GEN_IPADD 7 +#endif + +#include +#include +#include +#include +#include +#ifdef _WIN32 +#include +#ifdef _MSC_VER +#pragma comment(lib, "crypt32.lib") +#endif +#endif // _WIN32 +#ifdef CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN +#if TARGET_OS_MAC +#include +#endif +#endif +#endif // CPPHTTPLIB_WOLFSSL_SUPPORT + // Define CPPHTTPLIB_SSL_ENABLED if any SSL backend is available -// This simplifies conditional compilation when adding new backends (e.g., -// wolfSSL) -#if defined(CPPHTTPLIB_OPENSSL_SUPPORT) || defined(CPPHTTPLIB_MBEDTLS_SUPPORT) +#if defined(CPPHTTPLIB_OPENSSL_SUPPORT) || \ + defined(CPPHTTPLIB_MBEDTLS_SUPPORT) || defined(CPPHTTPLIB_WOLFSSL_SUPPORT) #define CPPHTTPLIB_SSL_ENABLED #endif @@ -440,6 +522,10 @@ using socket_t = int; */ namespace httplib { +namespace ws { +class WebSocket; +} // namespace ws + namespace detail { /* @@ -711,6 +797,143 @@ using Match = std::smatch; using DownloadProgress = std::function; using UploadProgress = std::function; + +#if __cplusplus >= 201703L + +using any = std::any; +using bad_any_cast = std::bad_any_cast; + +template T any_cast(const any &a) { return std::any_cast(a); } +template T any_cast(any &a) { return std::any_cast(a); } +template T any_cast(any &&a) { + return std::any_cast(std::move(a)); +} +template const T *any_cast(const any *a) noexcept { + return std::any_cast(a); +} +template T *any_cast(any *a) noexcept { + return std::any_cast(a); +} + +#else // C++11/14 implementation + +class bad_any_cast : public std::bad_cast { +public: + const char *what() const noexcept override { return "bad any_cast"; } +}; + +namespace detail { + +using any_type_id = const void *; + +// Returns a unique per-type ID without RTTI. +// The static address is stable across TUs because function templates are +// implicitly inline and the ODR merges their statics into one. +template any_type_id any_typeid() noexcept { + static const char id = 0; + return &id; +} + +struct any_storage { + virtual ~any_storage() = default; + virtual std::unique_ptr clone() const = 0; + virtual any_type_id type_id() const noexcept = 0; +}; + +template struct any_value final : any_storage { + T value; + template explicit any_value(U &&v) : value(std::forward(v)) {} + std::unique_ptr clone() const override { + return std::unique_ptr(new any_value(value)); + } + any_type_id type_id() const noexcept override { return any_typeid(); } +}; + +} // namespace detail + +class any { + std::unique_ptr storage_; + +public: + any() noexcept = default; + any(const any &o) : storage_(o.storage_ ? o.storage_->clone() : nullptr) {} + any(any &&) noexcept = default; + any &operator=(const any &o) { + storage_ = o.storage_ ? o.storage_->clone() : nullptr; + return *this; + } + any &operator=(any &&) noexcept = default; + + template < + typename T, typename D = typename std::decay::type, + typename std::enable_if::value, int>::type = 0> + any(T &&v) : storage_(new detail::any_value(std::forward(v))) {} + + template < + typename T, typename D = typename std::decay::type, + typename std::enable_if::value, int>::type = 0> + any &operator=(T &&v) { + storage_.reset(new detail::any_value(std::forward(v))); + return *this; + } + + bool has_value() const noexcept { return storage_ != nullptr; } + void reset() noexcept { storage_.reset(); } + + template friend T *any_cast(any *a) noexcept; + template friend const T *any_cast(const any *a) noexcept; +}; + +template T *any_cast(any *a) noexcept { + if (!a || !a->storage_) { return nullptr; } + if (a->storage_->type_id() != detail::any_typeid()) { return nullptr; } + return &static_cast *>(a->storage_.get())->value; +} + +template const T *any_cast(const any *a) noexcept { + if (!a || !a->storage_) { return nullptr; } + if (a->storage_->type_id() != detail::any_typeid()) { return nullptr; } + return &static_cast *>(a->storage_.get())->value; +} + +template T any_cast(const any &a) { + using U = + typename std::remove_cv::type>::type; + const U *p = any_cast(&a); +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + if (!p) { throw bad_any_cast{}; } +#else + if (!p) { std::abort(); } +#endif + return static_cast(*p); +} + +template T any_cast(any &a) { + using U = + typename std::remove_cv::type>::type; + U *p = any_cast(&a); +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + if (!p) { throw bad_any_cast{}; } +#else + if (!p) { std::abort(); } +#endif + return static_cast(*p); +} + +template T any_cast(any &&a) { + using U = + typename std::remove_cv::type>::type; + U *p = any_cast(&a); +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + if (!p) { throw bad_any_cast{}; } +#else + if (!p) { std::abort(); } +#endif + return static_cast(std::move(*p)); +} + +#endif // __cplusplus >= 201703L + struct Response; using ResponseHandler = std::function; @@ -805,6 +1028,60 @@ struct FormDataProvider { }; using FormDataProviderItems = std::vector; +inline FormDataProvider +make_file_provider(const std::string &name, const std::string &filepath, + const std::string &filename = std::string(), + const std::string &content_type = std::string()) { + FormDataProvider fdp; + fdp.name = name; + fdp.filename = filename.empty() ? filepath : filename; + fdp.content_type = content_type; + fdp.provider = [filepath](size_t offset, DataSink &sink) -> bool { + std::ifstream f(filepath, std::ios::binary); + if (!f) { return false; } + if (offset > 0) { + f.seekg(static_cast(offset)); + if (!f.good()) { + sink.done(); + return true; + } + } + char buf[8192]; + f.read(buf, sizeof(buf)); + auto n = static_cast(f.gcount()); + if (n > 0) { return sink.write(buf, n); } + sink.done(); // EOF + return true; + }; + return fdp; +} + +inline std::pair +make_file_body(const std::string &filepath) { + std::ifstream f(filepath, std::ios::binary | std::ios::ate); + if (!f) { return {0, ContentProvider{}}; } + auto size = static_cast(f.tellg()); + + ContentProvider provider = [filepath](size_t offset, size_t length, + DataSink &sink) -> bool { + std::ifstream f(filepath, std::ios::binary); + if (!f) { return false; } + f.seekg(static_cast(offset)); + if (!f.good()) { return false; } + char buf[8192]; + while (length > 0) { + auto to_read = (std::min)(sizeof(buf), length); + f.read(buf, static_cast(to_read)); + auto n = static_cast(f.gcount()); + if (n == 0) { break; } + if (!sink.write(buf, n)) { return false; } + length -= n; + } + return true; + }; + return {size, std::move(provider)}; +} + using ContentReceiverWithProgress = std::function; @@ -1010,6 +1287,10 @@ struct Response { std::string body; std::string location; // Redirect location + // User-defined context — set by pre-routing/pre-request handlers and read + // by route handlers to pass arbitrary data (e.g. decoded auth tokens). + std::map user_data; + bool has_header(const std::string &key) const; std::string get_header_value(const std::string &key, const char *def = "", size_t id = 0) const; @@ -1115,6 +1396,7 @@ public: virtual bool is_readable() const = 0; virtual bool wait_readable() const = 0; virtual bool wait_writable() const = 0; + virtual bool is_peer_alive() const { return wait_writable(); } virtual ssize_t read(char *ptr, size_t size) = 0; virtual ssize_t write(const char *ptr, size_t size) = 0; @@ -1124,6 +1406,11 @@ public: virtual time_t duration() const = 0; + virtual void set_read_timeout(time_t sec, time_t usec = 0) { + (void)sec; + (void)usec; + } + ssize_t write(const char *ptr); ssize_t write(const std::string &s); @@ -1146,7 +1433,7 @@ public: class ThreadPool final : public TaskQueue { public: - explicit ThreadPool(size_t n, size_t mqr = 0); + explicit ThreadPool(size_t n, size_t max_n = 0, size_t mqr = 0); ThreadPool(const ThreadPool &) = delete; ~ThreadPool() override = default; @@ -1154,20 +1441,22 @@ public: void shutdown() override; private: - struct worker { - explicit worker(ThreadPool &pool); + void worker(bool is_dynamic); + void move_to_finished(std::thread::id id); + void cleanup_finished_threads(); - void operator()(); - - ThreadPool &pool_; - }; - friend struct worker; - - std::vector threads_; - std::list> jobs_; + size_t base_thread_count_; + size_t max_thread_count_; + size_t max_queued_requests_; + size_t idle_thread_count_; bool shutdown_; - size_t max_queued_requests_ = 0; + + std::list> jobs_; + std::vector threads_; // base threads + std::list dynamic_threads_; // dynamic threads + std::vector + finished_threads_; // exited dynamic threads awaiting join std::condition_variable cond_; std::mutex mutex_; @@ -1294,6 +1583,11 @@ public: using Expect100ContinueHandler = std::function; + using WebSocketHandler = + std::function; + using SubProtocolSelector = + std::function &protocols)>; + Server(); virtual ~Server(); @@ -1311,6 +1605,10 @@ public: Server &Delete(const std::string &pattern, HandlerWithContentReader handler); Server &Options(const std::string &pattern, Handler handler); + Server &WebSocket(const std::string &pattern, WebSocketHandler handler); + Server &WebSocket(const std::string &pattern, WebSocketHandler handler, + SubProtocolSelector sub_protocol_selector); + bool set_base_dir(const std::string &dir, const std::string &mount_point = std::string()); bool set_mount_point(const std::string &mount_point, const std::string &dir, @@ -1386,7 +1684,8 @@ protected: int remote_port, const std::string &local_addr, int local_port, bool close_connection, bool &connection_closed, - const std::function &setup_request); + const std::function &setup_request, + bool *websocket_upgraded = nullptr); std::atomic svr_sock_{INVALID_SOCKET}; @@ -1488,6 +1787,14 @@ private: HandlersForContentReader delete_handlers_for_content_reader_; Handlers options_handlers_; + struct WebSocketHandlerEntry { + std::unique_ptr matcher; + WebSocketHandler handler; + SubProtocolSelector sub_protocol_selector; + }; + using WebSocketHandlers = std::vector; + WebSocketHandlers websocket_handlers_; + HandlerWithResponse error_handler_; ExceptionHandler exception_handler_; HandlerWithResponse pre_routing_handler_; @@ -2268,8 +2575,7 @@ public: tls::ctx_t tls_context() const; -#if defined(_WIN32) && \ - !defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE) +#ifdef CPPHTTPLIB_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE void enable_windows_certificate_verification(bool enabled); #endif @@ -2390,8 +2696,7 @@ public: tls::ctx_t tls_context() const { return ctx_; } -#if defined(_WIN32) && \ - !defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE) +#ifdef CPPHTTPLIB_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE void enable_windows_certificate_verification(bool enabled); #endif @@ -2423,8 +2728,7 @@ private: std::function session_verifier_; -#if defined(_WIN32) && \ - !defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE) +#ifdef CPPHTTPLIB_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE bool enable_windows_cert_verification_ = true; #endif @@ -2970,6 +3274,36 @@ struct MbedTlsContext { } // namespace tls #endif +#ifdef CPPHTTPLIB_WOLFSSL_SUPPORT +namespace tls { +namespace impl { + +// wolfSSL context wrapper (holds WOLFSSL_CTX and related state). +// This struct is accessible via tls::impl for use in SSL context +// setup callbacks (cast ctx_t to tls::impl::WolfSSLContext*). +struct WolfSSLContext { + WOLFSSL_CTX *ctx = nullptr; + bool is_server = false; + bool verify_client = false; + bool has_verify_callback = false; + std::string ca_pem_data_; // accumulated PEM for get_ca_names/get_ca_certs + + WolfSSLContext(); + ~WolfSSLContext(); + + WolfSSLContext(const WolfSSLContext &) = delete; + WolfSSLContext &operator=(const WolfSSLContext &) = delete; +}; + +// CA store for wolfSSL: holds raw PEM bytes to allow reloading into any ctx +struct WolfSSLCAStore { + std::string pem_data; +}; + +} // namespace impl +} // namespace tls +#endif + #endif // CPPHTTPLIB_SSL_ENABLED namespace stream { @@ -3335,6 +3669,143 @@ private: } // namespace sse +namespace ws { + +enum class Opcode : uint8_t { + Continuation = 0x0, + Text = 0x1, + Binary = 0x2, + Close = 0x8, + Ping = 0x9, + Pong = 0xA, +}; + +enum class CloseStatus : uint16_t { + Normal = 1000, + GoingAway = 1001, + ProtocolError = 1002, + UnsupportedData = 1003, + NoStatus = 1005, + Abnormal = 1006, + InvalidPayload = 1007, + PolicyViolation = 1008, + MessageTooBig = 1009, + MandatoryExtension = 1010, + InternalError = 1011, +}; + +enum ReadResult : int { Fail = 0, Text = 1, Binary = 2 }; + +class WebSocket { +public: + WebSocket(const WebSocket &) = delete; + WebSocket &operator=(const WebSocket &) = delete; + ~WebSocket(); + + ReadResult read(std::string &msg); + bool send(const std::string &data); + bool send(const char *data, size_t len); + void close(CloseStatus status = CloseStatus::Normal, + const std::string &reason = ""); + const Request &request() const; + bool is_open() const; + +private: + friend class httplib::Server; + friend class WebSocketClient; + + WebSocket(Stream &strm, const Request &req, bool is_server) + : strm_(strm), req_(req), is_server_(is_server) { + start_heartbeat(); + } + + WebSocket(std::unique_ptr &&owned_strm, const Request &req, + bool is_server) + : strm_(*owned_strm), owned_strm_(std::move(owned_strm)), req_(req), + is_server_(is_server) { + start_heartbeat(); + } + + void start_heartbeat(); + bool send_frame(Opcode op, const char *data, size_t len, bool fin = true); + + Stream &strm_; + std::unique_ptr owned_strm_; + Request req_; + bool is_server_; + std::atomic closed_{false}; + std::mutex write_mutex_; + std::thread ping_thread_; + std::mutex ping_mutex_; + std::condition_variable ping_cv_; +}; + +class WebSocketClient { +public: + explicit WebSocketClient(const std::string &scheme_host_port_path, + const Headers &headers = {}); + + ~WebSocketClient(); + WebSocketClient(const WebSocketClient &) = delete; + WebSocketClient &operator=(const WebSocketClient &) = delete; + + bool is_valid() const; + + bool connect(); + ReadResult read(std::string &msg); + bool send(const std::string &data); + bool send(const char *data, size_t len); + void close(CloseStatus status = CloseStatus::Normal, + const std::string &reason = ""); + bool is_open() const; + const std::string &subprotocol() const; + void set_read_timeout(time_t sec, time_t usec = 0); + void set_write_timeout(time_t sec, time_t usec = 0); + +#ifdef CPPHTTPLIB_SSL_ENABLED + void set_ca_cert_path(const std::string &path); + void set_ca_cert_store(tls::ca_store_t store); + void enable_server_certificate_verification(bool enabled); +#endif + +private: + void shutdown_and_close(); + bool create_stream(std::unique_ptr &strm); + + std::string host_; + int port_; + std::string path_; + Headers headers_; + std::string subprotocol_; + bool is_valid_ = false; + socket_t sock_ = INVALID_SOCKET; + std::unique_ptr ws_; + time_t read_timeout_sec_ = CPPHTTPLIB_WEBSOCKET_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = 0; + time_t write_timeout_sec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND; + time_t write_timeout_usec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND; + +#ifdef CPPHTTPLIB_SSL_ENABLED + bool is_ssl_ = false; + tls::ctx_t tls_ctx_ = nullptr; + tls::session_t tls_session_ = nullptr; + std::string ca_cert_file_path_; + tls::ca_store_t ca_cert_store_ = nullptr; + bool server_certificate_verification_ = true; +#endif +}; + +namespace impl { + +bool is_valid_utf8(const std::string &s); + +bool read_websocket_frame(Stream &strm, Opcode &opcode, std::string &payload, + bool &fin, bool expect_masked, size_t max_len); + +} // namespace impl + +} // namespace ws + } // namespace httplib diff --git a/vendor/miniaudio/miniaudio.h b/vendor/miniaudio/miniaudio.h index 2f5b9c4eaf..24e676bb26 100644 --- a/vendor/miniaudio/miniaudio.h +++ b/vendor/miniaudio/miniaudio.h @@ -1,6 +1,6 @@ /* Audio playback and capture library. Choice of public domain or MIT-0. See license statements at the end of this file. -miniaudio - v0.11.24 - TBD +miniaudio - v0.11.24 - 2026-01-17 David Reid - mackron@gmail.com @@ -3858,7 +3858,7 @@ typedef ma_uint16 wchar_t; /* Platform/backend detection. */ -#if defined(_WIN32) || defined(__COSMOPOLITAN__) +#if defined(_WIN32) #define MA_WIN32 #if defined(MA_FORCE_UWP) || (defined(WINAPI_FAMILY) && ((defined(WINAPI_FAMILY_PC_APP) && WINAPI_FAMILY == WINAPI_FAMILY_PC_APP) || (defined(WINAPI_FAMILY_PHONE_APP) && WINAPI_FAMILY == WINAPI_FAMILY_PHONE_APP))) #define MA_WIN32_UWP @@ -4182,9 +4182,13 @@ typedef enum MA_CHANNEL_AUX_29 = 49, MA_CHANNEL_AUX_30 = 50, MA_CHANNEL_AUX_31 = 51, + + /* Count. */ + MA_CHANNEL_POSITION_COUNT, + + /* Aliases. */ MA_CHANNEL_LEFT = MA_CHANNEL_FRONT_LEFT, MA_CHANNEL_RIGHT = MA_CHANNEL_FRONT_RIGHT, - MA_CHANNEL_POSITION_COUNT = (MA_CHANNEL_AUX_31 + 1) } _ma_channel_position; /* Do not use `_ma_channel_position` directly. Use `ma_channel` instead. */ typedef enum @@ -6604,16 +6608,12 @@ This section contains the APIs for device playback and capture. Here is where yo #if defined(MA_WIN32_DESKTOP) /* DirectSound and WinMM backends are only supported on desktops. */ #define MA_SUPPORT_DSOUND #define MA_SUPPORT_WINMM - - /* Don't enable JACK here if compiling with Cosmopolitan. It'll be enabled in the Linux section below. */ - #if !defined(__COSMOPOLITAN__) - #define MA_SUPPORT_JACK /* JACK is technically supported on Windows, but I don't know how many people use it in practice... */ - #endif + #define MA_SUPPORT_JACK /* JACK is technically supported on Windows, but I don't know how many people use it in practice... */ #endif #endif #if defined(MA_UNIX) && !defined(MA_ORBIS) && !defined(MA_PROSPERO) #if defined(MA_LINUX) - #if !defined(MA_ANDROID) && !defined(__COSMOPOLITAN__) /* ALSA is not supported on Android. */ + #if !defined(MA_ANDROID) && !defined(MA_EMSCRIPTEN) /* ALSA is not supported on Android. */ #define MA_SUPPORT_ALSA #endif #endif @@ -10520,6 +10520,7 @@ typedef struct ma_decoding_backend_vtable** ppCustomDecodingBackendVTables; ma_uint32 customDecodingBackendCount; void* pCustomDecodingBackendUserData; + ma_resampler_config resampling; } ma_resource_manager_config; MA_API ma_resource_manager_config ma_resource_manager_config_init(void); @@ -10847,6 +10848,7 @@ MA_API ma_result ma_node_graph_read_pcm_frames(ma_node_graph* pNodeGraph, void* MA_API ma_uint32 ma_node_graph_get_channels(const ma_node_graph* pNodeGraph); MA_API ma_uint64 ma_node_graph_get_time(const ma_node_graph* pNodeGraph); MA_API ma_result ma_node_graph_set_time(ma_node_graph* pNodeGraph, ma_uint64 globalTime); +MA_API ma_uint32 ma_node_graph_get_processing_size_in_frames(const ma_node_graph* pNodeGraph); @@ -11154,6 +11156,7 @@ typedef struct ma_bool8 isPitchDisabled; /* Pitching can be explicitly disabled with MA_SOUND_FLAG_NO_PITCH to optimize processing. */ ma_bool8 isSpatializationDisabled; /* Spatialization can be explicitly disabled with MA_SOUND_FLAG_NO_SPATIALIZATION. */ ma_uint8 pinnedListenerIndex; /* The index of the listener this node should always use for spatialization. If set to MA_LISTENER_INDEX_CLOSEST the engine will use the closest listener. */ + ma_resampler_config resampling; } ma_engine_node_config; MA_API ma_engine_node_config ma_engine_node_config_init(ma_engine* pEngine, ma_engine_node_type type, ma_uint32 flags); @@ -11168,7 +11171,7 @@ typedef struct ma_uint32 volumeSmoothTimeInPCMFrames; ma_mono_expansion_mode monoExpansionMode; ma_fader fader; - ma_linear_resampler resampler; /* For pitch shift. */ + ma_resampler resampler; /* For pitch shift. */ ma_spatializer spatializer; ma_panner panner; ma_gainer volumeGainer; /* This will only be used if volumeSmoothTimeInPCMFrames is > 0. */ @@ -11224,6 +11227,7 @@ typedef struct ma_uint64 loopPointEndInPCMFrames; ma_sound_end_proc endCallback; /* Fired when the sound reaches the end. Will be fired from the audio thread. Do not restart, uninitialize or otherwise change the state of the sound from here. Instead fire an event or set a variable to indicate to a different thread to change the start of the sound. Will not be fired in response to a scheduled stop with ma_sound_set_stop_time_*(). */ void* pEndCallbackUserData; + ma_resampler_config pitchResampling; #ifndef MA_NO_RESOURCE_MANAGER ma_resource_manager_pipeline_notifications initNotifications; #endif @@ -11242,7 +11246,10 @@ struct ma_sound MA_ATOMIC(4, ma_bool32) atEnd; ma_sound_end_proc endCallback; void* pEndCallbackUserData; - ma_bool8 ownsDataSource; + float* pProcessingCache; /* Will be null if pDataSource is null. */ + ma_uint32 processingCacheFramesRemaining; + ma_uint32 processingCacheCap; + ma_bool8 ownsDataSource; /* We're declaring a resource manager data source object here to save us a malloc when loading a @@ -11300,6 +11307,8 @@ typedef struct ma_vfs* pResourceManagerVFS; /* A pointer to a pre-allocated VFS object to use with the resource manager. This is ignored if pResourceManager is not NULL. */ ma_engine_process_proc onProcess; /* Fired at the end of each call to ma_engine_read_pcm_frames(). For engine's that manage their own internal device (the default configuration), this will be fired from the audio thread, and you do not need to call ma_engine_read_pcm_frames() manually in order to trigger this. */ void* pProcessUserData; /* User data that's passed into onProcess. */ + ma_resampler_config resourceManagerResampling; /* The resampling config to use with the resource manager. */ + ma_resampler_config pitchResampling; /* The resampling config for the pitch and Doppler effects. You will typically want this to be a fast resampler. For high quality stuff, it's recommended that you pre-resample. */ } ma_engine_config; MA_API ma_engine_config ma_engine_config_init(void); @@ -11329,6 +11338,7 @@ struct ma_engine ma_mono_expansion_mode monoExpansionMode; ma_engine_process_proc onProcess; void* pProcessUserData; + ma_resampler_config pitchResamplingConfig; }; MA_API ma_result ma_engine_init(const ma_engine_config* pConfig, ma_engine* pEngine); @@ -11389,8 +11399,12 @@ MA_API ma_engine* ma_sound_get_engine(const ma_sound* pSound); MA_API ma_data_source* ma_sound_get_data_source(const ma_sound* pSound); MA_API ma_result ma_sound_start(ma_sound* pSound); MA_API ma_result ma_sound_stop(ma_sound* pSound); -MA_API ma_result ma_sound_stop_with_fade_in_pcm_frames(ma_sound* pSound, ma_uint64 fadeLengthInFrames); /* Will overwrite any scheduled stop and fade. */ -MA_API ma_result ma_sound_stop_with_fade_in_milliseconds(ma_sound* pSound, ma_uint64 fadeLengthInFrames); /* Will overwrite any scheduled stop and fade. */ +MA_API ma_result ma_sound_stop_with_fade_in_pcm_frames(ma_sound* pSound, ma_uint64 fadeLengthInFrames); /* Will overwrite any scheduled stop and fade. If you want to restart the sound, first reset it with `ma_sound_reset_stop_time_and_fade()`. There are plans to make this less awkward in the future. */ +MA_API ma_result ma_sound_stop_with_fade_in_milliseconds(ma_sound* pSound, ma_uint64 fadeLengthInFrames); /* Will overwrite any scheduled stop and fade. If you want to restart the sound, first reset it with `ma_sound_reset_stop_time_and_fade()`. There are plans to make this less awkward in the future. */ +MA_API void ma_sound_reset_start_time(ma_sound* pSound); +MA_API void ma_sound_reset_stop_time(ma_sound* pSound); +MA_API void ma_sound_reset_fade(ma_sound* pSound); +MA_API void ma_sound_reset_stop_time_and_fade(ma_sound* pSound); /* Resets fades and scheduled stop time. Does not seek back to the start. */ MA_API void ma_sound_set_volume(ma_sound* pSound, float volume); MA_API float ma_sound_get_volume(const ma_sound* pSound); MA_API void ma_sound_set_pan(ma_sound* pSound, float pan); @@ -11643,7 +11657,7 @@ IMPLEMENTATION #endif /* Intrinsics Support */ -#if (defined(MA_X64) || defined(MA_X86)) && !defined(__COSMOPOLITAN__) +#if defined(MA_X64) || defined(MA_X86) #if defined(_MSC_VER) && !defined(__clang__) /* MSVC. */ #if _MSC_VER >= 1400 && !defined(MA_NO_SSE2) /* 2005 */ @@ -12080,7 +12094,7 @@ static MA_INLINE unsigned int ma_disable_denormals(void) } #elif defined(MA_X86) || defined(MA_X64) { - #if defined(MA_SUPPORT_SSE2) && defined(__SSE2__) && !(defined(__TINYC__) || defined(__WATCOMC__) || defined(__COSMOPOLITAN__)) /* <-- Add compilers that lack support for _mm_getcsr() and _mm_setcsr() to this list. */ + #if defined(MA_SUPPORT_SSE2) && defined(__SSE2__) && !(defined(__TINYC__) || defined(__WATCOMC__)) /* <-- Add compilers that lack support for _mm_getcsr() and _mm_setcsr() to this list. */ { prevState = _mm_getcsr(); _mm_setcsr(prevState | MA_MM_DENORMALS_ZERO_MASK | MA_MM_FLUSH_ZERO_MASK); @@ -12120,7 +12134,7 @@ static MA_INLINE void ma_restore_denormals(unsigned int prevState) } #elif defined(MA_X86) || defined(MA_X64) { - #if defined(MA_SUPPORT_SSE2) && defined(__SSE2__) && !(defined(__TINYC__) || defined(__WATCOMC__) || defined(__COSMOPOLITAN__)) /* <-- Add compilers that lack support for _mm_getcsr() and _mm_setcsr() to this list. */ + #if defined(MA_SUPPORT_SSE2) && defined(__SSE2__) && !(defined(__TINYC__) || defined(__WATCOMC__)) /* <-- Add compilers that lack support for _mm_getcsr() and _mm_setcsr() to this list. */ { _mm_setcsr(prevState); } @@ -17616,7 +17630,7 @@ static ma_result ma_thread_create__posix(ma_thread* pThread, ma_thread_priority int priorityStep = (priorityMax - priorityMin) / 7; /* 7 = number of priorities supported by miniaudio. */ struct sched_param sched; - if (pthread_attr_getschedparam(&attr, &sched) == 0) { + if (priorityMin != -1 && priorityMax != -1 && pthread_attr_getschedparam(&attr, &sched) == 0) { if (priority == ma_thread_priority_idle) { sched.sched_priority = priorityMin; } else if (priority == ma_thread_priority_realtime) { @@ -20073,7 +20087,7 @@ Timing struct timespec newTime; clock_gettime(MA_CLOCK_ID, &newTime); - pTimer->counter = (newTime.tv_sec * 1000000000) + newTime.tv_nsec; + pTimer->counter = ((ma_int64)newTime.tv_sec * 1000000000) + newTime.tv_nsec; } static MA_INLINE double ma_timer_get_time_in_seconds(ma_timer* pTimer) @@ -20084,7 +20098,7 @@ Timing struct timespec newTime; clock_gettime(MA_CLOCK_ID, &newTime); - newTimeCounter = (newTime.tv_sec * 1000000000) + newTime.tv_nsec; + newTimeCounter = ((ma_uint64)newTime.tv_sec * 1000000000) + newTime.tv_nsec; oldTimeCounter = pTimer->counter; return (newTimeCounter - oldTimeCounter) / 1000000000.0; @@ -20095,7 +20109,7 @@ Timing struct timeval newTime; gettimeofday(&newTime, NULL); - pTimer->counter = (newTime.tv_sec * 1000000) + newTime.tv_usec; + pTimer->counter = ((ma_int64)newTime.tv_sec * 1000000) + newTime.tv_usec; } static MA_INLINE double ma_timer_get_time_in_seconds(ma_timer* pTimer) @@ -20106,7 +20120,7 @@ Timing struct timeval newTime; gettimeofday(&newTime, NULL); - newTimeCounter = (newTime.tv_sec * 1000000) + newTime.tv_usec; + newTimeCounter = ((ma_uint64)newTime.tv_sec * 1000000) + newTime.tv_usec; oldTimeCounter = pTimer->counter; return (newTimeCounter - oldTimeCounter) / 1000000.0; @@ -31228,6 +31242,7 @@ static ma_result ma_init_pa_mainloop_and_pa_context__pulse(ma_context* pContext, result = ma_result_from_pulse(((ma_pa_context_connect_proc)pContext->pulse.pa_context_connect)((ma_pa_context*)pPulseContext, pServerName, (tryAutoSpawn) ? MA_PA_CONTEXT_NOFLAGS : MA_PA_CONTEXT_NOAUTOSPAWN, NULL)); if (result != MA_SUCCESS) { ma_log_postf(ma_context_get_log(pContext), MA_LOG_LEVEL_ERROR, "[PulseAudio] Failed to connect PulseAudio context."); + ((ma_pa_context_unref_proc)pContext->pulse.pa_context_unref)((ma_pa_context*)(pPulseContext)); ((ma_pa_mainloop_free_proc)pContext->pulse.pa_mainloop_free)((ma_pa_mainloop*)(pMainLoop)); return result; } @@ -31236,6 +31251,7 @@ static ma_result ma_init_pa_mainloop_and_pa_context__pulse(ma_context* pContext, result = ma_wait_for_pa_context_to_connect__pulse(pContext, pMainLoop, pPulseContext); if (result != MA_SUCCESS) { ma_log_postf(ma_context_get_log(pContext), MA_LOG_LEVEL_ERROR, "[PulseAudio] Waiting for connection failed."); + ((ma_pa_context_unref_proc)pContext->pulse.pa_context_unref)((ma_pa_context*)(pPulseContext)); ((ma_pa_mainloop_free_proc)pContext->pulse.pa_mainloop_free)((ma_pa_mainloop*)(pMainLoop)); return result; } @@ -41747,8 +41763,11 @@ static EM_BOOL ma_audio_worklet_process_callback__webaudio(int inputCount, const frameCount = pDevice->capture.internalPeriodSizeInFrames; } + /* + If this is called by the device has not yet been started we need to return early, making sure we output silence to + the output buffer. + */ if (ma_device_get_state(pDevice) != ma_device_state_started) { - /* Fill the output buffer with zero to avoid a noise sound */ for (int i = 0; i < outputCount; i += 1) { MA_ZERO_MEMORY(pOutputs[i].data, pOutputs[i].numberOfChannels * frameCount * sizeof(float)); } @@ -41770,7 +41789,9 @@ static EM_BOOL ma_audio_worklet_process_callback__webaudio(int inputCount, const if (outputCount > 0) { /* If it's a capture-only device, we'll need to output silence. */ if (pDevice->type == ma_device_type_capture) { - MA_ZERO_MEMORY(pOutputs[0].data, frameCount * pDevice->playback.internalChannels * sizeof(float)); + for (int i = 0; i < outputCount; i += 1) { + MA_ZERO_MEMORY(pOutputs[i].data, pOutputs[i].numberOfChannels * frameCount * sizeof(float)); + } } else { ma_device_process_pcm_frames_playback__webaudio(pDevice, frameCount, pDevice->webaudio.pIntermediaryBuffer); @@ -41780,6 +41801,14 @@ static EM_BOOL ma_audio_worklet_process_callback__webaudio(int inputCount, const pOutputs[0].data[frameCount*iChannel + iFrame] = pDevice->webaudio.pIntermediaryBuffer[iFrame*pDevice->playback.internalChannels + iChannel]; } } + + /* + Just above we output data to the first output buffer. Here we just make sure we're putting silence into any + remaining output buffers. + */ + for (int i = 1; i < outputCount; i += 1) { /* <-- Note that the counter starts at 1 instead of 0. */ + MA_ZERO_MEMORY(pOutputs[i].data, pOutputs[i].numberOfChannels * frameCount * sizeof(float)); + } } } @@ -50850,15 +50879,15 @@ static /*__attribute__((noinline))*/ ma_result ma_gainer_process_pcm_frames_inte a += d; } } + + pFramesOut = ma_offset_ptr(pFramesOut, interpolatedFrameCount * sizeof(float)); + pFramesIn = ma_offset_ptr(pFramesIn, interpolatedFrameCount * sizeof(float)); } + frameCount -= interpolatedFrameCount; + /* Make sure the timer is updated. */ pGainer->t = (ma_uint32)ma_min(pGainer->t + interpolatedFrameCount, pGainer->config.smoothTimeInFrames); - - /* Adjust our arguments so the next part can work normally. */ - frameCount -= interpolatedFrameCount; - pFramesOut = ma_offset_ptr(pFramesOut, interpolatedFrameCount * sizeof(float)); - pFramesIn = ma_offset_ptr(pFramesIn, interpolatedFrameCount * sizeof(float)); } /* All we need to do here is apply the new gains using an optimized path. */ @@ -52286,13 +52315,16 @@ static float ma_calculate_angular_gain(ma_vec3f dirA, ma_vec3f dirB, float coneI MA_API ma_result ma_spatializer_process_pcm_frames(ma_spatializer* pSpatializer, ma_spatializer_listener* pListener, void* pFramesOut, const void* pFramesIn, ma_uint64 frameCount) { - ma_channel* pChannelMapIn = pSpatializer->pChannelMapIn; - ma_channel* pChannelMapOut = pListener->config.pChannelMapOut; + ma_channel* pChannelMapIn; + ma_channel* pChannelMapOut; - if (pSpatializer == NULL) { + if (pSpatializer == NULL || pListener == NULL) { return MA_INVALID_ARGS; } + pChannelMapIn = pSpatializer->pChannelMapIn; + pChannelMapOut = pListener->config.pChannelMapOut; + /* If we're not spatializing we need to run an optimized path. */ if (ma_atomic_load_i32(&pSpatializer->attenuationModel) == ma_attenuation_model_none) { if (ma_spatializer_listener_is_enabled(pListener)) { @@ -52337,23 +52369,17 @@ MA_API ma_result ma_spatializer_process_pcm_frames(ma_spatializer* pSpatializer, We'll need the listener velocity for doppler pitch calculations. The speed of sound is defined by the listener, so we'll grab that here too. */ - if (pListener != NULL) { - listenerVel = ma_spatializer_listener_get_velocity(pListener); - speedOfSound = pListener->config.speedOfSound; - } else { - listenerVel = ma_vec3f_init_3f(0, 0, 0); - speedOfSound = MA_DEFAULT_SPEED_OF_SOUND; - } + listenerVel = ma_spatializer_listener_get_velocity(pListener); + speedOfSound = pListener->config.speedOfSound; - if (pListener == NULL || ma_spatializer_get_positioning(pSpatializer) == ma_positioning_relative) { - /* There's no listener or we're using relative positioning. */ + if (ma_spatializer_get_positioning(pSpatializer) == ma_positioning_relative) { relativePos = ma_spatializer_get_position(pSpatializer); relativeDir = ma_spatializer_get_direction(pSpatializer); } else { /* - We've found a listener and we're using absolute positioning. We need to transform the - sound's position and direction so that it's relative to listener. Later on we'll use - this for determining the factors to apply to each channel to apply the panning effect. + We're using absolute positioning. We need to transform the sound's position and + direction so that it's relative to listener. Later on we'll use this for determining + the factors to apply to each channel to apply the panning effect. */ ma_spatializer_get_relative_position_and_direction(pSpatializer, pListener, &relativePos, &relativeDir); } @@ -54388,7 +54414,7 @@ static ma_bool32 ma_is_spatial_channel_position(ma_channel channelPosition) return MA_FALSE; } - if (channelPosition >= MA_CHANNEL_AUX_0 && channelPosition <= MA_CHANNEL_AUX_31) { + if (channelPosition >= MA_CHANNEL_AUX_0) { return MA_FALSE; } @@ -61676,7 +61702,6 @@ static ma_result ma_default_vfs_info(ma_vfs* pVFS, ma_vfs_file file, ma_file_inf if (result == MA_NOT_IMPLEMENTED) { /* Not implemented. Fall back to seek/tell/seek. */ - ma_result result; ma_int64 cursor; ma_int64 sizeInBytes; @@ -61884,6 +61909,8 @@ Decoding and Encoding Headers. These are auto-generated from a tool. **************************************************************************************************************************************************************/ #if !defined(MA_NO_WAV) && (!defined(MA_NO_DECODING) || !defined(MA_NO_ENCODING)) +#define MA_HAS_WAV + /* dr_wav_h begin */ #ifndef ma_dr_wav_h #define ma_dr_wav_h @@ -61894,7 +61921,7 @@ extern "C" { #define MA_DR_WAV_XSTRINGIFY(x) MA_DR_WAV_STRINGIFY(x) #define MA_DR_WAV_VERSION_MAJOR 0 #define MA_DR_WAV_VERSION_MINOR 14 -#define MA_DR_WAV_VERSION_REVISION 1 +#define MA_DR_WAV_VERSION_REVISION 4 #define MA_DR_WAV_VERSION_STRING MA_DR_WAV_XSTRINGIFY(MA_DR_WAV_VERSION_MAJOR) "." MA_DR_WAV_XSTRINGIFY(MA_DR_WAV_VERSION_MINOR) "." MA_DR_WAV_XSTRINGIFY(MA_DR_WAV_VERSION_REVISION) #include #define MA_DR_WAVE_FORMAT_PCM 0x1 @@ -62317,6 +62344,8 @@ MA_API ma_bool32 ma_dr_wav_fourcc_equal(const ma_uint8* a, const char* b); #endif /* MA_NO_WAV */ #if !defined(MA_NO_FLAC) && !defined(MA_NO_DECODING) +#define MA_HAS_FLAC + /* dr_flac_h begin */ #ifndef ma_dr_flac_h #define ma_dr_flac_h @@ -62327,7 +62356,7 @@ extern "C" { #define MA_DR_FLAC_XSTRINGIFY(x) MA_DR_FLAC_STRINGIFY(x) #define MA_DR_FLAC_VERSION_MAJOR 0 #define MA_DR_FLAC_VERSION_MINOR 13 -#define MA_DR_FLAC_VERSION_REVISION 2 +#define MA_DR_FLAC_VERSION_REVISION 3 #define MA_DR_FLAC_VERSION_STRING MA_DR_FLAC_XSTRINGIFY(MA_DR_FLAC_VERSION_MAJOR) "." MA_DR_FLAC_XSTRINGIFY(MA_DR_FLAC_VERSION_MINOR) "." MA_DR_FLAC_XSTRINGIFY(MA_DR_FLAC_VERSION_REVISION) #include #if defined(_MSC_VER) && _MSC_VER >= 1700 @@ -62609,6 +62638,8 @@ MA_API ma_bool32 ma_dr_flac_next_cuesheet_track(ma_dr_flac_cuesheet_track_iterat #endif /* MA_NO_FLAC */ #if !defined(MA_NO_MP3) && !defined(MA_NO_DECODING) +#define MA_HAS_MP3 + #ifndef MA_DR_MP3_NO_SIMD #if (defined(MA_NO_NEON) && defined(MA_ARM)) || (defined(MA_NO_SSE2) && (defined(MA_X86) || defined(MA_X64))) #define MA_DR_MP3_NO_SIMD @@ -62625,7 +62656,7 @@ extern "C" { #define MA_DR_MP3_XSTRINGIFY(x) MA_DR_MP3_STRINGIFY(x) #define MA_DR_MP3_VERSION_MAJOR 0 #define MA_DR_MP3_VERSION_MINOR 7 -#define MA_DR_MP3_VERSION_REVISION 2 +#define MA_DR_MP3_VERSION_REVISION 3 #define MA_DR_MP3_VERSION_STRING MA_DR_MP3_XSTRINGIFY(MA_DR_MP3_VERSION_MAJOR) "." MA_DR_MP3_XSTRINGIFY(MA_DR_MP3_VERSION_MINOR) "." MA_DR_MP3_XSTRINGIFY(MA_DR_MP3_VERSION_REVISION) #include #define MA_DR_MP3_MAX_PCM_FRAMES_PER_MP3_FRAME 1152 @@ -63229,7 +63260,6 @@ static ma_result ma_decoder_init_custom_from_memory__internal(const void* pData, /* WAV */ #ifdef ma_dr_wav_h -#define MA_HAS_WAV typedef struct { @@ -63935,7 +63965,6 @@ static ma_result ma_decoder_init_wav_from_memory__internal(const void* pData, si /* FLAC */ #ifdef ma_dr_flac_h -#define MA_HAS_FLAC typedef struct { @@ -64579,7 +64608,6 @@ static ma_result ma_decoder_init_flac_from_memory__internal(const void* pData, s /* MP3 */ #ifdef ma_dr_mp3_h -#define MA_HAS_MP3 typedef struct { @@ -66257,11 +66285,9 @@ static ma_result ma_decoder_init__internal(ma_decoder_read_proc onRead, ma_decod We use trial and error to open a decoder. We prioritize custom decoders so that if they implement the same encoding format they take priority over the built-in decoders. */ + result = ma_decoder_init_custom__internal(pConfig, pDecoder); if (result != MA_SUCCESS) { - result = ma_decoder_init_custom__internal(pConfig, pDecoder); - if (result != MA_SUCCESS) { - onSeek(pDecoder, 0, ma_seek_origin_start); - } + onSeek(pDecoder, 0, ma_seek_origin_start); } /* @@ -66525,14 +66551,6 @@ MA_API ma_result ma_decoder_init_memory(const void* pData, size_t dataSize, cons /* Initialization was successful. Finish up. */ result = ma_decoder__postinit(&config, pDecoder); if (result != MA_SUCCESS) { - /* - The backend was initialized successfully, but for some reason post-initialization failed. This is most likely - due to an out of memory error. We're going to abort with an error here and not try to recover. - */ - if (pDecoder->pBackendVTable != NULL && pDecoder->pBackendVTable->onUninit != NULL) { - pDecoder->pBackendVTable->onUninit(pDecoder->pBackendUserData, &pDecoder->pBackend, &pDecoder->allocationCallbacks); - } - return result; } } else { @@ -66833,11 +66851,9 @@ MA_API ma_result ma_decoder_init_vfs(ma_vfs* pVFS, const char* pFilePath, const We use trial and error to open a decoder. We prioritize custom decoders so that if they implement the same encoding format they take priority over the built-in decoders. */ + result = ma_decoder_init_custom__internal(&config, pDecoder); if (result != MA_SUCCESS) { - result = ma_decoder_init_custom__internal(&config, pDecoder); - if (result != MA_SUCCESS) { - ma_decoder__on_seek_vfs(pDecoder, 0, ma_seek_origin_start); - } + ma_decoder__on_seek_vfs(pDecoder, 0, ma_seek_origin_start); } /* @@ -66966,11 +66982,9 @@ MA_API ma_result ma_decoder_init_vfs_w(ma_vfs* pVFS, const wchar_t* pFilePath, c We use trial and error to open a decoder. We prioritize custom decoders so that if they implement the same encoding format they take priority over the built-in decoders. */ + result = ma_decoder_init_custom__internal(&config, pDecoder); if (result != MA_SUCCESS) { - result = ma_decoder_init_custom__internal(&config, pDecoder); - if (result != MA_SUCCESS) { - ma_decoder__on_seek_vfs(pDecoder, 0, ma_seek_origin_start); - } + ma_decoder__on_seek_vfs(pDecoder, 0, ma_seek_origin_start); } /* @@ -67152,14 +67166,6 @@ MA_API ma_result ma_decoder_init_file(const char* pFilePath, const ma_decoder_co /* Initialization was successful. Finish up. */ result = ma_decoder__postinit(&config, pDecoder); if (result != MA_SUCCESS) { - /* - The backend was initialized successfully, but for some reason post-initialization failed. This is most likely - due to an out of memory error. We're going to abort with an error here and not try to recover. - */ - if (pDecoder->pBackendVTable != NULL && pDecoder->pBackendVTable->onUninit != NULL) { - pDecoder->pBackendVTable->onUninit(pDecoder->pBackendUserData, &pDecoder->pBackend, &pDecoder->allocationCallbacks); - } - return result; } } else { @@ -67302,14 +67308,6 @@ MA_API ma_result ma_decoder_init_file_w(const wchar_t* pFilePath, const ma_decod /* Initialization was successful. Finish up. */ result = ma_decoder__postinit(&config, pDecoder); if (result != MA_SUCCESS) { - /* - The backend was initialized successfully, but for some reason post-initialization failed. This is most likely - due to an out of memory error. We're going to abort with an error here and not try to recover. - */ - if (pDecoder->pBackendVTable != NULL && pDecoder->pBackendVTable->onUninit != NULL) { - pDecoder->pBackendVTable->onUninit(pDecoder->pBackendUserData, &pDecoder->pBackend, &pDecoder->allocationCallbacks); - } - return result; } } else { @@ -69955,6 +69953,7 @@ MA_API ma_resource_manager_config ma_resource_manager_config_init(void) config.decodedSampleRate = 0; config.jobThreadCount = 1; /* A single miniaudio-managed job thread by default. */ config.jobQueueCapacity = MA_JOB_TYPE_RESOURCE_MANAGER_QUEUE_CAPACITY; + config.resampling = ma_resampler_config_init(ma_format_unknown, 0, 0, 0, ma_resample_algorithm_linear); /* Format/channels/rate doesn't matter here. */ /* Flags. */ config.flags = 0; @@ -70208,6 +70207,7 @@ static ma_decoder_config ma_resource_manager__init_decoder_config(ma_resource_ma config.ppCustomBackendVTables = pResourceManager->config.ppCustomDecodingBackendVTables; config.customBackendCount = pResourceManager->config.customDecodingBackendCount; config.pCustomBackendUserData = pResourceManager->config.pCustomDecodingBackendUserData; + config.resampling = pResourceManager->config.resampling; return config; } @@ -71533,13 +71533,13 @@ MA_API ma_result ma_resource_manager_data_buffer_get_data_format(ma_resource_man MA_API ma_result ma_resource_manager_data_buffer_get_cursor_in_pcm_frames(ma_resource_manager_data_buffer* pDataBuffer, ma_uint64* pCursor) { - /* We cannot be using the data source after it's been uninitialized. */ - MA_ASSERT(ma_resource_manager_data_buffer_node_result(pDataBuffer->pNode) != MA_UNAVAILABLE); - if (pDataBuffer == NULL || pCursor == NULL) { return MA_INVALID_ARGS; } + /* We cannot be using the data source after it's been uninitialized. */ + MA_ASSERT(ma_resource_manager_data_buffer_node_result(pDataBuffer->pNode) != MA_UNAVAILABLE); + *pCursor = 0; switch (ma_resource_manager_data_buffer_node_get_data_supply_type(pDataBuffer->pNode)) @@ -71573,13 +71573,13 @@ MA_API ma_result ma_resource_manager_data_buffer_get_cursor_in_pcm_frames(ma_res MA_API ma_result ma_resource_manager_data_buffer_get_length_in_pcm_frames(ma_resource_manager_data_buffer* pDataBuffer, ma_uint64* pLength) { - /* We cannot be using the data source after it's been uninitialized. */ - MA_ASSERT(ma_resource_manager_data_buffer_node_result(pDataBuffer->pNode) != MA_UNAVAILABLE); - if (pDataBuffer == NULL || pLength == NULL) { return MA_INVALID_ARGS; } + /* We cannot be using the data source after it's been uninitialized. */ + MA_ASSERT(ma_resource_manager_data_buffer_node_result(pDataBuffer->pNode) != MA_UNAVAILABLE); + if (ma_resource_manager_data_buffer_node_get_data_supply_type(pDataBuffer->pNode) == ma_resource_manager_data_supply_type_unknown) { return MA_BUSY; /* Still loading. */ } @@ -72934,8 +72934,6 @@ static ma_result ma_job_process__resource_manager__free_data_buffer_node(ma_job* return ma_resource_manager_post_job(pResourceManager, pJob); /* Out of order. */ } - ma_resource_manager_data_buffer_node_free(pResourceManager, pDataBufferNode); - /* The event needs to be signalled last. */ if (pJob->data.resourceManager.freeDataBufferNode.pDoneNotification != NULL) { ma_async_notification_signal(pJob->data.resourceManager.freeDataBufferNode.pDoneNotification); @@ -72946,6 +72944,9 @@ static ma_result ma_job_process__resource_manager__free_data_buffer_node(ma_job* } ma_atomic_fetch_add_32(&pDataBufferNode->executionPointer, 1); + + ma_resource_manager_data_buffer_node_free(pResourceManager, pDataBufferNode); + return MA_SUCCESS; } @@ -73818,6 +73819,15 @@ MA_API ma_result ma_node_graph_set_time(ma_node_graph* pNodeGraph, ma_uint64 glo return ma_node_set_time(&pNodeGraph->endpoint, globalTime); /* Global time is just the local time of the endpoint. */ } +MA_API ma_uint32 ma_node_graph_get_processing_size_in_frames(const ma_node_graph* pNodeGraph) +{ + if (pNodeGraph == NULL) { + return 0; + } + + return pNodeGraph->processingSizeInFrames; +} + #define MA_NODE_OUTPUT_BUS_FLAG_HAS_READ 0x01 /* Whether or not this bus ready to read more data. Only used on nodes with multiple output buses. */ @@ -74977,12 +74987,12 @@ MA_API ma_node_state ma_node_get_state_by_time_range(const ma_node* pNode, ma_ui its start time not having been reached yet. Also, the stop time may have also been reached in which case it'll be considered stopped. */ - if (ma_node_get_state_time(pNode, ma_node_state_started) > globalTimeBeg) { - return ma_node_state_stopped; /* Start time has not yet been reached. */ + if (ma_node_get_state_time(pNode, ma_node_state_stopped) < globalTimeBeg) { + return ma_node_state_stopped; /* End time is before the start of the range. */ } - if (ma_node_get_state_time(pNode, ma_node_state_stopped) <= globalTimeEnd) { - return ma_node_state_stopped; /* Stop time has been reached. */ + if (ma_node_get_state_time(pNode, ma_node_state_started) > globalTimeEnd) { + return ma_node_state_stopped; /* Start time is after the end of the range. */ } /* Getting here means the node is marked as started and is within its start/stop times. */ @@ -75062,14 +75072,14 @@ static ma_result ma_node_read_pcm_frames(ma_node* pNode, ma_uint32 outputBusInde return MA_INVALID_ARGS; /* Invalid output bus index. */ } + globalTimeBeg = globalTime; + globalTimeEnd = globalTime + frameCount; + /* Don't do anything if we're in a stopped state. */ - if (ma_node_get_state_by_time_range(pNode, globalTime, globalTime + frameCount) != ma_node_state_started) { + if (ma_node_get_state_by_time_range(pNode, globalTimeBeg, globalTimeEnd) != ma_node_state_started) { return MA_SUCCESS; /* We're in a stopped state. This is not an error - we just need to not read anything. */ } - - globalTimeBeg = globalTime; - globalTimeEnd = globalTime + frameCount; startTime = ma_node_get_state_time(pNode, ma_node_state_started); stopTime = ma_node_get_state_time(pNode, ma_node_state_stopped); @@ -75082,11 +75092,16 @@ static ma_result ma_node_read_pcm_frames(ma_node* pNode, ma_uint32 outputBusInde therefore need to offset it by a number of frames to accommodate. The same thing applies for the stop time. */ - timeOffsetBeg = (globalTimeBeg < startTime) ? (ma_uint32)(globalTimeEnd - startTime) : 0; + timeOffsetBeg = (globalTimeBeg < startTime) ? (ma_uint32)(startTime - globalTimeBeg) : 0; timeOffsetEnd = (globalTimeEnd > stopTime) ? (ma_uint32)(globalTimeEnd - stopTime) : 0; /* Trim based on the start offset. We need to silence the start of the buffer. */ if (timeOffsetBeg > 0) { + MA_ASSERT(timeOffsetBeg <= frameCount); + if (timeOffsetBeg > frameCount) { + timeOffsetBeg = frameCount; + } + ma_silence_pcm_frames(pFramesOut, timeOffsetBeg, ma_format_f32, ma_node_get_output_channels(pNode, outputBusIndex)); pFramesOut += timeOffsetBeg * ma_node_get_output_channels(pNode, outputBusIndex); frameCount -= timeOffsetBeg; @@ -75094,6 +75109,11 @@ static ma_result ma_node_read_pcm_frames(ma_node* pNode, ma_uint32 outputBusInde /* Trim based on the end offset. We don't need to silence the tail section because we'll just have a reduced value written to pFramesRead. */ if (timeOffsetEnd > 0) { + MA_ASSERT(timeOffsetEnd <= frameCount); + if (timeOffsetEnd > frameCount) { + timeOffsetEnd = frameCount; + } + frameCount -= timeOffsetEnd; } @@ -76508,12 +76528,20 @@ static void ma_sound_set_at_end(ma_sound* pSound, ma_bool32 atEnd) MA_ASSERT(pSound != NULL); ma_atomic_exchange_32(&pSound->atEnd, atEnd); + /* + When this function is called the state of the sound will not yet be in a stopped state. This makes it confusing + because an end callback will intuitively expect ma_sound_is_playing() to return false from inside the callback. + I'm therefore no longer firing the callback here and will instead fire it manually in the *next* processing step + when the state should be set to stopped as expected. + */ + #if 0 /* Fire any callbacks or events. */ if (atEnd) { if (pSound->endCallback != NULL) { pSound->endCallback(pSound->pEndCallbackUserData, pSound); } } + #endif } static ma_bool32 ma_sound_get_at_end(const ma_sound* pSound) @@ -76533,6 +76561,7 @@ MA_API ma_engine_node_config ma_engine_node_config_init(ma_engine* pEngine, ma_e config.isPitchDisabled = (flags & MA_SOUND_FLAG_NO_PITCH) != 0; config.isSpatializationDisabled = (flags & MA_SOUND_FLAG_NO_SPATIALIZATION) != 0; config.monoExpansionMode = pEngine->monoExpansionMode; + config.resampling = pEngine->pitchResamplingConfig; return config; } @@ -76559,7 +76588,7 @@ static void ma_engine_node_update_pitch_if_required(ma_engine_node* pEngineNode) if (isUpdateRequired) { float basePitch = (float)pEngineNode->sampleRate / ma_engine_get_sample_rate(pEngineNode->pEngine); - ma_linear_resampler_set_rate_ratio(&pEngineNode->resampler, basePitch * pEngineNode->oldPitch * pEngineNode->oldDopplerPitch); + ma_resampler_set_rate_ratio(&pEngineNode->resampler, basePitch * pEngineNode->oldPitch * pEngineNode->oldDopplerPitch); } } @@ -76578,22 +76607,6 @@ static ma_bool32 ma_engine_node_is_spatialization_enabled(const ma_engine_node* return !ma_atomic_load_explicit_32(&pEngineNode->isSpatializationDisabled, ma_atomic_memory_order_acquire); } -static ma_uint64 ma_engine_node_get_required_input_frame_count(const ma_engine_node* pEngineNode, ma_uint64 outputFrameCount) -{ - ma_uint64 inputFrameCount = 0; - - if (ma_engine_node_is_pitching_enabled(pEngineNode)) { - ma_result result = ma_linear_resampler_get_required_input_frame_count(&pEngineNode->resampler, outputFrameCount, &inputFrameCount); - if (result != MA_SUCCESS) { - inputFrameCount = 0; - } - } else { - inputFrameCount = outputFrameCount; /* No resampling, so 1:1. */ - } - - return inputFrameCount; -} - static ma_result ma_engine_node_set_volume(ma_engine_node* pEngineNode, float volume) { if (pEngineNode == NULL) { @@ -76735,7 +76748,7 @@ static void ma_engine_node_process_pcm_frames__general(ma_engine_node* pEngineNo ma_uint64 resampleFrameCountIn = framesAvailableIn; ma_uint64 resampleFrameCountOut = framesAvailableOut; - ma_linear_resampler_process_pcm_frames(&pEngineNode->resampler, pRunningFramesIn, &resampleFrameCountIn, pWorkingBuffer, &resampleFrameCountOut); + ma_resampler_process_pcm_frames(&pEngineNode->resampler, pRunningFramesIn, &resampleFrameCountIn, pWorkingBuffer, &resampleFrameCountOut); isWorkingBufferValid = MA_TRUE; framesJustProcessedIn = (ma_uint32)resampleFrameCountIn; @@ -76859,6 +76872,11 @@ static void ma_engine_node_process_pcm_frames__sound(ma_node* pNode, const float /* If we're marked at the end we need to stop the sound and do nothing. */ if (ma_sound_at_end(pSound)) { ma_sound_stop(pSound); + + if (pSound->endCallback != NULL) { + pSound->endCallback(pSound->pEndCallbackUserData, pSound); + } + *pFrameCountOut = 0; return; } @@ -76896,55 +76914,74 @@ static void ma_engine_node_process_pcm_frames__sound(ma_node* pNode, const float /* Keep reading until we've read as much as was requested or we reach the end of the data source. */ while (totalFramesRead < frameCount) { ma_uint32 framesRemaining = frameCount - totalFramesRead; - ma_uint32 framesToRead; ma_uint64 framesJustRead; ma_uint32 frameCountIn; ma_uint32 frameCountOut; const float* pRunningFramesIn; float* pRunningFramesOut; - /* - The first thing we need to do is read into the temporary buffer. We can calculate exactly - how many input frames we'll need after resampling. - */ - framesToRead = (ma_uint32)ma_engine_node_get_required_input_frame_count(&pSound->engineNode, framesRemaining); - if (framesToRead > tempCapInFrames) { - framesToRead = tempCapInFrames; - } + /* If there's any input frames sitting in the cache get those processed first. */ + if (pSound->processingCacheFramesRemaining > 0) { + pRunningFramesIn = pSound->pProcessingCache; + frameCountIn = pSound->processingCacheFramesRemaining; - result = ma_data_source_read_pcm_frames(pSound->pDataSource, temp, framesToRead, &framesJustRead); + pRunningFramesOut = ma_offset_pcm_frames_ptr_f32(ppFramesOut[0], totalFramesRead, ma_node_get_output_channels(pNode, 0)); + frameCountOut = framesRemaining; - /* If we reached the end of the sound we'll want to mark it as at the end and stop it. This should never be returned for looping sounds. */ - if (result == MA_AT_END) { - ma_sound_set_at_end(pSound, MA_TRUE); /* This will be set to false in ma_sound_start(). */ - } - - pRunningFramesOut = ma_offset_pcm_frames_ptr_f32(ppFramesOut[0], totalFramesRead, ma_node_get_output_channels(pNode, 0)); - - frameCountIn = (ma_uint32)framesJustRead; - frameCountOut = framesRemaining; - - /* Convert if necessary. */ - if (dataSourceFormat == ma_format_f32) { - /* Fast path. No data conversion necessary. */ - pRunningFramesIn = (float*)temp; ma_engine_node_process_pcm_frames__general(&pSound->engineNode, &pRunningFramesIn, &frameCountIn, &pRunningFramesOut, &frameCountOut); + + MA_ASSERT(frameCountIn <= pSound->processingCacheFramesRemaining); + pSound->processingCacheFramesRemaining -= frameCountIn; + + /* Move any remaining data in the cache down. */ + if (pSound->processingCacheFramesRemaining > 0) { + MA_MOVE_MEMORY(pSound->pProcessingCache, ma_offset_pcm_frames_ptr_f32(pSound->pProcessingCache, frameCountIn, dataSourceChannels), pSound->processingCacheFramesRemaining * ma_get_bytes_per_frame(ma_format_f32, dataSourceChannels)); + } + + totalFramesRead += (ma_uint32)frameCountOut; /* Safe cast. */ + + if (result != MA_SUCCESS || ma_sound_at_end(pSound)) { + break; /* Might have reached the end. */ + } } else { - /* Slow path. Need to do sample format conversion to f32. If we give the f32 buffer the same count as the first temp buffer, we're guaranteed it'll be large enough. */ - float tempf32[MA_DATA_CONVERTER_STACK_BUFFER_SIZE]; /* Do not do `MA_DATA_CONVERTER_STACK_BUFFER_SIZE/sizeof(float)` here like we've done in other places. */ - ma_convert_pcm_frames_format(tempf32, ma_format_f32, temp, dataSourceFormat, framesJustRead, dataSourceChannels, ma_dither_mode_none); + /* Getting here means there's nothing in the cache. Read more data from the data source. */ + if (dataSourceFormat == ma_format_f32) { + /* Fast path. No conversion to f32 necessary. */ + result = ma_data_source_read_pcm_frames(pSound->pDataSource, pSound->pProcessingCache, pSound->processingCacheCap, &framesJustRead); + } else { + /* Slow path. Need to convert to f32. */ + ma_uint64 totalFramesConverted = 0; - /* Now that we have our samples in f32 format we can process like normal. */ - pRunningFramesIn = tempf32; - ma_engine_node_process_pcm_frames__general(&pSound->engineNode, &pRunningFramesIn, &frameCountIn, &pRunningFramesOut, &frameCountOut); - } + while (totalFramesConverted < pSound->processingCacheCap) { + ma_uint64 framesConverted; + ma_uint32 framesToConvertThisIteration = pSound->processingCacheCap - (ma_uint32)totalFramesConverted; + if (framesToConvertThisIteration > tempCapInFrames) { + framesToConvertThisIteration = tempCapInFrames; + } - /* We should have processed all of our input frames since we calculated the required number of input frames at the top. */ - MA_ASSERT(frameCountIn == framesJustRead); - totalFramesRead += (ma_uint32)frameCountOut; /* Safe cast. */ + result = ma_data_source_read_pcm_frames(pSound->pDataSource, temp, framesToConvertThisIteration, &framesConverted); + if (result != MA_SUCCESS) { + break; + } - if (result != MA_SUCCESS || ma_sound_at_end(pSound)) { - break; /* Might have reached the end. */ + ma_convert_pcm_frames_format(ma_offset_pcm_frames_ptr_f32(pSound->pProcessingCache, totalFramesConverted, dataSourceChannels), ma_format_f32, temp, dataSourceFormat, framesConverted, dataSourceChannels, ma_dither_mode_none); + totalFramesConverted += framesConverted; + } + + framesJustRead = totalFramesConverted; + } + + MA_ASSERT(framesJustRead <= pSound->processingCacheCap); + pSound->processingCacheFramesRemaining = (ma_uint32)framesJustRead; + + /* If we reached the end of the sound we'll want to mark it as at the end and stop it. This should never be returned for looping sounds. */ + if (result == MA_AT_END) { + ma_sound_set_at_end(pSound, MA_TRUE); /* This will be set to false in ma_sound_start(). */ + } + + if (result != MA_SUCCESS || ma_sound_at_end(pSound)) { + break; + } } } } @@ -76967,25 +77004,6 @@ static void ma_engine_node_process_pcm_frames__group(ma_node* pNode, const float ma_engine_node_process_pcm_frames__general((ma_engine_node*)pNode, ppFramesIn, pFrameCountIn, ppFramesOut, pFrameCountOut); } -static ma_result ma_engine_node_get_required_input_frame_count__group(ma_node* pNode, ma_uint32 outputFrameCount, ma_uint32* pInputFrameCount) -{ - ma_uint64 inputFrameCount; - - MA_ASSERT(pInputFrameCount != NULL); - - /* Our pitch will affect this calculation. We need to update it. */ - ma_engine_node_update_pitch_if_required((ma_engine_node*)pNode); - - inputFrameCount = ma_engine_node_get_required_input_frame_count((ma_engine_node*)pNode, outputFrameCount); - if (inputFrameCount > 0xFFFFFFFF) { - inputFrameCount = 0xFFFFFFFF; /* Will never happen because miniaudio will only ever process in relatively small chunks. */ - } - - *pInputFrameCount = (ma_uint32)inputFrameCount; - - return MA_SUCCESS; -} - static ma_node_vtable g_ma_engine_node_vtable__sound = { @@ -76999,7 +77017,7 @@ static ma_node_vtable g_ma_engine_node_vtable__sound = static ma_node_vtable g_ma_engine_node_vtable__group = { ma_engine_node_process_pcm_frames__group, - ma_engine_node_get_required_input_frame_count__group, + NULL, /* onGetRequiredInputFrameCount */ 1, /* Groups have one input bus. */ 1, /* Groups have one output bus. */ MA_NODE_FLAG_DIFFERENT_PROCESSING_RATES /* The engine node does resampling so should let miniaudio know about it. */ @@ -77045,9 +77063,10 @@ static ma_result ma_engine_node_get_heap_layout(const ma_engine_node_config* pCo ma_result result; size_t tempHeapSize; ma_node_config baseNodeConfig; - ma_linear_resampler_config resamplerConfig; + ma_resampler_config resamplerConfig; ma_spatializer_config spatializerConfig; ma_gainer_config gainerConfig; + ma_uint32 sampleRate; ma_uint32 channelsIn; ma_uint32 channelsOut; ma_channel defaultStereoChannelMap[2] = {MA_CHANNEL_SIDE_LEFT, MA_CHANNEL_SIDE_RIGHT}; /* <-- Consistent with the default channel map of a stereo listener. Means channel conversion can run on a fast path. */ @@ -77066,6 +77085,7 @@ static ma_result ma_engine_node_get_heap_layout(const ma_engine_node_config* pCo pHeapLayout->sizeInBytes = 0; + sampleRate = (pConfig->sampleRate > 0) ? pConfig->sampleRate : ma_engine_get_sample_rate(pConfig->pEngine); channelsIn = (pConfig->channelsIn != 0) ? pConfig->channelsIn : ma_engine_get_channels(pConfig->pEngine); channelsOut = (pConfig->channelsOut != 0) ? pConfig->channelsOut : ma_engine_get_channels(pConfig->pEngine); @@ -77085,10 +77105,13 @@ static ma_result ma_engine_node_get_heap_layout(const ma_engine_node_config* pCo /* Resmapler. */ - resamplerConfig = ma_linear_resampler_config_init(ma_format_f32, channelsIn, 1, 1); /* Input and output sample rates don't affect the calculation of the heap size. */ - resamplerConfig.lpfOrder = 0; + resamplerConfig = pConfig->resampling; + resamplerConfig.format = ma_format_f32; + resamplerConfig.channels = channelsIn; + resamplerConfig.sampleRateIn = sampleRate; + resamplerConfig.sampleRateOut = ma_engine_get_sample_rate(pConfig->pEngine); - result = ma_linear_resampler_get_heap_size(&resamplerConfig, &tempHeapSize); + result = ma_resampler_get_heap_size(&resamplerConfig, &tempHeapSize); if (result != MA_SUCCESS) { return result; /* Failed to retrieve the size of the heap for the resampler. */ } @@ -77156,7 +77179,7 @@ MA_API ma_result ma_engine_node_init_preallocated(const ma_engine_node_config* p ma_result result; ma_engine_node_heap_layout heapLayout; ma_node_config baseNodeConfig; - ma_linear_resampler_config resamplerConfig; + ma_resampler_config resamplerConfig; ma_fader_config faderConfig; ma_spatializer_config spatializerConfig; ma_panner_config pannerConfig; @@ -77231,10 +77254,13 @@ MA_API ma_result ma_engine_node_init_preallocated(const ma_engine_node_config* p */ /* We'll always do resampling first. */ - resamplerConfig = ma_linear_resampler_config_init(ma_format_f32, baseNodeConfig.pInputChannels[0], pEngineNode->sampleRate, ma_engine_get_sample_rate(pEngineNode->pEngine)); - resamplerConfig.lpfOrder = 0; /* <-- Need to disable low-pass filtering for pitch shifting for now because there's cases where the biquads are becoming unstable. Need to figure out a better fix for this. */ + resamplerConfig = pConfig->resampling; + resamplerConfig.format = ma_format_f32; + resamplerConfig.channels = baseNodeConfig.pInputChannels[0]; + resamplerConfig.sampleRateIn = pEngineNode->sampleRate; + resamplerConfig.sampleRateOut = ma_engine_get_sample_rate(pEngineNode->pEngine); - result = ma_linear_resampler_init_preallocated(&resamplerConfig, ma_offset_ptr(pHeap, heapLayout.resamplerOffset), &pEngineNode->resampler); + result = ma_resampler_init_preallocated(&resamplerConfig, ma_offset_ptr(pHeap, heapLayout.resamplerOffset), &pEngineNode->resampler); if (result != MA_SUCCESS) { goto error1; } @@ -77293,7 +77319,7 @@ MA_API ma_result ma_engine_node_init_preallocated(const ma_engine_node_config* p /* No need for allocation callbacks here because we use a preallocated heap. */ error3: ma_spatializer_uninit(&pEngineNode->spatializer, NULL); -error2: ma_linear_resampler_uninit(&pEngineNode->resampler, NULL); +error2: ma_resampler_uninit(&pEngineNode->resampler, NULL); error1: ma_node_uninit(&pEngineNode->baseNode, NULL); error0: return result; } @@ -77342,7 +77368,7 @@ MA_API void ma_engine_node_uninit(ma_engine_node* pEngineNode, const ma_allocati } ma_spatializer_uninit(&pEngineNode->spatializer, pAllocationCallbacks); - ma_linear_resampler_uninit(&pEngineNode->resampler, pAllocationCallbacks); + ma_resampler_uninit(&pEngineNode->resampler, pAllocationCallbacks); /* Free the heap last. */ if (pEngineNode->_ownsHeap) { @@ -77364,8 +77390,12 @@ MA_API ma_sound_config ma_sound_config_init_2(ma_engine* pEngine) if (pEngine != NULL) { config.monoExpansionMode = pEngine->monoExpansionMode; + config.pitchResampling = pEngine->pitchResamplingConfig; } else { config.monoExpansionMode = ma_mono_expansion_mode_default; + + config.pitchResampling = ma_resampler_config_init(ma_format_f32, 0, 0, 0, ma_resample_algorithm_linear); + config.pitchResampling.linear.lpfOrder = 0; /* <-- Need to disable low-pass filtering for pitch shifting for now because there's cases where the biquads are becoming unstable. Need to figure out a better fix for this. */ } config.rangeEndInPCMFrames = ~((ma_uint64)0); @@ -77387,8 +77417,12 @@ MA_API ma_sound_group_config ma_sound_group_config_init_2(ma_engine* pEngine) if (pEngine != NULL) { config.monoExpansionMode = pEngine->monoExpansionMode; + config.pitchResampling = pEngine->pitchResamplingConfig; } else { config.monoExpansionMode = ma_mono_expansion_mode_default; + + config.pitchResampling = ma_resampler_config_init(ma_format_f32, 0, 0, 0, ma_resample_algorithm_linear); + config.pitchResampling.linear.lpfOrder = 0; /* <-- Need to disable low-pass filtering for pitch shifting for now because there's cases where the biquads are becoming unstable. Need to figure out a better fix for this. */ } return config; @@ -77400,8 +77434,12 @@ MA_API ma_engine_config ma_engine_config_init(void) ma_engine_config config; MA_ZERO_OBJECT(&config); - config.listenerCount = 1; /* Always want at least one listener. */ - config.monoExpansionMode = ma_mono_expansion_mode_default; + config.listenerCount = 1; /* Always want at least one listener. */ + config.monoExpansionMode = ma_mono_expansion_mode_default; + config.resourceManagerResampling = ma_resampler_config_init(ma_format_unknown, 0, 0, 0, ma_resample_algorithm_linear); + + config.pitchResampling = ma_resampler_config_init(ma_format_f32, 0, 0, 0, ma_resample_algorithm_linear); + config.pitchResampling.linear.lpfOrder = 0; /* <-- Need to disable low-pass filtering for pitch shifting for now because there's cases where the biquads are becoming unstable. Need to figure out a better fix for this. */ return config; } @@ -77482,6 +77520,7 @@ MA_API ma_result ma_engine_init(const ma_engine_config* pConfig, ma_engine* pEng pEngine->defaultVolumeSmoothTimeInPCMFrames = engineConfig.defaultVolumeSmoothTimeInPCMFrames; pEngine->onProcess = engineConfig.onProcess; pEngine->pProcessUserData = engineConfig.pProcessUserData; + pEngine->pitchResamplingConfig = engineConfig.pitchResampling; ma_allocation_callbacks_init_copy(&pEngine->allocationCallbacks, &engineConfig.allocationCallbacks); #if !defined(MA_NO_RESOURCE_MANAGER) @@ -77664,6 +77703,7 @@ MA_API ma_result ma_engine_init(const ma_engine_config* pConfig, ma_engine* pEng resourceManagerConfig.decodedSampleRate = ma_engine_get_sample_rate(pEngine); ma_allocation_callbacks_init_copy(&resourceManagerConfig.allocationCallbacks, &pEngine->allocationCallbacks); resourceManagerConfig.pVFS = engineConfig.pResourceManagerVFS; + resourceManagerConfig.resampling = engineConfig.resourceManagerResampling; /* The Emscripten build cannot use threads unless it's targeting pthreads. */ #if defined(MA_EMSCRIPTEN) && !defined(__EMSCRIPTEN_PTHREADS__) @@ -78389,6 +78429,25 @@ static ma_result ma_sound_init_from_data_source_internal(ma_engine* pEngine, con } + /* + When pulling data from a data source we need a processing cache to hold onto unprocessed input data from the data source + after doing resampling. + */ + if (pSound->pDataSource != NULL) { + pSound->processingCacheFramesRemaining = 0; + pSound->processingCacheCap = ma_node_graph_get_processing_size_in_frames(&pEngine->nodeGraph); + if (pSound->processingCacheCap == 0) { + pSound->processingCacheCap = 512; + } + + pSound->pProcessingCache = (float*)ma_calloc(pSound->processingCacheCap * ma_get_bytes_per_frame(ma_format_f32, engineNodeConfig.channelsIn), &pEngine->allocationCallbacks); + if (pSound->pProcessingCache == NULL) { + ma_engine_node_uninit(&pSound->engineNode, &pEngine->allocationCallbacks); + return MA_OUT_OF_MEMORY; + } + } + + /* Apply initial range and looping state to the data source if applicable. */ if (pConfig->rangeBegInPCMFrames != 0 || pConfig->rangeEndInPCMFrames != ~((ma_uint64)0)) { ma_data_source_set_range_in_pcm_frames(ma_sound_get_data_source(pSound), pConfig->rangeBegInPCMFrames, pConfig->rangeEndInPCMFrames); @@ -78626,6 +78685,11 @@ MA_API void ma_sound_uninit(ma_sound* pSound) */ ma_engine_node_uninit(&pSound->engineNode, &pSound->engineNode.pEngine->allocationCallbacks); + if (pSound->pProcessingCache != NULL) { + ma_free(pSound->pProcessingCache, &pSound->engineNode.pEngine->allocationCallbacks); + pSound->pProcessingCache = NULL; + } + /* Once the sound is detached from the group we can guarantee that it won't be referenced by the mixer thread which means it's safe for us to destroy the data source. */ #ifndef MA_NO_RESOURCE_MANAGER if (pSound->ownsDataSource) { @@ -78721,6 +78785,27 @@ MA_API ma_result ma_sound_stop_with_fade_in_milliseconds(ma_sound* pSound, ma_ui return ma_sound_stop_with_fade_in_pcm_frames(pSound, (fadeLengthInMilliseconds * sampleRate) / 1000); } +MA_API void ma_sound_reset_start_time(ma_sound* pSound) +{ + ma_sound_set_start_time_in_pcm_frames(pSound, 0); +} + +MA_API void ma_sound_reset_stop_time(ma_sound* pSound) +{ + ma_sound_set_stop_time_in_pcm_frames(pSound, ~(ma_uint64)0); +} + +MA_API void ma_sound_reset_fade(ma_sound* pSound) +{ + ma_sound_set_fade_in_pcm_frames(pSound, 0, 1, 0); +} + +MA_API void ma_sound_reset_stop_time_and_fade(ma_sound* pSound) +{ + ma_sound_reset_stop_time(pSound); + ma_sound_reset_fade(pSound); +} + MA_API void ma_sound_set_volume(ma_sound* pSound, float volume) { if (pSound == NULL) { @@ -79372,7 +79457,7 @@ MA_API ma_result ma_sound_get_data_format(const ma_sound* pSound, ma_format* pFo } if (pSampleRate != NULL) { - *pSampleRate = pSound->engineNode.resampler.config.sampleRateIn; + *pSampleRate = pSound->engineNode.resampler.sampleRateIn; } if (pChannelMap != NULL) { @@ -82436,7 +82521,6 @@ MA_PRIVATE ma_bool32 ma_dr_wav__on_seek_memory(void* pUserData, int offset, ma_d ma_dr_wav* pWav = (ma_dr_wav*)pUserData; ma_int64 newCursor; MA_DR_WAV_ASSERT(pWav != NULL); - newCursor = pWav->memoryStream.currentReadPos; if (origin == MA_DR_WAV_SEEK_SET) { newCursor = 0; } else if (origin == MA_DR_WAV_SEEK_CUR) { @@ -82490,7 +82574,6 @@ MA_PRIVATE ma_bool32 ma_dr_wav__on_seek_memory_write(void* pUserData, int offset ma_dr_wav* pWav = (ma_dr_wav*)pUserData; ma_int64 newCursor; MA_DR_WAV_ASSERT(pWav != NULL); - newCursor = pWav->memoryStreamWrite.currentWritePos; if (origin == MA_DR_WAV_SEEK_SET) { newCursor = 0; } else if (origin == MA_DR_WAV_SEEK_CUR) { @@ -82499,7 +82582,7 @@ MA_PRIVATE ma_bool32 ma_dr_wav__on_seek_memory_write(void* pUserData, int offset newCursor = (ma_int64)pWav->memoryStreamWrite.dataSize; } else { MA_DR_WAV_ASSERT(!"Invalid seek origin"); - return MA_INVALID_ARGS; + return MA_FALSE; } newCursor += offset; if (newCursor < 0) { @@ -83000,7 +83083,7 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_ pWav->msadpcm.cachedFrames[2] = pWav->msadpcm.prevFrames[0][0]; pWav->msadpcm.cachedFrames[3] = pWav->msadpcm.prevFrames[0][1]; pWav->msadpcm.cachedFrameCount = 2; - if (pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff1Table)) { + if (pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff2Table)) { return totalFramesRead; } } else { @@ -83022,7 +83105,8 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_ pWav->msadpcm.cachedFrames[2] = pWav->msadpcm.prevFrames[0][1]; pWav->msadpcm.cachedFrames[3] = pWav->msadpcm.prevFrames[1][1]; pWav->msadpcm.cachedFrameCount = 2; - if (pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[1] >= ma_dr_wav_countof(coeff2Table)) { + if (pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff2Table) || + pWav->msadpcm.predictor[1] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[1] >= ma_dr_wav_countof(coeff2Table)) { return totalFramesRead; } } @@ -83059,6 +83143,9 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_ if (pWav->channels == 1) { ma_int32 newSample0; ma_int32 newSample1; + if (pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff2Table)) { + return totalFramesRead; + } newSample0 = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8; newSample0 += nibble0 * pWav->msadpcm.delta[0]; newSample0 = ma_dr_wav_clamp(newSample0, -32768, 32767); @@ -83083,6 +83170,9 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_ } else { ma_int32 newSample0; ma_int32 newSample1; + if (pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[0] >= ma_dr_wav_countof(coeff2Table)) { + return totalFramesRead; + } newSample0 = ((pWav->msadpcm.prevFrames[0][1] * coeff1Table[pWav->msadpcm.predictor[0]]) + (pWav->msadpcm.prevFrames[0][0] * coeff2Table[pWav->msadpcm.predictor[0]])) >> 8; newSample0 += nibble0 * pWav->msadpcm.delta[0]; newSample0 = ma_dr_wav_clamp(newSample0, -32768, 32767); @@ -83092,6 +83182,9 @@ MA_PRIVATE ma_uint64 ma_dr_wav_read_pcm_frames_s16__msadpcm(ma_dr_wav* pWav, ma_ } pWav->msadpcm.prevFrames[0][0] = pWav->msadpcm.prevFrames[0][1]; pWav->msadpcm.prevFrames[0][1] = newSample0; + if (pWav->msadpcm.predictor[1] >= ma_dr_wav_countof(coeff1Table) || pWav->msadpcm.predictor[1] >= ma_dr_wav_countof(coeff2Table)) { + return totalFramesRead; + } newSample1 = ((pWav->msadpcm.prevFrames[1][1] * coeff1Table[pWav->msadpcm.predictor[1]]) + (pWav->msadpcm.prevFrames[1][0] * coeff2Table[pWav->msadpcm.predictor[1]])) >> 8; newSample1 += nibble1 * pWav->msadpcm.delta[1]; newSample1 = ma_dr_wav_clamp(newSample1, -32768, 32767); @@ -84336,6 +84429,10 @@ MA_PRIVATE ma_int16* ma_dr_wav__read_pcm_frames_and_close_s16(ma_dr_wav* pWav, u ma_int16* pSampleData; ma_uint64 framesRead; MA_DR_WAV_ASSERT(pWav != NULL); + if (pWav->channels == 0 || pWav->totalPCMFrameCount > MA_SIZE_MAX / pWav->channels / sizeof(ma_int16)) { + ma_dr_wav_uninit(pWav); + return NULL; + } sampleDataSize = pWav->totalPCMFrameCount * pWav->channels * sizeof(ma_int16); if (sampleDataSize > MA_SIZE_MAX) { ma_dr_wav_uninit(pWav); @@ -84370,6 +84467,10 @@ MA_PRIVATE float* ma_dr_wav__read_pcm_frames_and_close_f32(ma_dr_wav* pWav, unsi float* pSampleData; ma_uint64 framesRead; MA_DR_WAV_ASSERT(pWav != NULL); + if (pWav->channels == 0 || pWav->totalPCMFrameCount > MA_SIZE_MAX / pWav->channels / sizeof(float)) { + ma_dr_wav_uninit(pWav); + return NULL; + } sampleDataSize = pWav->totalPCMFrameCount * pWav->channels * sizeof(float); if (sampleDataSize > MA_SIZE_MAX) { ma_dr_wav_uninit(pWav); @@ -84404,6 +84505,10 @@ MA_PRIVATE ma_int32* ma_dr_wav__read_pcm_frames_and_close_s32(ma_dr_wav* pWav, u ma_int32* pSampleData; ma_uint64 framesRead; MA_DR_WAV_ASSERT(pWav != NULL); + if (pWav->channels == 0 || pWav->totalPCMFrameCount > MA_SIZE_MAX / pWav->channels / sizeof(ma_int32)) { + ma_dr_wav_uninit(pWav); + return NULL; + } sampleDataSize = pWav->totalPCMFrameCount * pWav->channels * sizeof(ma_int32); if (sampleDataSize > MA_SIZE_MAX) { ma_dr_wav_uninit(pWav); @@ -85786,7 +85891,7 @@ static MA_INLINE ma_uint32 ma_dr_flac__clz_lzcnt(ma_dr_flac_cache_t x) { ma_uint64 r; __asm__ __volatile__ ( - "lzcnt{ %1, %0| %0, %1}" : "=r"(r) : "r"(x) : "cc" + "rep; bsr{q %1, %0| %0, %1}" : "=r"(r) : "r"(x) : "cc" ); return (ma_uint32)r; } @@ -85794,11 +85899,11 @@ static MA_INLINE ma_uint32 ma_dr_flac__clz_lzcnt(ma_dr_flac_cache_t x) { ma_uint32 r; __asm__ __volatile__ ( - "lzcnt{l %1, %0| %0, %1}" : "=r"(r) : "r"(x) : "cc" + "rep; bsr{l %1, %0| %0, %1}" : "=r"(r) : "r"(x) : "cc" ); return r; } - #elif defined(MA_ARM) && (defined(__ARM_ARCH) && __ARM_ARCH >= 5) && !defined(__ARM_ARCH_6M__) && !defined(MA_64BIT) + #elif defined(MA_ARM) && (defined(__ARM_ARCH) && __ARM_ARCH >= 5) && !defined(__ARM_ARCH_6M__) && !(defined(__thumb__) && !defined(__thumb2__)) && !defined(MA_64BIT) { unsigned int r; __asm__ __volatile__ ( @@ -88852,6 +88957,7 @@ static ma_bool32 ma_dr_flac__read_and_decode_metadata(ma_dr_flac_read_proc onRea } } blockSizeRemaining -= metadata.data.picture.pictureDataSize; + (void)blockSizeRemaining; metadata.data.picture.pPictureData = (const ma_uint8*)pPictureData; if (metadata.data.picture.pictureDataOffset != 0 || metadata.data.picture.pPictureData != NULL) { onMeta(pUserDataMD, &metadata); @@ -92276,56 +92382,41 @@ static type* ma_dr_flac__full_read_and_close_ ## extension (ma_dr_flac* pFlac, u { \ type* pSampleData = NULL; \ ma_uint64 totalPCMFrameCount; \ + type buffer[4096]; \ + ma_uint64 pcmFramesRead; \ + size_t sampleDataBufferSize = sizeof(buffer); \ \ MA_DR_FLAC_ASSERT(pFlac != NULL); \ \ - totalPCMFrameCount = pFlac->totalPCMFrameCount; \ + totalPCMFrameCount = 0; \ \ - if (totalPCMFrameCount == 0) { \ - type buffer[4096]; \ - ma_uint64 pcmFramesRead; \ - size_t sampleDataBufferSize = sizeof(buffer); \ + pSampleData = (type*)ma_dr_flac__malloc_from_callbacks(sampleDataBufferSize, &pFlac->allocationCallbacks); \ + if (pSampleData == NULL) { \ + goto on_error; \ + } \ \ - pSampleData = (type*)ma_dr_flac__malloc_from_callbacks(sampleDataBufferSize, &pFlac->allocationCallbacks); \ - if (pSampleData == NULL) { \ - goto on_error; \ - } \ + while ((pcmFramesRead = (ma_uint64)ma_dr_flac_read_pcm_frames_##extension(pFlac, sizeof(buffer)/sizeof(buffer[0])/pFlac->channels, buffer)) > 0) { \ + if (((totalPCMFrameCount + pcmFramesRead) * pFlac->channels * sizeof(type)) > sampleDataBufferSize) { \ + type* pNewSampleData; \ + size_t newSampleDataBufferSize; \ \ - while ((pcmFramesRead = (ma_uint64)ma_dr_flac_read_pcm_frames_##extension(pFlac, sizeof(buffer)/sizeof(buffer[0])/pFlac->channels, buffer)) > 0) { \ - if (((totalPCMFrameCount + pcmFramesRead) * pFlac->channels * sizeof(type)) > sampleDataBufferSize) { \ - type* pNewSampleData; \ - size_t newSampleDataBufferSize; \ - \ - newSampleDataBufferSize = sampleDataBufferSize * 2; \ - pNewSampleData = (type*)ma_dr_flac__realloc_from_callbacks(pSampleData, newSampleDataBufferSize, sampleDataBufferSize, &pFlac->allocationCallbacks); \ - if (pNewSampleData == NULL) { \ - ma_dr_flac__free_from_callbacks(pSampleData, &pFlac->allocationCallbacks); \ - goto on_error; \ - } \ - \ - sampleDataBufferSize = newSampleDataBufferSize; \ - pSampleData = pNewSampleData; \ + newSampleDataBufferSize = sampleDataBufferSize * 2; \ + pNewSampleData = (type*)ma_dr_flac__realloc_from_callbacks(pSampleData, newSampleDataBufferSize, sampleDataBufferSize, &pFlac->allocationCallbacks); \ + if (pNewSampleData == NULL) { \ + ma_dr_flac__free_from_callbacks(pSampleData, &pFlac->allocationCallbacks); \ + goto on_error; \ } \ \ - MA_DR_FLAC_COPY_MEMORY(pSampleData + (totalPCMFrameCount*pFlac->channels), buffer, (size_t)(pcmFramesRead*pFlac->channels*sizeof(type))); \ - totalPCMFrameCount += pcmFramesRead; \ + sampleDataBufferSize = newSampleDataBufferSize; \ + pSampleData = pNewSampleData; \ } \ \ + MA_DR_FLAC_COPY_MEMORY(pSampleData + (totalPCMFrameCount*pFlac->channels), buffer, (size_t)(pcmFramesRead*pFlac->channels*sizeof(type))); \ + totalPCMFrameCount += pcmFramesRead; \ + } \ + \ \ - MA_DR_FLAC_ZERO_MEMORY(pSampleData + (totalPCMFrameCount*pFlac->channels), (size_t)(sampleDataBufferSize - totalPCMFrameCount*pFlac->channels*sizeof(type))); \ - } else { \ - ma_uint64 dataSize = totalPCMFrameCount*pFlac->channels*sizeof(type); \ - if (dataSize > (ma_uint64)MA_SIZE_MAX) { \ - goto on_error; \ - } \ - \ - pSampleData = (type*)ma_dr_flac__malloc_from_callbacks((size_t)dataSize, &pFlac->allocationCallbacks); \ - if (pSampleData == NULL) { \ - goto on_error; \ - } \ - \ - totalPCMFrameCount = ma_dr_flac_read_pcm_frames_##extension(pFlac, pFlac->totalPCMFrameCount, pSampleData); \ - } \ + MA_DR_FLAC_ZERO_MEMORY(pSampleData + (totalPCMFrameCount*pFlac->channels), (size_t)(sampleDataBufferSize - totalPCMFrameCount*pFlac->channels*sizeof(type))); \ \ if (sampleRateOut) *sampleRateOut = pFlac->sampleRate; \ if (channelsOut) *channelsOut = pFlac->channels; \ @@ -94685,19 +94776,22 @@ static ma_bool32 ma_dr_mp3_init_internal(ma_dr_mp3* pMP3, ma_dr_mp3_read_proc on ((ma_uint32)ape[25] << 8) | ((ma_uint32)ape[26] << 16) | ((ma_uint32)ape[27] << 24); - streamEndOffset -= 32 + tagSize; - streamLen -= 32 + tagSize; - if (onMeta != NULL) { - if (onSeek(pUserData, streamEndOffset, MA_DR_MP3_SEEK_END)) { - size_t apeTagSize = (size_t)tagSize + 32; - ma_uint8* pTagData = (ma_uint8*)ma_dr_mp3_malloc(apeTagSize, pAllocationCallbacks); - if (pTagData != NULL) { - if (onRead(pUserData, pTagData, apeTagSize) == apeTagSize) { - ma_dr_mp3__on_meta(pMP3, MA_DR_MP3_METADATA_TYPE_APE, pTagData, apeTagSize); + if (32 + tagSize < streamLen) { + streamEndOffset -= 32 + tagSize; + streamLen -= 32 + tagSize; + if (onMeta != NULL) { + if (onSeek(pUserData, streamEndOffset, MA_DR_MP3_SEEK_END)) { + size_t apeTagSize = (size_t)tagSize + 32; + ma_uint8* pTagData = (ma_uint8*)ma_dr_mp3_malloc(apeTagSize, pAllocationCallbacks); + if (pTagData != NULL) { + if (onRead(pUserData, pTagData, apeTagSize) == apeTagSize) { + ma_dr_mp3__on_meta(pMP3, MA_DR_MP3_METADATA_TYPE_APE, pTagData, apeTagSize); + } + ma_dr_mp3_free(pTagData, pAllocationCallbacks); } - ma_dr_mp3_free(pTagData, pAllocationCallbacks); } } + } else { } } } @@ -94785,7 +94879,6 @@ static ma_bool32 ma_dr_mp3_init_internal(ma_dr_mp3* pMP3, ma_dr_mp3_read_proc on { ma_dr_mp3_bs bs; ma_dr_mp3_L3_gr_info grInfo[4]; - const ma_uint8* pTagData = pFirstFrameData; ma_dr_mp3_bs_init(&bs, pFirstFrameData + MA_DR_MP3_HDR_SIZE, firstFrameInfo.frame_bytes - MA_DR_MP3_HDR_SIZE); if (MA_DR_MP3_HDR_IS_CRC(pFirstFrameData)) { ma_dr_mp3_bs_get_bits(&bs, 16); @@ -94793,6 +94886,7 @@ static ma_bool32 ma_dr_mp3_init_internal(ma_dr_mp3* pMP3, ma_dr_mp3_read_proc on if (ma_dr_mp3_L3_read_side_info(&bs, grInfo, pFirstFrameData) >= 0) { ma_bool32 isXing = MA_FALSE; ma_bool32 isInfo = MA_FALSE; + const ma_uint8* pTagData; const ma_uint8* pTagDataBeg; pTagDataBeg = pFirstFrameData + MA_DR_MP3_HDR_SIZE + (bs.pos/8); pTagData = pTagDataBeg; @@ -94892,7 +94986,6 @@ static ma_bool32 ma_dr_mp3__on_seek_memory(void* pUserData, int byteOffset, ma_d ma_dr_mp3* pMP3 = (ma_dr_mp3*)pUserData; ma_int64 newCursor; MA_DR_MP3_ASSERT(pMP3 != NULL); - newCursor = pMP3->memory.currentReadPos; if (origin == MA_DR_MP3_SEEK_SET) { newCursor = 0; } else if (origin == MA_DR_MP3_SEEK_CUR) { @@ -95543,6 +95636,8 @@ static float* ma_dr_mp3__full_read_and_close_f32(ma_dr_mp3* pMP3, ma_dr_mp3_conf pNewFrames = (float*)ma_dr_mp3__realloc_from_callbacks(pFrames, (size_t)newFramesBufferSize, (size_t)oldFramesBufferSize, &pMP3->allocationCallbacks); if (pNewFrames == NULL) { ma_dr_mp3__free_from_callbacks(pFrames, &pMP3->allocationCallbacks); + pFrames = NULL; + totalFramesRead = 0; break; } pFrames = pNewFrames; @@ -95594,6 +95689,8 @@ static ma_int16* ma_dr_mp3__full_read_and_close_s16(ma_dr_mp3* pMP3, ma_dr_mp3_c pNewFrames = (ma_int16*)ma_dr_mp3__realloc_from_callbacks(pFrames, (size_t)newFramesBufferSize, (size_t)oldFramesBufferSize, &pMP3->allocationCallbacks); if (pNewFrames == NULL) { ma_dr_mp3__free_from_callbacks(pFrames, &pMP3->allocationCallbacks); + pFrames = NULL; + totalFramesRead = 0; break; } pFrames = pNewFrames;