diff --git a/.github/workflows/build-amd.yml b/.github/workflows/build-amd.yml new file mode 100644 index 0000000000..b6fe8de865 --- /dev/null +++ b/.github/workflows/build-amd.yml @@ -0,0 +1,52 @@ +name: CI (AMD) + +on: + workflow_dispatch: # allows manual triggering + push: + branches: + - master + paths: [ + '.github/workflows/build-amd.yml', + '**/CMakeLists.txt', + '**/.cmake', + '**/*.h', + '**/*.hpp', + '**/*.c', + '**/*.cpp', + '**/*.cu', + '**/*.cuh', + '**/*.comp' + ] + +concurrency: + group: ${{ github.workflow }}-${{ github.head_ref && github.ref || github.run_id }} + cancel-in-progress: true + +jobs: + ggml-ci-x64-amd-vulkan: + runs-on: [self-hosted, Linux, X64, AMD] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Test + id: ggml-ci + run: | + vulkaninfo --summary + GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp + + ggml-ci-x64-amd-rocm: + runs-on: [self-hosted, Linux, X64, AMD] + + steps: + - name: Clone + id: checkout + uses: actions/checkout@v4 + + - name: Test + id: ggml-ci + run: | + amd-smi static + GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp diff --git a/.github/workflows/build-linux-cross.yml b/.github/workflows/build-linux-cross.yml index e0e809ffd1..2b101876c5 100644 --- a/.github/workflows/build-linux-cross.yml +++ b/.github/workflows/build-linux-cross.yml @@ -253,3 +253,47 @@ jobs: -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH cmake --build build --config Release -j $(nproc) + + ubuntu-24-riscv64-cpu-spacemit-ime-cross: + runs-on: ubuntu-24.04 + + env: + SPACEMIT_IME_TOOLCHAIN_VERSION: "1.1.2" + SPACEMIT_IME_TOOLCHAIN_PATH: "spacemit-toolchain-linux-glibc-x86_64" + + steps: + - uses: actions/checkout@v4 + + - name: Cache Toolchain + uses: actions/cache@v4 + id: cache-spacemit-ime-cross-toolchain + with: + path: ./${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }} + key: ${{ runner.os }}-spacemit-ime-toolchain-v${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }} + + - name: Setup Toolchain + if: steps.cache-spacemit-ime-cross-toolchain.outputs.cache-hit != 'true' + run: | + wget --quiet --no-check-certificate https://archive.spacemit.com/toolchain/spacemit-toolchain-linux-glibc-x86_64-v${{ env.SPACEMIT_IME_TOOLCHAIN_VERSION }}.tar.xz -O ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}.tar.xz + rm -rf ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }} + mkdir -p ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }} + tar xf ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}.tar.xz -C ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }} --strip-components=1 + rm -rf ${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }}.tar.xz + + - name: Build + run: | + export RISCV_ROOT_PATH=${PWD}/${{ env.SPACEMIT_IME_TOOLCHAIN_PATH }} + cmake -B build -DLLAMA_CURL=OFF \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_OPENMP=OFF \ + -DLLAMA_BUILD_EXAMPLES=ON \ + -DLLAMA_BUILD_TOOLS=ON \ + -DLLAMA_BUILD_TESTS=OFF \ + -DGGML_CPU_RISCV64_SPACEMIT=ON \ + -DGGML_RVV=ON \ + -DGGML_RV_ZFH=ON \ + -DGGML_RV_ZICBOP=ON \ + -DRISCV64_SPACEMIT_IME_SPEC=RISCV64_SPACEMIT_IME1 \ + -DCMAKE_TOOLCHAIN_FILE=${PWD}/cmake/riscv64-spacemit-linux-gnu-gcc.cmake + + cmake --build build --config Release -j $(nproc) diff --git a/.github/workflows/build-riscv-native.yml b/.github/workflows/build-riscv-native.yml index acad316602..a3a0b0d663 100644 --- a/.github/workflows/build-riscv-native.yml +++ b/.github/workflows/build-riscv-native.yml @@ -58,3 +58,63 @@ jobs: -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH cmake --build build --config Release -j $(nproc) + + # debian-13-riscv64-spacemit-ime-native: # Bianbu 2.2 + # runs-on: [self-hosted, RISCV64] + + # steps: + # - name: Install prerequisites + # run: | + # sudo apt-get update || true + # sudo apt-get install -y libatomic1 + # - uses: actions/checkout@v4 + # - name: Setup Riscv + # run: | + # sudo apt-get update || true + # sudo apt-get install -y --no-install-recommends \ + # build-essential \ + # gcc-14-riscv64-linux-gnu \ + # g++-14-riscv64-linux-gnu \ + # ccache \ + # cmake + # sudo apt-get upgrade binutils -y + + # - name: Setup ccache + # run: | + # mkdir -p $HOME/.ccache + # ccache -M 5G -d $HOME/.ccache + # export CCACHE_LOGFILE=/home/runneruser/ccache_debug/ccache.log + # export CCACHE_DEBUGDIR="/home/runneruser/ccache_debug" + # echo "$GITHUB_WORKSPACE" + # echo "CCACHE_LOGFILE=$CCACHE_LOGFILE" >> $GITHUB_ENV + # echo "CCACHE_DEBUGDIR=$CCACHE_DEBUGDIR" >> $GITHUB_ENV + # echo "CCACHE_BASEDIR=$GITHUB_WORKSPACE" >> $GITHUB_ENV + # echo "CCACHE_DIR=$HOME/.ccache" >> $GITHUB_ENV + + # - name: Build + # run: | + # cmake -B build \ + # -DLLAMA_CURL=OFF \ + # -DCMAKE_BUILD_TYPE=Release \ + # -DGGML_OPENMP=OFF \ + # -DLLAMA_BUILD_EXAMPLES=ON \ + # -DLLAMA_BUILD_TOOLS=ON \ + # -DLLAMA_BUILD_TESTS=OFF \ + # -DCMAKE_SYSTEM_NAME=Linux \ + # -DCMAKE_SYSTEM_PROCESSOR=riscv64 \ + # -DCMAKE_C_COMPILER=riscv64-linux-gnu-gcc-14 \ + # -DCMAKE_CXX_COMPILER=riscv64-linux-gnu-g++-14 \ + # -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + # -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + # -DCMAKE_POSITION_INDEPENDENT_CODE=ON \ + # -DCMAKE_FIND_ROOT_PATH=/usr/lib/riscv64-linux-gnu \ + # -DCMAKE_FIND_ROOT_PATH_MODE_PROGRAM=NEVER \ + # -DCMAKE_FIND_ROOT_PATH_MODE_LIBRARY=ONLY \ + # -DCMAKE_FIND_ROOT_PATH_MODE_INCLUDE=BOTH \ + # -DGGML_RVV=ON \ + # -DGGML_RV_ZFH=ON \ + # -DGGML_RV_ZICBOP=ON \ + # -DGGML_CPU_RISCV64_SPACEMIT=ON \ + # -DRISCV64_SPACEMIT_IME_SPEC=RISCV64_SPACEMIT_IME1 + + # cmake --build build --config Release -j $(nproc) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 424b4ba786..410552813a 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -207,7 +207,7 @@ jobs: - name: ccache uses: ggml-org/ccache-action@v1.2.16 with: - key: ubuntu-cpu-cmake + key: ubuntu-cpu-cmake-${{ matrix.build }} evict-old-files: 1d - name: Build Dependencies @@ -1222,11 +1222,12 @@ jobs: - name: Clone uses: actions/checkout@v4 - - name: ccache - uses: ggml-org/ccache-action@v1.2.16 - with: - key: android-build - evict-old-files: 1d + # Disabled due to size (400MB) and always 0 cache hits + # - name: ccache + # uses: ggml-org/ccache-action@v1.2.16 + # with: + # key: android-build + # evict-old-files: 1d - name: Set up JDK uses: actions/setup-java@v3 @@ -1461,34 +1462,6 @@ jobs: run: | bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp -# ggml-ci-x64-amd-vulkan: -# runs-on: [self-hosted, Linux, X64, AMD] -# -# steps: -# - name: Clone -# id: checkout -# uses: actions/checkout@v4 -# -# - name: Test -# id: ggml-ci -# run: | -# vulkaninfo --summary -# GG_BUILD_VULKAN=1 bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp -# -# ggml-ci-x64-amd-rocm: -# runs-on: [self-hosted, Linux, X64, AMD] -# -# steps: -# - name: Clone -# id: checkout -# uses: actions/checkout@v4 -# -# - name: Test -# id: ggml-ci -# run: | -# amd-smi static -# GG_BUILD_ROCM=1 GG_BUILD_AMDGPU_TARGETS="gfx1101" bash ./ci/run.sh ~/results/llama.cpp /mnt/llama.cpp - ggml-ci-mac-metal: runs-on: [self-hosted, macOS, ARM64] diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 446c666b90..f73a2bc9f4 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -89,12 +89,15 @@ jobs: TYPE="-${{ matrix.config.tag }}" fi PREFIX="ghcr.io/${REPO_OWNER}/${REPO_NAME}:" + CACHETAGS="${PREFIX}buildcache${TYPE}" FULLTAGS="${PREFIX}full${TYPE},${PREFIX}full${TYPE}-${{ steps.srctag.outputs.name }}" LIGHTTAGS="${PREFIX}light${TYPE},${PREFIX}light${TYPE}-${{ steps.srctag.outputs.name }}" SERVERTAGS="${PREFIX}server${TYPE},${PREFIX}server${TYPE}-${{ steps.srctag.outputs.name }}" + echo "cache_output_tags=$CACHETAGS" >> $GITHUB_OUTPUT echo "full_output_tags=$FULLTAGS" >> $GITHUB_OUTPUT echo "light_output_tags=$LIGHTTAGS" >> $GITHUB_OUTPUT echo "server_output_tags=$SERVERTAGS" >> $GITHUB_OUTPUT + echo "cache_output_tags=$CACHETAGS" # print out for debugging echo "full_output_tags=$FULLTAGS" # print out for debugging echo "light_output_tags=$LIGHTTAGS" # print out for debugging echo "server_output_tags=$SERVERTAGS" # print out for debugging @@ -131,11 +134,14 @@ jobs: target: full provenance: false # using github experimental cache - cache-from: type=gha - cache-to: type=gha,mode=max + #cache-from: type=gha + #cache-to: type=gha,mode=max # return to this if the experimental github cache is having issues #cache-to: type=local,dest=/tmp/.buildx-cache #cache-from: type=local,src=/tmp/.buildx-cache + # using registry cache (no storage limit) + cache-from: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }} + cache-to: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }},mode=max - name: Build and push Light Docker image (tagged + versioned) if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.light == true }} @@ -150,11 +156,14 @@ jobs: target: light provenance: false # using github experimental cache - cache-from: type=gha - cache-to: type=gha,mode=max + #cache-from: type=gha + #cache-to: type=gha,mode=max # return to this if the experimental github cache is having issues #cache-to: type=local,dest=/tmp/.buildx-cache #cache-from: type=local,src=/tmp/.buildx-cache + # using registry cache (no storage limit) + cache-from: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }} + cache-to: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }},mode=max - name: Build and push Server Docker image (tagged + versioned) if: ${{ (github.event_name == 'push' || github.event_name == 'schedule' || github.event_name == 'workflow_dispatch') && matrix.config.server == true }} @@ -169,11 +178,14 @@ jobs: target: server provenance: false # using github experimental cache - cache-from: type=gha - cache-to: type=gha,mode=max + #cache-from: type=gha + #cache-to: type=gha,mode=max # return to this if the experimental github cache is having issues #cache-to: type=local,dest=/tmp/.buildx-cache #cache-from: type=local,src=/tmp/.buildx-cache + # using registry cache (no storage limit) + cache-from: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }} + cache-to: type=registry,ref=${{ steps.tag.outputs.cache_output_tags }},mode=max create_tag: name: Create and push git tag diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index f461456edf..f4eae5da11 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -150,7 +150,7 @@ jobs: - name: ccache uses: ggml-org/ccache-action@v1.2.16 with: - key: ubuntu-cpu-cmake + key: ubuntu-cpu-cmake-${{ matrix.build }} evict-old-files: 1d - name: Dependencies diff --git a/CODEOWNERS b/CODEOWNERS index 8f19c3f7f7..89b84ce850 100644 --- a/CODEOWNERS +++ b/CODEOWNERS @@ -14,6 +14,7 @@ /common/build-info.* @ggerganov /common/common.* @ggerganov /common/console.* @ggerganov +/common/http.* @angt /common/llguidance.* @ggerganov /common/log.* @ggerganov /common/sampling.* @ggerganov @@ -50,6 +51,7 @@ /ggml/src/ggml-blas/ @slaren /ggml/src/ggml-common.h @ggerganov @slaren /ggml/src/ggml-cpu/ @ggerganov @slaren +/ggml/src/ggml-cpu/spacemit/ @alex-spacemit /ggml/src/ggml-cuda/common.cuh @slaren /ggml/src/ggml-cuda/fattn* @JohannesGaessler /ggml/src/ggml-cuda/ggml-cuda.cu @slaren @@ -59,6 +61,7 @@ /ggml/src/ggml-cuda/mmvq.* @JohannesGaessler /ggml/src/ggml-impl.h @ggerganov @slaren /ggml/src/ggml-metal/ @ggerganov +/ggml/src/ggml-opencl/ @lhez @max-krasnyansky /ggml/src/ggml-opt.cpp @JohannesGaessler /ggml/src/ggml-quants.* @ggerganov /ggml/src/ggml-rpc/ @rgerganov diff --git a/ci/run.sh b/ci/run.sh index 68cbfdf2f5..b0af51723b 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -114,6 +114,7 @@ if [ ! -z ${GG_BUILD_NO_SVE} ]; then # arm 9 and newer enables sve by default, adjust these flags depending on the cpu used CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_NATIVE=OFF -DGGML_CPU_ARM_ARCH=armv8.5-a+fp16+i8mm" fi + ## helpers # download a file if it does not exist or if it is outdated diff --git a/cmake/riscv64-spacemit-linux-gnu-gcc.cmake b/cmake/riscv64-spacemit-linux-gnu-gcc.cmake new file mode 100644 index 0000000000..08fdbf5063 --- /dev/null +++ b/cmake/riscv64-spacemit-linux-gnu-gcc.cmake @@ -0,0 +1,29 @@ +set(CMAKE_SYSTEM_NAME Linux) +set(CMAKE_SYSTEM_PROCESSOR riscv64) +set(CMAKE_SYSTEM_VERSION 1) + +if (CMAKE_HOST_SYSTEM_PROCESSOR MATCHES "^(riscv)") + message(STATUS "HOST SYSTEM ${CMAKE_HOST_SYSTEM_PROCESSOR}") +else() + set(GNU_MACHINE riscv64-unknown-linux-gnu CACHE STRING "GNU compiler triple") + if (DEFINED ENV{RISCV_ROOT_PATH}) + file(TO_CMAKE_PATH $ENV{RISCV_ROOT_PATH} RISCV_ROOT_PATH) + else() + message(FATAL_ERROR "RISCV_ROOT_PATH env must be defined") + endif() + + set(RISCV_ROOT_PATH ${RISCV_ROOT_PATH} CACHE STRING "root path to riscv toolchain") + set(CMAKE_C_COMPILER ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-gcc) + set(CMAKE_CXX_COMPILER ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-g++) + set(CMAKE_STRIP ${RISCV_ROOT_PATH}/bin/riscv64-unknown-linux-gnu-strip) + set(CMAKE_FIND_ROOT_PATH "${RISCV_ROOT_PATH}/riscv64-unknown-linux-gnu") + set(CMAKE_SYSROOT "${RISCV_ROOT_PATH}/sysroot") +endif() + +set(CMAKE_FIND_ROOT_PATH_MODE_PROGRAM NEVER) +set(CMAKE_FIND_ROOT_PATH_MODE_LIBRARY ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_INCLUDE ONLY) +set(CMAKE_FIND_ROOT_PATH_MODE_PACKAGE ONLY) +set(CMAKE_C_FLAGS "-march=rv64gcv_zfh_zba_zicbop -mabi=lp64d ${CMAKE_C_FLAGS}") +set(CMAKE_CXX_FLAGS "-march=rv64gcv_zfh_zba_zicbop -mabi=lp64d ${CXX_FLAGS}") +set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -latomic") diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 8ab3d44510..fe290bf8fd 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -56,6 +56,7 @@ add_library(${TARGET} STATIC common.h console.cpp console.h + http.h json-partial.cpp json-partial.h json-schema-to-grammar.cpp diff --git a/common/arg.cpp b/common/arg.cpp index f6a775fc4a..cbca8b5ac5 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -32,13 +32,11 @@ #include #include -//#define LLAMA_USE_CURL - #if defined(LLAMA_USE_CURL) #include #include #else -#include +#include "http.h" #endif #ifdef __linux__ @@ -54,6 +52,13 @@ #endif #define LLAMA_MAX_URL_LENGTH 2084 // Maximum URL Length in Chrome: 2083 +// isatty +#if defined(_WIN32) +#include +#else +#include +#endif + using json = nlohmann::ordered_json; std::initializer_list mmproj_examples = { @@ -100,6 +105,14 @@ static void write_file(const std::string & fname, const std::string & content) { } } +static bool is_output_a_tty() { +#if defined(_WIN32) + return _isatty(_fileno(stdout)); +#else + return isatty(1); +#endif +} + common_arg & common_arg::set_examples(std::initializer_list examples) { this->examples = std::move(examples); return *this; @@ -217,12 +230,55 @@ struct common_hf_file_res { std::string mmprojFile; }; -#ifdef LLAMA_USE_CURL - -bool common_has_curl() { - return true; +static void write_etag(const std::string & path, const std::string & etag) { + const std::string etag_path = path + ".etag"; + write_file(etag_path, etag); + LOG_DBG("%s: file etag saved: %s\n", __func__, etag_path.c_str()); } +static std::string read_etag(const std::string & path) { + std::string none; + const std::string etag_path = path + ".etag"; + + if (std::filesystem::exists(etag_path)) { + std::ifstream etag_in(etag_path); + if (!etag_in) { + LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str()); + return none; + } + std::string etag; + std::getline(etag_in, etag); + return etag; + } + + // no etag file, but maybe there is an old .json + // remove this code later + const std::string metadata_path = path + ".json"; + + if (std::filesystem::exists(metadata_path)) { + std::ifstream metadata_in(metadata_path); + try { + nlohmann::json metadata_json; + metadata_in >> metadata_json; + LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), + metadata_json.dump().c_str()); + if (metadata_json.contains("etag") && metadata_json.at("etag").is_string()) { + std::string etag = metadata_json.at("etag"); + write_etag(path, etag); + if (!std::filesystem::remove(metadata_path)) { + LOG_WRN("%s: failed to delete old .json metadata file: %s\n", __func__, metadata_path.c_str()); + } + return etag; + } + } catch (const nlohmann::json::exception & e) { + LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what()); + } + } + return none; +} + +#ifdef LLAMA_USE_CURL + // // CURL utils // @@ -373,36 +429,15 @@ static bool common_download_head(CURL * curl, static bool common_download_file_single_online(const std::string & url, const std::string & path, const std::string & bearer_token) { - // If the file exists, check its JSON metadata companion file. - std::string metadata_path = path + ".json"; static const int max_attempts = 3; static const int retry_delay_seconds = 2; for (int i = 0; i < max_attempts; ++i) { - nlohmann::json metadata; // TODO @ngxson : get rid of this json, use regex instead - std::string etag; - std::string last_modified; + std::string etag; // Check if the file already exists locally const auto file_exists = std::filesystem::exists(path); if (file_exists) { - // Try and read the JSON metadata file (note: stream autoclosed upon exiting this block). - std::ifstream metadata_in(metadata_path); - if (metadata_in.good()) { - try { - metadata_in >> metadata; - LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), - metadata.dump().c_str()); - if (metadata.contains("etag") && metadata.at("etag").is_string()) { - etag = metadata.at("etag"); - } - if (metadata.contains("lastModified") && metadata.at("lastModified").is_string()) { - last_modified = metadata.at("lastModified"); - } - } catch (const nlohmann::json::exception & e) { - LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what()); - } - } - // if we cannot open the metadata file, we assume that the downloaded file is not valid (etag and last-modified are left empty, so we will download it again) + etag = read_etag(path); } else { LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str()); } @@ -440,11 +475,6 @@ static bool common_download_file_single_online(const std::string & url, headers.etag.c_str()); should_download = true; should_download_from_scratch = true; - } else if (!last_modified.empty() && last_modified != headers.last_modified) { - LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, - last_modified.c_str(), headers.last_modified.c_str()); - should_download = true; - should_download_from_scratch = true; } } @@ -475,15 +505,9 @@ static bool common_download_file_single_online(const std::string & url, } } } - - // Write the updated JSON metadata file. - metadata.update({ - { "url", url }, - { "etag", headers.etag }, - { "lastModified", headers.last_modified } - }); - write_file(metadata_path, metadata.dump(4)); - LOG_DBG("%s: file metadata saved: %s\n", __func__, metadata_path.c_str()); + if (head_request_ok) { + write_etag(path, headers.etag); + } // start the download LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", @@ -570,82 +594,11 @@ std::pair> common_remote_get_content(const std::string & #else -bool common_has_curl() { - return false; -} - -struct common_url { - std::string scheme; - std::string user; - std::string password; - std::string host; - std::string path; -}; - -static common_url parse_url(const std::string & url) { - common_url parts; - auto scheme_end = url.find("://"); - - if (scheme_end == std::string::npos) { - throw std::runtime_error("invalid URL: no scheme"); - } - parts.scheme = url.substr(0, scheme_end); - - if (parts.scheme != "http" && parts.scheme != "https") { - throw std::runtime_error("unsupported URL scheme: " + parts.scheme); +static void print_progress(size_t current, size_t total) { + if (!is_output_a_tty()) { + return; } - auto rest = url.substr(scheme_end + 3); - auto at_pos = rest.find('@'); - - if (at_pos != std::string::npos) { - auto auth = rest.substr(0, at_pos); - auto colon_pos = auth.find(':'); - if (colon_pos != std::string::npos) { - parts.user = auth.substr(0, colon_pos); - parts.password = auth.substr(colon_pos + 1); - } else { - parts.user = auth; - } - rest = rest.substr(at_pos + 1); - } - - auto slash_pos = rest.find('/'); - - if (slash_pos != std::string::npos) { - parts.host = rest.substr(0, slash_pos); - parts.path = rest.substr(slash_pos); - } else { - parts.host = rest; - parts.path = "/"; - } - return parts; -} - -static std::pair http_client(const std::string & url) { - common_url parts = parse_url(url); - - if (parts.host.empty()) { - throw std::runtime_error("error: invalid URL format"); - } - - if (!parts.user.empty()) { - throw std::runtime_error("error: user:password@ not supported yet"); // TODO - } - - httplib::Client cli(parts.scheme + "://" + parts.host); - cli.set_follow_location(true); - - // TODO cert - - return { std::move(cli), std::move(parts) }; -} - -static std::string show_masked_url(const common_url & parts) { - return parts.scheme + "://" + (parts.user.empty() ? "" : "****:****@") + parts.host + parts.path; -} - -static void print_progress(size_t current, size_t total) { // TODO isatty if (!total) { return; } @@ -664,51 +617,6 @@ static void print_progress(size_t current, size_t total) { // TODO isatty std::cout.flush(); } -struct common_file_metadata { - std::string etag; - std::string last_modified; -}; - -static std::optional read_metadata(const std::string & path) { - if (!std::filesystem::exists(path)) { - return std::nullopt; - } - - nlohmann::json metadata_json; - common_file_metadata metadata; - - std::ifstream metadata_in(path); - try { - metadata_in >> metadata_json; - LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, path.c_str(), - metadata_json.dump().c_str()); - if (metadata_json.contains("etag") && metadata_json.at("etag").is_string()) { - metadata.etag = metadata_json.at("etag"); - } - if (metadata_json.contains("lastModified") && metadata_json.at("lastModified").is_string()) { - metadata.last_modified = metadata_json.at("lastModified"); - } - } catch (const nlohmann::json::exception & e) { - LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, path.c_str(), e.what()); - return std::nullopt; - } - - return metadata; -} - -static void write_metadata(const std::string & path, - const std::string & url, - const common_file_metadata & metadata) { - nlohmann::json metadata_json = { - { "url", url }, - { "etag", metadata.etag }, - { "lastModified", metadata.last_modified } - }; - - write_file(path, metadata_json.dump(4)); - LOG_DBG("%s: file metadata saved: %s\n", __func__, path.c_str()); -} - static bool common_pull_file(httplib::Client & cli, const std::string & resolve_path, const std::string & path_tmp, @@ -775,12 +683,10 @@ static bool common_pull_file(httplib::Client & cli, static bool common_download_file_single_online(const std::string & url, const std::string & path, const std::string & bearer_token) { - // If the file exists, check its JSON metadata companion file. - std::string metadata_path = path + ".json"; static const int max_attempts = 3; static const int retry_delay_seconds = 2; - auto [cli, parts] = http_client(url); + auto [cli, parts] = common_http_client(url); httplib::Headers default_headers = {{"User-Agent", "llama-cpp"}}; if (!bearer_token.empty()) { @@ -788,12 +694,11 @@ static bool common_download_file_single_online(const std::string & url, } cli.set_default_headers(default_headers); - common_file_metadata last; const bool file_exists = std::filesystem::exists(path); + + std::string last_etag; if (file_exists) { - if (auto opt = read_metadata(metadata_path)) { - last = *opt; - } + last_etag = read_etag(path); } else { LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str()); } @@ -809,14 +714,9 @@ static bool common_download_file_single_online(const std::string & url, } } - common_file_metadata current; - if (head_ok) { - if (head->has_header("ETag")) { - current.etag = head->get_header_value("ETag"); - } - if (head->has_header("Last-Modified")) { - current.last_modified = head->get_header_value("Last-Modified"); - } + std::string etag; + if (head_ok && head->has_header("ETag")) { + etag = head->get_header_value("ETag"); } size_t total_size = 0; @@ -834,16 +734,10 @@ static bool common_download_file_single_online(const std::string & url, } bool should_download_from_scratch = false; - if (head_ok) { - if (!last.etag.empty() && last.etag != current.etag) { - LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, - last.etag.c_str(), current.etag.c_str()); - should_download_from_scratch = true; - } else if (!last.last_modified.empty() && last.last_modified != current.last_modified) { - LOG_WRN("%s: Last-Modified header is different (%s != %s): triggering a new download\n", __func__, - last.last_modified.c_str(), current.last_modified.c_str()); - should_download_from_scratch = true; - } + if (!last_etag.empty() && !etag.empty() && last_etag != etag) { + LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, + last_etag.c_str(), etag.c_str()); + should_download_from_scratch = true; } if (file_exists) { @@ -871,9 +765,8 @@ static bool common_download_file_single_online(const std::string & url, } // start the download - LOG_INF("%s: trying to download model from %s to %s (server_etag:%s, server_last_modified:%s)...\n", - __func__, show_masked_url(parts).c_str(), path_temporary.c_str(), - current.etag.c_str(), current.last_modified.c_str()); + LOG_INF("%s: trying to download model from %s to %s (etag:%s)...\n", + __func__, common_http_show_masked_url(parts).c_str(), path_temporary.c_str(), etag.c_str()); const bool was_pull_successful = common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size); if (!was_pull_successful) { if (i + 1 < max_attempts) { @@ -883,7 +776,6 @@ static bool common_download_file_single_online(const std::string & url, } else { LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts); } - continue; } @@ -891,7 +783,9 @@ static bool common_download_file_single_online(const std::string & url, LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); return false; } - write_metadata(metadata_path, url, current); + if (!etag.empty()) { + write_etag(path, etag); + } break; } @@ -900,7 +794,7 @@ static bool common_download_file_single_online(const std::string & url, std::pair> common_remote_get_content(const std::string & url, const common_remote_params & params) { - auto [cli, parts] = http_client(url); + auto [cli, parts] = common_http_client(url); httplib::Headers headers = {{"User-Agent", "llama-cpp"}}; for (const auto & header : params.headers) { diff --git a/common/arg.h b/common/arg.h index 70bea100fd..77997c4ef3 100644 --- a/common/arg.h +++ b/common/arg.h @@ -78,7 +78,6 @@ bool common_params_parse(int argc, char ** argv, common_params & params, llama_e // function to be used by test-arg-parser common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); -bool common_has_curl(); struct common_remote_params { std::vector headers; diff --git a/common/http.h b/common/http.h new file mode 100644 index 0000000000..8e29787dcc --- /dev/null +++ b/common/http.h @@ -0,0 +1,73 @@ +#pragma once + +#include + +struct common_http_url { + std::string scheme; + std::string user; + std::string password; + std::string host; + std::string path; +}; + +static common_http_url common_http_parse_url(const std::string & url) { + common_http_url parts; + auto scheme_end = url.find("://"); + + if (scheme_end == std::string::npos) { + throw std::runtime_error("invalid URL: no scheme"); + } + parts.scheme = url.substr(0, scheme_end); + + if (parts.scheme != "http" && parts.scheme != "https") { + throw std::runtime_error("unsupported URL scheme: " + parts.scheme); + } + + auto rest = url.substr(scheme_end + 3); + auto at_pos = rest.find('@'); + + if (at_pos != std::string::npos) { + auto auth = rest.substr(0, at_pos); + auto colon_pos = auth.find(':'); + if (colon_pos != std::string::npos) { + parts.user = auth.substr(0, colon_pos); + parts.password = auth.substr(colon_pos + 1); + } else { + parts.user = auth; + } + rest = rest.substr(at_pos + 1); + } + + auto slash_pos = rest.find('/'); + + if (slash_pos != std::string::npos) { + parts.host = rest.substr(0, slash_pos); + parts.path = rest.substr(slash_pos); + } else { + parts.host = rest; + parts.path = "/"; + } + return parts; +} + +static std::pair common_http_client(const std::string & url) { + common_http_url parts = common_http_parse_url(url); + + if (parts.host.empty()) { + throw std::runtime_error("error: invalid URL format"); + } + + httplib::Client cli(parts.scheme + "://" + parts.host); + + if (!parts.user.empty()) { + cli.set_basic_auth(parts.user, parts.password); + } + + cli.set_follow_location(true); + + return { std::move(cli), std::move(parts) }; +} + +static std::string common_http_show_masked_url(const common_http_url & parts) { + return parts.scheme + "://" + (parts.user.empty() ? "" : "****:****@") + parts.host + parts.path; +} diff --git a/docs/build-riscv64-spacemit.md b/docs/build-riscv64-spacemit.md new file mode 100644 index 0000000000..eaa6532546 --- /dev/null +++ b/docs/build-riscv64-spacemit.md @@ -0,0 +1,89 @@ +> [!IMPORTANT] +> This build documentation is specific only to RISC-V SpacemiT SOCs. + +## Build llama.cpp locally (for riscv64) + +1. Prepare Toolchain For RISCV +~~~ +wget https://archive.spacemit.com/toolchain/spacemit-toolchain-linux-glibc-x86_64-v1.1.2.tar.xz +~~~ + +2. Build +Below is the build script: it requires utilizing RISC-V vector instructions for acceleration. Ensure the `GGML_CPU_RISCV64_SPACEMIT` compilation option is enabled. The currently supported optimization version is `RISCV64_SPACEMIT_IME1`, corresponding to the `RISCV64_SPACEMIT_IME_SPEC` compilation option. Compiler configurations are defined in the `riscv64-spacemit-linux-gnu-gcc.cmake` file. Please ensure you have installed the RISC-V compiler and set the environment variable via `export RISCV_ROOT_PATH={your_compiler_path}`. +```bash + +cmake -B build \ + -DCMAKE_BUILD_TYPE=Release \ + -DGGML_CPU_RISCV64_SPACEMIT=ON \ + -DLLAMA_CURL=OFF \ + -DGGML_RVV=ON \ + -DGGML_RV_ZFH=ON \ + -DGGML_RV_ZICBOP=ON \ + -DRISCV64_SPACEMIT_IME_SPEC=RISCV64_SPACEMIT_IME1 \ + -DCMAKE_TOOLCHAIN_FILE=${PWD}/cmake/riscv64-spacemit-linux-gnu-gcc.cmake \ + -DCMAKE_INSTALL_PREFIX=build/installed + +cmake --build build --parallel $(nproc) --config Release + +pushd build +make install +popd +``` + +## Simulation +You can use QEMU to perform emulation on non-RISC-V architectures. + +1. Download QEMU +~~~ +wget https://archive.spacemit.com/spacemit-ai/qemu/jdsk-qemu-v0.0.14.tar.gz +~~~ + +2. Run Simulation +After build your llama.cpp, you can run the executable file via QEMU for simulation, for example: +~~~ +export QEMU_ROOT_PATH={your QEMU file path} +export RISCV_ROOT_PATH_IME1={your RISC-V compiler path} + +${QEMU_ROOT_PATH}/bin/qemu-riscv64 -L ${RISCV_ROOT_PATH_IME1}/sysroot -cpu max,vlen=256,elen=64,vext_spec=v1.0 ${PWD}/build/bin/llama-cli -m ${PWD}/models/Qwen2.5-0.5B-Instruct-Q4_0.gguf -t 1 +~~~ +## Performance +#### Quantization Support For Matrix +~~~ +model name : Spacemit(R) X60 +isa : rv64imafdcv_zicbom_zicboz_zicntr_zicond_zicsr_zifencei_zihintpause_zihpm_zfh_zfhmin_zca_zcd_zba_zbb_zbc_zbs_zkt_zve32f_zve32x_zve64d_zve64f_zve64x_zvfh_zvfhmin_zvkt_sscofpmf_sstc_svinval_svnapot_svpbmt +mmu : sv39 +uarch : spacemit,x60 +mvendorid : 0x710 +marchid : 0x8000000058000001 +~~~ + +Q4_0 +| Model | Size | Params | backend | threads | test | t/s | +| -----------| -------- | ------ | ------- | ------- | ---- |------| +Qwen2.5 0.5B |403.20 MiB|630.17 M| cpu | 4 | pp512|64.12 ± 0.26| +Qwen2.5 0.5B |403.20 MiB|630.17 M| cpu | 4 | tg128|10.03 ± 0.01| +Qwen2.5 1.5B |1011.16 MiB| 1.78 B | cpu | 4 | pp512|24.16 ± 0.02| +Qwen2.5 1.5B |1011.16 MiB| 1.78 B | cpu | 4 | tg128|3.83 ± 0.06| +Qwen2.5 3B | 1.86 GiB | 3.40 B | cpu | 4 | pp512|12.08 ± 0.02| +Qwen2.5 3B | 1.86 GiB | 3.40 B | cpu | 4 | tg128|2.23 ± 0.02| + +Q4_1 +| Model | Size | Params | backend | threads | test | t/s | +| -----------| -------- | ------ | ------- | ------- | ---- |------| +Qwen2.5 0.5B |351.50 MiB|494.03 M| cpu | 4 | pp512|62.07 ± 0.12| +Qwen2.5 0.5B |351.50 MiB|494.03 M| cpu | 4 | tg128|9.91 ± 0.01| +Qwen2.5 1.5B |964.06 MiB| 1.54 B | cpu | 4 | pp512|22.95 ± 0.25| +Qwen2.5 1.5B |964.06 MiB| 1.54 B | cpu | 4 | tg128|4.01 ± 0.15| +Qwen2.5 3B | 1.85 GiB | 3.09 B | cpu | 4 | pp512|11.55 ± 0.16| +Qwen2.5 3B | 1.85 GiB | 3.09 B | cpu | 4 | tg128|2.25 ± 0.04| + + +Q4_K +| Model | Size | Params | backend | threads | test | t/s | +| -----------| -------- | ------ | ------- | ------- | ---- |------| +Qwen2.5 0.5B |462.96 MiB|630.17 M| cpu | 4 | pp512|9.29 ± 0.05| +Qwen2.5 0.5B |462.96 MiB|630.17 M| cpu | 4 | tg128|5.67 ± 0.04| +Qwen2.5 1.5B | 1.04 GiB | 1.78 B | cpu | 4 | pp512|10.38 ± 0.10| +Qwen2.5 1.5B | 1.04 GiB | 1.78 B | cpu | 4 | tg128|3.17 ± 0.08| +Qwen2.5 3B | 1.95 GiB | 3.40 B | cpu | 4 | pp512|4.23 ± 0.04| +Qwen2.5 3B | 1.95 GiB | 3.40 B | cpu | 4 | tg128|1.73 ± 0.00| diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 1a0fdb676c..56420587a9 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -4,8 +4,7 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MINOR 9) -set(GGML_VERSION_PATCH 0) -set(GGML_VERSION_DEV "-dev") # "-dev" for development, "" for releases +set(GGML_VERSION_PATCH 4) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH) @@ -26,8 +25,8 @@ if(GIT_EXE) ) endif() -# Build the version string with optional -dev suffix and dirty flag -set(GGML_VERSION "${GGML_VERSION_BASE}${GGML_VERSION_DEV}") +# Build the version string with optional dirty flag +set(GGML_VERSION "${GGML_VERSION_BASE}") if(GGML_GIT_DIRTY AND NOT GGML_GIT_DIRTY EQUAL 0) set(GGML_VERSION "${GGML_VERSION}-dirty") endif() diff --git a/ggml/include/ggml.h b/ggml/include/ggml.h index 36b23dc6d0..5028a9cebf 100644 --- a/ggml/include/ggml.h +++ b/ggml/include/ggml.h @@ -237,6 +237,8 @@ #define GGML_EXIT_SUCCESS 0 #define GGML_EXIT_ABORTED 1 +// TODO: convert to enum https://github.com/ggml-org/llama.cpp/pull/16187#discussion_r2388538726 +#define GGML_ROPE_TYPE_NORMAL 0 #define GGML_ROPE_TYPE_NEOX 2 #define GGML_ROPE_TYPE_MROPE 8 #define GGML_ROPE_TYPE_VISION 24 diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index 7002cb07e0..136afec748 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -135,6 +135,10 @@ static void * dl_get_sym(dl_handle * handle, const char * name) { return p; } +static const char * dl_error() { + return ""; +} + #else using dl_handle = void; @@ -155,6 +159,11 @@ static void * dl_get_sym(dl_handle * handle, const char * name) { return dlsym(handle, name); } +static const char * dl_error() { + const char *rslt = dlerror(); + return rslt != nullptr ? rslt : ""; +} + #endif using dl_handle_ptr = std::unique_ptr; @@ -240,7 +249,7 @@ struct ggml_backend_registry { dl_handle_ptr handle { dl_load_library(path) }; if (!handle) { if (!silent) { - GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_str(path).c_str()); + GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_str(path).c_str(), dl_error()); } return nullptr; } @@ -530,7 +539,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, if (filename.native().find(file_prefix) == 0 && ext == file_extension) { dl_handle_ptr handle { dl_load_library(entry) }; if (!handle && !silent) { - GGML_LOG_ERROR("%s: failed to load %s\n", __func__, path_str(entry.path()).c_str()); + GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, path_str(entry.path()).c_str(), dl_error()); } if (handle) { auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); diff --git a/ggml/src/ggml-blas/CMakeLists.txt b/ggml/src/ggml-blas/CMakeLists.txt index 76064c3fd1..60ce4b1e02 100644 --- a/ggml/src/ggml-blas/CMakeLists.txt +++ b/ggml/src/ggml-blas/CMakeLists.txt @@ -74,7 +74,7 @@ if (BLAS_FOUND) target_compile_options(ggml-blas PRIVATE ${BLAS_LINKER_FLAGS}) - if (${BLAS_INCLUDE_DIRS} MATCHES "mkl" AND (${GGML_BLAS_VENDOR} MATCHES "Generic" OR ${GGML_BLAS_VENDOR} MATCHES "Intel")) + if ("${BLAS_INCLUDE_DIRS}" MATCHES "mkl" AND (${GGML_BLAS_VENDOR} MATCHES "Generic" OR ${GGML_BLAS_VENDOR} MATCHES "Intel")) add_compile_definitions(GGML_BLAS_USE_MKL) endif() diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 3699057507..42041b717a 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -439,6 +439,15 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ggml-cpu/arch/riscv/quants.c ggml-cpu/arch/riscv/repack.cpp ) + if (GGML_CPU_RISCV64_SPACEMIT) + target_compile_definitions(${GGML_CPU_NAME} PRIVATE GGML_USE_CPU_RISCV64_SPACEMIT ${RISCV64_SPACEMIT_IME_SPEC}) + list(APPEND GGML_CPU_SOURCES + ggml-cpu/spacemit/ime.cpp + ggml-cpu/spacemit/ime.h + ggml-cpu/spacemit/ime1_kernels.cpp + ggml-cpu/spacemit/ime_kernels.h + ) + endif() set(MARCH_STR "rv64gc") if (GGML_RV_ZFH) string(APPEND MARCH_STR "_zfh") @@ -504,9 +513,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name) # Fetch KleidiAI sources: include(FetchContent) - set(KLEIDIAI_COMMIT_TAG "v1.13.0") + set(KLEIDIAI_COMMIT_TAG "v1.14.0") set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz") - set(KLEIDIAI_ARCHIVE_MD5 "d82a8de939d9814621a5ba23907bdac1") + set(KLEIDIAI_ARCHIVE_MD5 "45e110675d93f99f82c23a1afcca76bc") if (POLICY CMP0135) cmake_policy(SET CMP0135 NEW) @@ -583,6 +592,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name) ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa_asm.S ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c ${KLEIDIAI_SRC}/kai/kai_common_sme_asm.S) diff --git a/ggml/src/ggml-cpu/ggml-cpu.cpp b/ggml/src/ggml-cpu/ggml-cpu.cpp index 81a314e4d6..3191faaa4c 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.cpp +++ b/ggml/src/ggml-cpu/ggml-cpu.cpp @@ -18,6 +18,10 @@ # include "kleidiai/kleidiai.h" #endif +#ifdef GGML_USE_CPU_RISCV64_SPACEMIT +# include "spacemit/ime.h" +#endif + #if defined(_WIN32) # define WIN32_LEAN_AND_MEAN # ifndef NOMINMAX @@ -45,6 +49,12 @@ std::vector & ggml_backend_cpu_get_extra_buffer_type } #endif +#ifdef GGML_USE_CPU_RISCV64_SPACEMIT + if (ggml_backend_cpu_riscv64_spacemit_buffer_type()) { + bufts.push_back(ggml_backend_cpu_riscv64_spacemit_buffer_type()); + } +#endif + #ifdef GGML_USE_CPU_KLEIDIAI if (ggml_backend_cpu_kleidiai_buffer_type()) { bufts.push_back(ggml_backend_cpu_kleidiai_buffer_type()); diff --git a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp index 8694ee15d3..44691e5dfd 100644 --- a/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp +++ b/ggml/src/ggml-cpu/kleidiai/kleidiai.cpp @@ -87,15 +87,38 @@ static inline int64_t ggml_ne(const ggml_tensor * tensor, int dim) { return tensor->ne[dim]; } +template +constexpr bool variant_any_invocable_impl(std::index_sequence) { + using V = std::remove_reference_t; + return (std::is_invocable_r_v< + Ret, + std::variant_alternative_t, + Args...> || ...); +} + +template +constexpr bool variant_any_invocable_v = + variant_any_invocable_impl( + std::make_index_sequence< + std::variant_size_v>>{}); + template -static Ret variant_call(const Variant & var, Args&&... args) { - return std::visit([&](auto&& func) -> Ret { - if constexpr (std::is_invocable_r_v) { - return func(std::forward(args)...); - } else { - throw std::runtime_error("Invalid function type in variant_call"); - } - }, var); +static inline Ret variant_call(Variant && var, Args&&... args) { + static_assert(variant_any_invocable_v, Ret, Args...>, + "No alternative in Variant is invocable with the provided arguments and return type."); + + return std::visit( + [&](auto && f) -> Ret { + using F = std::decay_t; + if constexpr (std::is_invocable_r_v) { + return std::invoke(std::forward(f), std::forward(args)...); + } else { + GGML_ABORT("Invalid function type in variant_call"); + GGML_UNREACHABLE(); + } + }, + std::forward(var) + ); } namespace ggml::cpu::kleidiai { @@ -138,7 +161,10 @@ class tensor_traits : public ggml::cpu::tensor_traits { if (kernels->rhs_type == GGML_TYPE_Q4_0) { size = variant_call(lhs_info->packed_size, m, k, QK4_0, mr, kr, sr); } else if (kernels->rhs_type == GGML_TYPE_F16) { - size = variant_call(lhs_info->packed_size, m, k, mr, kr, sr) + + const int64_t lhs_batch_size0 = op->src[1]->ne[2]; + const int64_t rhs_batch_size0 = op->src[0]->ne[2]; + const int64_t r = lhs_batch_size0 / rhs_batch_size0; + size = variant_call(lhs_info->packed_size, m * r, k, mr, kr, sr) + variant_call(kernels->rhs_info.packed_size, n, k) + k * n * sizeof(float) + n * sizeof(float); } else { @@ -148,7 +174,6 @@ class tensor_traits : public ggml::cpu::tensor_traits { return true; } - bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * dst) override { if (dst->op == GGML_OP_MUL_MAT) { if (dst->src[0]->type == GGML_TYPE_Q4_0) { @@ -165,8 +190,6 @@ class tensor_traits : public ggml::cpu::tensor_traits { } bool compute_forward_fp16(ggml_compute_params * params, struct ggml_tensor * dst) { - static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT; - const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src1 = dst->src[1]; @@ -175,7 +198,7 @@ class tensor_traits : public ggml::cpu::tensor_traits { ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst); GGML_ASSERT(kernels); - bool is_gemv = src1->ne[1] == 1; + const bool is_gemv = src1->ne[1] == 1; kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm; lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info; GGML_ASSERT(kernel); @@ -185,27 +208,30 @@ class tensor_traits : public ggml::cpu::tensor_traits { const int64_t lhs_batch_size0 = ne12; const int64_t rhs_batch_size0 = ne02; - const int64_t batch_size = rhs_batch_size0; + const int64_t batch_size = lhs_batch_size0; + GGML_ASSERT(rhs_batch_size0 > 0); + GGML_ASSERT(lhs_batch_size0 % rhs_batch_size0 == 0); const int64_t r = lhs_batch_size0 / rhs_batch_size0; - const int64_t m = ne11 * r; - const int64_t n = ne01; - const int64_t k = ne00; + const int64_t m_group = ne11; + const int64_t m = m_group; + const int64_t n = ne01; + const int64_t k = ne00; const size_t lhs_stride = src1->nb[1]; const size_t rhs_stride = src0->nb[1]; const size_t dst_stride = dst->nb[1]; - const int64_t mr = static_cast(kernel->get_mr()); - const int64_t nr = static_cast(kernel->get_nr()); - const int64_t kr = static_cast(kernel->get_kr()); - const int64_t sr = static_cast(kernel->get_sr()); + const int64_t mr = (int64_t) kernel->get_mr(); + const int64_t nr = (int64_t) kernel->get_nr(); + const int64_t kr = (int64_t) kernel->get_kr(); + const int64_t sr = (int64_t) kernel->get_sr(); - const size_t lhs_packed_size = variant_call(lhs_info->packed_size, m, k, mr, kr, sr); - const size_t rhs_packed_size = variant_call(kernels->rhs_info.packed_size, n, k); - const size_t kxn_size = k * n * sizeof(float); - const size_t bias_size = n * sizeof(float); + const size_t lhs_packed_size = variant_call(lhs_info->packed_size, (size_t)m, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr); + const size_t rhs_packed_size = variant_call(kernels->rhs_info.packed_size, (size_t)n, (size_t)k); + const size_t kxn_size = (size_t)k * (size_t)n * sizeof(float); + const size_t bias_size = (size_t)n * sizeof(float); const size_t wsize_required = lhs_packed_size + rhs_packed_size + kxn_size + bias_size; GGML_ASSERT(wsize_required <= params->wsize); @@ -216,82 +242,102 @@ class tensor_traits : public ggml::cpu::tensor_traits { uint8_t * bias = rhs_kxn + kxn_size; for (int64_t batch_idx = 0; batch_idx < batch_size; ++batch_idx) { - const uint8_t * lhs_batch = static_cast(src1->data) + batch_idx * m * lhs_stride; - const uint8_t * rhs_batch = static_cast(src0->data) + batch_idx * n * rhs_stride; - uint8_t * dst_batch = static_cast(dst->data) + batch_idx * m * dst_stride; + const int64_t rhs_batch_idx = batch_idx / r; + const uint8_t * rhs_batch_base = static_cast(src0->data) + rhs_batch_idx * src0->nb[2]; + uint8_t * dst_batch_base = static_cast(dst->data) + batch_idx * dst->nb[2]; - // LHS packing + // LHS packing (threaded over m, honoring mr alignment and KV groups) { const int64_t m_roundup_mr = kai_roundup(m, mr); const int64_t num_threads = KAI_MIN(m_roundup_mr / mr, nth); if (ith < num_threads) { - const int64_t num_m_per_thread0 = round_down(m_roundup_mr / num_threads, mr); + const int64_t num_m_per_thread0 = round_down((size_t)(m_roundup_mr / num_threads), (size_t)mr); const int64_t num_m_per_threadN_1 = m - (num_threads - 1) * num_m_per_thread0; - const int64_t m_start = ith * num_m_per_thread0; - const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0; + const int64_t m_start = ith * num_m_per_thread0; + const int64_t m_count = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0; - const size_t lhs_offset = variant_call(kernels->gemm.get_lhs_offset, m_start, lhs_stride); - const size_t lhs_packed_offset = variant_call(lhs_info->get_packed_offset, m_start, k, mr, kr, sr); + // Base packed offset (aligned) and per-row stride in bytes + const size_t base_packed_off = variant_call( + lhs_info->get_packed_offset, (size_t)m_start, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr); + const size_t next_block_off = variant_call( + lhs_info->get_packed_offset, (size_t)(m_start + mr), (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr); + const size_t row_stride_bytes = (next_block_off - base_packed_off) / (size_t)mr; - const void * src_ptr = static_cast(lhs_batch) + lhs_offset; - void * dst_ptr = static_cast(lhs_packed) + lhs_packed_offset; + int64_t remaining = m_count; + int64_t cur = m_start; - variant_call(lhs_info->pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr); + while (remaining > 0) { + const int64_t row_in_group = cur; + const int64_t avail = m_group - row_in_group; + const int64_t take = std::min(avail, remaining); + + const uint8_t * lhs_batch_base = static_cast(src1->data) + batch_idx * src1->nb[2]; + const void * src_ptr = lhs_batch_base + (size_t)row_in_group * lhs_stride; + const size_t dst_off = base_packed_off + (size_t)(cur - m_start) * row_stride_bytes; + void * dst_ptr = lhs_packed + dst_off; + + variant_call(lhs_info->pack_func, + (size_t)take, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr, + /*m_idx_start*/ 0, src_ptr, lhs_stride, dst_ptr); + + cur += take; + remaining -= take; + } } } - // RHS packing - if (first_to_arrive.test_and_set(std::memory_order_acquire) == false) { - // First thread to reach this point handles RHS packing - memset(bias, 0, n * sizeof(float)); - transpose_f32kxn_f16nxk(n, k, reinterpret_cast(rhs_kxn), - reinterpret_cast(rhs_batch), rhs_stride); + // RHS packing (single thread), then synchronize + if (ith == 0) { + memset(bias, 0, (size_t)n * sizeof(float)); + transpose_f32kxn_f16nxk((size_t)n, (size_t)k, + reinterpret_cast(rhs_kxn), + reinterpret_cast(rhs_batch_base), + rhs_stride); - variant_call(kernels->rhs_info.pack_func, 1, n, k, nr, kr, sr, n * sizeof(float), - rhs_kxn, bias, nullptr, rhs_packed, 0, nullptr); + variant_call(kernels->rhs_info.pack_func, + /*num_groups*/ 1, (size_t)n, (size_t)k, (size_t)nr, (size_t)kr, (size_t)sr, + /*rhs_stride (bytes)*/ (size_t)(n * sizeof(float)), + rhs_kxn, bias, nullptr, rhs_packed, /*extra_bytes*/ 0, /*params*/ nullptr); } ggml_barrier(params->threadpool); - first_to_arrive.clear(std::memory_order_release); - - // Perform the matmul + // Matmul (threaded over n) { - const int64_t m_to_process = m; - const int64_t m_start = 0; - - const int64_t n_step = static_cast(kernel->get_n_step()); - int64_t num_threads = KAI_MIN(n / n_step, nth); - if (num_threads <= 0) { - num_threads = 1; + const int64_t n_step = (int64_t) kernel->get_n_step(); + int64_t num_threads_n = KAI_MIN(n / n_step, nth); + if (num_threads_n <= 0) { + num_threads_n = 1; } - if (ith < num_threads) { - const int64_t num_n_per_thread0 = round_down(n / num_threads, n_step); - const int64_t num_n_per_threadN_1 = n - (num_threads - 1) * num_n_per_thread0; + if (ith < num_threads_n) { + const int64_t num_n_per_thread0 = round_down((size_t)(n / num_threads_n), (size_t)n_step); + const int64_t num_n_per_threadN_1 = n - (num_threads_n - 1) * num_n_per_thread0; const int64_t n_start = ith * num_n_per_thread0; - const int64_t n_to_process = (ith == num_threads - 1) ? num_n_per_threadN_1 : num_n_per_thread0; + const int64_t n_to_process = (ith == num_threads_n - 1) ? num_n_per_threadN_1 : num_n_per_thread0; - const size_t lhs_packed_offset = variant_call(kernel->get_lhs_offset, m_start, k); - const size_t rhs_packed_offset = variant_call(kernel->get_rhs_packed_offset, n_start, k); - const size_t dst_offset = kernel->get_dst_offset(m_start, n_start, dst_stride); + // LHS packed base at row 0 (consistent with packing above) + const size_t lhs_packed_offset0 = variant_call( + lhs_info->get_packed_offset, (size_t)0, (size_t)k, (size_t)mr, (size_t)kr, (size_t)sr); + const size_t rhs_packed_offset = variant_call(kernel->get_rhs_packed_offset, (size_t)n_start, (size_t)k); + const size_t dst_offset = kernel->get_dst_offset((size_t)0, (size_t)n_start, dst_stride); - const void * lhs_ptr = lhs_packed + lhs_packed_offset; + const void * lhs_ptr = lhs_packed + lhs_packed_offset0; const void * rhs_ptr = rhs_packed + rhs_packed_offset; - float * dst_ptr = reinterpret_cast(dst_batch + dst_offset); + float * dst_ptr = reinterpret_cast(dst_batch_base + dst_offset); - variant_call(kernel->run_kernel, m_to_process, n_to_process, k, lhs_ptr, rhs_ptr, dst_ptr, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); + variant_call(kernel->run_kernel, + (size_t)m, (size_t)n_to_process, (size_t)k, + lhs_ptr, rhs_ptr, + dst_ptr, dst_stride, sizeof(float), + -FLT_MAX, FLT_MAX); } } if (batch_idx != batch_size - 1) { - // This barrier is necessary when the batch size is larger than 1. While processing a batch, - // the work data buffer (params->wdata) is used as temporary storage which means that only - // a single batch can be processed at any given time. No barrier is needed for the last - // batch since GGML inserts a barrier between the execution of every operator. ggml_barrier(params->threadpool); } } diff --git a/ggml/src/ggml-cpu/spacemit/ime.cpp b/ggml/src/ggml-cpu/spacemit/ime.cpp new file mode 100644 index 0000000000..54d3dece0e --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ime.cpp @@ -0,0 +1,1024 @@ +#define GGML_COMMON_IMPL_CPP +#define GGML_COMMON_DECL_CPP + +#include "ime.h" + +#include "ggml-backend-impl.h" +#include "ggml-common.h" +#include "ggml-cpu.h" +#include "ime_kernels.h" +#include "traits.h" + +#include +#include +#include +#include // for GGML_ASSERT +#include +#include + +// clang-format off +#if defined(__riscv) + +#if !defined(__riscv_v) || !defined(__riscv_v_intrinsic) +#error "riscv v extension or v_intrinsic not enabled" +#else +#include +#endif + +#if !defined(__riscv_zfh) +#error "riscv zfh extension not enabled" +#endif + +#if defined(RISCV64_SPACEMIT_IME1) +#else +#error "RISCV64_SPACEMIT_IME1 not defined" +#endif + +#else + +#error "riscv not enabled in this build" + +#endif + +#if defined(__GNUC__) +#pragma GCC diagnostic ignored "-Woverlength-strings" +#pragma GCC diagnostic ignored "-Wcast-qual" +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif + +#if defined(RISCV64_SPACEMIT_IME1) +#define QGEMM_STRIDEN_THREAD_ALIGN 16 +#else +#define QGEMM_STRIDEN_THREAD_ALIGN 32 +#endif + +// clang-format on + +struct qnbitgemm_spacemit_ime_args { + const float * a_ptr = nullptr; + size_t lda = 0; + const std::byte * packed_quant_b_data = nullptr; + const float * quant_b_scale = nullptr; + const void * quant_b_zp = nullptr; + const float * quant_b_blksum = nullptr; + const float * bias = nullptr; + float * c_ptr = nullptr; + size_t ldc = 0; +}; + +constexpr size_t div_round_up(size_t up, size_t down) { + return (up + down - 1) / down; +} + +constexpr size_t q8_blk_size(size_t blk_len) { + const size_t blk_size = sizeof(float) + blk_len * sizeof(int8_t); + // Currently, the strictest alignment requirement of a block is for a float. + // Ensure contiguous blocks are suitably aligned. + assert(blk_size % alignof(float) == 0); + return blk_size; +} + +namespace ggml::cpu::riscv64_spacemit { + +const int num_ai_cores = std::thread::hardware_concurrency() / 2; + +} // namespace ggml::cpu::riscv64_spacemit + +static void sqnbitgemm_spacemit_ime_i8i4(const size_t blk_len, + const size_t gemm_k, + const qnbitgemm_spacemit_ime_args * gemm_args, + void * const per_gemm_ws, + const size_t m_start, + const size_t m_count, + const size_t n_start, + const size_t n_count) { + constexpr size_t scale_stride = sizeof(uint16_t); + constexpr size_t blk_bitwidth = 4; + + const size_t k_blks = div_round_up(gemm_k, blk_len); + + const size_t lda = k_blks * q8_blk_size(blk_len); + const size_t ldc = gemm_args->ldc; + const size_t ldb = k_blks * (blk_len * blk_bitwidth / 8); + const std::byte * quant_a_ptr = static_cast(per_gemm_ws) + m_start * lda; + + const size_t zero_point_stride = gemm_args->quant_b_zp != nullptr ? sizeof(uint8_t) : 0; + const size_t packed_b_stride = ldb + k_blks * (scale_stride + zero_point_stride); + const std::byte * packed_quant_b_data = gemm_args->packed_quant_b_data + n_start * packed_b_stride; + + float * c_ptr = gemm_args->c_ptr + m_start * ldc + n_start; + + size_t count_n = 0; + const size_t compute_block_count_n = m_count == 1 ? n_count : 16; + for (size_t n = 0; n < n_count; n += count_n) { + count_n = std::min(n_count - n, compute_block_count_n); + + const std::byte * a_row = quant_a_ptr; + const std::byte * b_col = packed_quant_b_data + n * packed_b_stride; + const std::byte * b_col_zp = (zero_point_stride != 0) ? b_col : nullptr; + float * c_blk = c_ptr + n; + + int32_t rows_remaining = m_count; + + while (rows_remaining > 0) { + const auto rows_handled = sqnbitgemm_spacemit_ime::ime1::gemm_kernel_i8i4( + blk_len, a_row, b_col, nullptr, b_col_zp, c_blk, rows_remaining, count_n, gemm_k, k_blks, ldc, nullptr, + scale_stride); + + c_blk += rows_handled * ldc; + a_row += rows_handled * lda; + + rows_remaining -= rows_handled; + } + } +} + +template constexpr int QK_0() { + if constexpr (K == 4) { + return QK4_0; + } + if constexpr (K == 8) { + return QK8_0; + } + return -1; +} + +template struct block { + ggml_half d[N]; // deltas for N qK_0 blocks + uint8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_0 blocks +}; + +template struct block_with_zp { + ggml_half d[N]; // deltas for N qK_1 blocks + uint8_t zp[N]; // zero points for N qK_1 blocks + uint8_t qs[(QK_0() * N * K) / 8]; // quants for N qK_1 blocks +}; + +// control size +static_assert(sizeof(block<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8, "wrong block<4,16> size/padding"); +static_assert(sizeof(block_with_zp<4, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 8 + 16 * sizeof(uint8_t), + "wrong block_with_zp<4,16> size/padding"); +static_assert(sizeof(block<8, 16>) == 16 * sizeof(ggml_half) + QK4_0 * 16, "wrong block<8,16> size/padding"); + +using block_q4_0x16 = block<4, 16>; +using block_q4_1x16 = block_with_zp<4, 16>; +using block_q8_0x16 = block<8, 16>; + +static block_q4_0x16 make_block_q4_0x16(block_q4_0 * in, unsigned int blck_size_interleave) { + block_q4_0x16 out; + GGML_ASSERT(QK4_0 / blck_size_interleave == 2); + + for (int i = 0; i < 16; i++) { + out.d[i] = in[i].d; + } + + for (int i = 0; i < 16; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b8] ......... [b7 b15] + out.qs[i * QK4_0 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_0 / 4] & 0x0F) << 4); + } + } + + for (int i = 0; i < 16; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_0 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b24] ......... [b23 b31] + out.qs[4 * QK4_0 + i * QK4_0 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_0 / 4] & 0xF0); + } + } + + return out; +} + +static block_q4_1x16 make_block_q4_1x16(block_q4_1 * in, unsigned int blck_size_interleave) { + block_q4_1x16 out; + GGML_ASSERT(QK4_1 / blck_size_interleave == 2); + + for (int i = 0; i < 16; i++) { + float d = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + float m = GGML_FP16_TO_FP32(in[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m); + float mid = -std::nearbyintf(m / d); + mid = std::min(15.0f, std::max(0.0f, mid)); + out.d[i] = GGML_FP32_TO_FP16(d); + out.zp[i] = static_cast(mid); + } + + for (int i = 0; i < 16; i++) { + // [0, 15], in.d & 0x0F + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b0 b8] ......... [b7 b15] + out.qs[i * QK4_1 / 4 + j] = (in[i].qs[j] & 0x0F) | ((in[i].qs[j + QK4_1 / 4] & 0x0F) << 4); + } + } + + for (int i = 0; i < 16; i++) { + // [16, 31], in.d & 0xF0 + for (int j = 0; j < QK4_1 / 4; j++) { + //src [b0 b16] ......... [b8 b24] ......... [b15 b31] + //dst [b16 b24] ......... [b23 b31] + out.qs[4 * QK4_1 + i * QK4_1 / 4 + j] = ((in[i].qs[j] & 0xF0) >> 4) | (in[i].qs[j + QK4_1 / 4] & 0xF0); + } + } + + return out; +} + +static int repack_q4_0_to_q4_0_16_bl(struct ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_0); + GGML_ASSERT(interleave_block == 16); + + constexpr int nrows_interleaved = 16; + + block_q4_0x16 * dst = (block_q4_0x16 *) t->data; + const block_q4_0 * src = (const block_q4_0 *) data; + block_q4_0 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_0; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_0)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_0 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_0x16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static int repack_q4_1_to_q4_1_16_bl(struct ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_1); + GGML_ASSERT(interleave_block == 16); + + constexpr int nrows_interleaved = 16; + + block_q4_1x16 * dst = (block_q4_1x16 *) t->data; + const block_q4_1 * src = (const block_q4_1 *) data; + block_q4_1 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK4_1; + + GGML_ASSERT(data_size == nrow * nblocks * sizeof(block_q4_1)); + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK4_1 != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int i = 0; i < nrows_interleaved; i++) { + dst_tmp[i] = src[x + i * nblocks]; + } + *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +static inline void get_scale_min_k4(int j, + const uint8_t * GGML_RESTRICT q, + uint8_t * GGML_RESTRICT d, + uint8_t * GGML_RESTRICT m) { + if (j < 4) { + *d = q[j] & 63; + *m = q[j + 4] & 63; + } else { + *d = (q[j + 4] & 0xF) | ((q[j - 4] >> 6) << 4); + *m = (q[j + 4] >> 4) | ((q[j - 0] >> 6) << 4); + } +} + +static int repack_q4_k_to_q4_1_16_bl(struct ggml_tensor * t, + int interleave_block, + const void * GGML_RESTRICT data, + size_t data_size) { + GGML_ASSERT(t->type == GGML_TYPE_Q4_K); + GGML_ASSERT(interleave_block == 16); + GGML_ASSERT(QK_K / QK4_1 == 8); + + constexpr int nrows_interleaved = 16; + + block_q4_1x16 * dst = (block_q4_1x16 *) t->data; + const block_q4_K * src = (const block_q4_K *) data; + block_q4_1 dst_tmp[16]; + int nrow = ggml_nrows(t); + int nblocks = t->ne[0] / QK_K; + + if (t->ne[1] % nrows_interleaved != 0 || t->ne[0] % QK_K != 0) { + return -1; + } + + for (int b = 0; b < nrow; b += nrows_interleaved) { + for (int64_t x = 0; x < nblocks; x++) { + for (int j = 0; j < 8; j++) { + for (int i = 0; i < nrows_interleaved; i++) { + uint8_t sc, m; + const float d = GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d); + const float min = + GGML_FP16_TO_FP32(src[x + i * nblocks].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.dmin); + get_scale_min_k4(j, src[x + i * nblocks].scales, &sc, &m); + const float d1 = d * sc; + const float m1 = min * m; + + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d = GGML_FP32_TO_FP16(d1); + dst_tmp[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.m = GGML_FP32_TO_FP16(-m1); + // src -> [b0, b32] [b1, b33] ... [b31, b63] + // dst -> [b0, b16] [b1, b17] ... [b15, b31] [b32, b48] [b33, b49] ... [b47, b63] + const uint8_t * q = src[x + i * nblocks].qs + (j / 2) * QK4_1; + if (j % 2 == 0) { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = (q[ii] & 0x0F) | ((q[ii + 16] & 0x0F) << 4); + } + } else { + for (int ii = 0; ii < 16; ii++) { + dst_tmp[i].qs[ii] = ((q[ii] & 0xF0) >> 4) | (q[ii + 16] & 0xF0); + } + } + } + *dst++ = make_block_q4_1x16(dst_tmp, interleave_block); + } + } + src += nrows_interleaved * nblocks; + } + return 0; + + GGML_UNUSED(data_size); +} + +namespace ggml::cpu::riscv64_spacemit { + +template +int repack(struct ggml_tensor *, const void *, size_t); + +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_0_to_q4_0_16_bl(t, 16, data, data_size); +} + +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_1_to_q4_1_16_bl(t, 16, data, data_size); +} + +template <> int repack(struct ggml_tensor * t, const void * data, size_t data_size) { + return repack_q4_k_to_q4_1_16_bl(t, 16, data, data_size); +} + +class tensor_traits_base : public ggml::cpu::tensor_traits { + public: + virtual int repack(struct ggml_tensor * t, const void * data, size_t data_size) = 0; +}; + +template class tensor_traits : public tensor_traits_base { + bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { + switch (op->op) { + case GGML_OP_MUL_MAT: + size = ggml_row_size(GGML_TYPE_Q8_0, ggml_nelements(op->src[1])) * 4; + size = ((size + QK4_0 - 1) / QK4_0) * (QK4_0 * sizeof(float) + sizeof(float)); + return true; + default: + // GGML_ABORT("fatal error"); + break; + } + return false; + } + + bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override { + switch (op->op) { + case GGML_OP_MUL_MAT: + if (op->src[0]->type == GGML_TYPE_Q4_0 || // + op->src[0]->type == GGML_TYPE_Q4_1 || // + op->src[0]->type == GGML_TYPE_Q4_K) { + forward_mul_mat_q4(params, op); + return true; + } + default: + // GGML_ABORT("fatal error"); + break; + } + return false; + } + + void forward_mul_mat_q4(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + const ggml_tensor * src1 = op->src[1]; + ggml_tensor * dst = op; + + GGML_TENSOR_BINARY_OP_LOCALS + + int ith = params->ith; + int nth = params->nth; + + [[maybe_unused]] const enum ggml_type type = src0->type; + + void * w_data = (void *) src0->data; + const float * feature = (const float *) src1->data; + float * output = (float *) dst->data; + + const size_t batch_feature = ne12 * ne13; + [[maybe_unused]] const size_t batch_weight = ne02 * ne03; + const size_t gemm_m = ne11; + const size_t gemm_k = ne10; + const size_t gemm_n = ne01; + + GGML_ASSERT(batch_weight == 1); + + const size_t block_count_k = div_round_up(gemm_k, QK4_0); + const size_t per_gemm_workspace_size = gemm_m * block_count_k * q8_blk_size(QK4_0); + const size_t per_gemm_workspace_stride = + div_round_up(per_gemm_workspace_size, alignof(uint64_t)) * alignof(uint64_t); + const size_t gemm_workspace_size = batch_feature * per_gemm_workspace_stride; + const size_t desired_wsize = gemm_workspace_size + alignof(uint64_t) - 1; + + if (ith == 0 && params->wsize < desired_wsize) { + throw std::runtime_error("wsize less than desired_wsize"); + } + + std::vector qnbitgemm_args(batch_feature); + + for (size_t i = 0; i < batch_feature; i++) { + qnbitgemm_args[i].a_ptr = feature + gemm_m * gemm_k * i; + qnbitgemm_args[i].lda = gemm_k; + qnbitgemm_args[i].packed_quant_b_data = (const std::byte *) w_data; + qnbitgemm_args[i].quant_b_scale = nullptr; + + if constexpr (std::is_same_v) { + qnbitgemm_args[i].quant_b_zp = nullptr; + } else { + qnbitgemm_args[i].quant_b_zp = w_data; + } + + qnbitgemm_args[i].bias = nullptr; + qnbitgemm_args[i].c_ptr = output + gemm_m * gemm_n * i; + qnbitgemm_args[i].ldc = gemm_n; + } + + const uintptr_t ws_ptr = reinterpret_cast(params->wdata); + void * ws = reinterpret_cast((ws_ptr + alignof(uint64_t) - 1) & (~(alignof(uint64_t) - 1))); + const size_t quant_a_stride = block_count_k * q8_blk_size(QK4_0); + + { + constexpr size_t block_size_m = 4; + size_t per_gemm_block_count_m = div_round_up(gemm_m, block_size_m); + int32_t task_count = batch_feature * per_gemm_block_count_m; + int32_t task_per_thread = (task_count + nth - 1) / nth; + int32_t start = ith * task_per_thread; + int32_t end = std::min((ith + 1) * task_per_thread, task_count); + for (int32_t compute_idx = start; compute_idx < end; compute_idx++) { + int32_t gemm_idx = compute_idx / block_size_m; + int32_t m_idx = compute_idx % block_size_m * block_size_m; + const qnbitgemm_spacemit_ime_args & data = qnbitgemm_args[gemm_idx]; + int32_t rows_tobe_handled = (gemm_m - m_idx) > block_size_m ? block_size_m : (gemm_m - m_idx); + + if (rows_tobe_handled == block_size_m) { + const float * a_row_ptr = data.a_ptr + m_idx * data.lda; + std::byte * quant_a_row_ptr = + static_cast(ws) + gemm_idx * per_gemm_workspace_stride + m_idx * quant_a_stride; + sqnbitgemm_spacemit_ime::ime1::quantize_a_4row_i8(QK4_0, a_row_ptr, gemm_k, quant_a_row_ptr); + } else { + while (rows_tobe_handled) { + const float * a_row_ptr = data.a_ptr + m_idx * data.lda; + std::byte * quant_a_row_ptr = static_cast(ws) + + gemm_idx * per_gemm_workspace_stride + m_idx * quant_a_stride; + sqnbitgemm_spacemit_ime::ime1::quantize_a_row_i8(QK4_0, a_row_ptr, gemm_k, quant_a_row_ptr); + rows_tobe_handled -= 1; + m_idx += 1; + } + } + } + } + + ggml_barrier(params->threadpool); + + if (ith >= ggml::cpu::riscv64_spacemit::num_ai_cores) { + return; + } + nth = std::min(nth, int{ ggml::cpu::riscv64_spacemit::num_ai_cores }); + + size_t threads_per_gemm = nth / batch_feature; + constexpr size_t gemm_m_stride = 128; + size_t nc = gemm_n; + const size_t gemm_m_blocked = div_round_up(gemm_m, gemm_m_stride); + const size_t max_nc = div_round_up(gemm_n * gemm_m_blocked, threads_per_gemm); + if (max_nc < nc) { + nc = std::min(nc, div_round_up(max_nc, QGEMM_STRIDEN_THREAD_ALIGN) * QGEMM_STRIDEN_THREAD_ALIGN); + } + const size_t gemm_n_stride = nc; + const size_t thread_count_m = div_round_up(gemm_m, gemm_m_stride); + const size_t thread_count_n = div_round_up(gemm_n, gemm_n_stride); + threads_per_gemm = thread_count_m * thread_count_n; + + { + int task_count = batch_feature * threads_per_gemm; + int task_per_thread = (task_count + nth - 1) / nth; + int start = ith * task_per_thread; + int end = std::min((ith + 1) * task_per_thread, task_count); + for (int compute_idx = start; compute_idx < end; compute_idx++) { + const auto gemm_i = compute_idx / threads_per_gemm; + const auto blk_i = compute_idx % threads_per_gemm; + const auto * data = &qnbitgemm_args[gemm_i]; + + const auto tid_n = blk_i / thread_count_m; + const auto tid_m = blk_i % thread_count_m; + + const size_t m_start = tid_m * gemm_m_stride; + const size_t m_count = std::min(gemm_m - m_start, (size_t) gemm_m_stride); + + const size_t n_start = tid_n * gemm_n_stride; + const size_t n_count = std::min(gemm_n - n_start, (size_t) gemm_n_stride); + + void * per_gemm_ws = reinterpret_cast(ws) + gemm_i * per_gemm_workspace_stride; + + sqnbitgemm_spacemit_ime_i8i4(QK4_0, gemm_k, data, per_gemm_ws, m_start, m_count, n_start, n_count); + } + } + } + + int repack(struct ggml_tensor * t, const void * data, size_t data_size) override { + GGML_LOG_DEBUG("%s: repack tensor %s with %s_%dx%d\n", __func__, t->name, ggml_type_name(t->type), + (int) NB_COLS, (int) INTER_SIZE); + return ggml::cpu::riscv64_spacemit::repack(t, data, data_size); + } +}; + +class tensor_traits_common : public tensor_traits_base { + bool work_size(int /* n_threads */, const struct ggml_tensor * op, size_t & size) override { + switch (op->op) { + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + size = 0; + return true; + default: + // GGML_ABORT("fatal error"); + break; + } + return false; + } + + bool compute_forward(struct ggml_compute_params * params, struct ggml_tensor * op) override { + switch (op->op) { + case GGML_OP_NORM: + forward_norm_f32(params, op); + return true; + case GGML_OP_RMS_NORM: + forward_rms_norm_f32(params, op); + return true; + default: + // GGML_ABORT("fatal error"); + break; + } + return false; + } + + void forward_norm_f32(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float epsilon; + memcpy(&epsilon, dst->op_params, sizeof(float)); + + GGML_ASSERT(epsilon > 0.0f); + + auto * input = (float *) src0->data; + auto * output = (float *) dst->data; + + const auto hidden_size = ne00; + const auto task_count = ne01 * ne02 * ne03; + const auto task_per_thread = (task_count + nth - 1) / nth; + + const auto task_begin = ith * task_per_thread; + const auto task_end = std::min((ith + 1) * task_per_thread, task_count); + + for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { + auto offset = task_idx * hidden_size; + auto * p_input = const_cast(input + offset); + + auto * p_output = output + offset; + auto * p_temp_output = p_output; + auto * p_gamma_data = (const float *) nullptr; + auto * p_beta_data = (const float *) nullptr; + size_t gvl = __riscv_vsetvlmax_e32m4(); + vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl); + vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); + int64_t length = hidden_size; + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + // load data + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); + + sum = __riscv_vfadd_vv_f32m4(sum, src_data, gvl); + sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); + + __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); + + p_input += gvl; + p_temp_output += gvl; + length -= gvl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + + float mean = 0.f; + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); + vfloat32m1_t mean_v = + __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum, 0), __riscv_vget_v_f32m4_f32m1(sum, 1), gvl); + mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 2), gvl); + mean_v = __riscv_vfadd_vv_f32m1(mean_v, __riscv_vget_v_f32m4_f32m1(sum, 3), gvl); + mean_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_v, zero_v, gvl); + mean = __riscv_vfmv_f_s_f32m1_f32(mean_v); + mean /= hidden_size; + + vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), + __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); + mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); + + float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); + mean_square /= hidden_size; + mean_square = sqrt(mean_square - mean * mean + epsilon); + + mean_square = 1.0f / mean_square; + length = hidden_size; + p_temp_output = p_output; + + if (p_gamma_data == nullptr && p_beta_data == nullptr) { + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + length -= gvl; + } + } else if (p_beta_data == nullptr) { + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); + src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + p_gamma_data += gvl; + length -= gvl; + } + } else if (p_gamma_data != nullptr) { + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); + src_data = __riscv_vfsub_vf_f32m4(src_data, mean, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); + vfloat32m4_t beta_data_v = __riscv_vle32_v_f32m4(p_beta_data, gvl); + src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl); + p_beta_data += gvl; + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + p_gamma_data += gvl; + length -= gvl; + } + } + } + } + + void forward_rms_norm_f32(ggml_compute_params * params, ggml_tensor * op) { + const ggml_tensor * src0 = op->src[0]; + ggml_tensor * dst = op; + GGML_ASSERT(ggml_are_same_shape(src0, dst)); + GGML_ASSERT(src0->nb[0] == sizeof(float)); + + const int ith = params->ith; + const int nth = params->nth; + + GGML_TENSOR_UNARY_OP_LOCALS + + float epsilon; + memcpy(&epsilon, dst->op_params, sizeof(float)); + + GGML_ASSERT(epsilon > 0.0f); + + auto * input = (float *) src0->data; + auto * output = (float *) dst->data; + + const auto hidden_size = ne00; + const auto task_count = ne01 * ne02 * ne03; + const auto task_per_thread = (task_count + nth - 1) / nth; + + const auto task_begin = ith * task_per_thread; + const auto task_end = std::min((ith + 1) * task_per_thread, task_count); + + for (auto task_idx = task_begin; task_idx < task_end; task_idx++) { + auto offset = task_idx * hidden_size; + auto * p_input = const_cast(input + offset); + auto * p_output = output + offset; + auto * p_temp_output = p_output; + auto * p_gamma_data = (const float *) nullptr; + auto * p_beta_data = (const float *) nullptr; + + size_t gvl = __riscv_vsetvlmax_e32m4(); + // vfloat32m4_t sum = __riscv_vfmv_v_f_f32m4(0.f, gvl); + vfloat32m4_t sum_sq = __riscv_vfmv_v_f_f32m4(0.f, gvl); + int64_t length = hidden_size; + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + // load data + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_input, gvl); + + sum_sq = __riscv_vfmacc_vv_f32m4(sum_sq, src_data, src_data, gvl); + + __riscv_vse32_v_f32m4(p_temp_output, src_data, gvl); + + p_input += gvl; + p_temp_output += gvl; + length -= gvl; + } + + gvl = __riscv_vsetvlmax_e32m1(); + + // float mean = 0.f; + vfloat32m1_t zero_v = __riscv_vfmv_v_f_f32m1(0.f, gvl); + + vfloat32m1_t mean_square_v = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m4_f32m1(sum_sq, 0), + __riscv_vget_v_f32m4_f32m1(sum_sq, 1), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 2), gvl); + mean_square_v = __riscv_vfadd_vv_f32m1(mean_square_v, __riscv_vget_v_f32m4_f32m1(sum_sq, 3), gvl); + mean_square_v = __riscv_vfredusum_vs_f32m1_f32m1(mean_square_v, zero_v, gvl); + + float mean_square = __riscv_vfmv_f_s_f32m1_f32(mean_square_v); + mean_square /= hidden_size; + + mean_square = sqrt(mean_square + epsilon); + + mean_square = 1.0f / mean_square; + length = hidden_size; + p_temp_output = p_output; + + if (p_gamma_data == nullptr && p_beta_data == nullptr) { + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + length -= gvl; + } + } else if (p_beta_data == nullptr) { + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + p_gamma_data += gvl; + length -= gvl; + } + } else if (p_gamma_data != nullptr) { + while (length > 0) { + gvl = __riscv_vsetvl_e32m4(length); + vfloat32m4_t src_data = __riscv_vle32_v_f32m4(p_temp_output, gvl); + vfloat32m4_t gamma_data_v = __riscv_vle32_v_f32m4(p_gamma_data, gvl); + src_data = __riscv_vfmul_vf_f32m4(src_data, mean_square, gvl); + src_data = __riscv_vfmul_vv_f32m4(src_data, gamma_data_v, gvl); + vfloat32m4_t beta_data_v = __riscv_vle32_v_f32m4(p_beta_data, gvl); + src_data = __riscv_vfadd_vv_f32m4(src_data, beta_data_v, gvl); + p_beta_data += gvl; + __riscv_vse32_v_f32m4(p_output, src_data, gvl); + p_temp_output += gvl; + p_output += gvl; + p_gamma_data += gvl; + length -= gvl; + } + } + } + } + + int repack(struct ggml_tensor * t, const void * data, size_t data_size) override { + memcpy(t->data, data, data_size); + return 0; + } +}; + +static const tensor_traits q4_0_16x8_q8_0; +static const tensor_traits q4_1_16x8_q8_0; +static const tensor_traits q4_k_16x8_q8_0; +static const tensor_traits_common rvv_impl; + +} // namespace ggml::cpu::riscv64_spacemit + +static const ggml::cpu::tensor_traits * ggml_riscv64_spacemit_get_optimal_repack_type(const struct ggml_tensor * cur) { + if (cur->type == GGML_TYPE_Q4_0) { + if (cur->ne[1] % 16 == 0) { + return &ggml::cpu::riscv64_spacemit::q4_0_16x8_q8_0; + } + } else if (cur->type == GGML_TYPE_Q4_1) { + if (cur->ne[1] % 16 == 0) { + return &ggml::cpu::riscv64_spacemit::q4_1_16x8_q8_0; + } + } else if (cur->type == GGML_TYPE_Q4_K) { + if (cur->ne[1] % 16 == 0) { + return &ggml::cpu::riscv64_spacemit::q4_k_16x8_q8_0; + } + } else if (cur->type == GGML_TYPE_F32) { + return &ggml::cpu::riscv64_spacemit::rvv_impl; + } + + return nullptr; +} + +static enum ggml_status ggml_backend_riscv64_spacemit_buffer_init_tensor(ggml_backend_buffer_t buffer, + struct ggml_tensor * tensor) { + tensor->extra = + (void *) const_cast(ggml_riscv64_spacemit_get_optimal_repack_type(tensor)); + + GGML_UNUSED(buffer); + + return GGML_STATUS_SUCCESS; +} + +static void ggml_backend_riscv64_spacemit_buffer_set_tensor(ggml_backend_buffer_t buffer, + struct ggml_tensor * tensor, + const void * data, + size_t offset, + size_t size) { + GGML_ASSERT(offset == 0); + GGML_ASSERT(size == ggml_nbytes(tensor)); + + auto tensor_traits = (ggml::cpu::riscv64_spacemit::tensor_traits_base *) tensor->extra; + if (tensor_traits) { + auto OK = tensor_traits->repack(tensor, data, size); + GGML_ASSERT(OK == 0); + } + + GGML_UNUSED(buffer); +} + +static const char * ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name(ggml_backend_buffer_type_t buft) { + return "CPU_RISCV64_SPACEMIT"; + + GGML_UNUSED(buft); +} + +static ggml_backend_buffer_t ggml_backend_cpu_riscv64_spacemit_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, + size_t size) { + ggml_backend_buffer_t buffer = ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); + + if (buffer == nullptr) { + return nullptr; + } + + buffer->buft = buft; + buffer->iface.init_tensor = ggml_backend_riscv64_spacemit_buffer_init_tensor; + buffer->iface.set_tensor = ggml_backend_riscv64_spacemit_buffer_set_tensor; + buffer->iface.get_tensor = nullptr; + buffer->iface.cpy_tensor = nullptr; + return buffer; +} + +static size_t ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return 64; + + GGML_UNUSED(buft); +} + +static size_t ggml_backend_cpu_riscv64_spacemit_nbytes(ggml_backend_buffer_type_t buft, + const struct ggml_tensor * tensor) { + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + if (tensor->ne[i] <= 0) { + return 0; + } + } + + size_t nbytes; + const size_t blck_size = ggml_blck_size(tensor->type); + if (blck_size == 1) { + nbytes = ggml_type_size(tensor->type); + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + nbytes += (tensor->ne[i] - 1) * tensor->nb[i]; + } + } else { + nbytes = tensor->ne[0] * tensor->nb[0] / blck_size; + if (tensor->type == GGML_TYPE_Q4_K) { + GGML_ASSERT(nbytes % sizeof(block_q4_K) == 0); + nbytes = (nbytes / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8; + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + nbytes += (tensor->ne[i] - 1) * (tensor->nb[i] / sizeof(block_q4_K)) * sizeof(block_q4_1) * 8; + } + } else { + for (int i = 1; i < GGML_MAX_DIMS; ++i) { + nbytes += (tensor->ne[i] - 1) * tensor->nb[i]; + } + } + } + + GGML_UNUSED(buft); + return nbytes; +} + +namespace ggml::cpu::riscv64_spacemit { + +class extra_buffer_type : ggml::cpu::extra_buffer_type { + bool supports_op(ggml_backend_dev_t, const struct ggml_tensor * op) override { + switch (op->op) { + case GGML_OP_MUL_MAT: + if (op->src[0]->buffer && (ggml_n_dims(op->src[0]) == 2) && + op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type() && + ggml_riscv64_spacemit_get_optimal_repack_type(op->src[0])) { + if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { + return false; + } + if (op->src[1]->type == GGML_TYPE_F32) { + return true; + } + } + break; + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + if (op->src[0]->type == GGML_TYPE_F32) { + return true; + } + break; + default: + // GGML_ABORT("fatal error"); + break; + } + return false; + } + + ggml::cpu::tensor_traits * get_tensor_traits(const struct ggml_tensor * op) override { + switch (op->op) { + case GGML_OP_MUL_MAT: + if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_riscv64_spacemit_buffer_type()) { + return (ggml::cpu::tensor_traits *) op->src[0]->extra; + } + break; + case GGML_OP_NORM: + case GGML_OP_RMS_NORM: + return (ggml::cpu::tensor_traits *) (&ggml::cpu::riscv64_spacemit::rvv_impl); + default: + // GGML_ABORT("fatal error"); + break; + } + + return nullptr; + } +}; + +} // namespace ggml::cpu::riscv64_spacemit + +ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void) { + static struct ggml_backend_buffer_type ggml_backend_cpu_buffer_type_riscv64_spacemit = { + /* .iface = */ + { + /* .get_name = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_get_name, + /* .alloc_buffer = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_cpu_riscv64_spacemit_buffer_type_get_alignment, + /* .get_max_size = */ nullptr, + /* .get_alloc_size = */ ggml_backend_cpu_riscv64_spacemit_nbytes, + /* .is_host = */ nullptr, + }, + /* .device = */ + ggml_backend_reg_dev_get(ggml_backend_cpu_reg(), 0), + /* .context = */ + new ggml::cpu::riscv64_spacemit::extra_buffer_type(), + }; + + return &ggml_backend_cpu_buffer_type_riscv64_spacemit; +} diff --git a/ggml/src/ggml-cpu/spacemit/ime.h b/ggml/src/ggml-cpu/spacemit/ime.h new file mode 100644 index 0000000000..800d91acda --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ime.h @@ -0,0 +1,13 @@ +#pragma once + +#include "ggml-alloc.h" + +#ifdef __cplusplus +extern "C" { +#endif + +ggml_backend_buffer_type_t ggml_backend_cpu_riscv64_spacemit_buffer_type(void); + +#ifdef __cplusplus +} +#endif diff --git a/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp b/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp new file mode 100644 index 0000000000..cbbb6cd916 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ime1_kernels.cpp @@ -0,0 +1,3196 @@ +#include "ggml.h" +#include "ime_kernels.h" + +#include +#include + +// clang-format off +#if defined(__GNUC__) +#pragma GCC diagnostic ignored "-Woverlength-strings" +#pragma GCC diagnostic ignored "-Wcast-qual" +#pragma GCC diagnostic ignored "-Wunused-parameter" +#endif +// clang-format on +namespace sqnbitgemm_spacemit_ime { + +#define QUANTIZEM4ROW_KERNEL \ + "vmv.s.x v16, zero \n\t" \ + "vfabs.v v8, v0 \n\t" \ + "vfredmax.vs v16, v8, v16 \n\t" \ + "vfmv.f.s f10, v16 \n\t" \ + "fmul.s f10, f10, %[RMAXREC] \n\t" \ + "fsw f10, (a1) \n\t" \ + "fdiv.s f11, %[FONE], f10 \n\t" \ + "vfmul.vf v16, v0, f11 \n\t" \ + "vfcvt.x.f.v v16, v16 \n\t" \ + "vsetvli t0, zero, e16, mf2 \n\t" \ + "vnclip.wx v16, v16, zero \n\t" \ + "vnclip.wx v17, v17, zero \n\t" \ + "vnclip.wx v18, v18, zero \n\t" \ + "vnclip.wx v19, v19, zero \n\t" \ + "vnclip.wx v20, v20, zero \n\t" \ + "vnclip.wx v21, v21, zero \n\t" \ + "vnclip.wx v22, v22, zero \n\t" \ + "vnclip.wx v23, v23, zero \n\t" \ + "vsetvli t0, zero, e8, mf4 \n\t" \ + "vnclip.wx v24, v16, zero \n\t" \ + "vnclip.wx v25, v17, zero \n\t" \ + "vnclip.wx v26, v18, zero \n\t" \ + "vnclip.wx v27, v19, zero \n\t" \ + "vnclip.wx v28, v20, zero \n\t" \ + "vnclip.wx v29, v21, zero \n\t" \ + "vnclip.wx v30, v22, zero \n\t" \ + "vnclip.wx v31, v23, zero \n\t" + +#define QUANTIZEM4ROW_STORE \ + "addi t1, %[BlkLen], 0 \n\t" \ + "vsetvli t0, t1, e8, mf4 \n\t" \ + "vse8.v v24, (s1) \n\t" \ + "addi s1, s1, 32 \n\t" \ + "sub t1, t1, t0 \n\t" \ + "vsetvli t0, t1, e8, mf4 \n\t" \ + "vse8.v v25, (s1) \n\t" \ + "addi s1, s1, 32 \n\t" \ + "sub t1, t1, t0 \n\t" \ + "vsetvli t0, t1, e8, mf4 \n\t" \ + "vse8.v v26, (s1) \n\t" \ + "addi s1, s1, 32 \n\t" \ + "sub t1, t1, t0 \n\t" \ + "vsetvli t0, t1, e8, mf4 \n\t" \ + "vse8.v v27, (s1) \n\t" \ + "addi s1, s1, 32 \n\t" \ + "sub t1, t1, t0 \n\t" \ + "vsetvli t0, t1, e8, mf4 \n\t" \ + "vse8.v v28, (s1) \n\t" \ + "addi s1, s1, 32 \n\t" \ + "sub t1, t1, t0 \n\t" \ + "vsetvli t0, t1, e8, mf4 \n\t" \ + "vse8.v v29, (s1) \n\t" \ + "addi s1, s1, 32 \n\t" \ + "sub t1, t1, t0 \n\t" \ + "vsetvli t0, t1, e8, mf4 \n\t" \ + "vse8.v v30, (s1) \n\t" \ + "addi s1, s1, 32 \n\t" \ + "sub t1, t1, t0 \n\t" \ + "vsetvli t0, t1, e8, mf4 \n\t" \ + "vse8.v v31, (s1) \n\t" + +namespace ime1 { +void quantize_a_4row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) { + constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1); + const float fone = 1.0f; + + if (BlkLen == 16 || BlkLen == 32 || BlkLen == 64) { + for (size_t row_index = 0; row_index < 4; ++row_index) { + const float * SRC = A + row_index * CountK; + std::byte * DST = QuantA + row_index * sizeof(float); + + const size_t offset = (4 - row_index) * 4 + row_index * 8; + const size_t stride = 4 * (sizeof(float) + BlkLen); + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "addi t2, %[CountK], 0 \n\t" + "addi a1, %[DST], 0 \n\t" + "blt t2, %[BlkLen], TAIL%= \n\t" + + "LOOP%=: \n\t" + "vsetvli t0, %[BlkLen], e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "sub t2, t2, t0 \n\t" + "slli t1, t0, 2 \n\t" + "add %[SRC], %[SRC], t1 \n\t" + "add s1, a1, %[OFFSET] \n\t" + + QUANTIZEM4ROW_KERNEL QUANTIZEM4ROW_STORE + + "add a1, a1, %[STRIDE] \n\t" + "bge t2, %[BlkLen], LOOP%= \n\t" + + "TAIL%=: \n\t" + "blez t2, QUIT%= \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vsetvli t0, t2, e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "add s1, a1, %[OFFSET] \n\t" + + QUANTIZEM4ROW_KERNEL + + "addi t3, %[BlkLen], 0 \n\t" + "addi s2, s1, 0 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vxor.vv v8, v8, v8 \n\t" + "SET_ZERO%=: \n\t" + "vse8.v v8, (s2) \n\t" + "addi s2, s2, 32 \n\t" + "addi t3, t3, -8 \n\t" + "bnez t3, SET_ZERO%= \n\t" + + QUANTIZEM4ROW_STORE + + "QUIT%=: \n\t" + : [SRC] "+r"(SRC) + : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), + [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) + : "cc", "t0", "t1", "t2", "t3", "a1", "s1", "s2", "f10", "f11"); + } + } else if (BlkLen == 128) { + for (size_t row_index = 0; row_index < 4; ++row_index) { + const float * SRC = A + row_index * CountK; + std::byte * DST = QuantA + row_index * sizeof(float); + + const size_t offset = (4 - row_index) * 4 + row_index * 8; + const size_t stride = 4 * (sizeof(float) + BlkLen); + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "li t6, 32 \n\t" + "addi t2, %[CountK], 0 \n\t" + "addi a1, %[DST], 0 \n\t" + "add s1, a1, %[OFFSET] \n\t" + "blt t2, %[BlkLen], TAIL%= \n\t" + + "LOOP%=: \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v8, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "addi t2, t2, -128 \n\t" + + "QUANTIZE%=: \n\t" + "add s1, a1, %[OFFSET] \n\t" + "vfabs.v v16, v0 \n\t" + "vfabs.v v24, v8 \n\t" + "vfmax.vv v16, v24, v16 \n\t" + "vfredmax.vs v24, v16, v24 \n\t" + "vfmv.f.s f10, v24 \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (a1) \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vfmul.vf v16, v0, f11 \n\t" + "vfmul.vf v24, v8, f11 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vsetvli t0, zero, e16, m4 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v20, v24, zero \n\t" + "vsetvli t0, zero, e8, m4 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vsetvli t0, zero, e64, m4 \n\t" + "vsse64.v v16, (s1), t6 \n\t" + "add a1, a1, %[STRIDE] \n\t" + "bge t2, %[BlkLen], LOOP%= \n\t" + + "TAIL%=: \n\t" + "blez t2, QUIT%= \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vxor.vv v8, v8, v8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "vsetvli t0, t2, e32, m8 \n\t" + "sub t2, t2, t0 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vsetvli t0, t2, e32, m8 \n\t" + "vle32.v v8, (%[SRC]) \n\t" + "sub t2, t2, t2 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "jal x0, QUANTIZE%= \n\t" + + "QUIT%=: \n\t" + : [SRC] "+r"(SRC) + : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), + [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) + : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11"); + } + } else if (BlkLen == 256) { + for (size_t row_index = 0; row_index < 4; ++row_index) { + const float * SRC = A + row_index * CountK; + std::byte * DST = QuantA + row_index * sizeof(float); + const size_t offset = (4 - row_index) * 4 + row_index * 8; + const size_t stride = 4 * (sizeof(float) + BlkLen); + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "li t6, 32 \n\t" + "addi t2, %[CountK], 0 \n\t" + "addi a1, %[DST], 0 \n\t" + "add s1, a1, %[OFFSET] \n\t" + "blt t2, %[BlkLen], TAIL%= \n\t" + + "LOOP%=: \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v8, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v16, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v24, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], -768 \n\t" + "addi t2, t2, -256 \n\t" + "vfabs.v v0, v0 \n\t" + "vfabs.v v8, v8 \n\t" + "vfabs.v v16, v16 \n\t" + "vfabs.v v24, v24 \n\t" + "vfmax.vv v8, v0, v8 \n\t" + "vfmax.vv v24, v24, v16 \n\t" + "vfmax.vv v8, v8, v24 \n\t" + "vfredmax.vs v24, v8, v24 \n\t" + "vfmv.f.s f10, v24 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v8, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v16, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v24, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + + "QUANTIZE%=: \n\t" + "add s1, a1, %[OFFSET] \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (a1) \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vfmul.vf v0, v0, f11 \n\t" + "vfmul.vf v8, v8, f11 \n\t" + "vfmul.vf v16, v16, f11 \n\t" + "vfmul.vf v24, v24, f11 \n\t" + "vfcvt.x.f.v v0, v0 \n\t" + "vfcvt.x.f.v v8, v8 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vsetvli t0, zero, e16, m4 \n\t" + "vnclip.wx v0, v0, zero \n\t" + "vnclip.wx v4, v8, zero \n\t" + "vnclip.wx v8, v16, zero \n\t" + "vnclip.wx v12, v24, zero \n\t" + "vsetvli t0, zero, e8, m4 \n\t" + "vnclip.wx v0, v0, zero \n\t" + "vnclip.wx v4, v8, zero \n\t" + "vsetvli t0, zero, e64, m8 \n\t" + "vsse64.v v0, (s1), t6 \n\t" + "add a1, a1, %[STRIDE] \n\t" + "bge t2, %[BlkLen], LOOP%= \n\t" + + "TAIL%=: \n\t" + "blez t2, QUIT%= \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vxor.vv v8, v8, v8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "addi t1, t2, 0 \n\t" + "vsetvli t0, t1, e32, m8 \n\t" + "sub t1, t1, t0 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vsetvli t0, t1, e32, m8 \n\t" + "sub t1, t1, t0 \n\t" + "vle32.v v8, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vsetvli t0, t1, e32, m8 \n\t" + "sub t1, t1, t0 \n\t" + "vle32.v v16, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vsetvli t0, t1, e32, m8 \n\t" + "vle32.v v24, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], -768 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vfabs.v v0, v0 \n\t" + "vfabs.v v8, v8 \n\t" + "vfabs.v v16, v16 \n\t" + "vfabs.v v24, v24 \n\t" + "vfmax.vv v8, v0, v8 \n\t" + "vfmax.vv v24, v16, v24 \n\t" + "vfmax.vv v8, v8, v24 \n\t" + "vfredmax.vs v24, v8, v24 \n\t" + "vfmv.f.s f10, v24 \n\t" + "add s1, a1, %[OFFSET] \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (a1) \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vsetvli t0, zero, e64, m8 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vsse64.v v0, (s1), t6 \n\t" + + "TAIL_LOOP%=: \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vsetvli t0, t2, e32, m1 \n\t" + "sub t2, t2, t0 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 32 \n\t" + "vfmul.vf v1, v0, f11 \n\t" + "vfcvt.x.f.v v2, v1 \n\t" + "vsetvli t0, zero, e16, mf2 \n\t" + "vnclip.wx v3, v2, zero \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vnclip.wx v3, v3, zero \n\t" + "vse8.v v3, (s1) \n\t" + "addi s1, s1, 32 \n\t" + "bnez t2, TAIL_LOOP%= \n\t" + + "QUIT%=: \n\t" + : [SRC] "+r"(SRC) + : [DST] "r"(DST), [BlkLen] "r"(BlkLen), [OFFSET] "r"(offset), [STRIDE] "r"(stride), + [CountK] "r"(CountK), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) + : "cc", "t0", "t1", "t2", "t6", "a1", "s1", "s2", "f10", "f11"); + } + } +} + +void quantize_a_row_i8(size_t BlkLen, const float * A, size_t CountK, std::byte * QuantA) { + const float * SRC = A; + std::byte * DST = QuantA; + constexpr float range_max_reciprocal = 1.0f / ((1 << 7) - 1); + const float fone = 1.0f; + std::byte * QuantA_offset = QuantA + CountK + 4 * ((CountK + BlkLen - 1) / BlkLen); + size_t offset = (CountK + BlkLen - 1) / BlkLen * BlkLen - CountK; + + if (CountK <= BlkLen) { + float max_abs_A = 0.0f; + for (size_t k = 0; k < CountK; k++) { + max_abs_A = std::max(max_abs_A, fabsf(A[k])); + } + float scale_A = max_abs_A * range_max_reciprocal; + + ((float *) QuantA)[0] = scale_A; + + auto * QuantAData_offset = (int8_t *) (QuantA + sizeof(float)); + + for (size_t k = 0; k < CountK; k++) { + QuantAData_offset[k] = + (int8_t) std::clamp(roundf(A[k] / scale_A), (float) std::numeric_limits::lowest(), + (float) std::numeric_limits::max()); + } + for (size_t k = CountK; k < BlkLen; k++) { + QuantAData_offset[k] = 0; + } + + return; + } + + if (BlkLen != 32 || BlkLen != 64 || BlkLen != 128) { + __asm__ volatile( + "vsetvli t0, zero, e8, m8 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "LOOP%=: \n\t" + "vsetvli t0, %[CNT], e8, m8 \n\t" + "vse8.v v24, (%[DST]) \n\t" + "addi %[DST], %[DST], 128 \n\t" + "sub %[CNT], %[CNT], t0 \n\t" + "bnez %[CNT], LOOP%= \n\t" + : [DST] "+r"(QuantA_offset), [CNT] "+r"(offset) + : + : "cc", "t0"); + } + if (BlkLen == 16) { + float buffer[64] = { 0.0f }; + __asm__ volatile( + "addi t3, zero, 16*8 \n\t" + "addi t2, zero, 16 \n\t" + "blt %[K], t3, LOOP_K%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + "LOOP_MAIN%=: \n\t" + "vsetvli t1, zero, e32, m2 \n\t" + "addi %[K], %[K], -128 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "vle32.v v2, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "vle32.v v4, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "vle32.v v6, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "vle32.v v8, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "vle32.v v10, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "vle32.v v12, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "vle32.v v14, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "addi a1, %[BUFFER], 0 \n\t" + "vfabs.v v16, v0 \n\t" + "vfabs.v v18, v2 \n\t" + "vfabs.v v20, v4 \n\t" + "vfabs.v v22, v6 \n\t" + "vfabs.v v24, v8 \n\t" + "vfabs.v v26, v10 \n\t" + "vfabs.v v28, v12 \n\t" + "vfabs.v v30, v14 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vfmax.vv v18, v18, v19 \n\t" + "vfmax.vv v20, v20, v21 \n\t" + "vfmax.vv v22, v22, v23 \n\t" + "vfmax.vv v24, v24, v25 \n\t" + "vfmax.vv v26, v26, v27 \n\t" + "vfmax.vv v28, v28, v29 \n\t" + "vfmax.vv v30, v30, v31 \n\t" + "vse32.v v16, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vse32.v v18, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vse32.v v20, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vse32.v v22, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vse32.v v24, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vse32.v v26, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vse32.v v28, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vse32.v v30, (a1) \n\t" + "addi a1, %[BUFFER], 0 \n\t" + "flw f0, (a1) \n\t" + "flw f1, 4(a1) \n\t" + "flw f2, 8(a1) \n\t" + "flw f3, 12(a1) \n\t" + "flw f4, 16(a1) \n\t" + "flw f5, 20(a1) \n\t" + "flw f6, 24(a1) \n\t" + "flw f7, 28(a1) \n\t" + "addi a1, a1, 32 \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f10, f3, f7 \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "fdiv.s f10, %[FONE], f10 \n\t" + "flw f0, (a1) \n\t" + "flw f1, 4(a1) \n\t" + "flw f2, 8(a1) \n\t" + "flw f3, 12(a1) \n\t" + "flw f4, 16(a1) \n\t" + "flw f5, 20(a1) \n\t" + "flw f6, 24(a1) \n\t" + "flw f7, 28(a1) \n\t" + "addi a1, a1, 32 \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f11, f3, f7 \n\t" + "fmul.s f11, f11, %[RMAXREC] \n\t" + "fsw f11, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "fdiv.s f11, %[FONE], f11 \n\t" + "flw f0, (a1) \n\t" + "flw f1, 4(a1) \n\t" + "flw f2, 8(a1) \n\t" + "flw f3, 12(a1) \n\t" + "flw f4, 16(a1) \n\t" + "flw f5, 20(a1) \n\t" + "flw f6, 24(a1) \n\t" + "flw f7, 28(a1) \n\t" + "addi a1, a1, 32 \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f12, f3, f7 \n\t" + "fmul.s f12, f12, %[RMAXREC] \n\t" + "fsw f12, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "fdiv.s f12, %[FONE], f12 \n\t" + "flw f0, (a1) \n\t" + "flw f1, 4(a1) \n\t" + "flw f2, 8(a1) \n\t" + "flw f3, 12(a1) \n\t" + "flw f4, 16(a1) \n\t" + "flw f5, 20(a1) \n\t" + "flw f6, 24(a1) \n\t" + "flw f7, 28(a1) \n\t" + "addi a1, a1, 32 \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f13, f3, f7 \n\t" + "fmul.s f13, f13, %[RMAXREC] \n\t" + "fsw f13, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "fdiv.s f13, %[FONE], f13 \n\t" + "flw f0, (a1) \n\t" + "flw f1, 4(a1) \n\t" + "flw f2, 8(a1) \n\t" + "flw f3, 12(a1) \n\t" + "flw f4, 16(a1) \n\t" + "flw f5, 20(a1) \n\t" + "flw f6, 24(a1) \n\t" + "flw f7, 28(a1) \n\t" + "addi a1, a1, 32 \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f14, f3, f7 \n\t" + "fmul.s f14, f14, %[RMAXREC] \n\t" + "fsw f14, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "fdiv.s f14, %[FONE], f14 \n\t" + "flw f0, (a1) \n\t" + "flw f1, 4(a1) \n\t" + "flw f2, 8(a1) \n\t" + "flw f3, 12(a1) \n\t" + "flw f4, 16(a1) \n\t" + "flw f5, 20(a1) \n\t" + "flw f6, 24(a1) \n\t" + "flw f7, 28(a1) \n\t" + "addi a1, a1, 32 \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f15, f3, f7 \n\t" + "fmul.s f15, f15, %[RMAXREC] \n\t" + "fsw f15, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "fdiv.s f15, %[FONE], f15 \n\t" + "flw f0, (a1) \n\t" + "flw f1, 4(a1) \n\t" + "flw f2, 8(a1) \n\t" + "flw f3, 12(a1) \n\t" + "flw f4, 16(a1) \n\t" + "flw f5, 20(a1) \n\t" + "flw f6, 24(a1) \n\t" + "flw f7, 28(a1) \n\t" + "addi a1, a1, 32 \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f16, f3, f7 \n\t" + "fmul.s f16, f16, %[RMAXREC] \n\t" + "fsw f16, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "fdiv.s f16, %[FONE], f16 \n\t" + "flw f0, (a1) \n\t" + "flw f1, 4(a1) \n\t" + "flw f2, 8(a1) \n\t" + "flw f3, 12(a1) \n\t" + "flw f4, 16(a1) \n\t" + "flw f5, 20(a1) \n\t" + "flw f6, 24(a1) \n\t" + "flw f7, 28(a1) \n\t" + "addi a1, a1, 32 \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f17, f3, f7 \n\t" + "fmul.s f17, f17, %[RMAXREC] \n\t" + "fsw f17, (%[DST]) \n\t" + "addi %[DST], %[DST], -136 \n\t" + "fdiv.s f17, %[FONE], f17 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmul.vf v16, v0, f10 \n\t" + "vfmul.vf v18, v2, f11 \n\t" + "vfmul.vf v20, v4, f12 \n\t" + "vfmul.vf v22, v6, f13 \n\t" + "vfmul.vf v24, v8, f14 \n\t" + "vfmul.vf v26, v10, f15 \n\t" + "vfmul.vf v28, v12, f16 \n\t" + "vfmul.vf v30, v14, f17 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v18, v18 \n\t" + "vfcvt.x.f.v v20, v20 \n\t" + "vfcvt.x.f.v v22, v22 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vfcvt.x.f.v v26, v26 \n\t" + "vfcvt.x.f.v v28, v28 \n\t" + "vfcvt.x.f.v v30, v30 \n\t" + "vsetvli t0, zero, e16, m1 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v18, v18, zero \n\t" + "vnclip.wx v20, v20, zero \n\t" + "vnclip.wx v22, v22, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vnclip.wx v26, v26, zero \n\t" + "vnclip.wx v28, v28, zero \n\t" + "vnclip.wx v30, v30, zero \n\t" + "vsetvli t0, t1, e8, mf2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v18, v18, zero \n\t" + "vnclip.wx v20, v20, zero \n\t" + "vnclip.wx v22, v22, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vnclip.wx v26, v26, zero \n\t" + "vnclip.wx v28, v28, zero \n\t" + "vnclip.wx v30, v30, zero \n\t" + "vse8.v v16, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "vse8.v v18, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "vse8.v v20, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "vse8.v v22, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "vse8.v v24, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "vse8.v v26, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "vse8.v v28, (%[DST]) \n\t" + "addi %[DST], %[DST], 20 \n\t" + "vse8.v v30, (%[DST]) \n\t" + "addi %[DST], %[DST], 16 \n\t" + "bge %[K], t3, LOOP_MAIN%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + "LOOP_K%=: \n\t" + "vsetvli t1, %[K], e32, m2 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 64 \n\t" + "sub %[K], %[K], t1 \n\t" + "vfabs.v v16, v0 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vse32.v v16, (%[BUFFER]) \n\t" + "flw f0, (%[BUFFER]) \n\t" + "flw f1, 4(%[BUFFER]) \n\t" + "flw f2, 8(%[BUFFER]) \n\t" + "flw f3, 12(%[BUFFER]) \n\t" + "flw f4, 16(%[BUFFER]) \n\t" + "flw f5, 20(%[BUFFER]) \n\t" + "flw f6, 24(%[BUFFER]) \n\t" + "flw f7, 28(%[BUFFER]) \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f10, f3, f7 \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (%[DST]) \n\t" + "addi %[DST], %[DST], 4 \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmul.vf v16, v0, f11 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vsetvli t0, zero, e16, m1 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vsetvli t0, t1, e8, mf2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vse8.v v16, (%[DST]) \n\t" + "addi %[DST], %[DST], 16 \n\t" + "bge %[K], t2, LOOP_K%= \n\t" + "TAIL%=: \n\t" + "blez %[K], END%= \n\t" + "vsetvli t0, t3, e32, m2 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "jal x0, LOOP_K%= \n\t" + "END%=: \n\t" + : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK) + : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BUFFER] "r"(buffer) + : "cc", "t3", "t2", "t1", "t0", "a1", "f0", "f1", "f2", "f3", "f4", "f5", "f6", "f7", "f10", "f11", "f12", + "f13", "f14", "f15", "f16", "f17"); + } else if (BlkLen == 32) { + __asm__ volatile( + "addi t3, zero, 32*4 \n\t" + "addi t2, zero, 32 \n\t" + + "addi a1, %[SRC], 0 \n\t" + "addi a2, %[SRC], 128 \n\t" + "addi a3, %[SRC], 256 \n\t" + "addi a4, %[SRC], 384 \n\t" + + "addi s1, %[DST], 0 \n\t" + "addi s2, %[DST], 36 \n\t" + "addi s3, %[DST], 72 \n\t" + "addi s4, %[DST], 108 \n\t" + "blt %[K], t3, LOOP_K%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + + "LOOP_MAIN%=: \n\t" + "vsetvli t1, zero, e32, m4 \n\t" + "addi %[K], %[K], -128 \n\t" + "vle32.v v0, (a1) \n\t" + "addi a1, a1, 512 \n\t" + "vle32.v v4, (a2) \n\t" + "addi a2, a2, 512 \n\t" + "vle32.v v8, (a3) \n\t" + "addi a3, a3, 512 \n\t" + "vle32.v v12, (a4) \n\t" + "addi a4, a4, 512 \n\t" + "vfabs.v v16, v0 \n\t" + "vfabs.v v20, v4 \n\t" + "vfabs.v v24, v8 \n\t" + "vfabs.v v28, v12 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v16, v16, v18 \n\t" + "vfmax.vv v20, v20, v22 \n\t" + "vfmax.vv v24, v24, v26 \n\t" + "vfmax.vv v28, v28, v30 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vfmax.vv v20, v20, v21 \n\t" + "vfmax.vv v24, v24, v25 \n\t" + "vfmax.vv v28, v28, v29 \n\t" + + "vfredmax.vs v17, v16, v17 \n\t" + "vfredmax.vs v21, v20, v21 \n\t" + "vfredmax.vs v25, v24, v25 \n\t" + "vfredmax.vs v29, v28, v29 \n\t" + "vfmv.f.s f10, v17 \n\t" + "vfmv.f.s f11, v21 \n\t" + "vfmv.f.s f12, v25 \n\t" + "vfmv.f.s f13, v29 \n\t" + + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fmul.s f11, f11, %[RMAXREC] \n\t" + "fmul.s f12, f12, %[RMAXREC] \n\t" + "fmul.s f13, f13, %[RMAXREC] \n\t" + "fsw f10, (s1) \n\t" + "addi s1, s1, 4 \n\t" + + "fsw f11, (s2) \n\t" + "addi s2, s2, 4 \n\t" + "fsw f12, (s3) \n\t" + "addi s3, s3, 4 \n\t" + "fsw f13, (s4) \n\t" + "addi s4, s4, 4 \n\t" + "fdiv.s f10, %[FONE], f10 \n\t" + "fdiv.s f11, %[FONE], f11 \n\t" + "fdiv.s f12, %[FONE], f12 \n\t" + "fdiv.s f13, %[FONE], f13 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmul.vf v16, v0, f10 \n\t" + "vfmul.vf v20, v4, f11 \n\t" + "vfmul.vf v24, v8, f12 \n\t" + "vfmul.vf v28, v12, f13 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v20, v20 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vfcvt.x.f.v v28, v28 \n\t" + "vsetvli t0, zero, e16, m2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v20, v20, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vnclip.wx v28, v28, zero \n\t" + "vsetvli t0, t1, e8, m1 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v20, v20, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vnclip.wx v28, v28, zero \n\t" + "vse8.v v16, (s1) \n\t" + "addi s1, s1, 140 \n\t" + "vse8.v v20, (s2) \n\t" + "addi s2, s2, 140 \n\t" + "vse8.v v24, (s3) \n\t" + "addi s3, s3, 140 \n\t" + "vse8.v v28, (s4) \n\t" + "addi s4, s4, 140 \n\t" + "bge %[K], t3, LOOP_MAIN%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + "LOOP_K%=: \n\t" + "vsetvli t1, %[K], e32, m4 \n\t" + "vle32.v v0, (a1) \n\t" + "addi a1, a1, 128 \n\t" + "sub %[K], %[K], t1 \n\t" + "vfabs.v v16, v0 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v16, v16, v18 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vfredmax.vs v17, v16, v17 \n\t" + "vfmv.f.s f10, v17 \n\t" + + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (s1) \n\t" + "addi s1, s1, 4 \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmul.vf v16, v0, f11 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vsetvli t0, zero, e16, m2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vse8.v v16, (s1) \n\t" + "addi s1, s1, 32 \n\t" + "bge %[K], t2, LOOP_K%= \n\t" + "TAIL%=: \n\t" + "blez %[K], END%= \n\t" + "vsetvli t0, t3, e32, m4 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "jal x0, LOOP_K%= \n\t" + "END%=: \n\t" + : [K] "+r"(CountK) + : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC), [DST] "r"(DST) + : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "a3", "a4", "s1", "s2", "s3", "s4", "f10", "f11", "f12", "f13"); + } else if (BlkLen == 64) { + __asm__ volatile( + "addi t3, zero, 64*2 \n\t" + "addi t2, zero, 64 \n\t" + "addi a1, %[SRC], 0 \n\t" + "addi a2, %[SRC], 256 \n\t" + "addi s1, %[DST], 0 \n\t" + "addi s2, %[DST], 68 \n\t" + "blt %[K], t3, LOOP_K%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + "LOOP_MAIN%=: \n\t" + "vsetvli t1, zero, e32, m8 \n\t" + "addi %[K], %[K], -128 \n\t" + "vle32.v v0, (a1) \n\t" + "addi a1, a1, 512 \n\t" + "vle32.v v8, (a2) \n\t" + "addi a2, a2, 512 \n\t" + "vfabs.v v16, v0 \n\t" + "vfabs.v v24, v8 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmax.vv v16, v16, v20 \n\t" + "vfmax.vv v24, v24, v28 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v16, v16, v18 \n\t" + "vfmax.vv v24, v24, v26 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vfmax.vv v24, v24, v25 \n\t" + "vfredmax.vs v17, v16, v17 \n\t" + "vfredmax.vs v25, v24, v25 \n\t" + "vfmv.f.s f10, v17 \n\t" + "vfmv.f.s f11, v25 \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fmul.s f11, f11, %[RMAXREC] \n\t" + "fsw f10, (s1) \n\t" + "addi s1, s1, 4 \n\t" + "fsw f11, (s2) \n\t" + "addi s2, s2, 4 \n\t" + "fdiv.s f10, %[FONE], f10 \n\t" + "fdiv.s f11, %[FONE], f11 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vfmul.vf v16, v0, f10 \n\t" + "vfmul.vf v24, v8, f11 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vsetvli t0, zero, e16, m4 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vsetvli t0, t1, e8, m2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v24, v24, zero \n\t" + "vse8.v v16, (s1) \n\t" + "addi s1, s1, 132 \n\t" + "vse8.v v24, (s2) \n\t" + "addi s2, s2, 132 \n\t" + "bge %[K], t3, LOOP_MAIN%= \n\t" + "blt %[K], t2, TAIL%= \n\t" + "LOOP_K%=: \n\t" + "vsetvli t1, %[K], e32, m8 \n\t" + "vle32.v v0, (a1) \n\t" + "addi a1, a1, 256 \n\t" + "sub %[K], %[K], t1 \n\t" + "vfabs.v v16, v0 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmax.vv v16, v16, v20 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v16, v16, v18 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v16, v16, v17 \n\t" + "vfredmax.vs v17, v16, v17 \n\t" + "vfmv.f.s f10, v17 \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (s1) \n\t" + "addi s1, s1, 4 \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vfmul.vf v16, v0, f11 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vsetvli t0, zero, e16, m4 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vsetvli t0, zero, e8, m2 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vse8.v v16, (s1) \n\t" + "addi s1, s1, 64 \n\t" + "bge %[K], t2, LOOP_K%= \n\t" + "TAIL%=: \n\t" + "blez %[K], END%= \n\t" + "vsetvli t0, t3, e32, m8 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "jal x0, LOOP_K%= \n\t" + "END%=: \n\t" + : [K] "+r"(CountK) + : [SRC] "r"(SRC), [DST] "r"(DST), [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal) + : "cc", "t3", "t2", "t1", "t0", "a1", "a2", "s1", "s2", "f10", "f11"); + } else if (BlkLen == 128) { + __asm__ volatile( + "addi t2, zero, 128 \n\t" + "addi a1, %[SRC], 0 \n\t" + "addi a2, %[SRC], 256 \n\t" + "blt %[K], t2, TAIL%= \n\t" + "LOOP_K%=: \n\t" + "vsetvli t1, zero, e32, m8 \n\t" + "vle32.v v0, (a1) \n\t" + "addi a1, a1, 512 \n\t" + "vle32.v v8, (a2) \n\t" + "addi a2, a2, 512 \n\t" + "sub %[K], %[K], t2 \n\t" + "QUANT%=: \n\t" + "vfabs.v v16, v0 \n\t" + "vfabs.v v24, v8 \n\t" + "vfmax.vv v24, v16, v24 \n\t" + "vsetvli t1, zero, e32, m4 \n\t" + "vfmax.vv v28, v24, v28 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v30, v28, v30 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v30, v30, v31 \n\t" + "vfredmax.vs v31, v30, v31 \n\t" + "vfmv.f.s f10, v31 \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (%[DST]) \n\t" + "addi %[DST], %[DST], 4 \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vfmul.vf v16, v0, f11 \n\t" + "vfmul.vf v24, v8, f11 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vsetvli t0, zero, e16, m4 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vnclip.wx v20, v24, zero \n\t" + "vsetvli t0, zero, e8, m4 \n\t" + "vnclip.wx v16, v16, zero \n\t" + "vse8.v v16, (%[DST]) \n\t" + "addi %[DST], %[DST], 128 \n\t" + "bge %[K], t2, LOOP_K%= \n\t" + "TAIL%=: \n\t" + "blez %[K], END%= \n\t" + "vsetvli t1, zero, e32, m8 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vxor.vv v8, v8, v8 \n\t" + "vsetvli t0, %[K], e32, m8 \n\t" + "vle32.v v0, (a1) \n\t" + "sub %[K], %[K], t0 \n\t" + "vsetvli t0, %[K], e32, m8 \n\t" + "vle32.v v8, (a2) \n\t" + "sub %[K], %[K], t0 \n\t" + "vsetvli t1, zero, e32, m8 \n\t" + "jal x0, QUANT%= \n\t" + "END%=: \n\t" + + : [DST] "+r"(DST), [K] "+r"(CountK) + : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [SRC] "r"(SRC) + : "cc", "t2", "t1", "t0", "a1", "a2", "f10", "f11"); + } else { + float buffer[8] = { 0.0f }; + size_t cnt = BlkLen / 256; + + __asm__ volatile( + "slli t3, %[BLK], 2 \n\t" + "blt %[K], %[BLK], LOOP_TAIL%= \n\t" + "LOOP_MAIN%=: \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "vse32.v v31, (%[BUFFER]) \n\t" + "addi t6, %[CNT], 0 \n\t" + "LOOP_CMP%=: \n\t" + "addi t6, t6, -1 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v8, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v16, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v24, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vfabs.v v0, v0 \n\t" + "vfabs.v v8, v8 \n\t" + "vfabs.v v16, v16 \n\t" + "vfabs.v v24, v24 \n\t" + "vfmax.vv v8, v0, v8 \n\t" + "vfmax.vv v16, v16, v24 \n\t" + "vfmax.vv v0, v0, v16 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmax.vv v0, v0, v4 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v0, v0, v2 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v0, v0, v1 \n\t" + "vle32.v v30, (%[BUFFER]) \n\t" + "vfmax.vv v31, v30, v0 \n\t" + "vse32.v v31, (%[BUFFER]) \n\t" + "bnez t6, LOOP_CMP%= \n\t" + "sub %[SRC], %[SRC], t3 \n\t" + "addi t6, %[CNT], 0 \n\t" + "flw f0, (%[BUFFER]) \n\t" + "flw f1, 4(%[BUFFER]) \n\t" + "flw f2, 8(%[BUFFER]) \n\t" + "flw f3, 12(%[BUFFER]) \n\t" + "flw f4, 16(%[BUFFER]) \n\t" + "flw f5, 20(%[BUFFER]) \n\t" + "flw f6, 24(%[BUFFER]) \n\t" + "flw f7, 28(%[BUFFER]) \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f10, f3, f7 \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (%[DST]) \n\t" + "addi %[DST], %[DST], 4 \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "addi t6, %[CNT], 0 \n\t" + "LOOP_QUANT%=: \n\t" + "addi t6, t6, -1 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v8, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v16, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vle32.v v24, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vfmul.vf v0, v0, f11 \n\t" + "vfmul.vf v8, v8, f11 \n\t" + "vfmul.vf v16, v16, f11 \n\t" + "vfmul.vf v24, v24, f11 \n\t" + "vfcvt.x.f.v v0, v0 \n\t" + "vfcvt.x.f.v v8, v8 \n\t" + "vfcvt.x.f.v v16, v16 \n\t" + "vfcvt.x.f.v v24, v24 \n\t" + "vsetvli t0, zero, e16, m4 \n\t" + "vnclip.wx v0, v0, zero \n\t" + "vnclip.wx v4, v8, zero \n\t" + "vnclip.wx v8, v16, zero \n\t" + "vnclip.wx v12, v24, zero \n\t" + "vsetvli t0, zero, e8, m4 \n\t" + "vnclip.wx v0, v0, zero \n\t" + "vnclip.wx v4, v8, zero \n\t" + "vse8.v v0, (%[DST]) \n\t" + "addi %[DST], %[DST], 128 \n\t" + "vse8.v v4, (%[DST]) \n\t" + "addi %[DST], %[DST], 128 \n\t" + "bnez t6, LOOP_QUANT%= \n\t" + "sub %[K], %[K], %[BLK] \n\t" + "bge %[K], %[BLK], LOOP_MAIN%= \n\t" + "blez %[K], END%= \n\t" + "LOOP_TAIL%=: \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vxor.vv v31, v31, v31 \n\t" + "vse32.v v31, (%[BUFFER]) \n\t" + "addi t6, %[K], 0 \n\t" + "addi s1, %[SRC], 0 \n\t" + "TAIL_CMP%=: \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vsetvli t0, t6, e32, m8 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi %[SRC], %[SRC], 256 \n\t" + "sub t6, t6, t0 \n\t" + "vfabs.v v0, v0 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vfmax.vv v0, v0, v4 \n\t" + "vsetvli t0, zero, e32, m2 \n\t" + "vfmax.vv v0, v0, v2 \n\t" + "vsetvli t0, zero, e32, m1 \n\t" + "vfmax.vv v0, v0, v1 \n\t" + "vle32.v v30, (%[BUFFER]) \n\t" + "vfmax.vv v31, v30, v0 \n\t" + "vse32.v v31, (%[BUFFER]) \n\t" + "bnez t6, TAIL_CMP%= \n\t" + "addi t6, %[K], 0 \n\t" + "flw f0, (%[BUFFER]) \n\t" + "flw f1, 4(%[BUFFER]) \n\t" + "flw f2, 8(%[BUFFER]) \n\t" + "flw f3, 12(%[BUFFER]) \n\t" + "flw f4, 16(%[BUFFER]) \n\t" + "flw f5, 20(%[BUFFER]) \n\t" + "flw f6, 24(%[BUFFER]) \n\t" + "flw f7, 28(%[BUFFER]) \n\t" + "fmax.s f1, f0, f1 \n\t" + "fmax.s f3, f2, f3 \n\t" + "fmax.s f5, f4, f5 \n\t" + "fmax.s f7, f6, f7 \n\t" + "fmax.s f3, f1, f3 \n\t" + "fmax.s f7, f5, f7 \n\t" + "fmax.s f10, f3, f7 \n\t" + "fmul.s f10, f10, %[RMAXREC] \n\t" + "fsw f10, (%[DST]) \n\t" + "addi %[DST], %[DST], 4 \n\t" + "fdiv.s f11, %[FONE], f10 \n\t" + "addi t6, %[K], 0 \n\t" + "TAIL_QUANT%=: \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v0, v0, v0 \n\t" + "vsetvli t1, t6, e32, m8 \n\t" + "vle32.v v0, (s1) \n\t" + "addi s1, s1, 256 \n\t" + "sub t6, t6, t1 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vfmul.vf v0, v0, f11 \n\t" + "vfcvt.x.f.v v0, v0 \n\t" + "vsetvli t0, zero, e16, m4 \n\t" + "vnclip.wx v0, v0, zero \n\t" + "vsetvli t0, t1, e8, m2 \n\t" + "vnclip.wx v0, v0, zero \n\t" + "vse8.v v0, (%[DST]) \n\t" + "addi %[DST], %[DST], 64 \n\t" + "bnez t6, TAIL_QUANT%= \n\t" + "END%=: \n\t" + : [SRC] "+r"(SRC), [DST] "+r"(DST), [K] "+r"(CountK) + : [FONE] "f"(fone), [RMAXREC] "f"(range_max_reciprocal), [BLK] "r"(BlkLen), [BUFFER] "r"(buffer), + [CNT] "r"(cnt) + : "cc", "t1", "t0", "t6", "s1", "f0", "f1", "f2", "f3", "f4", "f5", "f6"); + } +} + +} // namespace ime1 + +namespace { +#define SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 \ + "vmadot v16, v14, v0 \n\t" \ + "vmadot v18, v14, v1 \n\t" \ + "vmadot v20, v14, v2 \n\t" \ + "vmadot v22, v14, v3 \n\t" \ + "vmadot v16, v15, v4 \n\t" \ + "vmadot v18, v15, v5 \n\t" \ + "vmadot v20, v15, v6 \n\t" \ + "vmadot v22, v15, v7 \n\t" + +#define SQ4BIT_KERNEL_ACC_1X4X4 \ + "vfcvt.f.x.v v16, v16 \n\t" \ + "vfcvt.f.x.v v18, v18 \n\t" \ + "vfcvt.f.x.v v20, v20 \n\t" \ + "vfcvt.f.x.v v22, v22 \n\t" \ + "addi s2, s1, 16 \n\t" \ + "addi s3, s1, 32 \n\t" \ + "addi s4, s1, 48 \n\t" \ + "addi s6, s5, 12 \n\t" \ + "vfmacc.vv v28, v16, v24 \n\t" \ + "vfmacc.vv v29, v18, v25 \n\t" \ + "vfmacc.vv v30, v20, v26 \n\t" \ + "vfmacc.vv v31, v22, v27 \n\t" + +#define SQ4BIT_KERNEL_ACC_F16_1X4X4 \ + "vfcvt.f.x.v v16, v16 \n\t" \ + "vfcvt.f.x.v v18, v18 \n\t" \ + "vfcvt.f.x.v v20, v20 \n\t" \ + "vfcvt.f.x.v v22, v22 \n\t" \ + "addi s2, s1, 8 \n\t" \ + "addi s3, s1, 16 \n\t" \ + "addi s4, s1, 24 \n\t" \ + "addi s6, s5, 12 \n\t" \ + "vfmacc.vv v28, v16, v24 \n\t" \ + "vfmacc.vv v29, v18, v25 \n\t" \ + "vfmacc.vv v30, v20, v26 \n\t" \ + "vfmacc.vv v31, v22, v27 \n\t" + +#define SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 \ + "vle8.v v4, (s1) \n\t" \ + "addi s1, s1, 128 \n\t" \ + "vle8.v v5, (s2) \n\t" \ + "addi s2, s2, 128 \n\t" \ + "vle8.v v6, (s3) \n\t" \ + "addi s3, s3, 128 \n\t" \ + "vle8.v v7, (s4) \n\t" \ + "addi s4, s4, 128 \n\t" \ + "vsetvli t0, zero, e8, mf4 \n\t" \ + "vle8.v v14, (s5) \n\t" \ + "addi s5, s5, 16 \n\t" \ + "vle8.v v15, (s6) \n\t" \ + "addi s6, s6, 16 \n\t" \ + "addi t5, t5, -1 \n\t" \ + "vsetvli t0, zero, e8, m1 \n\t" \ + "vand.vi v0, v4, 15 \n\t" \ + "vand.vi v1, v5, 15 \n\t" \ + "vand.vi v2, v6, 15 \n\t" \ + "vand.vi v3, v7, 15 \n\t" \ + "vsrl.vi v4, v4, 4 \n\t" \ + "vsrl.vi v5, v5, 4 \n\t" \ + "vsrl.vi v6, v6, 4 \n\t" \ + "vsrl.vi v7, v7, 4 \n\t" + +#define SQ4BIT_KERNEL_LOAD_ZP_16X1 \ + "vsetvli t0, zero, e8, mf2 \n\t" \ + "vle8.v v1, (s7) \n\t" \ + "vsetvli t0, zero, e8, m1 \n\t" \ + "vrgather.vv v8, v1, v13 \n\t" \ + "vadd.vi v13, v13, 4 \n\t" \ + "vrgather.vv v9, v1, v13 \n\t" \ + "vadd.vi v13, v13, 4 \n\t" \ + "vrgather.vv v10, v1, v13 \n\t" \ + "vadd.vi v13, v13, 4 \n\t" \ + "vrgather.vv v11, v1, v13 \n\t" \ + "vadd.vi v13, v13, -12 \n\t" + +// using for M4Kernel +#define LOAD_B_16x8x2 \ + "vsetvli t0, zero, e8, m1 \n\t" \ + "vle8.v v6, (s1) \n\t" \ + "addi s1, s1, 32*4 \n\t" \ + "vle8.v v7, (s2) \n\t" \ + "addi s2, s2, 32*4 \n\t" \ + "vle8.v v8, (s3) \n\t" \ + "addi s3, s3, 32*4 \n\t" \ + "vle8.v v9, (s4) \n\t" \ + "addi s4, s4, 32*4 \n\t" \ + \ + "vand.vi v2, v6, 15 \n\t" \ + "vand.vi v3, v7, 15 \n\t" \ + "vand.vi v4, v8, 15 \n\t" \ + "vand.vi v5, v9, 15 \n\t" \ + \ + "vsrl.vi v6, v6, 4 \n\t" \ + "vsrl.vi v7, v7, 4 \n\t" \ + "vsrl.vi v8, v8, 4 \n\t" \ + "vsrl.vi v9, v9, 4 \n\t" + +// [s2|s5, s3, s4, s6] +#define LOAD_SCALE_4x16_FP16 \ + "addi s2, s5, -8 \n\t" \ + "addi s3, s5, 8 \n\t" \ + "addi s4, s5, 16 \n\t" \ + "addi s6, s5, 24 \n\t" \ + "li t1, 0xf0 \n\t" \ + "vmv.s.x v0, t1 \n\t" \ + "vsetvli t0, zero, e16, mf4 \n\t" \ + "vle16.v v9, (s5) \n\t" \ + "vle16.v v11, (s3) \n\t" \ + "vle16.v v13, (s4) \n\t" \ + "vle16.v v15, (s6) \n\t" \ + "vsetvli t0, zero, e16, mf2 \n\t" \ + "vle16.v v9, (s2), v0.t \n\t" \ + "vle16.v v11, (s5), v0.t \n\t" \ + "vle16.v v13, (s3), v0.t \n\t" \ + "vle16.v v15, (s4), v0.t \n\t" \ + "vfwcvt.f.f.v v8, v9 \n\t" \ + "vfwcvt.f.f.v v10, v11 \n\t" \ + "vfwcvt.f.f.v v12, v13 \n\t" \ + "vfwcvt.f.f.v v14, v15 \n\t" \ + "vsetvli t0, zero, e32, m1 \n\t" \ + "vmv.v.v v9, v8 \n\t" \ + "vmv.v.v v11, v10 \n\t" \ + "vmv.v.v v13, v12 \n\t" \ + "vmv.v.v v15, v14 \n\t" \ + "li t1, 0xf0 \n\t" \ + "vmv.s.x v0, t1 \n\t" \ + "vsetvli t0, zero, e32, mf2 \n\t" \ + "vfmul.vf v8, v8, f1 \n\t" \ + "vfmul.vf v10, v10, f1 \n\t" \ + "vfmul.vf v12, v12, f1 \n\t" \ + "vfmul.vf v14, v14, f1 \n\t" \ + "vfmul.vf v9, v9, f3 \n\t" \ + "vfmul.vf v11, v11, f3 \n\t" \ + "vfmul.vf v13, v13, f3 \n\t" \ + "vfmul.vf v15, v15, f3 \n\t" \ + "vsetvli t0, zero, e32, m1 \n\t" \ + "vfmul.vf v8, v8, f2, v0.t \n\t" \ + "vfmul.vf v10, v10, f2, v0.t \n\t" \ + "vfmul.vf v12, v12, f2, v0.t \n\t" \ + "vfmul.vf v14, v14, f2, v0.t \n\t" \ + "vfmul.vf v9, v9, f4, v0.t \n\t" \ + "vfmul.vf v11, v11, f4, v0.t \n\t" \ + "vfmul.vf v13, v13, f4, v0.t \n\t" \ + "vfmul.vf v15, v15, f4, v0.t \n\t" + +// [s2|s5, s3, s4, s6] +#define LOAD_SCALE_4x16 \ + "addi s2, s5, -16 \n\t" \ + "addi s3, s5, 16 \n\t" \ + "addi s4, s5, 32 \n\t" \ + "addi s6, s5, 48 \n\t" \ + "li t1, 0xf0 \n\t" \ + "vmv.s.x v0, t1 \n\t" \ + "vsetvli t0, zero, e32, mf2 \n\t" \ + "vle32.v v8, (s5) \n\t" \ + "vle32.v v10, (s3) \n\t" \ + "vle32.v v12, (s4) \n\t" \ + "vle32.v v14, (s6) \n\t" \ + "vsetvli t0, zero, e32, m1 \n\t" \ + "vle32.v v8, (s2), v0.t \n\t" \ + "vle32.v v10, (s5), v0.t \n\t" \ + "vle32.v v12, (s3), v0.t \n\t" \ + "vle32.v v14, (s4), v0.t \n\t" \ + "vmv.v.v v9, v8 \n\t" \ + "vmv.v.v v11, v10 \n\t" \ + "vmv.v.v v13, v12 \n\t" \ + "vmv.v.v v15, v14 \n\t" \ + "vsetvli t0, zero, e32, mf2 \n\t" \ + "vfmul.vf v8, v8, f1 \n\t" \ + "vfmul.vf v10, v10, f1 \n\t" \ + "vfmul.vf v12, v12, f1 \n\t" \ + "vfmul.vf v14, v14, f1 \n\t" \ + "vfmul.vf v9, v9, f3 \n\t" \ + "vfmul.vf v11, v11, f3 \n\t" \ + "vfmul.vf v13, v13, f3 \n\t" \ + "vfmul.vf v15, v15, f3 \n\t" \ + "vsetvli t0, zero, e32, m1 \n\t" \ + "vfmul.vf v8, v8, f2, v0.t \n\t" \ + "vfmul.vf v10, v10, f2, v0.t \n\t" \ + "vfmul.vf v12, v12, f2, v0.t \n\t" \ + "vfmul.vf v14, v14, f2, v0.t \n\t" \ + "vfmul.vf v9, v9, f4, v0.t \n\t" \ + "vfmul.vf v11, v11, f4, v0.t \n\t" \ + "vfmul.vf v13, v13, f4, v0.t \n\t" \ + "vfmul.vf v15, v15, f4, v0.t \n\t" + +//[s1| BIAS, s2, s3, s4] +#define LOAD_BIAS \ + "vsetvli t0, zero, e32, mf2 \n\t" \ + "li t1, 0xf0 \n\t" \ + "vmv.s.x v0, t1 \n\t" \ + "addi s1, %[BIAS], -16 \n\t" \ + "addi s2, %[BIAS], 16 \n\t" \ + "addi s3, %[BIAS], 32 \n\t" \ + "addi s4, %[BIAS], 48 \n\t" \ + \ + "vle32.v v24, (%[BIAS]) \n\t" \ + "vle32.v v26, (s2) \n\t" \ + "vle32.v v28, (s3) \n\t" \ + "vle32.v v30, (s4) \n\t" \ + "vsetvli t0, zero, e32, m1 \n\t" \ + "vle32.v v24, (s1), v0.t \n\t" \ + "vle32.v v26, (%[BIAS]), v0.t \n\t" \ + "vle32.v v28, (s2), v0.t \n\t" \ + "vle32.v v30, (s3), v0.t \n\t" \ + "vmv.v.v v25, v24 \n\t" \ + "vmv.v.v v27, v26 \n\t" \ + "vmv.v.v v29, v28 \n\t" \ + "vmv.v.v v31, v30 \n\t" + +#define SQ4BIT_KERNEL_COMP_4x16x16 \ + "vmadot v16, v10, v2 \n\t" \ + "vmadot v18, v10, v3 \n\t" \ + "vmadot v20, v10, v4 \n\t" \ + "vmadot v22, v10, v5 \n\t" \ + "vmadot v16, v11, v6 \n\t" \ + "vmadot v18, v11, v7 \n\t" \ + "vmadot v20, v11, v8 \n\t" \ + "vmadot v22, v11, v9 \n\t" + +#define SAVE_RESULT_4x16 \ + "addi a1, %[C], 0 \n\t" \ + "add a2, %[C], %[LDC] \n\t" \ + "add a3, a2, %[LDC] \n\t" \ + "add a4, a3, %[LDC] \n\t" \ + "addi a2, a2, -16 \n\t" \ + "addi a4, a4, -16 \n\t" \ + "li t1, 0xf0 \n\t" \ + "vmv.s.x v0, t1 \n\t" \ + "vsetvli t0, zero, e32, mf2 \n\t" \ + \ + "vse32.v v24, (a1) \n\t" \ + "addi a1, a1, 16 \n\t" \ + "vse32.v v25, (a3) \n\t" \ + "addi a3, a3, 16 \n\t" \ + \ + "vse32.v v26, (a1) \n\t" \ + "addi a1, a1, 16 \n\t" \ + "vse32.v v27, (a3) \n\t" \ + "addi a3, a3, 16 \n\t" \ + \ + "vse32.v v28, (a1) \n\t" \ + "addi a1, a1, 16 \n\t" \ + "vse32.v v29, (a3) \n\t" \ + "addi a3, a3, 16 \n\t" \ + \ + "vse32.v v30, (a1) \n\t" \ + "vse32.v v31, (a3) \n\t" \ + "vsetvli t0, zero, e32, m1 \n\t" \ + \ + "vse32.v v24, (a2), v0.t \n\t" \ + "addi a2, a2, 16 \n\t" \ + "vse32.v v25, (a4), v0.t \n\t" \ + "addi a4, a4, 16 \n\t" \ + \ + "vse32.v v26, (a2), v0.t \n\t" \ + "addi a2, a2, 16 \n\t" \ + "vse32.v v27, (a4), v0.t \n\t" \ + "addi a4, a4, 16 \n\t" \ + \ + "vse32.v v28, (a2), v0.t \n\t" \ + "addi a2, a2, 16 \n\t" \ + "vse32.v v29, (a4), v0.t \n\t" \ + "addi a4, a4, 16 \n\t" \ + \ + "vse32.v v30, (a2), v0.t \n\t" \ + "vse32.v v31, (a4), v0.t \n\t" + +#define SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 \ + "vsetvli t0, zero, e8, mf2 \n\t" \ + "vle8.v v11, (s6) \n\t" \ + "vsetvli t0, zero, e8, m1 \n\t" \ + "vrgather.vv v12, v11, v1 \n\t" \ + "vadd.vi v1, v1, 4 \n\t" \ + "vrgather.vv v13, v11, v1 \n\t" \ + "vadd.vi v1, v1, 4 \n\t" \ + "vrgather.vv v14, v11, v1 \n\t" \ + "vadd.vi v1, v1, 4 \n\t" \ + "vrgather.vv v15, v11, v1 \n\t" \ + "vadd.vi v1, v1, -12 \n\t" + +template +void SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, + const std::byte * QuantA, + const std::byte * QuantBData, + const float * QuantBScale, + const std::byte * QuantBZeroPoint, + float * C, + size_t CountN, + size_t BlockCountK, + const float * Bias, + const size_t ldc) { + GGML_UNUSED(QuantBScale); + GGML_UNUSED(QuantBZeroPoint); + size_t LDC = ldc * sizeof(float); + const size_t INNER = BlkLen / 16; + float tmp[4 * 16]; + + if constexpr (HasZeroPoint) { + for (size_t n = 0; n < CountN; n += 16) { + size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; + std::byte * QuantBDataPtr = (std::byte *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(uint8_t) + // zp + n * BlockCountK * sizeof(_Float16); // scale + float * CPtr = C + n; + if (NBLKS < 16) { + CPtr = tmp; + LDC = 16 * sizeof(float); + } + if (Bias != nullptr) { + const float * bias = Bias + n; + if (NBLKS < 16) { + __asm__ volatile( + "vsetvli t0, %[N], e32, m2 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "vse32.v v0, (%[DST]) \n\t" + : + : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) + : "cc", "t0"); + bias = tmp; + } + __asm__ volatile(LOAD_BIAS + + "addi t3, %[BlockCountK], 0 \n\t" + + "vsetvli t0, zero, e8, m1 \n\t" + "li s1, 24 \n\t" + "vmv.v.i v1, 3 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v1, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v1, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v1, 0 \n\t" + + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + + "BLOCK_COUNTK_LOOP%=: \n\t" + // scale offset + "addi s5, s1, 0 \n\t" + // zp offset + "addi s6, s1, 32 \n\t" + "addi s1, s6, 16 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 + + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vsub.vv v2, v2, v12 \n\t" + "vsub.vv v6, v6, v12 \n\t" + "vsub.vv v3, v3, v13 \n\t" + "vsub.vv v7, v7, v13 \n\t" + "vsub.vv v4, v4, v14 \n\t" + "vsub.vv v8, v8, v14 \n\t" + "vsub.vv v5, v5, v15 \n\t" + "vsub.vv v9, v9, v15 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16_FP16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", + "s2", "s3", "s4", "s5", "s6"); + + } else { + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "addi t3, %[BlockCountK], 0 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "li s1, 24 \n\t" + "vmv.v.i v1, 3 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v1, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v1, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v1, 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + // scale offset + "addi s5, s1, 0 \n\t" + // zp offset + "addi s6, s1, 32 \n\t" + "addi s1, s6, 16 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 + + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vsub.vv v2, v2, v12 \n\t" + "vsub.vv v6, v6, v12 \n\t" + "vsub.vv v3, v3, v13 \n\t" + "vsub.vv v7, v7, v13 \n\t" + "vsub.vv v4, v4, v14 \n\t" + "vsub.vv v8, v8, v14 \n\t" + "vsub.vv v5, v5, v15 \n\t" + "vsub.vv v9, v9, v15 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16_FP16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", + "s4", "s5", "s6"); + } + } + } else { + for (size_t n = 0; n < CountN; n += 16) { + size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; + std::byte * QuantBDataPtr = (std::byte *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(_Float16); // scale + float * CPtr = C + n; + if (NBLKS < 16) { + CPtr = tmp; + LDC = 16 * sizeof(float); + } + if (Bias != nullptr) { + const float * bias = Bias + n; + if (NBLKS < 16) { + __asm__ volatile( + "vsetvli t0, %[N], e32, m2 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "vse32.v v0, (%[DST]) \n\t" + : + : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) + : "cc", "t0"); + bias = tmp; + } + __asm__ volatile(LOAD_BIAS + + "addi t3, %[BlockCountK], 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + "addi s5, s1, 0 \n\t" + "addi s1, s5, 32 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vsetvli t0, zero, e8, m1 \n\t" + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + "vadd.vi v8, v8, -8 \n\t" + "vadd.vi v9, v9, -8 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16_FP16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", + "s2", "s3", "s4", "s5", "s6"); + + } else { + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "addi t3, %[BlockCountK], 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + "addi s5, s1, 0 \n\t" + "addi s1, s5, 32 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vsetvli t0, zero, e8, m1 \n\t" + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + "vadd.vi v8, v8, -8 \n\t" + "vadd.vi v9, v9, -8 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16_FP16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", + "s4", "s5", "s6"); + } + } + } + if (CountN % 16 != 0) { + // stroe output from tmp to C when NBLKS less than 16. + float * CPtr = C + CountN / 16 * 16; + const size_t N = CountN % 16; + LDC = ldc * sizeof(float); + __asm__ volatile( + "vsetvli t0, %[N], e32, m2 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi s2, %[SRC], 64 \n\t" + "addi s3, %[SRC], 64*2 \n\t" + "addi s4, %[SRC], 64*3 \n\t" + "vle32.v v2, (s2) \n\t" + "vle32.v v4, (s3) \n\t" + "vle32.v v6, (s4) \n\t" + "add t2, %[DST], %[LDC] \n\t" + "add t3, t2, %[LDC] \n\t" + "add t4, t3, %[LDC] \n\t" + "vse32.v v0, (%[DST]) \n\t" + "vse32.v v2, (t2) \n\t" + "vse32.v v4, (t3) \n\t" + "vse32.v v6, (t4) \n\t" + : + : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC) + : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4"); + } +} + +template +void SQ4BitGemmM4Kernel_CompInt8_Impl(size_t BlkLen, + const std::byte * QuantA, + const std::byte * QuantBData, + const float * QuantBScale, + const std::byte * QuantBZeroPoint, + float * C, + size_t CountN, + size_t BlockCountK, + const float * Bias, + const size_t ldc) { + GGML_UNUSED(QuantBScale); + GGML_UNUSED(QuantBZeroPoint); + size_t LDC = ldc * sizeof(float); + const size_t INNER = BlkLen / 16; + float tmp[4 * 16]; + + if constexpr (HasZeroPoint) { + for (size_t n = 0; n < CountN; n += 16) { + size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; + std::byte * QuantBDataPtr = (std::byte *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(uint8_t) + // zp + n * BlockCountK * sizeof(float); // scale + float * CPtr = C + n; + if (NBLKS < 16) { + CPtr = tmp; + LDC = 16 * sizeof(float); + } + if (Bias != nullptr) { + const float * bias = Bias + n; + if (NBLKS < 16) { + __asm__ volatile( + "vsetvli t0, %[N], e32, m2 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "vse32.v v0, (%[DST]) \n\t" + : + : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) + : "cc", "t0"); + bias = tmp; + } + + __asm__ volatile(LOAD_BIAS + "addi t3, %[BlockCountK], 0 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "li s1, 24 \n\t" + "vmv.v.i v1, 3 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v1, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v1, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v1, 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + // scale offset + "addi s5, s1, 0 \n\t" + // zp offset + "addi s6, s1, 64 \n\t" + "addi s1, s6, 16 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 + + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vsub.vv v2, v2, v12 \n\t" + "vsub.vv v6, v6, v12 \n\t" + "vsub.vv v3, v3, v13 \n\t" + "vsub.vv v7, v7, v13 \n\t" + "vsub.vv v4, v4, v14 \n\t" + "vsub.vv v8, v8, v14 \n\t" + "vsub.vv v5, v5, v15 \n\t" + "vsub.vv v9, v9, v15 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", + "s2", "s3", "s4", "s5", "s6"); + + } else { + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "addi t3, %[BlockCountK], 0 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "li s1, 24 \n\t" + "vmv.v.i v1, 3 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v1, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v1, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v1, 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + // scale offset + "addi s5, s1, 0 \n\t" + // zp offset + "addi s6, s1, 64 \n\t" + "addi s1, s6, 16 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1_v2 + + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vsub.vv v2, v2, v12 \n\t" + "vsub.vv v6, v6, v12 \n\t" + "vsub.vv v3, v3, v13 \n\t" + "vsub.vv v7, v7, v13 \n\t" + "vsub.vv v4, v4, v14 \n\t" + "vsub.vv v8, v8, v14 \n\t" + "vsub.vv v5, v5, v15 \n\t" + "vsub.vv v9, v9, v15 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", + "s4", "s5", "s6"); + } + } + } else { + for (size_t n = 0; n < CountN; n += 16) { + size_t NBLKS = (CountN - n) > 16 ? 16 : CountN - n; + std::byte * QuantBDataPtr = (std::byte *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(float); // scale + float * CPtr = C + n; + if (NBLKS < 16) { + CPtr = tmp; + LDC = 16 * sizeof(float); + } + if (Bias != nullptr) { + const float * bias = Bias + n; + if (NBLKS < 16) { + __asm__ volatile( + "vsetvli t0, %[N], e32, m2 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "vse32.v v0, (%[DST]) \n\t" + : + : [SRC] "r"(bias), [DST] "r"(tmp), [N] "r"(NBLKS) + : "cc", "t0"); + bias = tmp; + } + __asm__ volatile(LOAD_BIAS + "addi t3, %[BlockCountK], 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + "addi s5, s1, 0 \n\t" + "addi s1, s5, 64 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vsetvli t0, zero, e8, m1 \n\t" + "vle8.v v10, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + "vadd.vi v8, v8, -8 \n\t" + "vadd.vi v9, v9, -8 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr), [BIAS] "r"(bias) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", + "s2", "s3", "s4", "s5", "s6"); + + } else { + __asm__ volatile( + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v24, v24, v24 \n\t" + "addi t3, %[BlockCountK], 0 \n\t" + "addi a1, %[A], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "BLOCK_COUNTK_LOOP%=: \n\t" + "addi s5, s1, 0 \n\t" + "addi s1, s5, 64 \n\t" + "addi s2, s1, 32 \n\t" + "addi s3, s1, 32*2 \n\t" + "addi s4, s1, 32*3 \n\t" + "vsetvli t0, zero, e32, m8 \n\t" + "vxor.vv v16, v16, v16 \n\t" + // load a scale + "flw f1, (a1) \n\t" + "flw f2, 4(a1) \n\t" + "flw f3, 8(a1) \n\t" + "flw f4, 12(a1) \n\t" + "addi a1, a1, 16 \n\t" + "addi t2, %[INNER], 0 \n\t" + "BLOCK_INNER_LOOP%=: \n\t" + + LOAD_B_16x8x2 + + "vsetvli t0, zero, e8, m1 \n\t" + "vle8.v v10, (a1) \n\t" + + "addi a1, a1, 32 \n\t" + "vle8.v v11, (a1) \n\t" + "addi a1, a1, 32 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + "vadd.vi v8, v8, -8 \n\t" + "vadd.vi v9, v9, -8 \n\t" + + SQ4BIT_KERNEL_COMP_4x16x16 + + "addi t2, t2, -1 \n\t" + "bnez t2, BLOCK_INNER_LOOP%= \n\t" + + LOAD_SCALE_4x16 + + "vsetvli t0, zero, e32, m8 \n\t" + "vfcvt.f.x.v v16, v16 \n\t" + "vfmacc.vv v24, v16, v8 \n\t" + "addi t3, t3, -1 \n\t" + "bnez t3, BLOCK_COUNTK_LOOP%= \n\t" + + "RESULT_SAVE%=: \n\t" + + SAVE_RESULT_4x16 + + : + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [LDC] "r"(LDC), + [BlockCountK] "r"(BlockCountK), [C] "r"(CPtr) + : "cc", "t0", "t1", "t2", "t3", "a1", "a2", "a3", "a4", "f1", "f2", "f3", "f4", "s1", "s2", "s3", + "s4", "s5", "s6"); + } + } + } + if (CountN % 16 != 0) { + // stroe output from tmp to C when NBLKS less than 16. + float * CPtr = C + CountN / 16 * 16; + const size_t N = CountN % 16; + LDC = ldc * sizeof(float); + __asm__ volatile( + "vsetvli t0, %[N], e32, m2 \n\t" + "vle32.v v0, (%[SRC]) \n\t" + "addi s2, %[SRC], 64 \n\t" + "addi s3, %[SRC], 64*2 \n\t" + "addi s4, %[SRC], 64*3 \n\t" + "vle32.v v2, (s2) \n\t" + "vle32.v v4, (s3) \n\t" + "vle32.v v6, (s4) \n\t" + "add t2, %[DST], %[LDC] \n\t" + "add t3, t2, %[LDC] \n\t" + "add t4, t3, %[LDC] \n\t" + "vse32.v v0, (%[DST]) \n\t" + "vse32.v v2, (t2) \n\t" + "vse32.v v4, (t3) \n\t" + "vse32.v v6, (t4) \n\t" + : + : [N] "r"(N), [SRC] "r"(tmp), [DST] "r"(CPtr), [LDC] "r"(LDC) + : "cc", "t0", "t2", "t3", "t4", "s2", "s3", "s4"); + } +} + +template +void SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(size_t BlkLen, + const std::byte * QuantA, + const std::byte * QuantBData, + const float * QuantBScale, + const std::byte * QuantBZeroPoint, + float * C, + size_t CountN, + size_t BlockCountK, + const float * Bias) { + GGML_UNUSED(QuantBScale); + GGML_UNUSED(QuantBZeroPoint); + size_t INNER = BlkLen / 16; + + if constexpr (HasZeroPoint) { + for (size_t n = 0; n < CountN; n += 16) { + size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; + std::byte * QuantBDataPtr = (std::byte *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(uint8_t) + // zp + n * BlockCountK * sizeof(_Float16); // scale + float * CPtr = C + n; + size_t cnt = BlockCountK; + if (Bias != nullptr) { + const float * bias = Bias + n; + __asm__ volatile( + "addi t3, %[NBLKS], 0 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + + "vmv.v.i v13, 3 \n\t" + "li s1, 24 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v13, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v13, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v13, 0 \n\t" + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 8 \n\t" + "addi s3, %[B], 16 \n\t" + "addi s4, %[B], 24 \n\t" + // zp offset + "addi s7, %[B], 32 \n\t" + // a offset + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v28, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v29, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v30, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v31, (%[BIAS]) \n\t" + + "LOOP_K%=: \n\t" + "vsetvli t0, zero, e16, mf4 \n\t" + + "vle16.v v4, (s1) \n\t" + "addi s1, s1, 48 \n\t" + "vle16.v v5, (s2) \n\t" + "addi s2, s2, 72 \n\t" + "vle16.v v6, (s3) \n\t" + "addi s3, s3, 96 \n\t" + "vle16.v v7, (s4) \n\t" + "addi s4, s4, 120 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + "vfwcvt.f.f.v v8, v4 \n\t" + "vfwcvt.f.f.v v9, v5 \n\t" + "vfwcvt.f.f.v v10, v6 \n\t" + "vfwcvt.f.f.v v11, v7 \n\t" + + "vsetvli t0, zero, e32, mf2 \n\t" + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1 + + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vsub.vv v0, v0, v8 \n\t" + "vsub.vv v4, v4, v8 \n\t" + "vsub.vv v1, v1, v9 \n\t" + "vsub.vv v5, v5, v9 \n\t" + "vsub.vv v2, v2, v10 \n\t" + "vsub.vv v6, v6, v10 \n\t" + "vsub.vv v3, v3, v11 \n\t" + "vsub.vv v7, v7, v11 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_F16_1X4X4 + "addi s7, s1, 32 \n\t" + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); + } else { + __asm__ volatile( + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v28, v28, v28 \n\t" + + "vsetvli t0, zero, e8, m1 \n\t" + "vmv.v.i v13, 3 \n\t" + "li s1, 24 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v13, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v13, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v13, 0 \n\t" + + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 8 \n\t" + "addi s3, %[B], 16 \n\t" + "addi s4, %[B], 24 \n\t" + + "addi s7, %[B], 32 \n\t" + + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "LOOP_K%=: \n\t" + "vsetvli t0, zero, e16, mf4 \n\t" + "vle16.v v4, (s1) \n\t" + "addi s1, s1, 48 \n\t" + "vle16.v v5, (s2) \n\t" + "addi s2, s2, 72 \n\t" + "vle16.v v6, (s3) \n\t" + "addi s3, s3, 96 \n\t" + "vle16.v v7, (s4) \n\t" + "addi s4, s4, 120 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "vfwcvt.f.f.v v8, v4 \n\t" + "vfwcvt.f.f.v v9, v5 \n\t" + "vfwcvt.f.f.v v10, v6 \n\t" + "vfwcvt.f.f.v v11, v7 \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1 + + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vsub.vv v0, v0, v8 \n\t" + "vsub.vv v4, v4, v8 \n\t" + "vsub.vv v1, v1, v9 \n\t" + "vsub.vv v5, v5, v9 \n\t" + "vsub.vv v2, v2, v10 \n\t" + "vsub.vv v6, v6, v10 \n\t" + "vsub.vv v3, v3, v11 \n\t" + "vsub.vv v7, v7, v11 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_F16_1X4X4 + "addi s7, s1, 32 \n\t" + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); + } + } + } else { + for (size_t n = 0; n < CountN; n += 16) { + size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; + std::byte * QuantBDataPtr = (std::byte *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(_Float16); // scale + float * CPtr = C + n; + size_t cnt = BlockCountK; + if (Bias != nullptr) { + const float * bias = Bias + n; + __asm__ volatile( + "addi t3, %[NBLKS], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 8 \n\t" + "addi s3, %[B], 16 \n\t" + "addi s4, %[B], 24 \n\t" + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v28, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v29, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v30, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v31, (%[BIAS]) \n\t" + + "LOOP_K%=: \n\t" + "vsetvli t0, zero, e16, mf4 \n\t" + + "vle16.v v4, (s1) \n\t" + "addi s1, s1, 32 \n\t" + "vle16.v v5, (s2) \n\t" + "addi s2, s2, 56 \n\t" + "vle16.v v6, (s3) \n\t" + "addi s3, s3, 80 \n\t" + "vle16.v v7, (s4) \n\t" + "addi s4, s4, 104 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + "vfwcvt.f.f.v v8, v4 \n\t" + "vfwcvt.f.f.v v9, v5 \n\t" + "vfwcvt.f.f.v v10, v6 \n\t" + "vfwcvt.f.f.v v11, v7 \n\t" + + "vsetvli t0, zero, e32, mf2 \n\t" + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vadd.vi v0, v0, -8 \n\t" + "vadd.vi v1, v1, -8 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_F16_1X4X4 + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); + } else { + __asm__ volatile( + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 8 \n\t" + "addi s3, %[B], 16 \n\t" + "addi s4, %[B], 24 \n\t" + + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "LOOP_K%=: \n\t" + "vsetvli t0, zero, e16, mf4 \n\t" + "vle16.v v4, (s1) \n\t" + "addi s1, s1, 32 \n\t" + "vle16.v v5, (s2) \n\t" + "addi s2, s2, 56 \n\t" + "vle16.v v6, (s3) \n\t" + "addi s3, s3, 80 \n\t" + "vle16.v v7, (s4) \n\t" + "addi s4, s4, 104 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "vfwcvt.f.f.v v8, v4 \n\t" + "vfwcvt.f.f.v v9, v5 \n\t" + "vfwcvt.f.f.v v10, v6 \n\t" + "vfwcvt.f.f.v v11, v7 \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vadd.vi v0, v0, -8 \n\t" + "vadd.vi v1, v1, -8 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_F16_1X4X4 + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); + } + } + } +} + +template +void SQ4BitGemmM1Kernel_CompInt8_Impl(size_t BlkLen, + const std::byte * QuantA, + const std::byte * QuantBData, + const float * QuantBScale, + const std::byte * QuantBZeroPoint, + float * C, + size_t CountN, + size_t BlockCountK, + const float * Bias) { + GGML_UNUSED(QuantBScale); + GGML_UNUSED(QuantBZeroPoint); + const size_t INNER = BlkLen / 16; + if constexpr (HasZeroPoint) { + for (size_t n = 0; n < CountN; n += 16) { + size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; + std::byte * QuantBDataPtr = (std::byte *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(uint8_t) + // zp + n * BlockCountK * sizeof(float); // scale + float * CPtr = C + n; + size_t cnt = BlockCountK; + if (Bias != nullptr) { + const float * bias = Bias + n; + __asm__ volatile( + "addi t3, %[NBLKS], 0 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "vmv.v.i v13, 3 \n\t" + "li s1, 24 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v13, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v13, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v13, 0 \n\t" + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v28, v28, v28 \n\t" + + // scale offset, scale0.0, scale1.0, scale2.0, scale3.0....scale15.0 + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 16 \n\t" + "addi s3, %[B], 32 \n\t" + "addi s4, %[B], 48 \n\t" + // zp offset + "addi s7, %[B], 64 \n\t" + // a offset + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v28, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v29, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v30, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v31, (%[BIAS]) \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + "LOOP_K%=: \n\t" + + // load scale + "vle32.v v8, (s1) \n\t" + "addi s1, s1, 80 \n\t" + "vle32.v v9, (s2) \n\t" + "addi s2, s2, 96 \n\t" + "vle32.v v10, (s3) \n\t" + "addi s3, s3, 112 \n\t" + "vle32.v v11, (s4) \n\t" + "addi s4, s4, 128 \n\t" + + // load a scale + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + + // a scale * b scale + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1 + + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vsub.vv v0, v0, v8 \n\t" + "vsub.vv v4, v4, v8 \n\t" + "vsub.vv v1, v1, v9 \n\t" + "vsub.vv v5, v5, v9 \n\t" + "vsub.vv v2, v2, v10 \n\t" + "vsub.vv v6, v6, v10 \n\t" + "vsub.vv v3, v3, v11 \n\t" + "vsub.vv v7, v7, v11 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_1X4X4 + "addi s7, s1, 64 \n\t" + + "bnez %[CNT], LOOP_K%= \n\t" + + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); + } else { + __asm__ volatile( + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v28, v28, v28 \n\t" + + "vsetvli t0, zero, e8, m1 \n\t" + "vmv.v.i v13, 3 \n\t" + "li s1, 24 \n\t" + "vsetvli t0, s1, e8, m1 \n\t" + "vmv.v.i v13, 2 \n\t" + "vsetvli t0, zero, e8, mf2 \n\t" + "vmv.v.i v13, 1 \n\t" + "vsetvli t0, zero, e8, mf4 \n\t" + "vmv.v.i v13, 0 \n\t" + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 16 \n\t" + "addi s3, %[B], 32 \n\t" + "addi s4, %[B], 48 \n\t" + + "addi s7, %[B], 64 \n\t" + + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + "LOOP_K%=: \n\t" + "vle32.v v8, (s1) \n\t" + "addi s1, s1, 80 \n\t" + "vle32.v v9, (s2) \n\t" + "addi s2, s2, 96 \n\t" + "vle32.v v10, (s3) \n\t" + "addi s3, s3, 112 \n\t" + "vle32.v v11, (s4) \n\t" + "addi s4, s4, 128 \n\t" + + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + + SQ4BIT_KERNEL_LOAD_ZP_16X1 + + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vsub.vv v0, v0, v8 \n\t" + "vsub.vv v4, v4, v8 \n\t" + "vsub.vv v1, v1, v9 \n\t" + "vsub.vv v5, v5, v9 \n\t" + "vsub.vv v2, v2, v10 \n\t" + "vsub.vv v6, v6, v10 \n\t" + "vsub.vv v3, v3, v11 \n\t" + "vsub.vv v7, v7, v11 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_1X4X4 + "addi s7, s1, 64 \n\t" + + "bnez %[CNT], LOOP_K%= \n\t" + + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6", "s7"); + } + } + } else { + for (size_t n = 0; n < CountN; n += 16) { + size_t nblks = (CountN - n) > 16 ? 16 : CountN - n; + std::byte * QuantBDataPtr = (std::byte *) QuantBData + // + n * BlockCountK * BlkLen / 2 + // b data + n * BlockCountK * sizeof(float); // scale + float * CPtr = C + n; + size_t cnt = BlockCountK; + if (Bias != nullptr) { + const float * bias = Bias + n; + __asm__ volatile( + "addi t3, %[NBLKS], 0 \n\t" + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 16 \n\t" + "addi s3, %[B], 32 \n\t" + "addi s4, %[B], 48 \n\t" + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v28, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v29, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v30, (%[BIAS]) \n\t" + "sub t3, t3, t0 \n\t" + "addi %[BIAS], %[BIAS], 16 \n\t" + "vsetvli t0, t3, e32, mf2 \n\t" + "vle32.v v31, (%[BIAS]) \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + "LOOP_K%=: \n\t" + "vle32.v v8, (s1) \n\t" + "addi s1, s1, 64 \n\t" + "vle32.v v9, (s2) \n\t" + "addi s2, s2, 80 \n\t" + "vle32.v v10, (s3) \n\t" + "addi s3, s3, 96 \n\t" + "vle32.v v11, (s4) \n\t" + "addi s4, s4, 112 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vadd.vi v0, v0, -8 \n\t" + "vadd.vi v1, v1, -8 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_1X4X4 + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks), [BIAS] "+r"(bias) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); + } else { + __asm__ volatile( + "vsetvli t0, zero, e32, m4 \n\t" + "vxor.vv v28, v28, v28 \n\t" + "addi s1, %[B], 0 \n\t" + "addi s2, %[B], 16 \n\t" + "addi s3, %[B], 32 \n\t" + "addi s4, %[B], 48 \n\t" + + "addi s5, %[A], 0 \n\t" + "addi s6, %[A], 12 \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + "LOOP_K%=: \n\t" + "vle32.v v8, (s1) \n\t" + "addi s1, s1, 64 \n\t" + "vle32.v v9, (s2) \n\t" + "addi s2, s2, 80 \n\t" + "vle32.v v10, (s3) \n\t" + "addi s3, s3, 96 \n\t" + "vle32.v v11, (s4) \n\t" + "addi s4, s4, 112 \n\t" + "flw f1, (s5) \n\t" + "addi s5, s5, 4 \n\t" + + "addi t5, %[INNER], 0 \n\t" + "vxor.vv v16, v16, v16 \n\t" + "vxor.vv v18, v18, v18 \n\t" + "vxor.vv v20, v20, v20 \n\t" + "vxor.vv v22, v22, v22 \n\t" + "vfmul.vf v24, v8, f1 \n\t" + "vfmul.vf v25, v9, f1 \n\t" + "vfmul.vf v26, v10, f1 \n\t" + "vfmul.vf v27, v11, f1 \n\t" + "addi %[CNT], %[CNT], -1 \n\t" + "vsetvli t0, zero, e8, m1 \n\t" + "LOOP_INNER%=: \n\t" + + SQ4BIT_KERNEL_LOAD_1x8x2_4X8X4 + + "vadd.vi v0, v0, -8 \n\t" + "vadd.vi v1, v1, -8 \n\t" + "vadd.vi v2, v2, -8 \n\t" + "vadd.vi v3, v3, -8 \n\t" + "vadd.vi v4, v4, -8 \n\t" + "vadd.vi v5, v5, -8 \n\t" + "vadd.vi v6, v6, -8 \n\t" + "vadd.vi v7, v7, -8 \n\t" + + SQ4BIT_KERNEL_COMP_1x8x2_4X8X4 + + "bnez t5, LOOP_INNER%= \n\t" + "vsetvli t0, zero, e32, mf2 \n\t" + + SQ4BIT_KERNEL_ACC_1X4X4 + + "bnez %[CNT], LOOP_K%= \n\t" + "addi t3, zero, 16 \n\t" + "addi s1, %[C], 16 \n\t" + "addi s2, %[C], 32 \n\t" + "addi s3, %[C], 48 \n\t" + "blt %[NBLKS], t3, ST_TAIL%= \n\t" + "vse32.v v28, (%[C]) \n\t" + "vse32.v v29, (s1) \n\t" + "vse32.v v30, (s2) \n\t" + "vse32.v v31, (s3) \n\t" + "jal x0, END%= \n\t" + + "ST_TAIL%=: \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v28, (%[C]) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v29, (s1) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v30, (s2) \n\t" + "vsetvli t0, %[NBLKS], e32, mf2 \n\t" + "sub %[NBLKS], %[NBLKS], t0 \n\t" + "vse32.v v31, (s3) \n\t" + "END%=: \n\t" + + : [CNT] "+r"(cnt), [NBLKS] "+r"(nblks) + : [INNER] "r"(INNER), [A] "r"(QuantA), [B] "r"(QuantBDataPtr), [C] "r"(CPtr) + : "cc", "t0", "t5", "t3", "f1", "s1", "s2", "s3", "s4", "s5", "s6"); + } + } + } +} + +template +inline void SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen, + const std::byte * QuantA, + const std::byte * QuantBData, + const float * QuantBScale, + const std::byte * QuantBZeroPoint, + float * C, + size_t CountM, + size_t CountN, + size_t BlockStrideQuantB, + const float * Bias, + const size_t ldc, + const size_t scalestride) { + if (scalestride == 4) { + SQ4BitGemmM4Kernel_CompInt8_Impl(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, + CountN, BlockStrideQuantB, Bias, ldc); + + } else if (scalestride == 2) { + SQ4BitGemmM4Kernel_CompInt8_ScaleFp16_Impl( + BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias, ldc); + } +} + +template +inline void SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(size_t BlkLen, + const std::byte * QuantA, + const std::byte * QuantBData, + const float * QuantBScale, + const std::byte * QuantBZeroPoint, + float * C, + size_t CountM, + size_t CountN, + size_t BlockStrideQuantB, + const float * Bias, + const size_t ldc, + const size_t scalestride) { + if (scalestride == 4) { + SQ4BitGemmM1Kernel_CompInt8_Impl(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, C, + CountN, BlockStrideQuantB, Bias); + } else if (scalestride == 2) { + SQ4BitGemmM1Kernel_CompInt8_ScaleFp16_Impl(BlkLen, QuantA, QuantBData, QuantBScale, + QuantBZeroPoint, C, CountN, BlockStrideQuantB, Bias); + } +} + +} // namespace + +namespace ime1 { +size_t gemm_kernel_i8i4(size_t BlkLen, + const std::byte * QuantA, + const std::byte * QuantBData, + const float * QuantBScale, + const std::byte * QuantBZeroPoint, + float * C, + size_t CountM, + size_t CountN, + size_t CountK, + size_t BlockCountK, + size_t ldc, + const float * Bias, + const size_t ScaleStride) { + GGML_UNUSED(CountM); + GGML_UNUSED(CountK); + GGML_UNUSED(ldc); + if (CountM >= 4) { + if (QuantBZeroPoint != nullptr) { + SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, + C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride); + } else { + SQ4BitGemmM4Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, + QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias, + ldc, ScaleStride); + } + return 4; + } else { + if (QuantBZeroPoint != nullptr) { + SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, QuantBZeroPoint, + C, CountM, CountN, BlockCountK, Bias, ldc, ScaleStride); + } else { + SQ4BitGemmM1Kernel_CompInt8_DispatchOnBlkLen(BlkLen, QuantA, QuantBData, QuantBScale, + QuantBZeroPoint, C, CountM, CountN, BlockCountK, Bias, + ldc, ScaleStride); + } + return 1; + } +} +} // namespace ime1 +} // namespace sqnbitgemm_spacemit_ime diff --git a/ggml/src/ggml-cpu/spacemit/ime_kernels.h b/ggml/src/ggml-cpu/spacemit/ime_kernels.h new file mode 100644 index 0000000000..7570634150 --- /dev/null +++ b/ggml/src/ggml-cpu/spacemit/ime_kernels.h @@ -0,0 +1,26 @@ +#pragma once + +#include + +namespace sqnbitgemm_spacemit_ime { +namespace ime1 { +size_t gemm_kernel_i8i4(size_t blk_len, + const std::byte * quant_a_ptr, + const std::byte * quant_b_data, + const float * quant_b_scale, + const std::byte * quant_b_zp, + float * c_ptr, + size_t count_m, + size_t count_n, + size_t count_k, + size_t block_count_k, + size_t ldc, + const float * bias, + const size_t scale_stride); + +void quantize_a_row_i8(size_t blk_len, const float * a_ptr, size_t count_k, std::byte * quant_a_ptr); + +void quantize_a_4row_i8(size_t blk_len, const float * a_ptr, size_t count_k, std::byte * quant_a_ptr); + +} // namespace ime1 +} // namespace sqnbitgemm_spacemit_ime diff --git a/ggml/src/ggml-cpu/vec.h b/ggml/src/ggml-cpu/vec.h index ef334d089d..341e64e64f 100644 --- a/ggml/src/ggml-cpu/vec.h +++ b/ggml/src/ggml-cpu/vec.h @@ -610,7 +610,7 @@ inline static void ggml_vec_mad1_f32(const int n, float * y, const float * x, co for (int i = 0; i < np; i += GGML_F32_STEP) { for (int j = 0; j < GGML_F32_ARR; j++) { ay[j] = GGML_F32_VEC_LOAD(x + i + j*GGML_F32_EPR); - ay[j] = GGML_F32_VEC_FMA(ay[j], vs, vb); + ay[j] = GGML_F32_VEC_FMA(vb, ay[j], vs); GGML_F32_VEC_STORE(y + i + j*GGML_F32_EPR, ay[j]); } diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index 1b763a6289..746f43966b 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -329,7 +329,11 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg } else #endif // GGML_USE_MUSA && GGML_MUSA_MUDNN_COPY { - CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); + if (src0->type == GGML_TYPE_F32) { + ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else { + CUDA_CHECK(cudaMemcpyAsync(src1_ddc, src0_ddc, ggml_nbytes(src0), cudaMemcpyDeviceToDevice, main_stream)); + } } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); @@ -400,7 +404,13 @@ void ggml_cuda_dup(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void* ggml_cuda_cpy_fn(const ggml_tensor * src0, ggml_tensor * src1) { if (src0->type == src1->type && ggml_is_contiguous(src0) && ggml_is_contiguous(src1)) { - return nullptr; + // Prioritize CUDA graph compatibility over direct memory copy optimization. + // Using copy kernels here maintains graph indirection support, preventing performance regression from disabled CUDA graphs. + if (src0->type == GGML_TYPE_F32) { + return (void*) cpy_flt>; + } else { + return nullptr; + } } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32) { return (void*) cpy_flt>; } else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_BF16) { diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 5cd1e0d862..b7e81b21bc 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2641,6 +2641,8 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud const std::string ffn_moe_gate_bias_prefix = "ffn_moe_gate_biased"; const std::string ffn_moe_up_bias_prefix = "ffn_moe_up_biased"; const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased"; + const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out"; + const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d"; for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; @@ -2669,7 +2671,9 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud (node->src[1] ? node->src[1]->name != gemma3n_per_layer_proj_src1_name : true) && strncmp(node->name, ffn_moe_gate_bias_prefix.c_str(), ffn_moe_gate_bias_prefix.size()) != 0 && strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 && - strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0) { + strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 && + strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 && + strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) { // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation // by means of matching node names. See // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and @@ -3639,9 +3643,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CONV_TRANSPOSE_2D: case GGML_OP_POOL_2D: case GGML_OP_SUM: - case GGML_OP_ARGSORT: case GGML_OP_ACC: return true; + case GGML_OP_ARGSORT: + // TODO: Support arbitrary column width + return op->src[0]->ne[0] <= 1024; case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: case GGML_OP_GROUP_NORM: diff --git a/ggml/src/ggml-metal/ggml-metal-device.cpp b/ggml/src/ggml-metal/ggml-metal-device.cpp index 0bf7fe9f92..819f31c8a3 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -495,22 +495,17 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_ case GGML_TYPE_F16: case GGML_TYPE_BF16: { - if (ne00 == 4) { + if (ne00 < 32) { nsg = 1; nr0 = 32; - nr1 = 4; - suffix = "_c4"; - } else if (ne00 % 4 == 0) { - nsg = N_SG_F; - nr0 = N_R0_F; nr1 = 1; - smem = 32*sizeof(float)*N_R0_F; - suffix = "_4"; + suffix = "_short"; } else { - nsg = N_SG_F; - nr0 = N_R0_F; + nsg = std::min(4, (ne00 + 127) / 128); + nr0 = 2; nr1 = 1; - smem = 32*sizeof(float)*N_R0_F; + smem = 32*sizeof(float)*nr0; + suffix = ne00 % 4 == 0 ? "_4" : ""; } } break; case GGML_TYPE_Q4_0: @@ -727,18 +722,11 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra case GGML_TYPE_F16: case GGML_TYPE_BF16: { - if (ne00 % 4 == 0) { - nsg = N_SG_F; - nr0 = N_R0_F; - nr1 = 1; - smem = 32*sizeof(float)*N_R0_F; - suffix = "_4"; - } else { - nsg = N_SG_F; - nr0 = N_R0_F; - nr1 = 1; - smem = 32*sizeof(float)*N_R0_F; - } + nsg = std::min(4, (ne00 + 127) / 128); + nr0 = 2; + nr1 = 1; + smem = 32*sizeof(float)*nr0; + suffix = ne00 % 4 == 0 ? "_4" : ""; } break; case GGML_TYPE_Q4_0: { diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index cced0369d0..523f9d71ba 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -683,9 +683,11 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0); case GGML_OP_PAD_REFLECT_1D: case GGML_OP_TIMESTEP_EMBEDDING: - case GGML_OP_ARGSORT: case GGML_OP_LEAKY_RELU: return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_ARGSORT: + // TODO: Support arbitrary column width + return op->src[0]->ne[0] <= 1024; case GGML_OP_ARANGE: return true; case GGML_OP_FLASH_ATTN_EXT: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index d355c6dfc7..88c98423eb 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -8,9 +8,6 @@ // // TODO: for optimal performance, become function of the device and work size -#define N_R0_F 2 -#define N_SG_F 4 - #define N_R0_Q4_0 4 #define N_SG_Q4_0 2 @@ -352,6 +349,7 @@ typedef struct { uint64_t nb13; int32_t ne0; int32_t ne1; + int32_t nr0; int16_t r2; int16_t r3; } ggml_metal_kargs_mul_mv; @@ -427,6 +425,7 @@ typedef struct { int32_t ne0; int32_t ne1; uint64_t nb1; + int32_t nr0; } ggml_metal_kargs_mul_mv_id; // NORM diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index d7267a6aed..e85a223c01 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -1565,6 +1565,12 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { } else { ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op); + const int nr0 = ggml_metal_pipeline_get_nr0(pipeline); + const int nr1 = ggml_metal_pipeline_get_nr1(pipeline); + const int nsg = ggml_metal_pipeline_get_nsg(pipeline); + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + ggml_metal_kargs_mul_mv args = { /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, @@ -1582,16 +1588,11 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { /*.nb13 =*/ nb13, /*.ne0 =*/ ne0, /*.ne1 =*/ ne1, + /*.nr0 =*/ nr0, /*.r2 =*/ r2, /*.r3 =*/ r3, }; - const int nr0 = ggml_metal_pipeline_get_nr0(pipeline); - const int nr1 = ggml_metal_pipeline_get_nr1(pipeline); - const int nsg = ggml_metal_pipeline_get_nsg(pipeline); - - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); - ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); @@ -1758,6 +1759,14 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1); } } else { + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op); + + const int nr0 = ggml_metal_pipeline_get_nr0(pipeline); + const int nr1 = ggml_metal_pipeline_get_nr1(pipeline); + const int nsg = ggml_metal_pipeline_get_nsg(pipeline); + + const size_t smem = ggml_metal_pipeline_get_smem(pipeline); + ggml_metal_kargs_mul_mv_id args = { /*.nei0 =*/ ne20, /*.nei1 =*/ ne21, @@ -1778,16 +1787,9 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) { /*.ne0 =*/ ne0, /*.ne1 =*/ ne1, /*.nb1 =*/ nb1, + /*.nr0 =*/ nr0, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op); - - const int nr0 = ggml_metal_pipeline_get_nr0(pipeline); - const int nr1 = ggml_metal_pipeline_get_nr1(pipeline); - const int nsg = ggml_metal_pipeline_get_nsg(pipeline); - - const size_t smem = ggml_metal_pipeline_get_smem(pipeline); - if (ggml_is_quantized(op->src[0]->type)) { GGML_ASSERT(ne00 >= nsg*nr0); } diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 0271fd5b25..96df6f0ce6 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3531,7 +3531,25 @@ void kernel_mul_mv_t_t_impl( helper_mv_reduce_and_write(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); } -template +template +void kernel_mul_mv_t_t_disp( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + switch (args.nr0) { + //case 1: kernel_mul_mv_t_t_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + case 2: kernel_mul_mv_t_t_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + //case 3: kernel_mul_mv_t_t_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + //case 4: kernel_mul_mv_t_t_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + } +} + +template kernel void kernel_mul_mv_t_t( constant ggml_metal_kargs_mul_mv & args, device const char * src0, @@ -3541,17 +3559,17 @@ kernel void kernel_mul_mv_t_t( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_t_t_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_t_t_disp(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -typedef decltype(kernel_mul_mv_t_t) mul_mv_t_t; +typedef decltype(kernel_mul_mv_t_t) mul_mv_t_t; -template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t; -template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t; -template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t_t kernel_mul_mv_t_t; +template [[host_name("kernel_mul_mv_f32_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t; +template [[host_name("kernel_mul_mv_f16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t; +template [[host_name("kernel_mul_mv_f16_f16")]] kernel mul_mv_t_t kernel_mul_mv_t_t; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t; -template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t; +template [[host_name("kernel_mul_mv_bf16_f32")]] kernel mul_mv_t_t kernel_mul_mv_t_t; +template [[host_name("kernel_mul_mv_bf16_bf16")]] kernel mul_mv_t_t kernel_mul_mv_t_t; #endif template @@ -3637,7 +3655,25 @@ void kernel_mul_mv_t_t_4_impl( helper_mv_reduce_and_write(dst_f32, sumf, r0, args.ne01, tiisg, sgitg, shmem); } -template +template +void kernel_mul_mv_t_t_4_disp( + args_t args, + device const char * src0, + device const char * src1, + device char * dst, + threadgroup char * shmem, + uint3 tgpig, + ushort tiisg, + ushort sgitg) { + switch (args.nr0) { + //case 1: kernel_mul_mv_t_t_4_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + case 2: kernel_mul_mv_t_t_4_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + //case 3: kernel_mul_mv_t_t_4_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + //case 4: kernel_mul_mv_t_t_4_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); break; + }; +} + +template kernel void kernel_mul_mv_t_t_4( constant ggml_metal_kargs_mul_mv & args, device const char * src0, @@ -3647,23 +3683,21 @@ kernel void kernel_mul_mv_t_t_4( uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { - kernel_mul_mv_t_t_4_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + kernel_mul_mv_t_t_4_disp(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -typedef decltype(kernel_mul_mv_t_t_4) mul_mv_t_t_4; +typedef decltype(kernel_mul_mv_t_t_4) mul_mv_t_t_4; -template [[host_name("kernel_mul_mv_f32_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; -template [[host_name("kernel_mul_mv_f16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; -template [[host_name("kernel_mul_mv_f16_f16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +template [[host_name("kernel_mul_mv_f32_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +template [[host_name("kernel_mul_mv_f16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +template [[host_name("kernel_mul_mv_f16_f16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_mul_mv_bf16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; -template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +template [[host_name("kernel_mul_mv_bf16_f32_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; +template [[host_name("kernel_mul_mv_bf16_bf16_4")]] kernel mul_mv_t_t_4 kernel_mul_mv_t_t_4; #endif -#define N_MV_T_T 4 - -template -void kernel_mul_mv_c4_impl( +template +void kernel_mul_mv_t_t_short_impl( args_t args, device const char * src0, device const char * src1, @@ -3671,7 +3705,7 @@ void kernel_mul_mv_c4_impl( uint3 tgpig, ushort tiisg) { const int r0 = tgpig.x*32 + tiisg; - const int rb = tgpig.y*N_MV_T_T; + const int r1 = tgpig.y; const int im = tgpig.z; if (r0 >= args.ne01) { @@ -3683,33 +3717,32 @@ void kernel_mul_mv_c4_impl( const uint64_t offset0 = r0*args.nb01 + (i12/args.r2)*args.nb02 + (i13/args.r3)*args.nb03; - device const T04 * x = (device const T04 *) (src0 + offset0); + device const T0 * x = (device const T0 *) (src0 + offset0); device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1; - for (int row = 0; row < N_MV_T_T; ++row) { - int r1 = rb + row; - if (r1 >= args.ne11) { - break; - } + const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; - const uint64_t offset1 = r1*args.nb11 + (i12 )*args.nb12 + (i13 )*args.nb13; + device const T1 * y = (device const T1 *) (src1 + offset1); - device const T14 * y = (device const T14 *) (src1 + offset1); + float res = 0.0f; - dst_f32[(uint64_t)r1*args.ne0 + r0] = dot((float4) x[0], (float4) y[0]); + for (int i = 0; i < args.ne00; ++i) { + res += (float) x[i] * (float) y[i]; } + + dst_f32[(uint64_t)r1*args.ne0 + r0] = res; } -template -kernel void kernel_mul_mv_c4( +template +kernel void kernel_mul_mv_t_t_short( constant ggml_metal_kargs_mul_mv & args, device const char * src0, device const char * src1, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]]) { - kernel_mul_mv_c4_impl( + kernel_mul_mv_t_t_short_impl( args, src0, src1, @@ -3718,14 +3751,14 @@ kernel void kernel_mul_mv_c4( tiisg); } -typedef decltype(kernel_mul_mv_c4) mul_mv_c4_t; +typedef decltype(kernel_mul_mv_t_t_short) mul_mv_t_t_short_t; -template [[host_name("kernel_mul_mv_f32_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; -template [[host_name("kernel_mul_mv_f16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; -template [[host_name("kernel_mul_mv_f16_f16_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; +template [[host_name("kernel_mul_mv_f32_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short; +template [[host_name("kernel_mul_mv_f16_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short; +template [[host_name("kernel_mul_mv_f16_f16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_mul_mv_bf16_f32_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; -template [[host_name("kernel_mul_mv_bf16_bf16_c4")]] kernel mul_mv_c4_t kernel_mul_mv_c4; +template [[host_name("kernel_mul_mv_bf16_f32_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short; +template [[host_name("kernel_mul_mv_bf16_bf16_short")]] kernel mul_mv_t_t_short_t kernel_mul_mv_t_t_short; #endif static float rope_yarn_ramp(const float low, const float high, const int i0) { @@ -8458,7 +8491,7 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f16")]] kernel mul_mm_id kernel_m // matrix-vector multiplication // -typedef void (kernel_mul_mv_impl_t)( +typedef void (kernel_mul_mv_disp_t)( ggml_metal_kargs_mul_mv args, device const char * src0, device const char * src1, @@ -8466,7 +8499,7 @@ typedef void (kernel_mul_mv_impl_t)( uint3 tgpig, ushort tiisg); -typedef void (kernel_mul_mv2_impl_t)( +typedef void (kernel_mul_mv2_disp_t)( ggml_metal_kargs_mul_mv args, device const char * src0, device const char * src1, @@ -8476,7 +8509,7 @@ typedef void (kernel_mul_mv2_impl_t)( ushort tiisg, ushort sgitg); -template +template void mmv_fn( ggml_metal_kargs_mul_mv args, device const char * src0, @@ -8487,10 +8520,10 @@ void mmv_fn( ushort tiitg, ushort tiisg, ushort sgitg) { - impl_fn(args, src0, src1, dst, tgpig, tiisg); + disp_fn(args, src0, src1, dst, tgpig, tiisg); } -template +template void mmv_fn( ggml_metal_kargs_mul_mv args, device const char * src0, @@ -8501,12 +8534,12 @@ void mmv_fn( ushort tiitg, ushort tiisg, ushort sgitg) { - impl_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); + disp_fn(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -typedef decltype(mmv_fn>) mul_mv_impl_fn_t; +typedef decltype(mmv_fn>) mul_mv_disp_fn_t; -template +template kernel void kernel_mul_mv_id( constant ggml_metal_kargs_mul_mv_id & args, device const char * src0s, @@ -8553,11 +8586,12 @@ kernel void kernel_mul_mv_id( /*.nb13 =*/ args.nb12, // ne12 == 1 /*.ne0 =*/ args.ne0, /*.ne1 =*/ 1, // args.ne1, + /*.nr0 =*/ args.nr0, /*.r2 =*/ 1, /*.r3 =*/ 1, }; - impl_fn( + disp_fn( args0, /* src0 */ src0_cur, /* src1 */ src1_cur, @@ -8569,19 +8603,19 @@ kernel void kernel_mul_mv_id( sgitg); } -typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_t; +typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_t; -typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_4_t; +typedef decltype(kernel_mul_mv_id>>) kernel_mul_mv_id_4_t; -template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_f32_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_f16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_bf16_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; #endif -template [[host_name("kernel_mul_mv_id_f32_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id>>; -template [[host_name("kernel_mul_mv_id_f16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_f32_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_f16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id>>; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id>>; +template [[host_name("kernel_mul_mv_id_bf16_f32_4")]] kernel kernel_mul_mv_id_4_t kernel_mul_mv_id>>; #endif template [[host_name("kernel_mul_mv_id_q8_0_f32")]] kernel kernel_mul_mv_id_t kernel_mul_mv_id>>; diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 0cf3b92464..79d2148744 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -2889,10 +2889,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te case GGML_OP_REPEAT: return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded case GGML_OP_PAD: - return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 && - op->src[0]->ne[3] == 1 && op->ne[3] == 1 && - (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) && - (ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0); + return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_UPSCALE: return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_CONV_2D: @@ -4222,15 +4219,19 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c GGML_ASSERT(dst); GGML_ASSERT(dst->extra); - const int ne00 = src0 ? src0->ne[0] : 0; - const cl_ulong nb01 = src0 ? src0->nb[1] : 0; - const cl_ulong nb02 = src0 ? src0->nb[2] : 0; - const int ne10 = src1 ? src1->ne[0] : 0; - const cl_ulong nb10 = src1 ? src1->nb[0] : 0; - const int ne11 = src1 ? src1->ne[1] : 0; - const cl_ulong nb11 = src1 ? src1->nb[1] : 0; - const cl_ulong nb1 = dst ? dst->nb[1] : 0; - const cl_ulong nb2 = dst ? dst->nb[2] : 0; + const int ne00 = src0->ne[0]; + const cl_ulong nb01 = src0->nb[1]; + const cl_ulong nb02 = src0->nb[2]; + const cl_ulong nb03 = src0->nb[3]; + const int ne10 = src1->ne[0]; + const cl_ulong nb10 = src1->nb[0]; + const int ne11 = src1->ne[1]; + const int ne12 = src1->ne[2]; + const cl_ulong nb11 = src1->nb[1]; + const cl_ulong nb12 = src1->nb[2]; + const cl_ulong nb1 = dst->nb[1]; + const cl_ulong nb2 = dst->nb[2]; + const cl_ulong nb3 = dst->nb[3]; ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -4267,14 +4268,17 @@ static void ggml_cl_get_rows(ggml_backend_t backend, const ggml_tensor * src0, c CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01)); CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); - CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb10)); - CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb11)); - CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb1)); - CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb10)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb1)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(cl_ulong), &nb2)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &nb3)); - size_t global_work_size[] = {(size_t)ne10, (size_t)ne11, 1}; - size_t local_work_size[] = {1, 1, 1}; + size_t global_work_size[] = {(size_t)ne10*64, (size_t)ne11, (size_t)ne12}; + size_t local_work_size[] = {64, 1, 1}; backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } @@ -5874,7 +5878,6 @@ static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_t GGML_ASSERT(dst->extra); GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32); - GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; @@ -5892,28 +5895,67 @@ static void ggml_cl_pad(ggml_backend_t backend, const ggml_tensor * src0, ggml_t const int s_ne0 = src0->ne[0]; const int s_ne1 = src0->ne[1]; const int s_ne2 = src0->ne[2]; + const int s_ne3 = src0->ne[3]; + + const int s_nb0 = src0->nb[0]; + const int s_nb1 = src0->nb[1]; + const int s_nb2 = src0->nb[2]; + const int s_nb3 = src0->nb[3]; const int d_ne0 = dst->ne[0]; const int d_ne1 = dst->ne[1]; const int d_ne2 = dst->ne[2]; + const int d_ne3 = dst->ne[3]; + + const int d_nb0 = dst->nb[0]; + const int d_nb1 = dst->nb[1]; + const int d_nb2 = dst->nb[2]; + const int d_nb3 = dst->nb[3]; + + const int lp0 = ((const int*)(dst->op_params))[0]; + const int rp0 = ((const int*)(dst->op_params))[1]; + const int lp1 = ((const int*)(dst->op_params))[2]; + const int rp1 = ((const int*)(dst->op_params))[3]; + const int lp2 = ((const int*)(dst->op_params))[4]; + const int rp2 = ((const int*)(dst->op_params))[5]; + const int lp3 = ((const int*)(dst->op_params))[6]; + const int rp3 = ((const int*)(dst->op_params))[7]; cl_kernel kernel = backend_ctx->kernel_pad; - CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); - CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); - CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device)); - CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst)); - CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &s_ne0)); - CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &s_ne1)); - CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &s_ne2)); - CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &d_ne0)); - CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &d_ne1)); - CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &d_ne2)); + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra_src0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &off_src0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra_dst->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &off_dst)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &s_ne0)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &s_ne1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &s_ne2)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &s_ne3)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &s_nb0)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &s_nb1)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &s_nb2)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &s_nb3)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &d_ne0)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &d_ne1)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &d_ne2)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &d_ne3)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(cl_ulong), &d_nb0)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(cl_ulong), &d_nb1)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(cl_ulong), &d_nb2)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(cl_ulong), &d_nb3)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &lp0)); + CL_CHECK(clSetKernelArg(kernel, 21, sizeof(int), &rp0)); + CL_CHECK(clSetKernelArg(kernel, 22, sizeof(int), &lp1)); + CL_CHECK(clSetKernelArg(kernel, 23, sizeof(int), &rp1)); + CL_CHECK(clSetKernelArg(kernel, 24, sizeof(int), &lp2)); + CL_CHECK(clSetKernelArg(kernel, 25, sizeof(int), &rp2)); + CL_CHECK(clSetKernelArg(kernel, 26, sizeof(int), &lp3)); + CL_CHECK(clSetKernelArg(kernel, 27, sizeof(int), &rp3)); size_t lws0 = 64; size_t gws0 = (( (size_t)d_ne0 + lws0 - 1 ) / lws0) * lws0; - size_t global_work_size[] = { gws0, (size_t)d_ne1, (size_t)d_ne2 }; + size_t global_work_size[] = { gws0, (size_t)d_ne1, (size_t)d_ne2*d_ne3 }; size_t local_work_size[] = { lws0, 1, 1 }; size_t * local_work_size_ptr = local_work_size; diff --git a/ggml/src/ggml-opencl/kernels/get_rows.cl b/ggml/src/ggml-opencl/kernels/get_rows.cl index b3fea2923d..c2962edc98 100644 --- a/ggml/src/ggml-opencl/kernels/get_rows.cl +++ b/ggml/src/ggml-opencl/kernels/get_rows.cl @@ -69,11 +69,14 @@ kernel void kernel_get_rows_f32( int ne00, ulong nb01, ulong nb02, + ulong nb03, int ne10, ulong nb10, ulong nb11, + ulong nb12, ulong nb1, - ulong nb2 + ulong nb2, + ulong nb3 ) { src0 = (global void*)((global char*)src0 + offset0); src1 = (global int*)((global char*)src1 + offset1); @@ -81,14 +84,19 @@ kernel void kernel_get_rows_f32( int i10 = get_group_id(0); int i11 = get_group_id(1); + int i12 = get_group_id(2); - int r = ((global int *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; + int r = ((global int *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0]; int i02 = i11; + int i03 = i12; for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { - ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = - ((global float *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; + if (ind >= ne00) { + return; + } + ((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] = + ((global float *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind]; } } @@ -102,11 +110,14 @@ kernel void kernel_get_rows_f16( int ne00, ulong nb01, ulong nb02, + ulong nb03, int ne10, ulong nb10, ulong nb11, + ulong nb12, ulong nb1, - ulong nb2 + ulong nb2, + ulong nb3 ) { src0 = (global void*)((global char*)src0 + offset0); src1 = (global int*)((global char*)src1 + offset1); @@ -114,14 +125,19 @@ kernel void kernel_get_rows_f16( int i10 = get_group_id(0); int i11 = get_group_id(1); + int i12 = get_group_id(2); - int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; + int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0]; int i02 = i11; + int i03 = i12; for (int ind = get_local_id(0); ind < ne00; ind += get_local_size(0)) { - ((global float *) ((global char *) dst + i11*nb2 + i10*nb1))[ind] = - ((global half *) ((global char *) src0 + r*nb01 + i02*nb02))[ind]; + if (ind >= ne00) { + return; + } + ((global float *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1))[ind] = + ((global half *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03))[ind]; } } @@ -135,11 +151,14 @@ kernel void kernel_get_rows_q4_0( int ne00, ulong nb01, ulong nb02, + ulong nb03, int ne10, ulong nb10, ulong nb11, + ulong nb12, ulong nb1, - ulong nb2 + ulong nb2, + ulong nb3 ) { src0 = (global void*)((global char*)src0 + offset0); src1 = (global int*)((global char*)src1 + offset1); @@ -149,15 +168,20 @@ kernel void kernel_get_rows_q4_0( int i10 = get_group_id(0); int i11 = get_group_id(1); + int i12 = get_group_id(2); - int r = ((global int32_t *) ((global char *) src1 + i11*nb11 + i10*nb10))[0]; + int r = ((global int32_t *) ((global char *) src1 + i12*nb12 + i11*nb11 + i10*nb10))[0]; int i02 = i11; + int i03 = i12; for (int ind = get_local_id(0); ind < ne00/16; ind += get_local_size(0)) { float16 temp; + if (ind >= ne00) { + return; + } dequantize_q4_0_f32( - ((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02)) + ind/NL, ind%NL, &temp); - *(((global float16 *) ((global char *) dst + i11*nb2 + i10*nb1)) + ind) = temp; + ((global struct block_q4_0 *) ((global char *) src0 + r*nb01 + i02*nb02 + i03*nb03)) + ind/NL, ind%NL, &temp); + *(((global float16 *) ((global char *) dst + i12*nb3 + i11*nb2 + i10*nb1)) + ind) = temp; } } diff --git a/ggml/src/ggml-opencl/kernels/pad.cl b/ggml/src/ggml-opencl/kernels/pad.cl index 747fa7febc..31fb7ccd3b 100644 --- a/ggml/src/ggml-opencl/kernels/pad.cl +++ b/ggml/src/ggml-opencl/kernels/pad.cl @@ -1,30 +1,39 @@ kernel void kernel_pad( - global const void * src0_ptr, - ulong src0_offset, - global void * dst_ptr, - ulong dst_offset, - int s_ne0, int s_ne1, int s_ne2, - int d_ne0, int d_ne1, int d_ne2 + global void * src0, + ulong offset0, + global void * dst, + ulong offsetd, + int ne00, int ne01, int ne02, int ne03, + ulong nb00, ulong nb01, ulong nb02, ulong nb03, + int ne0, int ne1, int ne2, int ne3, + ulong nb0, ulong nb1, ulong nb2, ulong nb3, + int lp0, int rp0, + int lp1, int rp1, + int lp2, int rp2, + int lp3, int rp3 ) { - global const float * src0 = (global const float *)((global const char *)src0_ptr + src0_offset); - global float * dst = (global float *)((global char *)dst_ptr + dst_offset); + src0 = (global float*)((global char*)src0 + offset0); + dst = (global float*)((global char*)dst + offsetd); - int nidx = get_global_id(0); - int idx_d1 = get_group_id(1); - int idx_d2 = get_group_id(2); + int i0 = get_global_id(0); + int i1 = get_group_id(1); + int i2 = get_group_id(2) % ne2; + int i3 = get_group_id(2) / ne2; - if (nidx >= d_ne0) { + if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { return; } - int dst_el_offset = nidx + idx_d1 * d_ne0 + idx_d2 * d_ne0 * d_ne1; + uint src0_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00; + uint dst_idx = i3*nb3 + i2*nb2 + i1*nb1 + i0*nb0; - bool in_src_bounds = (nidx < s_ne0) && (idx_d1 < s_ne1) && (idx_d2 < s_ne2); + global float * src0_ptr = (global float *)((global char *)src0 + src0_idx); + global float * dst_ptr = (global float *)((global char *)dst + dst_idx); - if (in_src_bounds) { - int src_el_offset = nidx + idx_d1 * s_ne0 + idx_d2 * s_ne0 * s_ne1; - dst[dst_el_offset] = src0[src_el_offset]; - } else { - dst[dst_el_offset] = 0.0f; - } + bool in_src_bounds = (i0 >= lp0 && i0 < ne0 - rp0) && + (i1 >= lp1 && i1 < ne1 - rp1) && + (i2 >= lp2 && i2 < ne2 - rp2) && + (i3 >= lp3 && i3 < ne3 - rp3); + + *dst_ptr = in_src_bounds ? *src0_ptr : 0.0f; } diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 2608cbd068..003a901067 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -9,8 +9,14 @@ #define VULKAN_HPP_DISPATCH_LOADER_DYNAMIC 1 // We use VULKAN_HPP_DEFAULT_DISPATCHER, but not VULKAN_HPP_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE // to avoid conflicts with applications or other libraries who might use it. +#if VK_HEADER_VERSION >= 301 namespace vk::detail { class DispatchLoaderDynamic; } -vk::detail::DispatchLoaderDynamic & ggml_vk_default_dispatcher(); +using vk::detail::DispatchLoaderDynamic; +#else +namespace vk { class DispatchLoaderDynamic; } +using vk::DispatchLoaderDynamic; +#endif +DispatchLoaderDynamic & ggml_vk_default_dispatcher(); #define VULKAN_HPP_DEFAULT_DISPATCHER ggml_vk_default_dispatcher() #include @@ -4538,9 +4544,8 @@ static bool ggml_vk_instance_portability_enumeration_ext_available(const std::ve static bool ggml_vk_instance_debug_utils_ext_available(const std::vector & instance_extensions); static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev); -static vk::detail::DispatchLoaderDynamic ggml_vk_default_dispatcher_instance; - -vk::detail::DispatchLoaderDynamic & ggml_vk_default_dispatcher() { +static DispatchLoaderDynamic ggml_vk_default_dispatcher_instance; +DispatchLoaderDynamic & ggml_vk_default_dispatcher() { return ggml_vk_default_dispatcher_instance; } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp index e80eff2781..9b1f153bf7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.comp @@ -67,30 +67,48 @@ layout (binding = 5) writeonly buffer O {D_TYPE data_o[];}; #if defined(A_TYPE_PACKED16) #define BINDING_IDX_K 0 #define BINDING_IDX_V 1 -layout (binding = 1) readonly buffer KV_PACKED16 {A_TYPE_PACKED16 data_packed16[];} kv_packed[2]; +layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed; +layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed; #endif #if defined(DATA_A_Q4_0) #define BLOCK_BYTE_SIZE 18 vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - uint vui_lo = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); - uint vui_hi = uint(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); - uint shift = (iqs & 0x10) >> 2; - vui_lo >>= shift; - vui_hi >>= shift; + if (binding_idx == BINDING_IDX_K) { + uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; - return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); + return float(k_packed.k_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); + } else { + uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + + return float(v_packed.v_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); + } } #endif #if defined(DATA_A_Q8_0) #define BLOCK_BYTE_SIZE 34 vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { - const i8vec2 v0 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 - const i8vec2 v1 = unpack8(int32_t(kv_packed[binding_idx].data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; + if (binding_idx == BINDING_IDX_K) { + const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; - return float(kv_packed[binding_idx].data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); + return float(k_packed.k_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); + } else { + const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; + + return float(v_packed.v_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); + } } #endif diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index cee4b08366..93200a4d29 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -130,13 +130,15 @@ struct webgpu_context_struct { wgpu::ComputePipeline set_rows_pipeline; wgpu::ComputePipeline get_rows_pipeline[30]; wgpu::ComputePipeline get_rows_f32_no_vec_pipeline; - wgpu::ComputePipeline cpy_pipeline; - wgpu::ComputePipeline add_pipeline[2]; - wgpu::ComputePipeline add_ip_pipeline[2]; - wgpu::ComputePipeline mul_pipeline[2]; - wgpu::ComputePipeline mul_ip_pipeline[2]; - wgpu::ComputePipeline rms_norm_pipeline; - wgpu::ComputePipeline rms_norm_ip_pipeline; + wgpu::ComputePipeline cpy_pipeline[2][2]; // src type, dst type + wgpu::ComputePipeline add_pipeline[2][2]; // type, inplace + wgpu::ComputePipeline sub_pipeline[2][2]; // type, inplace + wgpu::ComputePipeline mul_pipeline[2][2]; // type, inplace + wgpu::ComputePipeline div_pipeline[2][2]; // type, inplace + wgpu::ComputePipeline rms_norm_pipeline[2]; // inplace + wgpu::ComputePipeline rope_pipeline[2][2][2]; // type, ff, inplace + wgpu::ComputePipeline glu_pipeline[7][2][2]; // glu-op, type, split + wgpu::ComputePipeline scale_pipeline[2]; // inplace size_t memset_bytes_per_thread; @@ -489,8 +491,9 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor (uint32_t) (src->nb[2] / ggml_type_size(src->type)), (uint32_t) (src->nb[3] / ggml_type_size(src->type)), (uint32_t) (dst->nb[0] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), - // Logical shape — same for both tensors even if permuted - (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) src->ne[3] + // Logical shapes + (uint32_t) src->ne[0], (uint32_t) src->ne[1], (uint32_t) src->ne[2], (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], (uint32_t) dst->ne[2] }; std::vector entries = { @@ -506,7 +509,8 @@ static void ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor size_t max_wg_size = ctx->max_wg_size_x; uint32_t wg_x = (ne + max_wg_size - 1) / max_wg_size; - ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline, params, entries, wg_x, ggml_op_name(dst->op)); + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->cpy_pipeline[src->type][dst->type], params, entries, wg_x, + ggml_op_name(dst->op)); } static void ggml_webgpu_set_rows(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * idx, ggml_tensor * dst) { @@ -649,7 +653,7 @@ static void ggml_webgpu_binary_op(webgpu_context & ctx, ggml_tensor * src1, ggml_tensor * dst, wgpu::ComputePipeline & pipeline, - bool in_place) { + bool inplace) { std::vector params = { (uint32_t) ggml_nelements(dst), (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), @@ -678,7 +682,7 @@ static void ggml_webgpu_binary_op(webgpu_context & ctx, .offset = ggml_webgpu_tensor_align_offset(ctx, src1), .size = ggml_webgpu_tensor_binding_size(ctx, src1) } }; - if (!in_place) { + if (!inplace) { entries.push_back({ .binding = 2, .buffer = ggml_webgpu_tensor_buf(dst), .offset = ggml_webgpu_tensor_align_offset(ctx, dst), @@ -691,30 +695,23 @@ static void ggml_webgpu_binary_op(webgpu_context & ctx, } static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { - bool in_place = ggml_webgpu_tensor_equal(src, dst); - - uint32_t eps; - memcpy(&eps, dst->op_params, sizeof(float)); + int inplace = ggml_webgpu_tensor_equal(src, dst); std::vector params = { (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), + (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + (uint32_t) src->ne[0], + (uint32_t) src->ne[1], + (uint32_t) src->ne[2], + (uint32_t) src->ne[3], + *(uint32_t *) dst->op_params // epsilon, treated as f32 in the shader }; - if (!in_place) { - params.push_back((uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type))); - } - params.push_back((uint32_t) (src->nb[1] / ggml_type_size(src->type))); - params.push_back((uint32_t) (src->nb[2] / ggml_type_size(src->type))); - params.push_back((uint32_t) (src->nb[3] / ggml_type_size(src->type))); - if (!in_place) { - params.push_back((uint32_t) (dst->nb[1] / ggml_type_size(dst->type))); - params.push_back((uint32_t) (dst->nb[2] / ggml_type_size(dst->type))); - params.push_back((uint32_t) (dst->nb[3] / ggml_type_size(dst->type))); - } - params.push_back((uint32_t) src->ne[0]); - params.push_back((uint32_t) src->ne[1]); - params.push_back((uint32_t) src->ne[2]); - params.push_back((uint32_t) src->ne[3]); - params.push_back(eps); // epsilon, will be bitcast to float in shader std::vector entries = { { .binding = 0, @@ -722,24 +719,199 @@ static void ggml_webgpu_rms_norm(webgpu_context & ctx, ggml_tensor * src, ggml_t .offset = ggml_webgpu_tensor_align_offset(ctx, src), .size = ggml_webgpu_tensor_binding_size(ctx, src) } }; - if (!in_place) { + if (!inplace) { entries.push_back({ .binding = 1, .buffer = ggml_webgpu_tensor_buf(dst), .offset = ggml_webgpu_tensor_align_offset(ctx, dst), .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); } - wgpu::ComputePipeline pipeline; - if (in_place) { - pipeline = ctx->rms_norm_ip_pipeline; - } else { - pipeline = ctx->rms_norm_pipeline; - } size_t max_wg_size = ctx->max_wg_size_x; uint32_t wg_x = (src->ne[1] * src->ne[2] * src->ne[3] + max_wg_size - 1) / max_wg_size; + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->rms_norm_pipeline[inplace], params, entries, wg_x, + ggml_op_name(dst->op)); +} + +static void ggml_webgpu_rope(webgpu_context & ctx, + ggml_tensor * src0, + ggml_tensor * src1, + ggml_tensor * src2, + ggml_tensor * dst) { + const int inplace = ggml_webgpu_tensor_equal(src0, dst); + const int has_freq_factor = (src2 != nullptr); + + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + + float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; + memcpy(&freq_base, (int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (int32_t *) dst->op_params + 6, sizeof(float)); + memcpy(&ext_factor, (int32_t *) dst->op_params + 7, sizeof(float)); + memcpy(&attn_factor, (int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (int32_t *) dst->op_params + 10, sizeof(float)); + + int sections[4]; + memcpy(sections, (int32_t *) dst->op_params + 11, 4 * sizeof(int)); + + float theta_scale = powf(freq_base, -2.0f / n_dims); + + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)), + src2 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src2) / ggml_type_size(src2->type)) : 0, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + (uint32_t) ggml_nelements(src0) / 2, + (uint32_t) src0->ne[0], + (uint32_t) src0->ne[1], + (uint32_t) src0->ne[2], + (uint32_t) n_dims, + (uint32_t) mode, + *(uint32_t *) &theta_scale, + *(uint32_t *) &attn_factor, + *(uint32_t *) &freq_scale, + *(uint32_t *) &ext_factor, + *(uint32_t *) &corr_dims[0], + *(uint32_t *) &corr_dims[1], + (uint32_t) sections[0], + (uint32_t) sections[1], + (uint32_t) sections[2], + (uint32_t) sections[3] + }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, + { .binding = 1, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) } + }; + uint32_t dst_binding = 2; + if (has_freq_factor) { + dst_binding = 3; + entries.push_back({ .binding = 2, + .buffer = ggml_webgpu_tensor_buf(src2), + .offset = ggml_webgpu_tensor_align_offset(ctx, src2), + .size = ggml_webgpu_tensor_binding_size(ctx, src2) }); + } + if (!inplace) { + entries.push_back({ .binding = dst_binding, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + } + + wgpu::ComputePipeline pipeline = ctx->rope_pipeline[dst->type][has_freq_factor][inplace]; + size_t max_wg_size = ctx->max_wg_size_x; + uint32_t wg_x = (ggml_nelements(src0) / 2 + max_wg_size - 1) / max_wg_size; ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op)); } +static void ggml_webgpu_glu(webgpu_context & ctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst) { + const int split = (src1 != nullptr); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)), + src1 != nullptr ? (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)) : 0, + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + src1 != nullptr ? (uint32_t) (src1->nb[1] / ggml_type_size(src1->type)) : + (uint32_t) (src0->nb[1] / ggml_type_size(src0->type)), + src1 != nullptr ? (uint32_t) (src1->nb[2] / ggml_type_size(src1->type)) : + (uint32_t) (src0->nb[2] / ggml_type_size(src0->type)), + src1 != nullptr ? (uint32_t) (src1->nb[3] / ggml_type_size(src1->type)) : + (uint32_t) (src0->nb[3] / ggml_type_size(src0->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + (uint32_t) ggml_nelements(dst), + (uint32_t) dst->ne[0], + (uint32_t) dst->ne[1], + (uint32_t) dst->ne[2], + (uint32_t) ((int32_t *) dst->op_params)[1], // swapped + *(uint32_t *) &dst->op_params[2], // alpha, for swiglu_oai + *(uint32_t *) &dst->op_params[3], // limit, for swiglu_oai + }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src0), + .offset = ggml_webgpu_tensor_align_offset(ctx, src0), + .size = ggml_webgpu_tensor_binding_size(ctx, src0) }, + }; + uint32_t dst_binding = 1; + if (split) { + dst_binding = 2; + entries.push_back({ .binding = 1, + .buffer = ggml_webgpu_tensor_buf(src1), + .offset = ggml_webgpu_tensor_align_offset(ctx, src1), + .size = ggml_webgpu_tensor_binding_size(ctx, src1) }); + } + entries.push_back({ .binding = dst_binding, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + + wgpu::ComputePipeline pipeline = ctx->glu_pipeline[ggml_get_glu_op(dst)][dst->type][split]; + size_t max_wg_size = ctx->max_wg_size_x; + uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; + ggml_backend_webgpu_build_and_enqueue(ctx, pipeline, params, entries, wg_x, ggml_op_name(dst->op)); +} + +static void ggml_webgpu_scale(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) { + int inplace = ggml_webgpu_tensor_equal(src, dst); + + std::vector params = { + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src) / ggml_type_size(src->type)), + (uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)), + (uint32_t) (src->nb[1] / ggml_type_size(src->type)), + (uint32_t) (src->nb[2] / ggml_type_size(src->type)), + (uint32_t) (src->nb[3] / ggml_type_size(src->type)), + (uint32_t) (dst->nb[1] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[2] / ggml_type_size(dst->type)), + (uint32_t) (dst->nb[3] / ggml_type_size(dst->type)), + (uint32_t) ggml_nelements(dst), + (uint32_t) src->ne[0], + (uint32_t) src->ne[1], + (uint32_t) src->ne[2], + *(uint32_t *) dst->op_params, // scale + *(uint32_t *) &dst->op_params[1] // bias + }; + + std::vector entries = { + { .binding = 0, + .buffer = ggml_webgpu_tensor_buf(src), + .offset = ggml_webgpu_tensor_align_offset(ctx, src), + .size = ggml_webgpu_tensor_binding_size(ctx, src) } + }; + if (!inplace) { + entries.push_back({ .binding = 1, + .buffer = ggml_webgpu_tensor_buf(dst), + .offset = ggml_webgpu_tensor_align_offset(ctx, dst), + .size = ggml_webgpu_tensor_binding_size(ctx, dst) }); + } + + size_t max_wg_size = ctx->max_wg_size_x; + uint32_t wg_x = (ggml_nelements(dst) + max_wg_size - 1) / max_wg_size; + ggml_backend_webgpu_build_and_enqueue(ctx, ctx->scale_pipeline[inplace], params, entries, wg_x, + ggml_op_name(dst->op)); +} + // Returns true if node has enqueued work into the queue, false otherwise static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { if (ggml_is_empty(node)) { @@ -749,6 +921,7 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { ggml_tensor * src0 = node->src[0]; ggml_tensor * src1 = node->src[1]; + ggml_tensor * src2 = node->src[2]; switch (node->op) { // no-ops @@ -759,6 +932,7 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { case GGML_OP_RESHAPE: return false; case GGML_OP_CPY: + case GGML_OP_CONT: ggml_webgpu_cpy(ctx, src0, node); break; case GGML_OP_SET_ROWS: @@ -771,22 +945,41 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) { ggml_webgpu_mul_mat(ctx, src0, src1, node); break; case GGML_OP_ADD: - if (ggml_webgpu_tensor_equal(src0, node)) { - ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_ip_pipeline[node->type], true); - } else { - ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type], false); + { + int inplace = ggml_webgpu_tensor_equal(src0, node); + ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipeline[node->type][inplace], inplace); + break; + } + case GGML_OP_SUB: + { + int inplace = ggml_webgpu_tensor_equal(src0, node); + ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipeline[node->type][inplace], inplace); + break; } - break; case GGML_OP_MUL: - if (ggml_webgpu_tensor_equal(src0, node)) { - ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_ip_pipeline[node->type], true); - } else { - ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type], false); + { + int inplace = ggml_webgpu_tensor_equal(src0, node); + ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipeline[node->type][inplace], inplace); + break; + } + case GGML_OP_DIV: + { + int inplace = ggml_webgpu_tensor_equal(src0, node); + ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipeline[node->type][inplace], inplace); + break; } - break; case GGML_OP_RMS_NORM: ggml_webgpu_rms_norm(ctx, src0, node); break; + case GGML_OP_ROPE: + ggml_webgpu_rope(ctx, src0, src1, src2, node); + break; + case GGML_OP_GLU: + ggml_webgpu_glu(ctx, src0, src1, node); + break; + case GGML_OP_SCALE: + ggml_webgpu_scale(ctx, src0, node); + break; default: return false; } @@ -1170,40 +1363,153 @@ static void ggml_webgpu_init_get_rows_pipeline(webgpu_context & webgpu_ctx) { } static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) { - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline, wgsl_cpy, "cpy", - ggml_webgpu_max_wg_size_entry(webgpu_ctx)); + std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F32], + wgsl_cpy_f32_f32, "cpy_f32_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F32][GGML_TYPE_F16], + wgsl_cpy_f32_f16, "cpy_f32_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F16][GGML_TYPE_F32], + wgsl_cpy_f16_f32, "cpy_f16_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->cpy_pipeline[GGML_TYPE_F16][GGML_TYPE_F16], + wgsl_cpy_f16_f16, "cpy_f16_f16", constants); } static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32], wgsl_add_f32, "add_f32", + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][0], wgsl_add_f32, "add_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16], wgsl_add_f16, "add_f16", + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][0], wgsl_add_f16, "add_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F32], wgsl_add_in_place_f32, - "add_in_place_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_ip_pipeline[GGML_TYPE_F16], wgsl_add_in_place_f16, - "add_in_place_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F32][1], wgsl_add_f32_inplace, + "add_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->add_pipeline[GGML_TYPE_F16][1], wgsl_add_f16_inplace, + "add_f16_inplace", constants); +} + +static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][0], wgsl_sub_f32, "sub_f32", + constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][0], wgsl_sub_f16, "sub_f16", + constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F32][1], wgsl_sub_f32_inplace, + "sub_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->sub_pipeline[GGML_TYPE_F16][1], wgsl_sub_f16_inplace, + "sub_f16_inplace", constants); } static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32], wgsl_mul_f32, "mul_f32", + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][0], wgsl_mul_f32, "mul_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16], wgsl_mul_f16, "mul_f16", + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][0], wgsl_mul_f16, "mul_f16", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F32], wgsl_mul_in_place_f32, - "mul_in_place_f32", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_ip_pipeline[GGML_TYPE_F16], wgsl_mul_in_place_f16, - "mul_in_place_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F32][1], wgsl_mul_f32_inplace, + "mul_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->mul_pipeline[GGML_TYPE_F16][1], wgsl_mul_f16_inplace, + "mul_f16_inplace", constants); +} + +static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][0], wgsl_div_f32, "div_f32", + constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][0], wgsl_div_f16, "div_f16", + constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F32][1], wgsl_div_f32_inplace, + "div_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->div_pipeline[GGML_TYPE_F16][1], wgsl_div_f16_inplace, + "div_f16_inplace", constants); } static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) { std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline, wgsl_rms_norm, "rms_norm", + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[0], wgsl_rms_norm, "rms_norm", constants); - ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_ip_pipeline, wgsl_rms_norm_in_place, - "rms_norm_in_place", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rms_norm_pipeline[1], wgsl_rms_norm_inplace, + "rms_norm_inplace", constants); +} + +static void ggml_webgpu_init_rope_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][0], wgsl_rope_f32, + "rope_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][0][1], + wgsl_rope_f32_inplace, "rope_f32_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][1][0], wgsl_rope_f32_ff, + "rope_f32_ff", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F32][1][1], + wgsl_rope_f32_ff_inplace, "rope_f32_ff_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][0][0], wgsl_rope_f16, + "rope_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][0][1], + wgsl_rope_f16_inplace, "rope_f16_inplace", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][1][0], wgsl_rope_f16_ff, + "rope_f16_ff", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->rope_pipeline[GGML_TYPE_F16][1][1], + wgsl_rope_f16_ff_inplace, "rope_f16_ff_inplace", constants); +} + +static void ggml_webgpu_init_glu_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + // reglu + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F32][0], + wgsl_reglu_f32, "reglu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F16][0], + wgsl_reglu_f16, "reglu_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F32][1], + wgsl_reglu_f32_split, "reglu_f32_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_REGLU][GGML_TYPE_F16][1], + wgsl_reglu_f16_split, "reglu_f16_split", constants); + // geglu + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][0], + wgsl_geglu_f32, "geglu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][0], + wgsl_geglu_f16, "geglu_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F32][1], + wgsl_geglu_f32_split, "geglu_f32_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU][GGML_TYPE_F16][1], + wgsl_geglu_f16_split, "geglu_f16_split", constants); + // swiglu + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][0], + wgsl_swiglu_f32, "swiglu_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][0], + wgsl_swiglu_f16, "swiglu_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F32][1], + wgsl_swiglu_f32_split, "swiglu_f32_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU][GGML_TYPE_F16][1], + wgsl_swiglu_f16_split, "swiglu_f16_split", constants); + // swiglu_oai + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][0], + wgsl_swiglu_oai_f32, "swiglu_oai_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_SWIGLU_OAI][GGML_TYPE_F32][1], + wgsl_swiglu_oai_f32_split, "swiglu_oai_f32_split", constants); + // geglu_erf + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][0], + wgsl_geglu_erf_f32, "geglu_erf_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][0], + wgsl_geglu_erf_f16, "geglu_erf_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F32][1], + wgsl_geglu_erf_f32_split, "geglu_erf_f32_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_ERF][GGML_TYPE_F16][1], + wgsl_geglu_erf_f16_split, "geglu_erf_f16_split", constants); + // geglu_quick + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][0], + wgsl_geglu_quick_f32, "geglu_quick_f32", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][0], + wgsl_geglu_quick_f16, "geglu_quick_f16", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F32][1], + wgsl_geglu_quick_f32_split, "geglu_quick_f32_split", constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->glu_pipeline[GGML_GLU_OP_GEGLU_QUICK][GGML_TYPE_F16][1], + wgsl_geglu_quick_f16_split, "geglu_quick_f16_split", constants); +} + +static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) { + std::vector constants = ggml_webgpu_max_wg_size_entry(webgpu_ctx); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[0], wgsl_scale_f32, "scale_f32", + constants); + ggml_webgpu_create_pipeline(webgpu_ctx->device, webgpu_ctx->scale_pipeline[1], wgsl_scale_f32_inplace, + "scale_f32_inplace", constants); } static ggml_backend_t ggml_backend_webgpu_device_init(ggml_backend_dev_t dev, const char * params) { @@ -1287,6 +1593,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * src0 = op->src[0]; ggml_tensor * src1 = op->src[1]; + // on smaller devices (or CI), tensors may be larger than the max storage buffer size if (ggml_nbytes(op) > webgpu_ctx->limits.maxStorageBufferBindingSize || (src0 != nullptr && ggml_nbytes(src0) > webgpu_ctx->limits.maxStorageBufferBindingSize) || @@ -1304,28 +1611,34 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const supports_op = true; break; case GGML_OP_ADD: + case GGML_OP_SUB: case GGML_OP_MUL: - supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (op->src[0]->type == op->type) && - (op->src[1]->type == op->type); + case GGML_OP_DIV: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type) && + (src1->type == op->type); break; case GGML_OP_CPY: + case GGML_OP_CONT: + supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + break; case GGML_OP_SET_ROWS: supports_op = (op->type == GGML_TYPE_F16 && op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_I64); break; case GGML_OP_GET_ROWS: - if (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 || - op->src[0]->type == GGML_TYPE_I32 || ggml_webgpu_supported_qtype(op->src[0]->type)) { + if (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_I32 || + ggml_webgpu_supported_qtype(src0->type)) { supports_op = (op->type == GGML_TYPE_F32); } break; case GGML_OP_MUL_MAT: { - switch (op->src[1]->type) { + switch (src1->type) { case GGML_TYPE_F16: - supports_op = (op->src[0]->type == GGML_TYPE_F16); + supports_op |= (src0->type == GGML_TYPE_F16); break; case GGML_TYPE_F32: - switch (op->src[0]->type) { + switch (src0->type) { case GGML_TYPE_F32: case GGML_TYPE_F16: case GGML_TYPE_Q4_0: @@ -1358,7 +1671,29 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const break; } case GGML_OP_RMS_NORM: - supports_op = op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; + supports_op = op->type == GGML_TYPE_F32 && src0->type == GGML_TYPE_F32; + break; + case GGML_OP_ROPE: + supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16; + break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(op)) { + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: + supports_op = op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16; + break; + case GGML_GLU_OP_SWIGLU_OAI: + supports_op = op->type == GGML_TYPE_F32; + break; + default: + break; + } + break; + case GGML_OP_SCALE: + supports_op = op->type == GGML_TYPE_F32; break; default: break; @@ -1484,8 +1819,13 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t ggml_webgpu_init_get_rows_pipeline(ctx); ggml_webgpu_init_cpy_pipeline(ctx); ggml_webgpu_init_add_pipeline(ctx); + ggml_webgpu_init_sub_pipeline(ctx); ggml_webgpu_init_mul_pipeline(ctx); + ggml_webgpu_init_div_pipeline(ctx); ggml_webgpu_init_rms_norm_pipeline(ctx); + ggml_webgpu_init_rope_pipeline(ctx); + ggml_webgpu_init_glu_pipeline(ctx); + ggml_webgpu_init_scale_pipeline(ctx); #ifdef GGML_WEBGPU_DEBUG // Initialize debug buffers diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl deleted file mode 100644 index f261cbb553..0000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl +++ /dev/null @@ -1,44 +0,0 @@ -#define(VARIANTS) - -[ - { - "REPLS": { - "TYPE" : "f32", - } - }, - { - "REPLS": { - "TYPE" : "f16", - } - } -] - -#end(VARIANTS) - -#define(SHADER) - -enable f16; - -#include "binary_head.tmpl" - -@group(0) @binding(0) -var src0: array<{{TYPE}}>; - -@group(0) @binding(1) -var src1: array<{{TYPE}}>; - -@group(0) @binding(2) -var dst: array<{{TYPE}}>; - -@group(0) @binding(3) -var params: Params; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x < params.ne) { - dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_index(gid.x)]; - } -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl deleted file mode 100644 index 903f7bdbcc..0000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/add_in_place.tmpl.wgsl +++ /dev/null @@ -1,41 +0,0 @@ -#define(VARIANTS) - -[ - { - "REPLS": { - "TYPE" : "f32", - } - }, - { - "REPLS": { - "TYPE" : "f16", - } - } -] - -#end(VARIANTS) - -#define(SHADER) - -enable f16; - -#include "binary_head.tmpl" - -@group(0) @binding(0) -var src0: array<{{TYPE}}>; - -@group(0) @binding(1) -var src1: array<{{TYPE}}>; - -@group(0) @binding(2) -var params: Params; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x < params.ne) { - src0[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_index(gid.x)]; - } -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl new file mode 100644 index 0000000000..1ce4d83fa8 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/bin_op.tmpl.wgsl @@ -0,0 +1,188 @@ +#define(VARIANTS) + +[ + { + "SHADER_NAME": "add_f32", + "REPLS": { + "TYPE" : "f32", + "OP": "+" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "add_f16", + "REPLS": { + "TYPE" : "f16", + "OP": "+" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "add_f32_inplace", + "REPLS": { + "TYPE" : "f32", + "OP": "+" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "add_f16_inplace", + "REPLS": { + "TYPE" : "f16", + "OP": "+" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "mul_f32", + "REPLS": { + "TYPE" : "f32", + "OP": "*" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "mul_f16", + "REPLS": { + "TYPE" : "f16", + "OP": "*" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "mul_f32_inplace", + "REPLS": { + "TYPE" : "f32", + "OP": "*" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "mul_f16_inplace", + "REPLS": { + "TYPE" : "f16", + "OP": "*" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "sub_f32", + "REPLS": { + "TYPE" : "f32", + "OP": "-" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "sub_f16", + "REPLS": { + "TYPE" : "f16", + "OP": "-" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "sub_f32_inplace", + "REPLS": { + "TYPE" : "f32", + "OP": "-" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "sub_f16_inplace", + "REPLS": { + "TYPE" : "f16", + "OP": "-" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "div_f32", + "REPLS": { + "TYPE" : "f32", + "OP": "/" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "div_f16", + "REPLS": { + "TYPE" : "f16", + "OP": "/" + }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "div_f32_inplace", + "REPLS": { + "TYPE" : "f32", + "OP": "/" + }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "div_f16_inplace", + "REPLS": { + "TYPE" : "f16", + "OP": "/" + }, + "DECLS": ["INPLACE"] + } +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(NOT_INPLACE) + +fn update(dst_i: u32, src0_i: u32, src1_i: u32) { + dst[dst_i] = src0[src0_i] {{OP}} src1[src1_i]; +} + +@group(0) @binding(2) +var dst: array<{{TYPE}}>; + +@group(0) @binding(3) +var params: Params; + +#enddecl(NOT_INPLACE) + +#decl(INPLACE) + +fn update(dst_i: u32, src0_i: u32, src1_i: u32) { + src0[dst_i] = src0[src0_i] {{OP}} src1[src1_i]; +} + +@group(0) @binding(2) +var params: Params; + +#enddecl(INPLACE) + +#end(DECLS) + + +#define(SHADER) + +enable f16; + +#include "binary_head.tmpl" + +@group(0) @binding(0) +var src0: array<{{TYPE}}>; + +@group(0) @binding(1) +var src1: array<{{TYPE}}>; + +DECLS + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x < params.ne) { + update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x)); + } +} + +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl new file mode 100644 index 0000000000..db1aa34903 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.tmpl.wgsl @@ -0,0 +1,101 @@ +#define(VARIANTS) + +[ + { + "REPLS": { + "SRC_TYPE": "f32", + "DST_TYPE": "f32" + } + }, + { + "REPLS": { + "SRC_TYPE": "f32", + "DST_TYPE": "f16" + } + }, + { + "REPLS": { + "SRC_TYPE": "f16", + "DST_TYPE": "f16" + } + }, + { + "REPLS": { + "SRC_TYPE": "f16", + "DST_TYPE": "f32" + } + } +] + +#end(VARIANTS) + +#define(SHADER) +enable f16; + +@group(0) @binding(0) +var src: array<{{SRC_TYPE}}>; + +@group(0) @binding(1) +var dst: array<{{DST_TYPE}}>; + +struct Params { + ne: u32, // total number of elements + offset_src: u32, // in elements + offset_dst: u32, // in elements + + // Strides (in elements) — may be permuted + stride_src0: u32, + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_dst0: u32, + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // Logical shapes + src_ne0: u32, + src_ne1: u32, + src_ne2: u32, + + dst_ne0: u32, + dst_ne1: u32, + dst_ne2: u32 +}; + +@group(0) @binding(2) +var params: Params; + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne) { + return; + } + + var i = gid.x; + let i3 = i / (params.src_ne2 * params.src_ne1 * params.src_ne0); + i = i % (params.src_ne2 * params.src_ne1 * params.src_ne0); + let i2 = i / (params.src_ne1 * params.src_ne0); + i = i % (params.src_ne1 * params.src_ne0); + let i1 = i / params.src_ne0; + let i0 = i % params.src_ne0; + + var j = gid.x; + let j3 = j / (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); + j = j % (params.dst_ne2 * params.dst_ne1 * params.dst_ne0); + let j2 = j / (params.dst_ne1 * params.dst_ne0); + j = j % (params.dst_ne1 * params.dst_ne0); + let j1 = j / params.dst_ne0; + let j0 = j % params.dst_ne0; + + let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + + i2 * params.stride_src2 + i3 * params.stride_src3; + + let dst_idx = j0 * params.stride_dst0 + j1 * params.stride_dst1 + + j2 * params.stride_dst2 + j3 * params.stride_dst3; + + dst[params.offset_dst + dst_idx] = {{DST_TYPE}}((src[params.offset_src + src_idx])); +} +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl deleted file mode 100644 index 6fe924c554..0000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/cpy.wgsl +++ /dev/null @@ -1,60 +0,0 @@ -enable f16; - -@group(0) @binding(0) -var src: array; - -@group(0) @binding(1) -var dst: array; - -struct Params { - ne: u32, // total number of elements - offset_src: u32, // in elements - offset_dst: u32, // in elements - - // Strides (in elements) — may be permuted - stride_src0: u32, - stride_src1: u32, - stride_src2: u32, - stride_src3: u32, - - stride_dst0: u32, - stride_dst1: u32, - stride_dst2: u32, - stride_dst3: u32, - - // Logical shape (same for both tensors) - ne0: u32, - ne1: u32, - ne2: u32, - ne3: u32, -}; - -@group(0) @binding(2) -var params: Params; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= params.ne) { - return; - } - - var i = gid.x; - - let i3 = i / (params.ne2 * params.ne1 * params.ne0); - i = i % (params.ne2 * params.ne1 * params.ne0); - - let i2 = i / (params.ne1 * params.ne0); - i = i % (params.ne1 * params.ne0); - - let i1 = i / params.ne0; - let i0 = i % params.ne0; - - let src_idx = i0 * params.stride_src0 + i1 * params.stride_src1 + - i2 * params.stride_src2 + i3 * params.stride_src3; - - let dst_idx = i0 * params.stride_dst0 + i1 * params.stride_dst1 + - i2 * params.stride_dst2 + i3 * params.stride_dst3; - - dst[params.offset_dst + dst_idx] = f16(src[params.offset_src + src_idx]); -} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py index d9dfd7d6f4..251051eaec 100755 --- a/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py +++ b/ggml/src/ggml-webgpu/wgsl-shaders/embed_wgsl.py @@ -88,15 +88,20 @@ def generate_variants(fname, input_dir, output_dir, outfile): raise ValueError(f"DECLS key '{key}' not found.") decls_code += decls_map[key] + "\n\n" - shader_variant = replace_placeholders(shader_template, variant["REPLS"]) - final_shader = re.sub(r'\bDECLS\b', decls_code, shader_variant) + final_shader = re.sub(r'\bDECLS\b', decls_code, shader_template) + if "REPLS" in variant: + final_shader = replace_placeholders(final_shader, variant["REPLS"]) final_shader = expand_includes(final_shader, input_dir) - if "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]: + if "SHADER_NAME" in variant: + output_name = variant["SHADER_NAME"] + elif "SHADER_SUFFIX" in variant: + output_name = f"{shader_base_name}_" + variant["SHADER_SUFFIX"] + elif "REPLS" in variant and "SRC0_TYPE" in variant["REPLS"] and "SRC1_TYPE" in variant["REPLS"]: output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC0_TYPE"], variant["REPLS"]["SRC1_TYPE"]]) - elif "TYPE_SUFFIX" in variant["REPLS"]: - output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE_SUFFIX"] - elif "TYPE" in variant["REPLS"]: + elif "REPLS" in variant and "SRC_TYPE" in variant["REPLS"] and "DST_TYPE" in variant["REPLS"]: + output_name = f"{shader_base_name}_" + "_".join([variant["REPLS"]["SRC_TYPE"], variant["REPLS"]["DST_TYPE"]]) + elif "REPLS" in variant and "TYPE" in variant["REPLS"]: output_name = f"{shader_base_name}_" + variant["REPLS"]["TYPE"] else: output_name = shader_base_name diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl index e3fe311b26..f80ce1fc55 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/get_rows.tmpl.wgsl @@ -2,9 +2,9 @@ [ { + "SHADER_SUFFIX": "f32_vec", "REPLS": { "TYPE" : "vec4", - "TYPE_SUFFIX": "f32_vec", "DST_TYPE": "vec4", "BLOCK_SIZE": 4 }, diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl new file mode 100644 index 0000000000..03fcd54868 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/glu.tmpl.wgsl @@ -0,0 +1,323 @@ +#define(VARIANTS) + +[ + { + "SHADER_NAME": "reglu_f32", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["NO_SPLIT", "REGLU"] + }, + { + "SHADER_NAME": "reglu_f32_split", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["SPLIT", "REGLU"] + }, + { + "SHADER_NAME": "reglu_f16", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["NO_SPLIT", "REGLU"] + }, + { + "SHADER_NAME": "reglu_f16_split", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["SPLIT", "REGLU"] + }, + { + "SHADER_NAME": "geglu_f32", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["NO_SPLIT", "GEGLU"] + }, + { + "SHADER_NAME": "geglu_f32_split", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["SPLIT", "GEGLU"] + }, + { + "SHADER_NAME": "geglu_f16", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["NO_SPLIT", "GEGLU"] + }, + { + "SHADER_NAME": "geglu_f16_split", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["SPLIT", "GEGLU"] + }, + { + "SHADER_NAME": "swiglu_f32", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["NO_SPLIT", "SWIGLU"] + }, + { + "SHADER_NAME": "swiglu_f32_split", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["SPLIT", "SWIGLU"] + }, + { + "SHADER_NAME": "swiglu_f16", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["NO_SPLIT", "SWIGLU"] + }, + { + "SHADER_NAME": "swiglu_f16_split", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["SPLIT", "SWIGLU"] + }, + { + "SHADER_NAME": "swiglu_oai_f32", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["NO_SPLIT", "SWIGLU_OAI"] + }, + { + "SHADER_NAME": "swiglu_oai_f32_split", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["SPLIT", "SWIGLU_OAI"] + }, + { + "SHADER_NAME": "geglu_erf_f32", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["NO_SPLIT", "GEGLU_ERF"] + }, + { + "SHADER_NAME": "geglu_erf_f32_split", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["SPLIT", "GEGLU_ERF"] + }, + { + "SHADER_NAME": "geglu_erf_f16", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["NO_SPLIT", "GEGLU_ERF"] + }, + { + "SHADER_NAME": "geglu_erf_f16_split", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["SPLIT", "GEGLU_ERF"] + }, + { + "SHADER_NAME": "geglu_quick_f32", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["NO_SPLIT", "GEGLU_QUICK"] + }, + { + "SHADER_NAME": "geglu_quick_f32_split", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["SPLIT", "GEGLU_QUICK"] + }, + { + "SHADER_NAME": "geglu_quick_f16", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["NO_SPLIT", "GEGLU_QUICK"] + }, + { + "SHADER_NAME": "geglu_quick_f16_split", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["SPLIT", "GEGLU_QUICK"] + }, +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(REGLU) +fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { + return max(a, 0) * b; +} +#enddecl(REGLU) + +#decl(GEGLU) +const SQRT_2_OVER_PI: {{TYPE}} = 0.79788456080286535587989211986876; +const GELU_COEF_A: {{TYPE}} = 0.044715; + +fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { + let val = SQRT_2_OVER_PI * a * (1.0 + GELU_COEF_A * a * a); + return 0.5 * a * (2.0 - 2.0 / (exp(2 * val) + 1)) * b; +} +#enddecl(GEGLU) + +#decl(SWIGLU) +fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { + return a / (1.0 + exp(-a)) * b; +} +#enddecl(SWIGLU) + +#decl(SWIGLU_OAI) +fn op(a: f32, b: f32) -> f32 { + let xi = min(a, params.limit); + let gi = max(min(b, params.limit), -params.limit); + var out_glu = xi / (1.0 + exp(-xi * params.alpha)); + out_glu = out_glu * (1.0 + gi); + return out_glu; +} +#enddecl(SWIGLU_OAI) + +#decl(GEGLU_ERF) +const p_erf: {{TYPE}} = 0.3275911; +const a1_erf: {{TYPE}} = 0.254829592; +const a2_erf: {{TYPE}} = -0.284496736; +const a3_erf: {{TYPE}} = 1.421413741; +const a4_erf: {{TYPE}} = -1.453152027; +const a5_erf: {{TYPE}} = 1.061405429; +const SQRT_2_INV: {{TYPE}} = 0.7071067811865476; + +fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { + let a_div_sqr2 = a * SQRT_2_INV; + let sign_x = sign(a_div_sqr2); + let x = abs(a_div_sqr2); + let t = 1.0 / (1.0 + p_erf * x); + let y = 1.0 - (((((a5_erf * t + a4_erf) * t + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x)); + let erf_approx = sign_x * y; + return 0.5 * a * (1.0 + erf_approx) * b; +} +#enddecl(GEGLU_ERF) + +#decl(GEGLU_QUICK) +const GELU_QUICK_COEF: {{TYPE}} = -1.702; + +fn op(a: {{TYPE}}, b: {{TYPE}}) -> {{TYPE}} { + return a * (1.0 / (1.0 + exp(GELU_QUICK_COEF * a))) * b; +} +#enddecl(GEGLU_QUICK) + +#decl(NO_SPLIT) +@group(0) @binding(1) +var dst: array<{{TYPE}}>; + +@group(0) @binding(2) +var params: Params; + +fn a_value(base: u32) -> {{TYPE}} { + let offset: u32 = select(0, params.ne0, params.swapped != 0); + return src0[base + offset]; +} + +fn b_value(base: u32) -> {{TYPE}} { + let offset: u32 = select(params.ne0, 0, params.swapped != 0); + return src0[base + offset]; +} +#enddecl(NO_SPLIT) + +#decl(SPLIT) +@group(0) @binding(1) +var src1: array<{{TYPE}}>; + +@group(0) @binding(2) +var dst: array<{{TYPE}}>; + +@group(0) @binding(3) +var params: Params; + +fn a_value(base: u32) -> {{TYPE}} { + return src0[base]; +} + +fn b_value(base: u32) -> {{TYPE}} { + return src1[base]; +} +#enddecl(SPLIT) + +#end(DECLS) + +#define(SHADER) + +enable f16; + +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_dst: u32, + + // Strides (in elements) + stride_src01: u32, + stride_src02: u32, + stride_src03: u32, + + stride_src11: u32, + stride_src12: u32, + stride_src13: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + // shape of dst + ne: u32, + ne0: u32, + ne1: u32, + ne2: u32, + + swapped: u32, + alpha: f32, + limit: f32, +} + +@group(0) @binding(0) +var src0: array<{{TYPE}}>; + +DECLS + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne) { + return; + } + + var i = gid.x; + let i3 = i / (params.ne2 * params.ne1 * params.ne0); + i = i % (params.ne2 * params.ne1 * params.ne0); + let i2 = i / (params.ne1 * params.ne0); + i = i % (params.ne1 * params.ne0); + let i1 = i / params.ne0; + let i0 = i % params.ne0; + + let i_a = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01 + i0; + let i_b = params.offset_src1 + i3 * params.stride_src13 + i2 * params.stride_src12 + i1 * params.stride_src11 + i0; + let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0; + + dst[i_dst] = op(a_value(i_a), b_value(i_b)); +} + +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl deleted file mode 100644 index 12506e1420..0000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul.tmpl.wgsl +++ /dev/null @@ -1,44 +0,0 @@ -#define(VARIANTS) - -[ - { - "REPLS": { - "TYPE" : "f32", - } - }, - { - "REPLS": { - "TYPE" : "f16", - } - } -] - -#end(VARIANTS) - -#define(SHADER) - -enable f16; - -#include "binary_head.tmpl" - -@group(0) @binding(0) -var src0: array<{{TYPE}}>; - -@group(0) @binding(1) -var src1: array<{{TYPE}}>; - -@group(0) @binding(2) -var dst: array<{{TYPE}}>; - -@group(0) @binding(3) -var params: Params; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x < params.ne) { - dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] * src1[params.offset_src1 + src1_index(gid.x)]; - } -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl deleted file mode 100644 index e467e59edb..0000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/mul_in_place.tmpl.wgsl +++ /dev/null @@ -1,41 +0,0 @@ -#define(VARIANTS) - -[ - { - "REPLS": { - "TYPE" : "f32", - } - }, - { - "REPLS": { - "TYPE" : "f16", - } - } -] - -#end(VARIANTS) - -#define(SHADER) - -enable f16; - -#include "binary_head.tmpl" - -@group(0) @binding(0) -var src0: array<{{TYPE}}>; - -@group(0) @binding(1) -var src1: array<{{TYPE}}>; - -@group(0) @binding(2) -var params: Params; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x < params.ne) { - src0[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] * src1[params.offset_src1 + src1_index(gid.x)]; - } -} - -#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl index f919a51336..a275eeb978 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm.wgsl @@ -1,9 +1,48 @@ -@group(0) @binding(0) -var src: array; +#define(VARIANTS) + +[ + { + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_SUFFIX": "inplace", + "DECLS": ["INPLACE"] + }, +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(NOT_INPLACE) + +fn update(src_offset: u32, dst_offset: u32, scale: f32) { + dst[dst_offset] = scale * src[src_offset]; +} @group(0) @binding(1) var dst: array; +@group(0) @binding(2) +var params: Params; + +#enddecl(NOT_INPLACE) + +#decl(INPLACE) + +fn update(src_offset: u32, dst_offset: u32, scale: f32) { + src[dst_offset] = scale * src[src_offset]; +} + +@group(0) @binding(1) +var params: Params; + +#enddecl(INPLACE) + +#end(DECLS) + +#define(SHADER) + struct Params { offset_src: u32, // in elements offset_dst: u32, // in elements @@ -23,11 +62,13 @@ struct Params { ne2: u32, ne3: u32, - eps: u32 + eps: f32 }; -@group(0) @binding(2) -var params: Params; +@group(0) @binding(0) +var src: array; + +DECLS override wg_size: u32; @compute @workgroup_size(wg_size) @@ -49,9 +90,9 @@ fn main(@builtin(global_invocation_id) gid: vec3) { for (var j: u32 = 0; j < params.ne0; j++) { sum += src[i_src_row + j] * src[i_src_row + j]; } - let eps = bitcast(params.eps); - let scale = 1.0/sqrt(sum/f32(params.ne0) + eps); + let scale = 1.0/sqrt(sum/f32(params.ne0) + params.eps); for (var j: u32 = 0; j < params.ne0; j++) { - dst[i_dst_row + j] = scale * src[i_src_row + j]; + update(i_src_row + j, i_dst_row + j, scale); } } +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl deleted file mode 100644 index ae84f556d6..0000000000 --- a/ggml/src/ggml-webgpu/wgsl-shaders/rms_norm_in_place.wgsl +++ /dev/null @@ -1,48 +0,0 @@ -@group(0) @binding(0) -var a: array; - -struct Params { - offset: u32, // in elements - - // Strides (in elements) - stride1: u32, - stride2: u32, - stride3: u32, - - // Shape - ne0: u32, - ne1: u32, - ne2: u32, - ne3: u32, - - eps: u32 -}; - -@group(0) @binding(1) -var params: Params; - -override wg_size: u32; -@compute @workgroup_size(wg_size) -fn main(@builtin(global_invocation_id) gid: vec3) { - if (gid.x >= params.ne1 * params.ne2 * params.ne3) { - return; - } - - // one thread per row - var i = gid.x; - let i3 = i / (params.ne2 * params.ne1); - i = i % (params.ne2 * params.ne1); - let i2 = i / params.ne1; - let i1 = i % params.ne1; - let i_row = params.offset + i3 * params.stride3 + i2 * params.stride2 + i1 * params.stride1; - - var sum = 0.0f; - for (var j: u32 = 0; j < params.ne0; j++) { - sum += a[i_row + j] * a[i_row + j]; - } - let eps = bitcast(params.eps); - let scale = 1.0/sqrt(sum/f32(params.ne0) + eps); - for (var j: u32 = 0; j < params.ne0; j++) { - a[i_row + j] = scale * a[i_row + j]; - } -} diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl new file mode 100644 index 0000000000..9a6ff41128 --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/rope.tmpl.wgsl @@ -0,0 +1,282 @@ +#define(VARIANTS) + +[ + { + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"] + }, + { + "SHADER_SUFFIX": "f32_inplace", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"] + }, + { + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["NO_FF_BINDINGS", "NO_FF_FUNC", "ROTATE"] + }, + { + "SHADER_SUFFIX": "f16_inplace", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["NO_FF_BINDINGS_INPLACE", "NO_FF_FUNC", "ROTATE_INPLACE"] + }, + { + "SHADER_SUFFIX": "f32_ff", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"] + }, + { + "SHADER_SUFFIX": "f32_ff_inplace", + "REPLS": { + "TYPE" : "f32", + }, + "DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"] + }, + { + "SHADER_SUFFIX": "f16_ff", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["FF_BINDINGS", "FF_FUNC", "ROTATE"] + }, + { + "SHADER_SUFFIX": "f16_ff_inplace", + "REPLS": { + "TYPE" : "f16", + }, + "DECLS": ["FF_BINDINGS_INPLACE", "FF_FUNC", "ROTATE_INPLACE"] + } +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(ROTATE) +fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) { + dst[i_dst0] = {{TYPE}}(out0); + dst[i_dst1] = {{TYPE}}(out1); +} +#enddecl(ROTATE) + +#decl(ROTATE_INPLACE) +fn rotate(i_dst0: u32, i_dst1: u32, out0: f32, out1: f32) { + src0[i_dst0] = {{TYPE}}(out0); + src0[i_dst1] = {{TYPE}}(out1); +} +#enddecl(ROTATE_INPLACE) + +#decl(NO_FF_FUNC) +fn freq_factor(i: u32) -> f32 { + return 1.0f; +} +#enddecl(NO_FF_FUNC) + +#decl(FF_FUNC) +fn freq_factor(i: u32) -> f32 { + return src2[params.offset_src2 + i/2]; +} +#enddecl(FF_FUNC) + +#decl(NO_FF_BINDINGS) + +@group(0) @binding(2) +var dst: array<{{TYPE}}>; + +@group(0) @binding(3) +var params: Params; + +#enddecl(NO_FF_BINDINGS) + +#decl(NO_FF_BINDINGS_INPLACE) + +@group(0) @binding(2) +var params: Params; + +#enddecl(NO_FF_BINDINGS_INPLACE) + +#decl(FF_BINDINGS) + +@group(0) @binding(2) +var src2: array; + +@group(0) @binding(3) +var dst: array<{{TYPE}}>; + +@group(0) @binding(4) +var params: Params; + +#enddecl(FF_BINDINGS) + +#decl(FF_BINDINGS_INPLACE) + +@group(0) @binding(2) +var src2: array; + +@group(0) @binding(3) +var params: Params; + +#enddecl(FF_BINDINGS_INPLACE) + +#end(DECLS) + +#define(SHADER) + +enable f16; + +struct Params { + offset_src0: u32, + offset_src1: u32, + offset_src2: u32, + offset_dst: u32, + + // Strides (in elements) + stride_src01: u32, + stride_src02: u32, + stride_src03: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + n_threads: u32, + ne0: u32, + ne1: u32, + ne2: u32, + + n_dims: u32, + mode: u32, + theta_scale: f32, + attn_factor: f32, + freq_scale: f32, + ext_factor: f32, + corr_dim0: f32, + corr_dim1: f32, + sections0: u32, + sections1: u32, + sections2: u32, + sections3: u32 +}; + +@group(0) @binding(0) +var src0: array<{{TYPE}}>; + +@group(0) @binding(1) +var src1: array; + +DECLS + +fn rope_yarn_ramp(low: f32, high: f32, i: u32) -> f32 { + let y = (f32(i / 2) - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +// returns vector of (cos_theta, sin_theta) +// TODO: check performance of instantiating once on the CPU and passed as buffer, since it's repeated per-row +fn rope_yarn(theta_extrap: f32, i: u32) -> vec2 { + var mscale = params.attn_factor; + var theta = params.freq_scale * theta_extrap; + if (params.ext_factor != 0.0f) { + let ramp_mix = rope_yarn_ramp(params.corr_dim0, params.corr_dim1, i) * params.ext_factor; + theta = theta * (1 - ramp_mix) + theta_extrap * ramp_mix; + mscale *= 1.0f + 0.1f * log(1.0f / params.freq_scale); + } + return vec2(cos(theta) * mscale, sin(theta) * mscale); +} + +fn pair_base(i0: u32, div_2: bool) -> u32 { + if (div_2) { + return i0 / 2; + } else { + return i0; + } +} + +fn pair_offset(is_neox: bool, is_mrope: bool, is_vision: bool) -> u32 { + if (is_vision) { + return params.n_dims; + } else if (is_neox || is_mrope) { + return params.n_dims / 2; + } else { + return 1; + } +} + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + // two elements per thread + if (gid.x >= params.n_threads) { + return; + } + + let is_neox = bool(params.mode & 2); + let is_mrope = bool(params.mode & 8); + let is_vision = params.mode == 24; + + var i = gid.x * 2; // start index for this thread + let i3 = i / (params.ne2 * params.ne1 * params.ne0); + i = i % (params.ne2 * params.ne1 * params.ne0); + let i2 = i / (params.ne1 * params.ne0); + i = i % (params.ne1 * params.ne0); + let i1 = i / params.ne0; + let i0 = i % params.ne0; + + let i_src_row = params.offset_src0 + i3 * params.stride_src03 + i2 * params.stride_src02 + i1 * params.stride_src01; + let i_dst_row = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1; + + if (i0 >= params.n_dims && !is_vision) { + let i_src = i_src_row + i0; + let i_dst = i_dst_row + i0; + rotate(i_dst, i_dst + 1, f32(src0[i_src]), f32(src0[i_src + 1])); + return; + } + + var theta_base_mult: u32 = 0; + var theta_scale_pwr: u32 = i0 / 2; + if (is_mrope) { + let sect_dims = params.sections0 + params.sections1 + params.sections2 + params.sections3; + let sec_w = params.sections1 + params.sections0; + let sec_e = params.sections2 + sec_w; + let sector = (i0 / 2) % sect_dims; + if (sector >= params.sections0 && sector < sec_w) { + theta_base_mult = 1; + if (is_vision) { + theta_scale_pwr = sector - params.sections0; + } + } else if (sector >= sec_w && sector < sec_e) { + theta_base_mult = 2; + if (is_vision) { + theta_scale_pwr = sector - sec_w; + } + } else if (sector >= sec_e) { + if (is_vision) { + theta_scale_pwr = sector - sec_e; + theta_scale_pwr = (i0 / 2) % sec_e; + } + theta_base_mult = 3; + } else if (is_vision) { + theta_scale_pwr = sector; + } + } + let theta_base = f32(src1[params.offset_src1 + i2 + params.ne2 * theta_base_mult]) * pow(params.theta_scale, f32(theta_scale_pwr)); + let thetas = rope_yarn(theta_base/freq_factor(i0), i0); + + let i_src = i_src_row + pair_base(i0, is_neox || is_mrope || is_vision); + let i_dst = i_dst_row + pair_base(i0, is_neox || is_mrope || is_vision); + + let x0 = f32(src0[i_src]); + let x1 = f32(src0[i_src + pair_offset(is_neox, is_mrope, is_vision)]); + rotate(i_dst, i_dst + pair_offset(is_neox, is_mrope, is_vision), x0 * thetas.x - x1 * thetas.y, x0 * thetas.y + x1 * thetas.x); +} + +#end(SHADER) diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl new file mode 100644 index 0000000000..040e80dfea --- /dev/null +++ b/ggml/src/ggml-webgpu/wgsl-shaders/scale.tmpl.wgsl @@ -0,0 +1,90 @@ +#define(VARIANTS) + +[ + { + "SHADER_NAME": "scale_f32", + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "scale_f32_inplace", + "DECLS": ["INPLACE"] + } +] + +#end(VARIANTS) + +#define(DECLS) + +#decl(NOT_INPLACE) +@group(0) @binding(1) +var dst: array; + +@group(0) @binding(2) +var params: Params; + +fn store_scale(val: f32, offset: u32) { + dst[offset] = val; +} +#enddecl(NOT_INPLACE) + +#decl(INPLACE) +@group(0) @binding(1) +var params: Params; + +fn store_scale(val: f32, offset: u32) { + src[offset] = val; +} +#enddecl(INPLACE) + +#end(DECLS) + +#define(SHADER) + +struct Params { + offset_src: u32, + offset_dst: u32, + + // Strides (in elements) + stride_src1: u32, + stride_src2: u32, + stride_src3: u32, + + stride_dst1: u32, + stride_dst2: u32, + stride_dst3: u32, + + ne: u32, + ne0: u32, + ne1: u32, + ne2: u32, + + scale: f32, + bias: f32 +}; + +@group(0) @binding(0) +var src: array; + +DECLS + +override wg_size: u32; +@compute @workgroup_size(wg_size) +fn main(@builtin(global_invocation_id) gid: vec3) { + if (gid.x >= params.ne) { + return; + } + + var i = gid.x; + let i3 = i / (params.ne2 * params.ne1 * params.ne0); + i = i % (params.ne2 * params.ne1 * params.ne0); + let i2 = i / (params.ne1 * params.ne0); + i = i % (params.ne1 * params.ne0); + let i1 = i / params.ne0; + let i0 = i % params.ne0; + + let i_src = params.offset_src + i3 * params.stride_src3 + i2 * params.stride_src2 + i1 * params.stride_src1 + i0; + let i_dst = params.offset_dst + i3 * params.stride_dst3 + i2 * params.stride_dst2 + i1 * params.stride_dst1 + i0; + + store_scale(src[i_src] * params.scale + params.bias, i_dst); +} +#end(SHADER) diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index a5796214f2..aecbdad5a3 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -3687,6 +3687,7 @@ struct ggml_tensor * ggml_set_rows( result->op = GGML_OP_SET_ROWS; result->src[0] = b; result->src[1] = c; + result->src[2] = a; // note: order is weird due to legacy reasons (https://github.com/ggml-org/llama.cpp/pull/16063#discussion_r2385795931) return result; } diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 733d30cfa2..5e09de499e 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -978f6e1993f2eeb4e99b63d4e70b4401c0a2dae2 +72632094336524a9c809e129e8b1c52154543a5a diff --git a/src/llama-model.cpp b/src/llama-model.cpp index a96ab347c9..2779ba1f4a 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -4879,11 +4879,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); - layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags); layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); - layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags); - layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags); + + // Optional tensors + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); } } } @@ -11927,6 +11929,7 @@ struct llm_graph_context_mamba : public llm_graph_context { // TODO: skip computing output earlier for unused tokens y = ggml_add(ctx0, y, ggml_mul(ctx0, x, model.layers[il].ssm_d)); + cb(y, "mamba2_y_add_d", il); y = ggml_swiglu_split(ctx0, ggml_cont(ctx0, z), y); // grouped RMS norm @@ -14881,6 +14884,7 @@ struct llm_build_nemotron_h : public llm_graph_context_mamba { ggml_tensor * inpL; inpL = build_inp_embd(model.tok_embd); + ggml_build_forward_expand(gf, inpL); auto * inp = build_inp_mem_hybrid(); @@ -14912,7 +14916,7 @@ struct llm_build_nemotron_h : public llm_graph_context_mamba { // add residual cur = ggml_add(ctx0, cur, inpSA); - cb(cur, "block_out", il); + cb(cur, "nemotron_h_block_out", il); // input for next layer inpL = cur; diff --git a/tests/test-arg-parser.cpp b/tests/test-arg-parser.cpp index e2836ca481..a60ca12fe5 100644 --- a/tests/test-arg-parser.cpp +++ b/tests/test-arg-parser.cpp @@ -126,52 +126,35 @@ int main(void) { assert(params.cpuparams.n_threads == 1010); #endif // _WIN32 - if (common_has_curl()) { - printf("test-arg-parser: test curl-related functions\n\n"); - const char * GOOD_URL = "https://ggml.ai/"; - const char * BAD_URL = "https://www.google.com/404"; - const char * BIG_FILE = "https://huggingface.co/ggerganov/whisper.cpp/resolve/main/ggml-large-v1.bin"; + printf("test-arg-parser: test curl-related functions\n\n"); + const char * GOOD_URL = "http://ggml.ai/"; + const char * BAD_URL = "http://ggml.ai/404"; - { - printf("test-arg-parser: test good URL\n\n"); - auto res = common_remote_get_content(GOOD_URL, {}); - assert(res.first == 200); - assert(res.second.size() > 0); - std::string str(res.second.data(), res.second.size()); - assert(str.find("llama.cpp") != std::string::npos); - } + { + printf("test-arg-parser: test good URL\n\n"); + auto res = common_remote_get_content(GOOD_URL, {}); + assert(res.first == 200); + assert(res.second.size() > 0); + std::string str(res.second.data(), res.second.size()); + assert(str.find("llama.cpp") != std::string::npos); + } - { - printf("test-arg-parser: test bad URL\n\n"); - auto res = common_remote_get_content(BAD_URL, {}); - assert(res.first == 404); - } + { + printf("test-arg-parser: test bad URL\n\n"); + auto res = common_remote_get_content(BAD_URL, {}); + assert(res.first == 404); + } - { - printf("test-arg-parser: test max size error\n"); - common_remote_params params; - params.max_size = 1; - try { - common_remote_get_content(GOOD_URL, params); - assert(false && "it should throw an error"); - } catch (std::exception & e) { - printf(" expected error: %s\n\n", e.what()); - } + { + printf("test-arg-parser: test max size error\n"); + common_remote_params params; + params.max_size = 1; + try { + common_remote_get_content(GOOD_URL, params); + assert(false && "it should throw an error"); + } catch (std::exception & e) { + printf(" expected error: %s\n\n", e.what()); } - - { - printf("test-arg-parser: test timeout error\n"); - common_remote_params params; - params.timeout = 1; - try { - common_remote_get_content(BIG_FILE, params); - assert(false && "it should throw an error"); - } catch (std::exception & e) { - printf(" expected error: %s\n\n", e.what()); - } - } - } else { - printf("test-arg-parser: no curl, skipping curl-related functions\n"); } printf("test-arg-parser: all tests OK\n\n"); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 1b70d94052..62d815cc26 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -2140,6 +2140,27 @@ struct test_set_rows : public test_case { } } } + + double max_nmse_err() override { + if (type == GGML_TYPE_Q4_0 || type == GGML_TYPE_Q4_1 || type == GGML_TYPE_IQ4_NL || + type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1 || type == GGML_TYPE_Q8_0) { + // estimate what the max nmse error would be if one quantized value is + // off by one. The test values are distributed in [-1,1], so it'll be + // roughly (2.0 / 2^bits)^2, divided by the mean square value of the reference, + // which is roughly 0.25 times the number of elements. + double err_estimate = 1.0f/8.0f; + if (type == GGML_TYPE_Q5_0 || type == GGML_TYPE_Q5_1) { + err_estimate /= 2.0f; + } + if (type == GGML_TYPE_Q8_0) { + err_estimate /= 8.0f; + } + err_estimate *= err_estimate; + err_estimate /= 0.25f*float(ne[0] * r * ne[2]*nr23[0] * ne[3]*nr23[1]); + return err_estimate; + } + return 1e-7; + } }; // GGML_OP_ARGMAX @@ -2430,6 +2451,30 @@ struct test_cpy : public test_case { } double max_nmse_err() override { + if (type_src == type_dst) { + return 0.0; + } + if (type_dst == GGML_TYPE_Q4_0 || type_dst == GGML_TYPE_Q4_1 || type_dst == GGML_TYPE_IQ4_NL || + type_dst == GGML_TYPE_Q5_0 || type_dst == GGML_TYPE_Q5_1 || type_dst == GGML_TYPE_Q8_0) { + // estimate what the max nmse error would be if one quantized value is + // off by one. The test values are distributed in [-150,150], so it'll be + // roughly (150*2.0 / 2^bits)^2, divided by the mean square value of the reference, + // which is roughly 0.25*150^2 times the number of elements. + double err_estimate = 1.0f/8.0f * 150.0f; + if (type_dst == GGML_TYPE_IQ4_NL) { + // iq4_nl values are a bit more spread out + err_estimate *= 2.0f; + } + if (type_dst == GGML_TYPE_Q5_0 || type_dst == GGML_TYPE_Q5_1) { + err_estimate /= 2.0f; + } + if (type_dst == GGML_TYPE_Q8_0) { + err_estimate /= 8.0f; + } + err_estimate *= err_estimate; + err_estimate /= (150.0f*150.0f*0.25f)*float(ne[0] * ne[1] * ne[2] * ne[3]); + return err_estimate; + } return 1e-6; } @@ -2688,23 +2733,30 @@ struct test_scale : public test_case { const std::array ne; float scale; float bias; + bool inplace; std::string vars() override { - return VARS_TO_STR4(type, ne, scale, bias); + return VARS_TO_STR5(type, ne, scale, bias, inplace); } test_scale(ggml_type type = GGML_TYPE_F32, std::array ne = {10, 10, 10, 10}, float scale = 2.0f, - float bias = 0.0f) - : type(type), ne(ne), scale(scale), bias(bias) {} + float bias = 0.0f, + bool inplace = false) + : type(type), ne(ne), scale(scale), bias(bias), inplace(inplace) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_set_param(a); ggml_set_name(a, "a"); - ggml_tensor * out = ggml_scale_bias(ctx, a, scale, bias); + ggml_tensor * out; + if (inplace) { + out = ggml_scale_bias_inplace(ctx, a, scale, bias); + } else { + out = ggml_scale_bias(ctx, a, scale, bias); + } ggml_set_name(out, "out"); return out; @@ -2861,16 +2913,18 @@ struct test_rms_norm : public test_case { const std::array ne; const bool v; // whether a is a non-contiguous view const float eps; + const bool inplace; // whether to do the operation inplace std::string vars() override { - return VARS_TO_STR4(type, ne, v, eps); + return VARS_TO_STR5(type, ne, v, eps, inplace); } test_rms_norm(ggml_type type = GGML_TYPE_F32, std::array ne = {64, 5, 4, 3}, bool v = false, - float eps = 1e-6f) - : type(type), ne(ne), v(v), eps(eps) {} + float eps = 1e-6f, + bool inplace = false) + : type(type), ne(ne), v(v), eps(eps), inplace(inplace) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); @@ -2882,7 +2936,12 @@ struct test_rms_norm : public test_case { ggml_set_name(a, "view of a"); } - ggml_tensor * out = ggml_rms_norm(ctx, a, eps); + ggml_tensor * out; + if (inplace) { + out = ggml_rms_norm_inplace(ctx, a, eps); + } else { + out = ggml_rms_norm(ctx, a, eps); + } ggml_set_name(out, "out"); return out; @@ -3787,17 +3846,18 @@ struct test_rope : public test_case { bool ff; int v; // view (1 : non-contiguous a) bool forward; + bool inplace; std::string vars() override { // forward can be inferred from the op, does not need to be printed - return VARS_TO_STR10(type, ne_a, n_dims, mode, n_ctx, fs, ef, af, ff, v); + return VARS_TO_STR11(type, ne_a, n_dims, mode, n_ctx, fs, ef, af, ff, v, inplace); } test_rope(ggml_type type = GGML_TYPE_F32, std::array ne_a = {10, 5, 3, 1}, - int n_dims = 10, int mode = 0, int n_ctx = 512, float fs = 1.0f, - float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0, bool forward = true) - : type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v), forward(forward) {} + int n_dims = 10, int mode = GGML_ROPE_TYPE_NORMAL, int n_ctx = 512, float fs = 1.0f, + float ef = 0.0f, float af = 0.0f, bool ff = false, int v = 0, bool forward = true, bool inplace = false) + : type(type), ne_a(ne_a), n_dims(n_dims), mode(mode), n_ctx(n_ctx), fs(fs), ef(ef), af(af), ff(ff), v(v), forward(forward), inplace(inplace) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a; @@ -3842,7 +3902,11 @@ struct test_rope : public test_case { GGML_ASSERT(n_dims/4 > 0); int rope_sections[4] = {n_dims/4, n_dims/4, 0, 0}; // Vision-RoPE only use first two dimension for image (x, y) coordinate if (forward) { - out = ggml_rope_multi (ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + if (inplace) { + out = ggml_rope_multi_inplace(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } else { + out = ggml_rope_multi(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } } else { out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims/2, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); } @@ -3850,14 +3914,22 @@ struct test_rope : public test_case { GGML_ASSERT(n_dims/3 > 0); int rope_sections[4] = {n_dims/3, n_dims/3, n_dims/3, 0}; if (forward) { - out = ggml_rope_multi (ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + if (inplace) { + out = ggml_rope_multi_inplace(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } else { + out = ggml_rope_multi(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } } else { out = ggml_rope_multi_back(ctx, a, pos, freq, n_dims, rope_sections, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); } } } else { if (forward) { - out = ggml_rope_ext (ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + if (inplace) { + out = ggml_rope_ext_inplace(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } else { + out = ggml_rope_ext(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); + } } else { out = ggml_rope_ext_back(ctx, a, pos, freq, n_dims, mode, 0, 10000.0f, fs, ef, af, 1.0f, 1.0f); } @@ -6138,9 +6210,11 @@ static std::vector> make_test_cases_eval() { //add_test_bin_bcast(type, {3, 3, 2560, 1280}, {2, 1, 1, 1}); } - // single in-place tests, especially important for WebGPU backend since kernels for in-place vs. not are different + // single inplace tests, especially important for WebGPU backend since kernels for inplace vs. not are different test_cases.emplace_back(new test_bin_bcast(ggml_add_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16)); test_cases.emplace_back(new test_bin_bcast(ggml_mul_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16)); + test_cases.emplace_back(new test_bin_bcast(ggml_sub_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16)); + test_cases.emplace_back(new test_bin_bcast(ggml_div_inplace, GGML_TYPE_F32, {16, 5, 4, 3}, {1, 1, 1, 1}, 16)); // fusion test_cases.emplace_back(new test_bin_bcast(ggml_add, GGML_TYPE_F32, {10, 5, 4, 3}, {2, 1, 1, 1}, 2)); @@ -6155,6 +6229,8 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_add1()); test_cases.emplace_back(new test_scale()); test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f)); + test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {10, 10, 10, 10}, 2.0f, 1.0f, true)); // inplace test + test_cases.emplace_back(new test_scale(GGML_TYPE_F32, {100, 10, 10, 10}, 2.0f, 1.0f)); test_cases.emplace_back(new test_softcap(GGML_TYPE_F32, {10, 10, 10, 10}, 50.0f)); test_cases.emplace_back(new test_silu_back()); @@ -6166,6 +6242,10 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, {64, 5, 4, 3}, eps)); test_cases.emplace_back(new test_l2_norm (GGML_TYPE_F32, {64, 5, 4, 3}, eps)); } + + // in-place tests + test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, false, 1e-6f, true)); + for (float eps : {0.0f, 1e-6f, 1e-4f, 1e-1f, 1.0f}) { test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false)); test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true)); @@ -6513,26 +6593,26 @@ static std::vector> make_test_cases_eval() { for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { for (bool ff : {false, true}) { // freq_factors for (float v : { 0, 1 }) { - test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 7B + test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 7B if (all) { - test_cases.emplace_back(new test_rope(type, {128, 40, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 13B - test_cases.emplace_back(new test_rope(type, {128, 52, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 30B - test_cases.emplace_back(new test_rope(type, {128, 64, 2, 1}, 128, 0, 512, fs, ef, af, ff, v, fw)); // llama 65B + test_cases.emplace_back(new test_rope(type, {128, 40, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 13B + test_cases.emplace_back(new test_rope(type, {128, 52, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 30B + test_cases.emplace_back(new test_rope(type, {128, 64, 2, 1}, 128, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); // llama 65B } if (all) { - test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B) - test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B) - test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B) + test_cases.emplace_back(new test_rope(type, { 64, 1, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B) + test_cases.emplace_back(new test_rope(type, { 64, 71, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 7B) + test_cases.emplace_back(new test_rope(type, { 64, 8, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B) - test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 0, 512, fs, ef, af, ff, v, fw)); - test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 0, 512, fs, ef, af, ff, v, fw)); - test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 0, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); + test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, GGML_ROPE_TYPE_NORMAL, 512, fs, ef, af, ff, v, fw)); - test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, 2, 512, fs, ef, af, ff, v, fw)); // neox (stablelm) - test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) - test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, 2, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 20, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (stablelm) + test_cases.emplace_back(new test_rope(type, { 80, 32, 2, 1}, 32, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) + test_cases.emplace_back(new test_rope(type, { 80, 32, 4, 1}, 32, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (phi-2) } if (all) { @@ -6543,7 +6623,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_rope(type, { 80, 16, 2, 1}, 80, GGML_ROPE_TYPE_VISION, 512, fs, ef, af, ff, v, fw)); // rope_multi,m-rope (qwen2vl ViT) } - test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, 2, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B) + test_cases.emplace_back(new test_rope(type, { 64, 128, 2, 1}, 64, GGML_ROPE_TYPE_NEOX, 512, fs, ef, af, ff, v, fw)); // neox (falcon 40B) } } @@ -6554,6 +6634,15 @@ static std::vector> make_test_cases_eval() { } } + // single inplace test per type/mode/ff + for (ggml_type type : {GGML_TYPE_F32, GGML_TYPE_F16}) { + for (int mode : {GGML_ROPE_TYPE_NORMAL, GGML_ROPE_TYPE_NEOX, GGML_ROPE_TYPE_MROPE, GGML_ROPE_TYPE_VISION}) { + for (bool ff : {false, true}) { + test_cases.emplace_back(new test_rope(type, {128, 32, 2, 1}, 128, mode, 512, 1.4245f, 0.7465f, 1.4245f, ff, 0, true, true)); + } + } + } + for (int v : { 0, 1, 2, 3 }) { for (int dim : { 0, 1, 2, 3, }) { test_cases.emplace_back(new test_concat(GGML_TYPE_F32, {11, 12, 13, 14}, 7, dim, v)); @@ -6566,6 +6655,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16, 10, 10, 10}, order)); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {60, 10, 10, 10}, order)); // qwen test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {1024, 1, 1, 1}, order)); + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {16384, 1, 1, 1}, order)); // bailingmoe2 (group selection) } for (ggml_scale_mode mode : {GGML_SCALE_MODE_NEAREST, GGML_SCALE_MODE_BILINEAR}) { diff --git a/tools/main/main.cpp b/tools/main/main.cpp index 083fc0cf26..498e00e3a5 100644 --- a/tools/main/main.cpp +++ b/tools/main/main.cpp @@ -707,6 +707,10 @@ int main(int argc, char ** argv) { embd.push_back(id); + if (params.conversation_mode && !waiting_for_first_input && !llama_vocab_is_eog(vocab, id)) { + assistant_ss << common_token_to_piece(ctx, id, false); + } + // echo this to console input_echo = true; @@ -824,11 +828,7 @@ int main(int argc, char ** argv) { } } - // if current token is not EOG, we add it to current assistant message if (params.conversation_mode && !waiting_for_first_input) { - const auto id = common_sampler_last(smpl); - assistant_ss << common_token_to_piece(ctx, id, false); - if (!prompt.empty()) { prompt.clear(); is_interacting = false; diff --git a/tools/perplexity/perplexity.cpp b/tools/perplexity/perplexity.cpp index c22d187cd4..caf080e8d1 100644 --- a/tools/perplexity/perplexity.cpp +++ b/tools/perplexity/perplexity.cpp @@ -1931,11 +1931,13 @@ static void kl_divergence(llama_context * ctx, const common_params & params) { LOG("Maximum KLD: %10.6f\n", kld_values.back()); LOG("99.9%% KLD: %10.6f\n", percentile(kld_values, 0.999f)); LOG("99.0%% KLD: %10.6f\n", percentile(kld_values, 0.990f)); + LOG("95.0%% KLD: %10.6f\n", percentile(kld_values, 0.950f)); LOG("90.0%% KLD: %10.6f\n", percentile(kld_values, 0.900f)); LOG("Median KLD: %10.6f\n", kld_median); LOG("10.0%% KLD: %10.6f\n", percentile(kld_values, 0.100f)); LOG(" 5.0%% KLD: %10.6f\n", percentile(kld_values, 0.050f)); LOG(" 1.0%% KLD: %10.6f\n", percentile(kld_values, 0.010f)); + LOG(" 0.1%% KLD: %10.6f\n", percentile(kld_values, 0.001f)); LOG("Minimum KLD: %10.6f\n", kld_values.front()); LOG("\n"); diff --git a/tools/run/run.cpp b/tools/run/run.cpp index 772d66c921..b90a7253c4 100644 --- a/tools/run/run.cpp +++ b/tools/run/run.cpp @@ -9,6 +9,7 @@ #include #if defined(_WIN32) +# define WIN32_LEAN_AND_MEAN # ifndef NOMINMAX # define NOMINMAX # endif @@ -22,6 +23,8 @@ #if defined(LLAMA_USE_CURL) # include +#else +# include "http.h" #endif #include @@ -397,7 +400,6 @@ class File { # endif }; -#ifdef LLAMA_USE_CURL class HttpClient { public: int init(const std::string & url, const std::vector & headers, const std::string & output_file, @@ -428,6 +430,8 @@ class HttpClient { return 0; } +#ifdef LLAMA_USE_CURL + ~HttpClient() { if (chunk) { curl_slist_free_all(chunk); @@ -532,6 +536,117 @@ class HttpClient { return curl_easy_perform(curl); } +#else // LLAMA_USE_CURL is not defined + +#define curl_off_t long long // temporary hack + + private: + // this is a direct translation of the cURL download() above + int download(const std::string & url, const std::vector & headers_vec, const std::string & output_file, + const bool progress, std::string * response_str = nullptr) { + try { + auto [cli, url_parts] = common_http_client(url); + + httplib::Headers headers; + for (const auto & h : headers_vec) { + size_t pos = h.find(':'); + if (pos != std::string::npos) { + headers.emplace(h.substr(0, pos), h.substr(pos + 2)); + } + } + + File out; + if (!output_file.empty()) { + if (!out.open(output_file, "ab")) { + printe("Failed to open file for writing\n"); + return 1; + } + if (out.lock()) { + printe("Failed to exclusively lock file\n"); + return 1; + } + } + + size_t resume_offset = 0; + if (!output_file.empty() && std::filesystem::exists(output_file)) { + resume_offset = std::filesystem::file_size(output_file); + if (resume_offset > 0) { + headers.emplace("Range", "bytes=" + std::to_string(resume_offset) + "-"); + } + } + + progress_data data; + data.file_size = resume_offset; + + long long total_size = 0; + long long received_this_session = 0; + + auto response_handler = + [&](const httplib::Response & response) { + if (resume_offset > 0 && response.status != 206) { + printe("\nServer does not support resuming. Restarting download.\n"); + out.file = freopen(output_file.c_str(), "wb", out.file); + if (!out.file) { + return false; + } + data.file_size = 0; + } + if (progress) { + if (response.has_header("Content-Length")) { + total_size = std::stoll(response.get_header_value("Content-Length")); + } else if (response.has_header("Content-Range")) { + auto range = response.get_header_value("Content-Range"); + auto slash = range.find('/'); + if (slash != std::string::npos) { + total_size = std::stoll(range.substr(slash + 1)); + } + } + } + return true; + }; + + auto content_receiver = + [&](const char * chunk, size_t length) { + if (out.file && fwrite(chunk, 1, length, out.file) != length) { + return false; + } + if (response_str) { + response_str->append(chunk, length); + } + received_this_session += length; + + if (progress && total_size > 0) { + update_progress(&data, total_size, received_this_session, 0, 0); + } + return true; + }; + + auto res = cli.Get(url_parts.path, headers, response_handler, content_receiver); + + if (data.printed) { + printe("\n"); + } + + if (!res) { + auto err = res.error(); + printe("Fetching resource '%s' failed: %s\n", url.c_str(), httplib::to_string(err).c_str()); + return 1; + } + + if (res->status >= 400) { + printe("Fetching resource '%s' failed with status code: %d\n", url.c_str(), res->status); + return 1; + } + + } catch (const std::exception & e) { + printe("HTTP request failed: %s\n", e.what()); + return 1; + } + return 0; + } + +#endif // LLAMA_USE_CURL + static std::string human_readable_time(double seconds) { int hrs = static_cast(seconds) / 3600; int mins = (static_cast(seconds) % 3600) / 60; @@ -644,8 +759,8 @@ class HttpClient { str->append(static_cast(ptr), size * nmemb); return size * nmemb; } + }; -#endif class LlamaData { public: @@ -673,7 +788,6 @@ class LlamaData { } private: -#ifdef LLAMA_USE_CURL int download(const std::string & url, const std::string & output_file, const bool progress, const std::vector & headers = {}, std::string * response_str = nullptr) { HttpClient http; @@ -683,14 +797,6 @@ class LlamaData { return 0; } -#else - int download(const std::string &, const std::string &, const bool, const std::vector & = {}, - std::string * = nullptr) { - printe("%s: llama.cpp built without libcurl, downloading from an url not supported.\n", __func__); - - return 1; - } -#endif // Helper function to handle model tag extraction and URL construction std::pair extract_model_and_tag(std::string & model, const std::string & base_url) { diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index 15c18e2d49..4f18a634ce 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/webui/scripts/dev.sh b/tools/server/webui/scripts/dev.sh index e0e8b26e9a..2bda8f22c8 100644 --- a/tools/server/webui/scripts/dev.sh +++ b/tools/server/webui/scripts/dev.sh @@ -1,5 +1,14 @@ #!/bin/bash +# Development script for llama.cpp webui +# +# This script starts the webui development servers (Storybook and Vite). +# Note: You need to start llama-server separately. +# +# Usage: +# bash scripts/dev.sh +# npm run dev + cd ../../../ # Check and install git hooks if missing @@ -28,76 +37,19 @@ check_and_install_hooks() { # Install git hooks if needed check_and_install_hooks -# Check if llama-server binary already exists -if [ ! -f "build/bin/llama-server" ]; then - echo "Building llama-server..." - cmake -B build && cmake --build build --config Release -t llama-server -else - echo "llama-server binary already exists, skipping build." -fi - -# Start llama-server and capture output -echo "Starting llama-server..." -mkfifo server_output.pipe -build/bin/llama-server -hf ggml-org/gpt-oss-20b-GGUF --jinja -c 0 --no-webui > server_output.pipe 2>&1 & -SERVER_PID=$! - -# Function to wait for server to be ready -wait_for_server() { - echo "Waiting for llama-server to be ready..." - local max_wait=60 - local start_time=$(date +%s) - - # Read server output in background and look for the ready message - ( - while IFS= read -r line; do - echo "🔍 Server: $line" - if [[ "$line" == *"server is listening on http://127.0.0.1:8080 - starting the main loop"* ]]; then - echo "✅ llama-server is ready!" - echo "READY" > server_ready.flag - break - fi - done < server_output.pipe - ) & - - # Wait for ready flag or timeout - while [ ! -f server_ready.flag ]; do - local current_time=$(date +%s) - local elapsed=$((current_time - start_time)) - - if [ $elapsed -ge $max_wait ]; then - echo "❌ Server failed to start within $max_wait seconds" - rm -f server_ready.flag - return 1 - fi - - sleep 1 - done - - rm -f server_ready.flag - return 0 -} - # Cleanup function cleanup() { echo "🧹 Cleaning up..." - kill $SERVER_PID 2>/dev/null - rm -f server_output.pipe server_ready.flag exit } # Set up signal handlers trap cleanup SIGINT SIGTERM -# Wait for server to be ready -if wait_for_server; then - echo "🚀 Starting development servers..." - cd tools/server/webui - storybook dev -p 6006 --ci & vite dev --host 0.0.0.0 & - - # Wait for all background processes - wait -else - echo "❌ Failed to start development environment" - cleanup -fi +echo "🚀 Starting development servers..." +echo "📝 Note: Make sure to start llama-server separately if needed" +cd tools/server/webui +storybook dev -p 6006 --ci & vite dev --host 0.0.0.0 & + +# Wait for all background processes +wait diff --git a/tools/server/webui/src/app.css b/tools/server/webui/src/app.css index a9ac80ab9c..2ca1536409 100644 --- a/tools/server/webui/src/app.css +++ b/tools/server/webui/src/app.css @@ -37,8 +37,9 @@ --sidebar-accent-foreground: oklch(0.205 0 0); --sidebar-border: oklch(0.922 0 0); --sidebar-ring: oklch(0.708 0 0); - --code-background: oklch(0.225 0 0); - --code-foreground: oklch(0.875 0 0); + --code-background: oklch(0.975 0 0); + --code-foreground: oklch(0.145 0 0); + --layer-popover: 1000000; } .dark { @@ -73,6 +74,8 @@ --sidebar-accent-foreground: oklch(0.985 0 0); --sidebar-border: oklch(1 0 0 / 10%); --sidebar-ring: oklch(0.556 0 0); + --code-background: oklch(0.225 0 0); + --code-foreground: oklch(0.875 0 0); } @theme inline { diff --git a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageAssistant.svelte b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageAssistant.svelte index 013b77cbbe..ad3ffa3792 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageAssistant.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageAssistant.svelte @@ -3,12 +3,14 @@ import { useProcessingState } from '$lib/hooks/use-processing-state.svelte'; import { isLoading } from '$lib/stores/chat.svelte'; import { fade } from 'svelte/transition'; - import { Check, X } from '@lucide/svelte'; + import { Check, Copy, Package, X } from '@lucide/svelte'; import { Button } from '$lib/components/ui/button'; import { Checkbox } from '$lib/components/ui/checkbox'; import { INPUT_CLASSES } from '$lib/constants/input-classes'; import ChatMessageActions from './ChatMessageActions.svelte'; import Label from '$lib/components/ui/label/label.svelte'; + import { config } from '$lib/stores/settings.svelte'; + import { copyToClipboard } from '$lib/utils/copy'; interface Props { class?: string; @@ -136,6 +138,23 @@ {/if} + {#if config().showModelInfo && message.model} + + + + Model used: + + + + {/if} + {#if message.timestamp && !isEditing}
@@ -441,7 +447,7 @@
- +