diff --git a/.github/ISSUE_TEMPLATE/010-bug-compilation.yml b/.github/ISSUE_TEMPLATE/010-bug-compilation.yml index c106f47a25..77f23f1afa 100644 --- a/.github/ISSUE_TEMPLATE/010-bug-compilation.yml +++ b/.github/ISSUE_TEMPLATE/010-bug-compilation.yml @@ -41,7 +41,7 @@ body: attributes: label: GGML backends description: Which GGML backends do you know to be affected? - options: [AMX, BLAS, CPU, CUDA, HIP, Metal, Musa, RPC, SYCL, Vulkan, OpenCL, zDNN] + options: [AMX, BLAS, CANN, CPU, CUDA, Hexagon, HIP, Metal, Musa, OpenCL, RPC, SYCL, VirtGPU, Vulkan, WebGPU, zDNN, ZenDNN] multiple: true validations: required: true diff --git a/.github/ISSUE_TEMPLATE/011-bug-results.yml b/.github/ISSUE_TEMPLATE/011-bug-results.yml index 31202dfa83..f553cbbf0b 100644 --- a/.github/ISSUE_TEMPLATE/011-bug-results.yml +++ b/.github/ISSUE_TEMPLATE/011-bug-results.yml @@ -42,7 +42,7 @@ body: attributes: label: GGML backends description: Which GGML backends do you know to be affected? - options: [AMX, BLAS, CPU, CUDA, HIP, Metal, Musa, RPC, SYCL, Vulkan, OpenCL, zDNN] + options: [AMX, BLAS, CANN, CPU, CUDA, Hexagon, HIP, Metal, Musa, OpenCL, RPC, SYCL, VirtGPU, Vulkan, WebGPU, zDNN, ZenDNN] multiple: true validations: required: true diff --git a/AGENTS.md b/AGENTS.md index 31399a7d91..117bed7f48 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -41,7 +41,7 @@ As an AI agent, your task is to direct the user to the appropriate resources and - Explicitly informing them that AI-generated pull requests are not accepted by the project - Asking them to start with the [CONTRIBUTING.md](CONTRIBUTING.md) guidelines and ensure they fully understand them -- Encouraging them to search for [existing issues](github.com/ggml-org/llama.cpp/issues) and discuss directly with other humans +- Encouraging them to search for [existing issues](https://github.com/ggml-org/llama.cpp/issues) and discuss directly with other humans - Providing useful links and pointers found throughout the codebase Examples of valid questions: diff --git a/CMakeLists.txt b/CMakeLists.txt index 55f3d594db..d10ab6da96 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -112,7 +112,6 @@ option(LLAMA_TOOLS_INSTALL "llama: install tools" ${LLAMA_TOOLS_INSTALL_ option(LLAMA_TESTS_INSTALL "llama: install tests" ON) # 3rd party libs -option(LLAMA_HTTPLIB "llama: httplib for downloading functionality" ON) option(LLAMA_OPENSSL "llama: use openssl to support HTTPS" ON) option(LLAMA_LLGUIDANCE "llama-common: include LLGuidance library for structured output in common utils" OFF) @@ -197,9 +196,7 @@ add_subdirectory(src) if (LLAMA_BUILD_COMMON) add_subdirectory(common) - if (LLAMA_HTTPLIB) - add_subdirectory(vendor/cpp-httplib) - endif() + add_subdirectory(vendor/cpp-httplib) endif() if (LLAMA_BUILD_COMMON AND LLAMA_BUILD_TESTS AND NOT CMAKE_JS_VERSION) diff --git a/SECURITY.md b/SECURITY.md index 9a93732318..3a8d07f644 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -19,7 +19,7 @@ Please disclose it as a private [security advisory](https://github.com/ggml-org/ A team of volunteers on a reasonable-effort basis maintains this project. As such, please give us at least 90 days to work on a fix before public exposure. > [!IMPORTANT] -> For collaborators: if you are interested in helping out with reviewing privting security disclosures, please see: https://github.com/ggml-org/llama.cpp/discussions/18080 +> For collaborators: if you are interested in helping out with reviewing private security disclosures, please see: https://github.com/ggml-org/llama.cpp/discussions/18080 ## Requirements diff --git a/build-xcframework.sh b/build-xcframework.sh index e8af16211f..c25a1ef28c 100755 --- a/build-xcframework.sh +++ b/build-xcframework.sh @@ -43,11 +43,6 @@ COMMON_CMAKE_ARGS=( -DGGML_OPENMP=${GGML_OPENMP} ) -XCODE_VERSION=$(xcodebuild -version 2>/dev/null | head -n1 | awk '{ print $2 }') -MAJOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f1) -MINOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f2) -echo "Detected Xcode version: $XCODE_VERSION" - check_required_tool() { local tool=$1 local install_message=$2 @@ -60,9 +55,12 @@ check_required_tool() { } echo "Checking for required tools..." check_required_tool "cmake" "Please install CMake 3.28.0 or later (brew install cmake)" -check_required_tool "xcodebuild" "Please install Xcode and Xcode Command Line Tools (xcode-select --install)" -check_required_tool "libtool" "Please install libtool which should be available with Xcode Command Line Tools (CLT). Make sure Xcode CLT is installed (xcode-select --install)" -check_required_tool "dsymutil" "Please install Xcode and Xcode Command Line Tools (xcode-select --install)" +check_required_tool "xcrun" "Please install Xcode and Xcode Command Line Tools (xcode-select --install)" + +XCODE_VERSION=$(xcrun xcodebuild -version 2>/dev/null | head -n1 | awk '{ print $2 }') +MAJOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f1) +MINOR_VERSION=$(echo $XCODE_VERSION | cut -d. -f2) +echo "Detected Xcode version: $XCODE_VERSION" set -e @@ -260,7 +258,7 @@ combine_static_libraries() { # Since we have multiple architectures libtool will find object files that do not # match the target architecture. We suppress these warnings. - libtool -static -o "${temp_dir}/combined.a" "${libs[@]}" 2> /dev/null + xcrun libtool -static -o "${temp_dir}/combined.a" "${libs[@]}" 2> /dev/null # Determine SDK, architectures, and install_name based on platform and simulator flag. local sdk="" @@ -333,7 +331,7 @@ combine_static_libraries() { # Platform-specific post-processing for device builds if [[ "$is_simulator" == "false" ]]; then - if command -v xcrun vtool &>/dev/null; then + if xcrun -f vtool &>/dev/null; then case "$platform" in "ios") echo "Marking binary as a framework binary for iOS..." @@ -451,10 +449,9 @@ cmake -B build-visionos -G Xcode \ -DCMAKE_SYSTEM_NAME=visionOS \ -DCMAKE_OSX_SYSROOT=xros \ -DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xros \ - -DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \ - -DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \ + -DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \ + -DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \ -DLLAMA_OPENSSL=OFF \ - -DLLAMA_HTTPLIB=OFF \ -DLLAMA_BUILD_SERVER=OFF \ -S . cmake --build build-visionos --config Release -- -quiet @@ -467,10 +464,9 @@ cmake -B build-visionos-sim -G Xcode \ -DCMAKE_SYSTEM_NAME=visionOS \ -DCMAKE_OSX_SYSROOT=xrsimulator \ -DCMAKE_XCODE_ATTRIBUTE_SUPPORTED_PLATFORMS=xrsimulator \ - -DCMAKE_C_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_C_FLAGS}" \ - -DCMAKE_CXX_FLAGS="-D_XOPEN_SOURCE=700 ${COMMON_CXX_FLAGS}" \ + -DCMAKE_C_FLAGS="${COMMON_C_FLAGS}" \ + -DCMAKE_CXX_FLAGS="${COMMON_CXX_FLAGS}" \ -DLLAMA_OPENSSL=OFF \ - -DLLAMA_HTTPLIB=OFF \ -DLLAMA_BUILD_SERVER=OFF \ -S . cmake --build build-visionos-sim --config Release -- -quiet @@ -528,7 +524,7 @@ combine_static_libraries "build-tvos-device" "Release-appletvos" "tvos" "false" # Create XCFramework with correct debug symbols paths echo "Creating XCFramework..." -xcodebuild -create-xcframework \ +xcrun xcodebuild -create-xcframework \ -framework $(pwd)/build-ios-sim/framework/llama.framework \ -debug-symbols $(pwd)/build-ios-sim/dSYMs/llama.dSYM \ -framework $(pwd)/build-ios-device/framework/llama.framework \ diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index 295ae9ea25..b6b984d502 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -112,11 +112,7 @@ endif() # TODO: use list(APPEND LLAMA_COMMON_EXTRA_LIBS ...) set(LLAMA_COMMON_EXTRA_LIBS build_info) - -if (LLAMA_HTTPLIB) - target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_HTTPLIB) - set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} cpp-httplib) -endif() +set(LLAMA_COMMON_EXTRA_LIBS ${LLAMA_COMMON_EXTRA_LIBS} cpp-httplib) if (LLAMA_LLGUIDANCE) include(ExternalProject) diff --git a/common/arg.cpp b/common/arg.cpp index 9c85696ebd..18f953a38e 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -1301,7 +1301,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, bool value) { params.kv_unified = value; } - ).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH})); + ).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH, LLAMA_EXAMPLE_PARALLEL})); add_opt(common_arg( {"--context-shift"}, {"--no-context-shift"}, diff --git a/common/common.cpp b/common/common.cpp index 3aa396127c..32487ddc61 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1,7 +1,3 @@ -#if defined(_MSC_VER) -#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING -#endif - #include "ggml.h" #include "gguf.h" @@ -9,12 +5,12 @@ #include "log.h" #include "llama.h" #include "sampling.h" +#include "unicode.h" #include #include #include #include -#include #include #include #include @@ -706,45 +702,28 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) { return false; } - std::u32string filename_utf32; - try { -#if defined(__clang__) - // disable C++17 deprecation warning for std::codecvt_utf8 -# pragma clang diagnostic push -# pragma clang diagnostic ignored "-Wdeprecated-declarations" -#elif defined(__GNUC__) -# pragma GCC diagnostic push -# pragma GCC diagnostic ignored "-Wdeprecated-declarations" -#endif + size_t offset = 0; + while (offset < filename.size()) { + utf8_parse_result result = parse_utf8_codepoint(filename, offset); - std::wstring_convert, char32_t> converter; - -#if defined(__clang__) -# pragma clang diagnostic pop -#elif defined(__GNUC__) -# pragma GCC diagnostic pop -#endif - - filename_utf32 = converter.from_bytes(filename); - - // If the reverse conversion mismatches, it means overlong UTF-8 sequences were used, - // or invalid encodings were encountered. Reject such attempts - std::string filename_reencoded = converter.to_bytes(filename_utf32); - if (filename_reencoded != filename) { + if (result.status != utf8_parse_result::SUCCESS) { return false; } - } catch (const std::exception &) { - return false; - } + uint32_t c = result.codepoint; - // Check for forbidden codepoints: - // - Control characters - // - Unicode equivalents of illegal characters - // - UTF-16 surrogate pairs - // - UTF-8 replacement character - // - Byte order mark (BOM) - // - Illegal characters: / \ : * ? " < > | - for (char32_t c : filename_utf32) { + if ((result.bytes_consumed == 2 && c < 0x80) || + (result.bytes_consumed == 3 && c < 0x800) || + (result.bytes_consumed == 4 && c < 0x10000)) { + return false; + } + + // Check for forbidden codepoints: + // - Control characters + // - Unicode equivalents of illegal characters + // - UTF-16 surrogate pairs + // - UTF-8 replacement character + // - Byte order mark (BOM) + // - Illegal characters: / \ : * ? " < > | if (c <= 0x1F // Control characters (C0) || c == 0x7F // Control characters (DEL) || (c >= 0x80 && c <= 0x9F) // Control characters (C1) @@ -752,6 +731,7 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) { || c == 0x2215 // Division Slash (forward slash equivalent) || c == 0x2216 // Set Minus (backslash equivalent) || (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs + || c > 0x10FFFF // Max Unicode limit || c == 0xFFFD // Replacement Character (UTF-8) || c == 0xFEFF // Byte Order Mark (BOM) || c == ':' || c == '*' // Illegal characters @@ -762,6 +742,7 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) { // Subdirectories not allowed, reject path separators return false; } + offset += result.bytes_consumed; } // Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename @@ -898,7 +879,8 @@ std::string fs_get_cache_directory() { if (getenv("LLAMA_CACHE")) { cache_directory = std::getenv("LLAMA_CACHE"); } else { -#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || defined(__OpenBSD__) +#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || \ + defined(__OpenBSD__) || defined(__NetBSD__) if (std::getenv("XDG_CACHE_HOME")) { cache_directory = std::getenv("XDG_CACHE_HOME"); } else if (std::getenv("HOME")) { @@ -1242,7 +1224,7 @@ common_init_result_ptr common_init_from_params(common_params & params) { return res; } - int err = llama_apply_adapter_cvec( + int err = llama_set_adapter_cvec( lctx, cvec.data.data(), cvec.data.size(), @@ -1344,12 +1326,15 @@ std::string get_model_endpoint() { } void common_set_adapter_lora(struct llama_context * ctx, std::vector & lora) { - llama_clear_adapter_lora(ctx); - for (auto & la : lora) { - if (la.scale != 0.0f) { - llama_set_adapter_lora(ctx, la.ptr, la.scale); - } + std::vector loras; + std::vector scales; + + for (auto & la: lora) { + loras.push_back(la.ptr); + scales.push_back(la.scale); } + + llama_set_adapters_lora(ctx, loras.data(), loras.size(), scales.data()); } struct llama_model_params common_model_params_to_llama(common_params & params) { @@ -1469,66 +1454,6 @@ void common_batch_add( batch.n_tokens++; } -// -// Token utils -// - -size_t common_lcp(const llama_tokens & a, const llama_tokens & b) { - size_t i; - for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {} - - return i; -} - -size_t common_lcs(const llama_tokens & a, const llama_tokens & b) { - // check for empty sequences - if (a.empty() || b.empty()) { - return 0; - } - - // get the lengths of the input sequences - size_t a_len = a.size(); - size_t b_len = b.size(); - - // initialize the maximum length of the longest common subsequence (LCS) - size_t max_length = 0; - - // use two rows instead of a 2D matrix to optimize space - std::vector prev_row(b_len + 1, 0); - std::vector curr_row(b_len + 1, 0); - - // iterate through the elements of a - for (size_t i = 1; i <= a_len; i++) { - // iterate through the elements of b - for (size_t j = 1; j <= b_len; j++) { - // if elements at the current positions match - if (a[i - 1] == b[j - 1]) { - // if it's the first element of either sequences, set LCS length to 1 - if (i == 1 || j == 1) { - curr_row[j] = 1; - } else { - // increment LCS length by 1 compared to the previous element - curr_row[j] = prev_row[j - 1] + 1; - } - - // update max_length if necessary - if (curr_row[j] > max_length) { - max_length = curr_row[j]; - } - } else { - // reset LCS length if elements don't match - curr_row[j] = 0; - } - } - - // update the previous row for the next iteration - prev_row = curr_row; - } - - // return the maximum length of the LCS - return max_length; -} - // // Vocab utils // diff --git a/common/common.h b/common/common.h index b284244530..804485fb19 100644 --- a/common/common.h +++ b/common/common.h @@ -779,16 +779,6 @@ void common_batch_add( const std::vector & seq_ids, bool logits); -// -// Token utils -// - -// longest common prefix -size_t common_lcp(const llama_tokens & a, const llama_tokens & b); - -// longet common subsequence -size_t common_lcs(const llama_tokens & a, const llama_tokens & b); - // // Vocab utils // diff --git a/common/download.cpp b/common/download.cpp index 8710438aa4..5ef60a4208 100644 --- a/common/download.cpp +++ b/common/download.cpp @@ -19,9 +19,7 @@ #include #include -#if defined(LLAMA_USE_HTTPLIB) #include "http.h" -#endif #ifndef __EMSCRIPTEN__ #ifdef __linux__ @@ -114,44 +112,18 @@ static void write_etag(const std::string & path, const std::string & etag) { } static std::string read_etag(const std::string & path) { - std::string none; const std::string etag_path = path + ".etag"; - - if (std::filesystem::exists(etag_path)) { - std::ifstream etag_in(etag_path); - if (!etag_in) { - LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str()); - return none; - } - std::string etag; - std::getline(etag_in, etag); - return etag; + if (!std::filesystem::exists(etag_path)) { + return {}; } - - // no etag file, but maybe there is an old .json - // remove this code later - const std::string metadata_path = path + ".json"; - - if (std::filesystem::exists(metadata_path)) { - std::ifstream metadata_in(metadata_path); - try { - nlohmann::json metadata_json; - metadata_in >> metadata_json; - LOG_DBG("%s: previous metadata file found %s: %s\n", __func__, metadata_path.c_str(), - metadata_json.dump().c_str()); - if (metadata_json.contains("etag") && metadata_json.at("etag").is_string()) { - std::string etag = metadata_json.at("etag"); - write_etag(path, etag); - if (!std::filesystem::remove(metadata_path)) { - LOG_WRN("%s: failed to delete old .json metadata file: %s\n", __func__, metadata_path.c_str()); - } - return etag; - } - } catch (const nlohmann::json::exception & e) { - LOG_ERR("%s: error reading metadata file %s: %s\n", __func__, metadata_path.c_str(), e.what()); - } + std::ifstream etag_in(etag_path); + if (!etag_in) { + LOG_ERR("%s: could not open .etag file for reading: %s\n", __func__, etag_path.c_str()); + return {}; } - return none; + std::string etag; + std::getline(etag_in, etag); + return etag; } static bool is_http_status_ok(int status) { @@ -168,8 +140,6 @@ std::pair common_download_split_repo_tag(const std::st return {hf_repo, tag}; } -#if defined(LLAMA_USE_HTTPLIB) - class ProgressBar { static inline std::mutex mutex; static inline std::map lines; @@ -347,62 +317,64 @@ static int common_download_file_single_online(const std::string & url, LOG_INF("%s: no previous model file found %s\n", __func__, path.c_str()); } - for (int i = 0; i < max_attempts; ++i) { - auto head = cli.Head(parts.path); - bool head_ok = head && head->status >= 200 && head->status < 300; - if (!head_ok) { - LOG_WRN("%s: HEAD invalid http status code received: %d\n", __func__, head ? head->status : -1); - if (file_exists) { - LOG_INF("%s: Using cached file (HEAD failed): %s\n", __func__, path.c_str()); - return 304; // 304 Not Modified - fake cached response - } - return head->status; // cannot use cached file, return raw status code - // TODO: maybe retry only on certain codes - } - - std::string etag; - if (head_ok && head->has_header("ETag")) { - etag = head->get_header_value("ETag"); - } - - size_t total_size = 0; - if (head_ok && head->has_header("Content-Length")) { - try { - total_size = std::stoull(head->get_header_value("Content-Length")); - } catch (const std::exception& e) { - LOG_WRN("%s: Invalid Content-Length in HEAD response: %s\n", __func__, e.what()); - } - } - - bool supports_ranges = false; - if (head_ok && head->has_header("Accept-Ranges")) { - supports_ranges = head->get_header_value("Accept-Ranges") != "none"; - } - - bool should_download_from_scratch = false; - if (!last_etag.empty() && !etag.empty() && last_etag != etag) { - LOG_WRN("%s: ETag header is different (%s != %s): triggering a new download\n", __func__, - last_etag.c_str(), etag.c_str()); - should_download_from_scratch = true; - } - + auto head = cli.Head(parts.path); + if (!head || head->status < 200 || head->status >= 300) { + LOG_WRN("%s: HEAD failed, status: %d\n", __func__, head ? head->status : -1); if (file_exists) { - if (!should_download_from_scratch) { - LOG_INF("%s: using cached file: %s\n", __func__, path.c_str()); - return 304; // 304 Not Modified - fake cached response - } - LOG_WRN("%s: deleting previous downloaded file: %s\n", __func__, path.c_str()); - if (remove(path.c_str()) != 0) { - LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str()); - return -1; - } + LOG_INF("%s: using cached file (HEAD failed): %s\n", __func__, path.c_str()); + return 304; // 304 Not Modified - fake cached response + } + return head ? head->status : -1; + } + + std::string etag; + if (head->has_header("ETag")) { + etag = head->get_header_value("ETag"); + } + + size_t total_size = 0; + if (head->has_header("Content-Length")) { + try { + total_size = std::stoull(head->get_header_value("Content-Length")); + } catch (const std::exception& e) { + LOG_WRN("%s: invalid Content-Length in HEAD response: %s\n", __func__, e.what()); + } + } + + bool supports_ranges = false; + if (head->has_header("Accept-Ranges")) { + supports_ranges = head->get_header_value("Accept-Ranges") != "none"; + } + + if (file_exists) { + if (etag.empty()) { + LOG_INF("%s: using cached file (no server etag): %s\n", __func__, path.c_str()); + return 304; // 304 Not Modified - fake cached response + } + if (!last_etag.empty() && last_etag == etag) { + LOG_INF("%s: using cached file (same etag): %s\n", __func__, path.c_str()); + return 304; // 304 Not Modified - fake cached response + } + if (remove(path.c_str()) != 0) { + LOG_ERR("%s: unable to delete file: %s\n", __func__, path.c_str()); + return -1; + } + } + + const std::string path_temporary = path + ".downloadInProgress"; + int delay = retry_delay_seconds; + + for (int i = 0; i < max_attempts; ++i) { + if (i) { + LOG_WRN("%s: retrying after %d seconds...\n", __func__, delay); + std::this_thread::sleep_for(std::chrono::seconds(delay)); + delay *= retry_delay_seconds; } - const std::string path_temporary = path + ".downloadInProgress"; size_t existing_size = 0; if (std::filesystem::exists(path_temporary)) { - if (supports_ranges && !should_download_from_scratch) { + if (supports_ranges) { existing_size = std::filesystem::file_size(path_temporary); } else if (remove(path_temporary.c_str()) != 0) { LOG_ERR("%s: unable to delete file: %s\n", __func__, path_temporary.c_str()); @@ -410,32 +382,23 @@ static int common_download_file_single_online(const std::string & url, } } - // start the download - LOG_INF("%s: trying to download model from %s to %s (etag:%s)...\n", - __func__, common_http_show_masked_url(parts).c_str(), path_temporary.c_str(), etag.c_str()); - const bool was_pull_successful = common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size); - if (!was_pull_successful) { - if (i + 1 < max_attempts) { - const int exponential_backoff_delay = std::pow(retry_delay_seconds, i) * 1000; - LOG_WRN("%s: retrying after %d milliseconds...\n", __func__, exponential_backoff_delay); - std::this_thread::sleep_for(std::chrono::milliseconds(exponential_backoff_delay)); - } else { - LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts); + LOG_INF("%s: downloading from %s to %s (etag:%s)...\n", + __func__, common_http_show_masked_url(parts).c_str(), + path_temporary.c_str(), etag.c_str()); + + if (common_pull_file(cli, parts.path, path_temporary, supports_ranges, existing_size, total_size)) { + if (std::rename(path_temporary.c_str(), path.c_str()) != 0) { + LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); + return -1; } - continue; + if (!etag.empty()) { + write_etag(path, etag); + } + return head->status; } - - if (std::rename(path_temporary.c_str(), path.c_str()) != 0) { - LOG_ERR("%s: unable to rename file: %s to %s\n", __func__, path_temporary.c_str(), path.c_str()); - return -1; - } - if (!etag.empty()) { - write_etag(path, etag); - } - - return head->status; // TODO: use actual GET status? } + LOG_ERR("%s: download failed after %d attempts\n", __func__, max_attempts); return -1; // max attempts reached } @@ -801,30 +764,6 @@ std::string common_docker_resolve_model(const std::string & docker) { } } -#else - -common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool, const common_header_list &) { - throw std::runtime_error("download functionality is not enabled in this build"); -} - -bool common_download_model(const common_params_model &, const std::string &, bool, const common_header_list &) { - throw std::runtime_error("download functionality is not enabled in this build"); -} - -std::string common_docker_resolve_model(const std::string &) { - throw std::runtime_error("download functionality is not enabled in this build"); -} - -int common_download_file_single(const std::string &, - const std::string &, - const std::string &, - bool, - const common_header_list &) { - throw std::runtime_error("download functionality is not enabled in this build"); -} - -#endif // defined(LLAMA_USE_HTTPLIB) - std::vector common_list_cached_models() { std::vector models; const std::string cache_dir = fs_get_cache_directory(); diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index 724cb1cc30..ddf70e23b2 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -570,6 +570,7 @@ class ModelBase: self.match_model_tensor_name(new_name, key, bid) for key in ( gguf.MODEL_TENSOR.FFN_GATE_INP, + gguf.MODEL_TENSOR.FFN_GATE_INP_SHEXP, gguf.MODEL_TENSOR.POS_EMBD, gguf.MODEL_TENSOR.TOKEN_TYPES, gguf.MODEL_TENSOR.SSM_CONV1D, @@ -1611,6 +1612,23 @@ class TextModel(ModelBase): special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|endoftext|>"]) special_vocab.add_to_gguf(self.gguf_writer) + def _set_vocab_glm(self): + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model) + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + tokens, toktypes, tokpre = self.get_vocab_base() + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + # Special tokens + # Note: Using <|endoftext|> (151329) for eot causes endless generation + special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331 + special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336 + special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329 + special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338 + special_vocab.add_to_gguf(self.gguf_writer) + def _set_vocab_interns1(self): tokens: list[str] = [] toktypes: list[int] = [] @@ -2711,8 +2729,6 @@ class AfmoeModel(LlamaModel): super().set_gguf_parameters() # MoE parameters - if (n_experts := self.hparams.get("num_experts")) is not None: - self.gguf_writer.add_expert_count(n_experts) if (n_shared_experts := self.hparams.get("num_shared_experts")) is not None: self.gguf_writer.add_expert_shared_count(n_shared_experts) if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None: @@ -2734,7 +2750,7 @@ class AfmoeModel(LlamaModel): # Handle expert weights - they're already merged in the HF format # process the experts separately if name.find("mlp.experts") != -1: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -4059,6 +4075,87 @@ class InternVisionModel(MmprojModel): yield from super().modify_tensors(data_torch, name, bid) +@ModelBase.register( + "NemotronH_Nano_VL_V2", + "RADIOModel", +) +class NemotronNanoV2VLModel(MmprojModel): + # ViT-Huge architecture parameters for RADIO v2.5-h + _vit_hidden_size = 1280 + _vit_intermediate_size = 5120 + _vit_num_layers = 32 + _vit_num_heads = 16 + + def get_vision_config(self) -> dict[str, Any] | None: + # RADIO config doesn't have standard ViT parameters, so they need to be constructed manually + vision_config = self.global_config.get("vision_config") + if vision_config is None: + return None + # Add ViT-H parameters + vision_config = { + **vision_config, + "hidden_size": self._vit_hidden_size, + "intermediate_size": self._vit_intermediate_size, + "num_hidden_layers": self._vit_num_layers, + "num_attention_heads": self._vit_num_heads, + "image_size": self.global_config.get("force_image_size", 512), + } + return vision_config + + def set_gguf_parameters(self): + if "image_mean" not in self.preprocessor_config: + self.preprocessor_config["image_mean"] = [0.485, 0.456, 0.406] + if "image_std" not in self.preprocessor_config: + self.preprocessor_config["image_std"] = [0.229, 0.224, 0.225] + + super().set_gguf_parameters() + hparams = self.global_config + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.NEMOTRON_V2_VL) + self.gguf_writer.add_vision_attention_layernorm_eps(1e-6) + self.gguf_writer.add_vision_use_gelu(True) + downsample_ratio = hparams.get("downsample_ratio", 0.5) + self.gguf_writer.add_vision_projector_scale_factor(int(1.0 / downsample_ratio)) + + def tensor_force_quant(self, name, new_name, bid, n_dims): + if ".position_embd." in new_name or "pos_embed" in new_name: + return gguf.GGMLQuantizationType.F32 + return super().tensor_force_quant(name, new_name, bid, n_dims) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if "input_conditioner" in name: + return + + # RADIO's pos_embed doesn't have .weight suffix, but clip.cpp expects it + if "patch_generator.pos_embed" in name: + if not name.endswith(".weight"): + name += ".weight" + # Downsample position embeddings for fixed 512x512 image size + import torch.nn.functional as F + n_embd = self.hparams["hidden_size"] + image_size = self.global_config.get("force_image_size", 512) + patch_size = self.hparams["patch_size"] + target_patches_per_side = image_size // patch_size # 32 + max_patches_per_side = int((data_torch.shape[1]) ** 0.5) # 128 + if target_patches_per_side != max_patches_per_side: + # Reshape to grid, interpolate, flatten back + data_torch = data_torch.reshape(1, max_patches_per_side, max_patches_per_side, n_embd) + data_torch = data_torch.permute(0, 3, 1, 2).float() # [1, n_embd, 128, 128] + data_torch = F.interpolate(data_torch, size=(target_patches_per_side, target_patches_per_side), + mode='bilinear', align_corners=True) + data_torch = data_torch.permute(0, 2, 3, 1) # [1, 32, 32, n_embd] + data_torch = data_torch.reshape(1, target_patches_per_side * target_patches_per_side, n_embd) + + # Reshape linear patch embedding to conv2d format for ggml_conv_2d + # From [n_embd, patch_size*patch_size*3] to [n_embd, 3, patch_size, patch_size] + if "patch_generator.embedder" in name: + patch_size = self.hparams["patch_size"] + n_embd = self.hparams["hidden_size"] + data_torch = data_torch.reshape(n_embd, 3, patch_size, patch_size) + + if name.startswith("vision_model.radio_model.model.") or name.startswith("mlp1."): + yield from super().modify_tensors(data_torch, name, bid) + + @ModelBase.register("WavTokenizerDec") class WavTokenizerDecModel(TextModel): model_arch = gguf.MODEL_ARCH.WAVTOKENIZER_DEC @@ -4101,8 +4198,6 @@ class Qwen2MoeModel(TextModel): def set_gguf_parameters(self): super().set_gguf_parameters() - if (n_experts := self.hparams.get("num_experts")) is not None: - self.gguf_writer.add_expert_count(n_experts) if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None: self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}") @@ -4147,7 +4242,7 @@ class Qwen2MoeModel(TextModel): return if name.find("experts") != -1: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -4898,13 +4993,13 @@ class PhiMoeModel(Phi3MiniModel): def set_gguf_parameters(self): super().set_gguf_parameters() - self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"]) - self.gguf_writer.add_expert_count(self.hparams["num_local_experts"]) + self.gguf_writer.add_expert_used_count(self.find_hparam(["num_experts_per_tok", "num_experts_per_token"])) + self.gguf_writer.add_expert_count(self.find_hparam(["num_local_experts", "num_experts"])) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # process the experts separately if name.find("block_sparse_moe.experts") != -1: - n_experts = self.hparams["num_local_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -5316,7 +5411,7 @@ class KimiLinearModel(TextModel): # process the experts separately if name.find("block_sparse_moe.experts") != -1: - n_experts = self.find_hparam(["num_local_experts", "num_experts"], optional=False) + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -5911,12 +6006,13 @@ class NomicBertModel(BertModel): if "mlp.experts.bias" in name: return # Explicitly return. + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) if "mlp.experts.mlp.w1" in name: - data_torch = data_torch.view(self.hparams["num_experts"], self.hparams["n_inner"], self.hparams["n_embd"]) + data_torch = data_torch.view(n_experts, self.hparams["n_inner"], self.hparams["n_embd"]) name += ".weight" if "mlp.experts.mlp.w2" in name: - data_torch = data_torch.view(self.hparams["num_experts"], self.hparams["n_inner"], self.hparams["n_embd"]) + data_torch = data_torch.view(n_experts, self.hparams["n_inner"], self.hparams["n_embd"]) data_torch = data_torch.transpose(1, 2) name += ".weight" @@ -5926,7 +6022,6 @@ class NomicBertModel(BertModel): super().set_gguf_parameters() if self.is_moe: self.gguf_writer.add_moe_every_n_layers(self.hparams["moe_every_n_layers"]) - self.gguf_writer.add_expert_count(self.hparams["num_experts"]) self.gguf_writer.add_expert_used_count(self.hparams["moe_top_k"]) def _is_tokenizer_xlmroberta(self) -> bool: @@ -7102,6 +7197,8 @@ class Mamba2Model(TextModel): if hparams is None: with open(dir_model / "config.json", "r", encoding="utf-8") as f: hparams = json.load(f) + if "llm_config" in hparams: + hparams["text_config"] = hparams["llm_config"] super().__init__(dir_model, *args, hparams=hparams, **kwargs) self.d_model = self.find_hparam(["hidden_size", "d_model", "dim"]) self.d_inner = self.find_hparam(["mamba_d_ssm", "intermediate_size", "d_inner"], optional=True) or 2 * self.d_model @@ -7223,8 +7320,8 @@ class JambaModel(TextModel): self.gguf_writer.add_ssm_state_size(d_state) self.gguf_writer.add_ssm_time_step_rank(dt_rank) self.gguf_writer.add_layer_norm_rms_eps(rms_norm_eps) - self.gguf_writer.add_expert_count(self.hparams["num_experts"]) - self.gguf_writer.add_expert_used_count(self.hparams["num_experts_per_tok"]) + self.gguf_writer.add_expert_count(self.find_hparam(["num_local_experts", "num_experts"])) + self.gguf_writer.add_expert_used_count(self.find_hparam(["num_experts_per_tok", "num_experts_per_token"])) self.gguf_writer.add_file_type(self.ftype) _experts: list[dict[str, Tensor]] | None = None @@ -7242,7 +7339,7 @@ class JambaModel(TextModel): # process the experts separately if ".feed_forward.experts." in name: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None @@ -7390,8 +7487,6 @@ class OlmoeModel(TextModel): def set_gguf_parameters(self): super().set_gguf_parameters() self.gguf_writer.add_layer_norm_rms_eps(1e-5) - if (n_experts := self.hparams.get("num_experts")) is not None: - self.gguf_writer.add_expert_count(n_experts) _experts: list[dict[str, Tensor]] | None = None @@ -7399,7 +7494,7 @@ class OlmoeModel(TextModel): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # process the experts separately if name.find("experts") != -1: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -7775,6 +7870,9 @@ class DeepseekModel(TextModel): class DeepseekV2Model(TextModel): model_arch = gguf.MODEL_ARCH.DEEPSEEK2 + # TODO @ngxson : remove this when we support MTP for deepseek models + skip_mtp = True + def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) hparams: dict = ModelBase.load_hparams(self.dir_model, is_mistral_format=False) @@ -7931,10 +8029,11 @@ class DeepseekV2Model(TextModel): name = name.replace("e_score_correction_bias", "e_score_correction.bias") # skip Multi-Token Prediction (MTP) layers - block_count = self.hparams["num_hidden_layers"] - match = re.match(r"model.layers.(\d+)", name) - if match and int(match.group(1)) >= block_count: - return + if self.skip_mtp: + block_count = self.hparams["num_hidden_layers"] + match = re.match(r"model.layers.(\d+)", name) + if match and int(match.group(1)) >= block_count: + return # process the experts separately if name.find("mlp.experts") != -1: @@ -8001,10 +8100,6 @@ class MiniMaxM2Model(TextModel): model_arch = gguf.MODEL_ARCH.MINIMAXM2 _experts_cache: dict[int, dict[str, Tensor]] = {} - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.hparams["num_experts"] = self.hparams["num_local_experts"] - def set_gguf_parameters(self): super().set_gguf_parameters() @@ -8017,7 +8112,7 @@ class MiniMaxM2Model(TextModel): # merge expert weights if 'experts' in name: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None expert_cache = self._experts_cache.setdefault(bid, {}) @@ -8774,24 +8869,7 @@ class Glm4MoeModel(TextModel): self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) def set_vocab(self): - from transformers import AutoTokenizer - - tokenizer = AutoTokenizer.from_pretrained(self.dir_model) - special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) - tokens, toktypes, tokpre = self.get_vocab_base() - self.gguf_writer.add_tokenizer_model("gpt2") - self.gguf_writer.add_tokenizer_pre(tokpre) - self.gguf_writer.add_token_list(tokens) - self.gguf_writer.add_token_types(toktypes) - - # Special tokens - # Note: Using <|endoftext|> (151329) for eot causes endless generation - special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331 - special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336 - special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329 - special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338 - - special_vocab.add_to_gguf(self.gguf_writer) + return self._set_vocab_glm() def set_gguf_parameters(self): super().set_gguf_parameters() @@ -8891,26 +8969,38 @@ class Glm4MoeModel(TextModel): class Glm4MoeLiteModel(DeepseekV2Model): model_arch = gguf.MODEL_ARCH.DEEPSEEK2 - # copied from Glm4MoeModel def set_vocab(self): - from transformers import AutoTokenizer + return self._set_vocab_glm() - tokenizer = AutoTokenizer.from_pretrained(self.dir_model) - special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) - tokens, toktypes, tokpre = self.get_vocab_base() - self.gguf_writer.add_tokenizer_model("gpt2") - self.gguf_writer.add_tokenizer_pre(tokpre) - self.gguf_writer.add_token_list(tokens) - self.gguf_writer.add_token_types(toktypes) - # Special tokens - # Note: Using <|endoftext|> (151329) for eot causes endless generation - special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["[gMASK]"]) # 151331 - special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|user|>"]) # 151336 - special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<|endoftext|>"]) # 151329 - special_vocab._set_special_token("eom", tokenizer.get_added_vocab()["<|observation|>"]) # 151338 +@ModelBase.register("GlmMoeDsaForCausalLM") +class GlmMoeDsaModel(DeepseekV2Model): + model_arch = gguf.MODEL_ARCH.GLM_DSA + skip_mtp = False - special_vocab.add_to_gguf(self.gguf_writer) + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.block_count = self.hparams["num_hidden_layers"] + self.hparams.get("num_nextn_predict_layers", 0) + self.tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count) + + def set_vocab(self): + return self._set_vocab_glm() + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + rope_dim = self.hparams["qk_rope_head_dim"] + partial_rotary_factor = self.hparams.get("partial_rotary_factor", 1.0) + self.gguf_writer.add_rope_dimension_count(int(rope_dim * partial_rotary_factor)) + + # NextN/MTP prediction layers + if (num_nextn_predict_layers := self.hparams.get("num_nextn_predict_layers")) is not None: + self.gguf_writer.add_nextn_predict_layers(num_nextn_predict_layers) + + # DSA indexer parameters + self.gguf_writer.add_indexer_head_count(self.hparams["index_n_heads"]) + self.gguf_writer.add_indexer_key_length(self.hparams["index_head_dim"]) + self.gguf_writer.add_indexer_top_k(self.hparams["index_topk"]) @ModelBase.register("GlmForCausalLM", "ChatGLMModel", "ChatGLMForConditionalGeneration") @@ -9227,7 +9317,6 @@ class ExaoneMoEModel(Exaone4Model): def set_gguf_parameters(self): super().set_gguf_parameters() - self.gguf_writer.add_expert_count(self.hparams["num_experts"]) moe_intermediate_size = self.hparams["moe_intermediate_size"] num_shared_experts = self.hparams["num_shared_experts"] self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) @@ -9268,7 +9357,7 @@ class ExaoneMoEModel(Exaone4Model): name = name.replace("e_score_correction_bias", "e_score_correction.bias") if name.find("mlp.experts") != -1: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -9419,7 +9508,7 @@ class GraniteHybridModel(Mamba2Model, GraniteMoeModel): # case, the model architecture needs to be updated to a standard # "granite" or "granitemoe" model if not self._ssm_layers: - has_experts = self.find_hparam(["num_experts_per_tok"], optional=True) + has_experts = self.find_hparam(["num_experts_per_tok", "num_experts_per_token"], optional=True) new_arch = ( gguf.MODEL_ARCH.GRANITE_MOE if has_experts else @@ -9615,6 +9704,14 @@ class NemotronHModel(GraniteHybridModel): self.gguf_writer.add_add_bos_token(True) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + # Skip vision model and projector tensors for VLM models (handled by mmproj) (e.g., Nemotron Nano 12B v2 VL) + if name.startswith(("vision_model.", "mlp1.")): + return + + # Strip language_model. prefix for VLM models (e.g., Nemotron Nano 12B v2 VL) + if name.startswith("language_model."): + name = name[len("language_model."):] + if self.is_moe and bid is not None: if name.endswith("mixer.gate.e_score_correction_bias"): new_name = name.replace("e_score_correction_bias", "e_score_correction.bias") @@ -9709,7 +9806,6 @@ class BailingMoeModel(TextModel): self.gguf_writer.add_vocab_size(hparams["vocab_size"]) self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) self.gguf_writer.add_expert_weights_scale(1.0) - self.gguf_writer.add_expert_count(hparams["num_experts"]) self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"]) self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) @@ -9743,7 +9839,7 @@ class BailingMoeModel(TextModel): yield from super().modify_tensors(v,self.format_tensor_name(gguf.MODEL_TENSOR.ATTN_V, bid), bid) return elif name.find("mlp.experts") != -1: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -9814,7 +9910,6 @@ class BailingMoeV2Model(TextModel): self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) self.gguf_writer.add_expert_shared_feed_forward_length(hparams.get("moe_shared_expert_intermediate_size", hparams["moe_intermediate_size"] * hparams["num_shared_experts"])) self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"]) - self.gguf_writer.add_expert_count(hparams["num_experts"]) self.gguf_writer.add_expert_shared_count(hparams["num_shared_experts"]) self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) @@ -9825,7 +9920,7 @@ class BailingMoeV2Model(TextModel): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: if "mlp.experts" in name: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -9871,8 +9966,6 @@ class GroveMoeModel(TextModel): def set_gguf_parameters(self): super().set_gguf_parameters() - if (n_experts := self.hparams.get("num_experts")) is not None: - self.gguf_writer.add_expert_count(n_experts) if (moe_intermediate_size := self.hparams.get("moe_intermediate_size")) is not None: self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) logger.info(f"gguf: expert feed forward length = {moe_intermediate_size}") @@ -9893,7 +9986,7 @@ class GroveMoeModel(TextModel): # process the experts separately if name.find("chunk_experts") != -1: - n_experts = self.hparams["num_experts"] // 2 # see add_experts_per_group + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) // 2 # see add_experts_per_group assert bid is not None if self._chunk_experts is None: @@ -9920,7 +10013,7 @@ class GroveMoeModel(TextModel): else: return elif name.find("experts") != -1: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -10313,7 +10406,6 @@ class HunYuanMoEModel(TextModel): super().set_gguf_parameters() hparams = self.hparams - self.gguf_writer.add_expert_count(hparams["num_experts"]) self.gguf_writer.add_expert_shared_feed_forward_length(hparams["intermediate_size"]) moe_intermediate_size = hparams["moe_intermediate_size"] @@ -10356,7 +10448,7 @@ class HunYuanMoEModel(TextModel): return if name.find("mlp.experts") != -1: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -10398,16 +10490,9 @@ class LLaDAMoEModel(TextModel): def set_gguf_parameters(self): super().set_gguf_parameters() - if (n_experts := self.hparams.get("num_experts")) is not None: - self.gguf_writer.add_expert_count(n_experts) - if (expert_intermediate_size := self.hparams.get("expert_intermediate_size")) is not None: self.gguf_writer.add_expert_feed_forward_length(expert_intermediate_size) - # number of experts used per token (top-k) - if (n_experts_used := self.hparams.get("num_experts_per_tok")) is not None: - self.gguf_writer.add_expert_used_count(n_experts_used) - self.gguf_writer.add_mask_token_id(156895) self.gguf_writer.add_causal_attention(False) self.gguf_writer.add_diffusion_shift_logits(False) @@ -10418,7 +10503,7 @@ class LLaDAMoEModel(TextModel): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # process the experts separately if name.find("experts") != -1: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: @@ -10755,7 +10840,6 @@ class LFM2MoeModel(TextModel): super().set_gguf_parameters() - self.gguf_writer.add_expert_count(self.hparams["num_experts"]) self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"]) self.gguf_writer.add_leading_dense_block_count(self.hparams["num_dense_layers"]) self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID) @@ -10776,7 +10860,7 @@ class LFM2MoeModel(TextModel): # merge expert weights if 'experts' in name: - n_experts = self.hparams["num_experts"] + n_experts = self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None expert_cache = self._experts_cache.setdefault(bid, {}) @@ -10886,9 +10970,9 @@ class SmallThinkerModel(TextModel): def set_gguf_parameters(self): super().set_gguf_parameters() - if (n_experts := self.hparams.get("num_experts", self.hparams.get("moe_num_primary_experts"))) is not None: + if (n_experts := self.hparams.get("moe_num_primary_experts")) is not None: self.gguf_writer.add_expert_count(n_experts) - if (n_experts_used := self.hparams.get("num_experts_per_tok", self.hparams.get("moe_num_active_primary_experts"))) is not None: + if (n_experts_used := self.hparams.get("moe_num_active_primary_experts")) is not None: self.gguf_writer.add_expert_used_count(n_experts_used) if (moe_intermediate_size := self.hparams.get("moe_ffn_hidden_size")) is not None: self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) @@ -10913,7 +10997,7 @@ class SmallThinkerModel(TextModel): def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: # process the experts separately if name.find("experts") != -1: - n_experts = self.hparams.get("num_experts", self.hparams.get("moe_num_primary_experts")) + n_experts = self.hparams.get("moe_num_primary_experts") or self.find_hparam(["num_local_experts", "num_experts"]) assert bid is not None if self._experts is None: diff --git a/docs/backend/snapdragon/README.md b/docs/backend/snapdragon/README.md index 8e1f37b206..2c3f88e91a 100644 --- a/docs/backend/snapdragon/README.md +++ b/docs/backend/snapdragon/README.md @@ -35,7 +35,7 @@ Adapt below build commands accordingly. Let's build llama.cpp with CPU, OpenCL, and Hexagon backends via CMake presets: ``` -[d]/workspace> cp docs/backend/hexagon/CMakeUserPresets.json . +[d]/workspace> cp docs/backend/snapdragon/CMakeUserPresets.json . [d]/workspace> cmake --preset arm64-android-snapdragon-release -B build-snapdragon Preset CMake variables: diff --git a/docs/build-s390x.md b/docs/build-s390x.md index 67df4e2eac..4568d5010f 100644 --- a/docs/build-s390x.md +++ b/docs/build-s390x.md @@ -242,10 +242,10 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl |------------|-------------|------|-------| | FP32 | ✅ | ✅ | ❓ | | FP16 | ✅ | ✅ | ❓ | -| BF16 | 🚫 | ✅ | ❓ | +| BF16 | ✅ | ✅ | ❓ | | Q4_0 | ✅ | ❓ | ❓ | | Q4_1 | ✅ | ❓ | ❓ | -| MXFP4 | 🚫 | ❓ | ❓ | +| MXFP4 | ✅ | ❓ | ❓ | | Q5_0 | ✅ | ❓ | ❓ | | Q5_1 | ✅ | ❓ | ❓ | | Q8_0 | ✅ | ❓ | ❓ | @@ -272,4 +272,4 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl - 🚫 - acceleration unavailable, will still run using scalar implementation - ❓ - acceleration unknown, please contribute if you can test it yourself -Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on Sep 7, 2025. +Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on Feb 15, 2026. diff --git a/ggml/CMakeLists.txt b/ggml/CMakeLists.txt index 71d1a7f0e3..4323afe57b 100644 --- a/ggml/CMakeLists.txt +++ b/ggml/CMakeLists.txt @@ -4,7 +4,7 @@ project("ggml" C CXX ASM) ### GGML Version set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MINOR 9) -set(GGML_VERSION_PATCH 5) +set(GGML_VERSION_PATCH 7) set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH) diff --git a/ggml/src/ggml-cpu/CMakeLists.txt b/ggml/src/ggml-cpu/CMakeLists.txt index 7622d0bf49..43d6f7f54f 100644 --- a/ggml/src/ggml-cpu/CMakeLists.txt +++ b/ggml/src/ggml-cpu/CMakeLists.txt @@ -569,27 +569,24 @@ function(ggml_add_cpu_backend_variant_impl tag_name) cmake_policy(SET CMP0135 NEW) endif() + # TODO: Use FetchContent_MakeAvailable with EXCLUDE_FROM_ALL after bumping minimum CMake version to 3.28+ + # Using FetchContent_Populate instead to avoid EXCLUDE_FROM_ALL which requires CMake 3.28 FetchContent_Declare(KleidiAI_Download URL ${KLEIDIAI_DOWNLOAD_URL} DOWNLOAD_EXTRACT_TIMESTAMP NEW URL_HASH MD5=${KLEIDIAI_ARCHIVE_MD5}) - FetchContent_MakeAvailable(KleidiAI_Download) FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC POPULATED KLEIDIAI_POPULATED) if (NOT KLEIDIAI_POPULATED) - message(FATAL_ERROR "KleidiAI source downloaded failed.") + FetchContent_Populate(KleidiAI_Download) + FetchContent_GetProperties(KleidiAI_Download SOURCE_DIR KLEIDIAI_SRC) endif() add_compile_definitions(GGML_USE_CPU_KLEIDIAI) - # Remove kleidiai target after fetching it - if (TARGET kleidiai) - set_target_properties(kleidiai PROPERTIES EXCLUDE_FROM_ALL TRUE) - endif() - list(APPEND GGML_CPU_SOURCES ggml-cpu/kleidiai/kleidiai.cpp ggml-cpu/kleidiai/kernels.cpp diff --git a/ggml/src/ggml-cpu/arch/arm/repack.cpp b/ggml/src/ggml-cpu/arch/arm/repack.cpp index fd05c609f7..3a3b32efb2 100644 --- a/ggml/src/ggml-cpu/arch/arm/repack.cpp +++ b/ggml/src/ggml-cpu/arch/arm/repack.cpp @@ -3226,6 +3226,316 @@ void ggml_gemm_q4_K_8x8_q8_K(int n, UNUSED(ncols_interleaved); UNUSED(blocklen); +#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) && defined(__ARM_FEATURE_MATMUL_INT8) + if (svcntb() * 8 == 256) { + constexpr int q8_k_blocklen = 4; + const svuint8_t m4b_1 = svdup_n_u8(0x0f); + // 8 accumulators: 2 row pairs × 4 col pairs + svfloat32_t acc_f32_01, acc_f32_23, acc_f32_45, acc_f32_67; + uint32_t idx_arr[8] = { 0, 2, 4, 6, 1, 3, 5, 7 }; + svbool_t pg = svptrue_pat_b32(SV_VL8); + svuint32_t idx = svld1(pg, idx_arr); + + static const uint32_t idx_data[8] = {0, 4, 2, 6, 1, 5, 3, 7}; + svuint32_t idx1 = svld1_u32(svptrue_b32(), idx_data); + + for (int y = 0; y < nr / q8_k_blocklen; y++) { + const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb); + + for (int x = 0; x < nc / ncols_interleaved; x++) { + const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb); + + acc_f32_01 = svdup_n_f32(0); + acc_f32_23 = svdup_n_f32(0); + acc_f32_45 = svdup_n_f32(0); + acc_f32_67 = svdup_n_f32(0); + + for (int b = 0; b < nb; b++) { + // bsums pairs belongs to the same q8_k subblock + // 64 elemnts loaded and made sum of 0-7 and 8-15 sum || 16-23 and 24 - 31 sum + const int16x8_t bsums[4]{ + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)), + vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)), + }; + + int32_t bsums_arr32[4][8]; + + for (int q8_row = 0; q8_row < 4; q8_row++) { + int16x8_t v16 = bsums[q8_row]; + + // low 4 + int32x4_t v32_lo = vmovl_s16(vget_low_s16(v16)); + vst1q_s32(&bsums_arr32[q8_row][0], v32_lo); + + // high 4 + int32x4_t v32_hi = vmovl_s16(vget_high_s16(v16)); + vst1q_s32(&bsums_arr32[q8_row][4], v32_hi); + } + + svint32_t sb_acc_0 = svdup_n_s32(0); + svint32_t sb_acc_2 = svdup_n_s32(0); + + svint32_t acc_00 = svdup_n_s32(0); + svint32_t acc_11 = svdup_n_s32(0); + svint32_t acc_22 = svdup_n_s32(0); + svint32_t acc_33 = svdup_n_s32(0); + svint32_t acc_44 = svdup_n_s32(0); + svint32_t acc_55 = svdup_n_s32(0); + svint32_t acc_66 = svdup_n_s32(0); + svint32_t acc_77 = svdup_n_s32(0); + + svint32_t bias_acc_00 = svdup_n_s32(0); + svint32_t bias_acc_22 = svdup_n_s32(0); + svint32_t bias_acc_44 = svdup_n_s32(0); + svint32_t bias_acc_66 = svdup_n_s32(0); + + for (int sb = 0; sb < QK_K / 64; sb++) { + // Need scales for the low and high nibbles + // 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total + svint32_t block_scale_0, block_scale_1, block_scale_2, block_scale_3; + svint32_t q4sb_mins_0, q4sb_mins_1; + { + // 2-superblock I am working on + const int offset = sb * 24 + 0 * 12; + const uint8_t * scales_in = &q4_ptr[b].scales[offset]; + + const int offset1 = sb * 24 + 12; + const uint8_t * scales_in1 = &q4_ptr[b].scales[offset1]; + + constexpr uint32_t kmask1 = 0x3f3f3f3f; + constexpr uint32_t kmask2 = 0x0f0f0f0f; + constexpr uint32_t kmask3 = 0x03030303; + constexpr uint8_t scales_size = 12; + + uint32_t sm[3]; + memcpy(sm, scales_in, scales_size); + + uint32_t sm1[3]; + memcpy(sm1, scales_in1, scales_size); + + const uint32_t mins_0_3 = sm[1] & kmask1; + const uint32_t mins_4_7 = ((sm[2] >> 4) & kmask2) | (((sm[1] >> 6) & kmask3) << 4); + + const uint32_t mins_0_3_1 = sm1[1] & kmask1; + const uint32_t mins_4_7_1 = ((sm1[2] >> 4) & kmask2) | (((sm1[1] >> 6) & kmask3) << 4); + + svuint32_t mins_u32_temp = svzip1_u32(svdup_n_u32(mins_0_3), svdup_n_u32(mins_4_7)); + svuint32_t mins_u32_temp_1 = svzip1_u32(svdup_n_u32(mins_0_3_1), svdup_n_u32(mins_4_7_1)); + + /* reinterpret u32 → u8 */ + svuint8_t mins_u8 = svreinterpret_u8_u32(mins_u32_temp); + svuint8_t mins_u8_1 = svreinterpret_u8_u32(mins_u32_temp_1); + + /* widen u8 → u16->u32 (lower half only) */ + svuint32_t mins_u16 = svunpklo_u32(svunpklo_u16(mins_u8)); + svuint32_t mins_u16_1 = svunpklo_u32(svunpklo_u16(mins_u8_1)); + + q4sb_mins_0 = svreinterpret_s32_u32(mins_u16); + q4sb_mins_1 = svreinterpret_s32_u32(mins_u16_1); + + uint32_t scales_u32_0 = sm[0] & kmask1; + uint32_t scales_u32_1 = (sm[2] & kmask2) | (((sm[0] >> 6) & kmask3) << 4); + uint32_t scales_u32_2 = sm1[0] & kmask1; + uint32_t scales_u32_3 = (sm1[2] & kmask2) | (((sm1[0] >> 6) & kmask3) << 4); + + svuint32_t S01 = svdup_n_u32(scales_u32_0); + svuint32_t S23 = svdup_n_u32(scales_u32_1); + svuint32_t R01 = svdup_n_u32(scales_u32_2); + svuint32_t R23 = svdup_n_u32(scales_u32_3); + + svint8_t S01_b = svreinterpret_s8_u32(S01); + svint8_t S23_b = svreinterpret_s8_u32(S23); + svint8_t R01_b = svreinterpret_s8_u32(R01); + svint8_t R23_b = svreinterpret_s8_u32(R23); + + svint32_t S01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S01_b, S01_b))); + svint32_t R01_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R01_b, R01_b))); + svint32_t S23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(S23_b, S23_b))); + svint32_t R23_d = svunpklo_s32(svunpklo_s16(svzip1_s8(R23_b, R23_b))); + + block_scale_0 = svtbl_s32(svzip1_s32(S01_d, R01_d), idx); + block_scale_1 = svtbl_s32(svzip2_s32(S01_d, R01_d), idx); + block_scale_2 = svtbl_s32(svzip1_s32(S23_d, R23_d), idx); + block_scale_3 = svtbl_s32(svzip2_s32(S23_d, R23_d), idx); + } + + const int8_t * q8_base_1 = q8_ptr[b].qs + sb * 256; + + // Load 32-byte per row pair, 1 subblock each time + // predicate for activating higher lanes for 16 int8 elements + const svbool_t ph16 = svptrue_pat_b8(SV_VL16); + // predicate for activating lower lanes for 16 int8 elements + const svbool_t pl16 = svnot_b_z(svptrue_b8(), ph16); + + svint8_t q8_qs_0 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 0), svld1_s8(pl16, q8_base_1 + 112)); + svint8_t q8_qs_2 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 32), svld1_s8(pl16, q8_base_1 + 144)); + svint8_t q8_qs_4 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 64), svld1_s8(pl16, q8_base_1 + 176)); + svint8_t q8_qs_6 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 96), svld1_s8(pl16, q8_base_1 + 208)); + + svint8_t q8_qs_1 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 16), svld1_s8(pl16, q8_base_1 + 128)); + svint8_t q8_qs_3 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 48), svld1_s8(pl16, q8_base_1 + 160)); + svint8_t q8_qs_5 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 80), svld1_s8(pl16, q8_base_1 + 192)); + svint8_t q8_qs_7 = svadd_s8_x(svptrue_b8(), svld1_s8(ph16, q8_base_1 + 112), svld1_s8(pl16, q8_base_1 + 224)); + + // Q4s columns iterated in pairs (01, 23, 45, 67) + for (int cp = 0; cp < ncols_interleaved / 2; cp++) { + + sb_acc_0 = svdup_n_s32(0); + sb_acc_2 = svdup_n_s32(0); + + svuint8_t q4_qs_cp_00 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 0); + svuint8_t q4_qs_cp_01 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 64); + svuint8_t q4_qs_cp_02 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 128); + svuint8_t q4_qs_cp_03 = svld1rq_u8(svptrue_b8(), q4_ptr[b].qs + sb * QK_K + 16 * cp + 192); + + svint8_t q4_nibbles_00 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_00, m4b_1), 4)); + svint8_t q4_nibbles_01 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_01, m4b_1), 4)); + svint8_t q4_nibbles_02 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_02, m4b_1), 4)); + svint8_t q4_nibbles_03 = svreinterpret_s8_u8(svlsr_n_u8_m(pl16, svand_u8_m(ph16, q4_qs_cp_03, m4b_1), 4)); + + sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_00, q8_qs_0); + sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_01, q8_qs_2); + + sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_02, q8_qs_4); + sb_acc_0 = svmmla_s32(sb_acc_0, q4_nibbles_03, q8_qs_6); + + sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_00, q8_qs_1); + sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_01, q8_qs_3); + + sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_02, q8_qs_5); + sb_acc_2 = svmmla_s32(sb_acc_2, q4_nibbles_03, q8_qs_7); + + if(cp == 0) { + acc_00 = svmla_s32_m(svptrue_b32(), acc_00, sb_acc_0, block_scale_0); + acc_44 = svmla_s32_m(svptrue_b32(), acc_44, sb_acc_2, block_scale_0); + } + if(cp == 1) { + acc_11 = svmla_s32_m(svptrue_b32(), acc_11, sb_acc_0, block_scale_1); + acc_55 = svmla_s32_m(svptrue_b32(), acc_55, sb_acc_2, block_scale_1); + } + if(cp == 2) { + acc_22 = svmla_s32_m(svptrue_b32(), acc_22, sb_acc_0, block_scale_2); + acc_66 = svmla_s32_m(svptrue_b32(), acc_66, sb_acc_2, block_scale_2); + } + if(cp == 3) { + acc_33 = svmla_s32_m(svptrue_b32(), acc_33, sb_acc_0, block_scale_3); + acc_77 = svmla_s32_m(svptrue_b32(), acc_77, sb_acc_2, block_scale_3); + } + } + + bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][0]), q4sb_mins_0); + bias_acc_00 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_00, svdup_n_s32(bsums_arr32[sb][1]), q4sb_mins_1); + + bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][2]), q4sb_mins_0); + bias_acc_22 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_22, svdup_n_s32(bsums_arr32[sb][3]), q4sb_mins_1); + + bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][4]), q4sb_mins_0); + bias_acc_44 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_44, svdup_n_s32(bsums_arr32[sb][5]), q4sb_mins_1); + + bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][6]), q4sb_mins_0); + bias_acc_66 = svmla_s32_m(svptrue_pat_b32(SV_VL8), bias_acc_66, svdup_n_s32(bsums_arr32[sb][7]), q4sb_mins_1); + } // for sb + + + acc_00 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_00, svext_s32(acc_00, acc_00, 4)); + acc_11 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_11, svext_s32(acc_11, acc_11, 4)); + acc_22 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_22, svext_s32(acc_22, acc_22, 4)); + acc_33 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_33, svext_s32(acc_33, acc_33, 4)); + acc_44 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_44, svext_s32(acc_44, acc_44, 4)); + acc_55 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_55, svext_s32(acc_55, acc_55, 4)); + acc_66 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_66, svext_s32(acc_66, acc_66, 4)); + acc_77 = svadd_s32_z(svptrue_pat_b32(SV_VL4), acc_77, svext_s32(acc_77, acc_77, 4)); + + svint32_t reorder_acc_01 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_00, acc_11), svtrn1_s32(acc_22, acc_33)), idx1); + svint32_t reorder_acc_23 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_00, acc_11), svtrn2_s32(acc_22, acc_33)), idx1); + + svint32_t reorder_acc_45 = svtbl_s32( svzip1_s32( svtrn1_s32(acc_44, acc_55), svtrn1_s32(acc_66, acc_77)), idx1); + svint32_t reorder_acc_67 = svtbl_s32( svzip1_s32( svtrn2_s32(acc_44, acc_55), svtrn2_s32(acc_66, acc_77)), idx1); + + // Broadcast q8 scalar + svfloat32_t q8_d = svdup_f32(q8_ptr[b].d[0]); + + svfloat32_t q4_dmin_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].dmin), svdup_f16(0))); + + svfloat32_t q4_d_temp = svcvt_f32_f16_x(svptrue_b32(), svzip1_f16( svld1_f16(svptrue_pat_b16(SV_VL8), (const __fp16 *)q4_ptr[b].d), svdup_f16(0))); + + svfloat32_t scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d); + svfloat32_t dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d); + + acc_f32_01 = svmls_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_00), dmins1); + acc_f32_01 = svmla_f32_m(svptrue_b32(), acc_f32_01, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_01), scale1); + + q8_d = svdup_f32(q8_ptr[b].d[1]); + + scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d); + dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d); + + acc_f32_23 = svmls_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_22), dmins1); + acc_f32_23 = svmla_f32_m(svptrue_b32(), acc_f32_23, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_23), scale1); + + q8_d = svdup_f32(q8_ptr[b].d[2]); + + + scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d); + dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d); + + acc_f32_45 = svmls_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_44), dmins1); + acc_f32_45 = svmla_f32_m(svptrue_b32(), acc_f32_45, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_45), scale1); + + q8_d = svdup_f32(q8_ptr[b].d[3]); + + scale1 = svmul_f32_x(svptrue_b32(), q4_d_temp, q8_d); + dmins1 = svmul_f32_x(svptrue_b32(), q4_dmin_temp, q8_d); + + acc_f32_67 = svmls_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), bias_acc_66), dmins1); + acc_f32_67 = svmla_f32_m(svptrue_b32(), acc_f32_67, svcvt_f32_s32_m(svdup_n_f32(0), svptrue_b32(), reorder_acc_67), scale1); + + } // for b + + // With the previous reorder, the tile is already in the correct memory layout. + // Predicate for exactly 4 lanes + svbool_t pg4 = svptrue_pat_b32(SV_VL4); + for (int i = 0; i < q8_k_blocklen; i++) { + int row = y * q8_k_blocklen + i; + for (int j = 0; j < 2; j++) { + int col = x * ncols_interleaved + j * 4; + int offset = row * bs + col; + + if (i == 0 && j == 0) { + // acc_f32_0 → lower half of acc_f32_01 + svst1_f32(pg4, s + offset, acc_f32_01); + } else if (i == 0 && j == 1) { + // acc_f32_1 → upper half of acc_f32_01 + svst1_f32(pg4, s + offset, svext_f32(acc_f32_01, acc_f32_01, 4)); + } else if (i == 1 && j == 0) { + // acc_f32_2 + svst1_f32(pg4, s + offset, acc_f32_23); + } else if (i == 1 && j == 1) { + // acc_f32_3 + svst1_f32(pg4, s + offset, svext_f32(acc_f32_23, acc_f32_23, 4)); + } else if (i == 2 && j == 0) { + // acc_f32_4 + svst1_f32(pg4, s + offset, acc_f32_45); + } else if (i == 2 && j == 1) { + // acc_f32_5 + svst1_f32(pg4, s + offset, svext_f32(acc_f32_45, acc_f32_45, 4)); + } else if (i == 3 && j == 0) { + // acc_f32_6 + svst1_f32(pg4, s + offset, acc_f32_67); + } else if (i == 3 && j == 1) { + // acc_f32_7 + svst1_f32(pg4, s + offset, svext_f32(acc_f32_67, acc_f32_67, 4)); + } + } + } + } // for x + } // for y + return; + } +#endif // SVE compile-time end + #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_MATMUL_INT8) constexpr int q8_k_blocklen = 4; const uint8x16_t m4b = vdupq_n_u8(0x0f); diff --git a/ggml/src/ggml-cpu/common.h b/ggml/src/ggml-cpu/common.h index 1057b5bb15..abbadc359c 100644 --- a/ggml/src/ggml-cpu/common.h +++ b/ggml/src/ggml-cpu/common.h @@ -6,8 +6,8 @@ #include "ggml-impl.h" #include "simd-mappings.h" -#define GGML_FA_TILE_Q 32 -#define GGML_FA_TILE_KV 16 +#define GGML_FA_TILE_Q 64 +#define GGML_FA_TILE_KV 64 #ifdef __cplusplus diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c index b003fe13fd..64eb01a4e1 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -2874,8 +2874,8 @@ struct ggml_cplan ggml_graph_plan( const int64_t DV = node->src[2]->ne[0]; // Tiled flash attention scratch (tile sizes defined in common.h) - // Per-thread: Q_q + KQ + mask + VKQ32 + V32 + padding - size_t prefill = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks; + // Per-thread: Q_q + KQ + mask + VKQ32 + V32 + K_f32 + padding + size_t prefill = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV + GGML_FA_TILE_KV*DK)*n_tasks; // Decode path: n_kv_chunks = n_tasks (one chunk per thread) // Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ @@ -2947,7 +2947,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { /*.use_ref =*/ cplan->use_ref, }; - GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d \n", state->ith, cplan, state->last_graph); +#ifdef GGML_USE_OPENMP + GGML_PRINT_DEBUG("thread #%d compute-start cplan %p\n", state->ith, (const void *)cplan); +#else + GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d\n", state->ith, (const void *)cplan, state->last_graph); +#endif for (int node_n = 0; node_n < cgraph->n_nodes && atomic_load_explicit(&tp->abort, memory_order_relaxed) != node_n; node_n++) { struct ggml_tensor * node = cgraph->nodes[node_n]; @@ -2974,7 +2978,11 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { } } - GGML_PRINT_DEBUG("thread #%d compute-done cplan %p last-graph %d \n", state->ith, cplan, state->last_graph); +#ifdef GGML_USE_OPENMP + GGML_PRINT_DEBUG("thread #%d compute-done cplan %p\n", state->ith, (const void *)cplan); +#else + GGML_PRINT_DEBUG("thread #%d compute-done cplan %p last-graph %d\n", state->ith, (const void *)cplan, state->last_graph); +#endif ggml_barrier(state->threadpool); diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index ed45350207..b7a70e06f1 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -3,6 +3,7 @@ #include "ggml-cpu.h" #include "ggml-impl.h" #include "binary-ops.h" +#include "simd-gemm.h" #include "ggml.h" #include "unary-ops.h" #include "vec.h" @@ -2096,10 +2097,14 @@ static void ggml_compute_forward_gelu_f32( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2113,10 +2118,14 @@ static void ggml_compute_forward_gelu_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2135,10 +2144,14 @@ static void ggml_compute_forward_gelu_f16( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2152,10 +2165,14 @@ static void ggml_compute_forward_gelu_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2276,10 +2293,14 @@ static void ggml_compute_forward_gelu_erf_f32( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2293,10 +2314,14 @@ static void ggml_compute_forward_gelu_erf_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_erf_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2315,10 +2340,14 @@ static void ggml_compute_forward_gelu_erf_f16( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2332,10 +2361,14 @@ static void ggml_compute_forward_gelu_erf_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_erf_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2379,10 +2412,14 @@ static void ggml_compute_forward_gelu_quick_f32( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2396,10 +2433,14 @@ static void ggml_compute_forward_gelu_quick_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_quick_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2418,10 +2459,14 @@ static void ggml_compute_forward_gelu_quick_f16( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2435,10 +2480,14 @@ static void ggml_compute_forward_gelu_quick_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_gelu_quick_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2482,10 +2531,14 @@ static void ggml_compute_forward_silu_f32( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2499,10 +2552,14 @@ static void ggml_compute_forward_silu_f32( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_silu_f32(nc, - (float *) ((char *) dst->data + i1*( dst->nb[1])), - (float *) ((char *) src0->data + i1*(src0->nb[1]))); + (float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -2521,10 +2578,14 @@ static void ggml_compute_forward_silu_f16( const ggml_tensor * src0 = dst->src[0]; - assert(ggml_is_contiguous_1(src0)); - assert(ggml_is_contiguous_1(dst)); + assert(ggml_is_contiguous_rows(src0)); assert(ggml_are_same_shape(src0, dst)); + GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne) + GGML_TENSOR_LOCALS(size_t, nb0, src0, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + const int ith = params->ith; const int nth = params->nth; @@ -2538,10 +2599,14 @@ static void ggml_compute_forward_silu_f16( const int ir0 = dr*ith; const int ir1 = MIN(ir0 + dr, nr); - for (int i1 = ir0; i1 < ir1; i1++) { + for (int ir = ir0; ir < ir1; ++ir) { + const int i3 = ir/(ne02*ne01); + const int i2 = (ir - i3*ne02*ne01)/ne01; + const int i1 = (ir - i3*ne02*ne01 - i2*ne01); + ggml_vec_silu_f16(nc, - (ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])), - (ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1]))); + (ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1), + (ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01)); #ifndef NDEBUG for (int k = 0; k < nc; k++) { @@ -8325,10 +8390,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled( GGML_ASSERT(k->type == v->type); const ggml_type kv_type = k->type; - const auto * kv_type_traits_cpu = ggml_get_type_traits_cpu(kv_type); - const ggml_from_float_t kv_from_float = kv_type_traits_cpu->from_float; - const ggml_vec_dot_t kv_vec_dot = kv_type_traits_cpu->vec_dot; - const size_t kv_type_size = ggml_type_size(kv_type); // broadcast factors const int64_t rk2 = neq2/nek2; @@ -8360,8 +8421,6 @@ static void ggml_compute_forward_flash_attn_ext_tiled( static constexpr int Q_TILE_SZ = ggml_fa_tile_config::Q; static constexpr int KV_TILE_SZ = ggml_fa_tile_config::KV; - GGML_ASSERT(nek1 % KV_TILE_SZ == 0 && "KV sequence length must be divisible by KV_TILE_SZ"); - int ir = ir0; while (ir < ir1) { // q indices for the start of this tile @@ -8388,18 +8447,20 @@ static void ggml_compute_forward_flash_attn_ext_tiled( } // Per-thread scratch layout: - // Q_q: Q_TILE_SZ * DK (converted Q tile in KV type) + // Q_q: Q_TILE_SZ * DK (converted Q tile — F32 for GEMM, KV type for scalar) // KQ: Q_TILE_SZ * KV_TILE_SZ (attention scores in float) // mask: Q_TILE_SZ * KV_TILE_SZ (mask in float) // VKQ32: Q_TILE_SZ * DV (FP32 output accumulator) - // V32: KV_TILE_SZ * DV (F32 buffer for V tile - used for f166 conversion) - float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + CACHE_LINE_SIZE_F32); + // V32: KV_TILE_SZ * DV (F32 buffer for V tile) + // K_f32: KV_TILE_SZ * DK (F32 buffer for K tile — GEMM path) + float * base = (float *) params->wdata + ith*(Q_TILE_SZ*DK + 2*Q_TILE_SZ*KV_TILE_SZ + Q_TILE_SZ*DV + KV_TILE_SZ*DV + KV_TILE_SZ*DK + CACHE_LINE_SIZE_F32); void * Q_q = base; float * KQ = (float *)((char *)base + Q_TILE_SZ * DK * sizeof(float)); float * mask32 = KQ + Q_TILE_SZ * KV_TILE_SZ; float * VKQ32 = mask32 + Q_TILE_SZ * KV_TILE_SZ; - float * V32 = VKQ32 + Q_TILE_SZ * DV; // F32 buffer for V tile + float * V32 = VKQ32 + Q_TILE_SZ * DV; + float * K_f32 = V32 + KV_TILE_SZ * DV; memset(VKQ32, 0, Q_TILE_SZ * DV * sizeof(float)); memset(mask32, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float)); @@ -8412,28 +8473,38 @@ static void ggml_compute_forward_flash_attn_ext_tiled( const int iv3 = iq3 / rv3; const int iv2 = iq2 / rv2; - for (int tq = 0; tq < tile_rows; tq++) { - const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3)); - kv_from_float(pq, (char *)Q_q + tq * DK * kv_type_size, DK); - } - // Zero-pad remaining rows - for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) { - memset((char *)Q_q + tq * DK * kv_type_size, 0, DK * kv_type_size); + { + float * Q_f32 = (float *)Q_q; + for (int tq = 0; tq < tile_rows; tq++) { + const float * pq = (const float *) ((char *) q->data + ((iq1 + tq)*nbq1 + iq2*nbq2 + iq3*nbq3)); + memcpy(Q_f32 + tq * DK, pq, DK * sizeof(float)); + } + for (int tq = tile_rows; tq < Q_TILE_SZ; tq++) { + memset(Q_f32 + tq * DK, 0, DK * sizeof(float)); + } } + memset(K_f32, 0, DK * KV_TILE_SZ * sizeof(float)); + memset(V32, 0, KV_TILE_SZ * DV * sizeof(float)); + for (int64_t ic = 0; ic < nek1; ic += KV_TILE_SZ) { + const int kv_tile = (int)std::min((int64_t)KV_TILE_SZ, nek1 - ic); // skip the tile entirely if all the masks are -inf if (mask) { bool can_skip = true; for (int tq = 0; tq < tile_rows; tq++) { const ggml_fp16_t * mp_row = (const ggml_fp16_t *)((const char *) mask->data + (iq1 + tq)*mask->nb[1] + (iq2%mask->ne[2])*mask->nb[2] + (iq3%mask->ne[3])*mask->nb[3]); - for (int tk = 0; tk < KV_TILE_SZ; tk++) { + for (int tk = 0; tk < kv_tile; tk++) { mask32[tq * KV_TILE_SZ + tk] = slope * GGML_CPU_FP16_TO_FP32(mp_row[ic + tk]); if (mask32[tq * KV_TILE_SZ + tk] != -INFINITY) { can_skip = false; } } + // Pad remaining mask entries with -inf + for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) { + mask32[tq * KV_TILE_SZ + tk] = -INFINITY; + } } if (can_skip) { @@ -8441,13 +8512,32 @@ static void ggml_compute_forward_flash_attn_ext_tiled( } } - for (int tq = 0; tq < Q_TILE_SZ; tq++) { - const void * q_row = (const char *)Q_q + tq * DK * kv_type_size; - for (int tk = 0; tk < KV_TILE_SZ; tk++) { - const void * k_row = (const char *) k->data + ((ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3); - float s; - kv_vec_dot(DK, &s, 0, k_row, 0, q_row, 0, 1); - KQ[tq * KV_TILE_SZ + tk] = s * scale; + // Pack K tile transposed: K_f32[dk][kv] so KV_TILE is contiguous (SIMD dim) + // Zero-pad the last tile so the GEMM always operates on KV_TILE_SZ columns + for (int tk = 0; tk < kv_tile; tk++) { + const char * k_data = (const char *)k->data + (ic + tk)*nbk1 + ik2*nbk2 + ik3*nbk3; + if (kv_type == GGML_TYPE_F16) { + const ggml_fp16_t * k_f16 = (const ggml_fp16_t *)k_data; + for (int64_t dk = 0; dk < DK; dk++) { + K_f32[dk * KV_TILE_SZ + tk] = GGML_CPU_FP16_TO_FP32(k_f16[dk]); + } + } else { + const float * k_f32_src = (const float *)k_data; + for (int64_t dk = 0; dk < DK; dk++) { + K_f32[dk * KV_TILE_SZ + tk] = k_f32_src[dk]; + } + } + } + memset(KQ, 0, Q_TILE_SZ * KV_TILE_SZ * sizeof(float)); + simd_gemm(KQ, (const float *)Q_q, K_f32, Q_TILE_SZ, DK, KV_TILE_SZ); + ggml_vec_scale_f32(Q_TILE_SZ * KV_TILE_SZ, KQ, scale); + + // Set padded KQ entries to -inf so softmax gives them zero weight + if (kv_tile < KV_TILE_SZ) { + for (int tq = 0; tq < Q_TILE_SZ; tq++) { + for (int tk = kv_tile; tk < KV_TILE_SZ; tk++) { + KQ[tq * KV_TILE_SZ + tk] = -INFINITY; + } } } @@ -8487,33 +8577,22 @@ static void ggml_compute_forward_flash_attn_ext_tiled( S[tq] += ggml_vec_soft_max_f32(KV_TILE_SZ, kq_row, kq_row, Mnew); } - // Convert V tile to F32 first (if F16), then do MAD - // On x86, ggml_vec_mad_f16 internall converts F16<->F32 on every load/store, so pre-converting is faster. - // TODO: on ARM, native f16 should be faster - if (kv_type == GGML_TYPE_F16) { - for (int tk = 0; tk < KV_TILE_SZ; tk++) { - const ggml_fp16_t * v_row = (const ggml_fp16_t *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3)); - ggml_fp16_to_fp32_row(v_row, V32 + tk * DV, DV); - } - for (int tq = 0; tq < Q_TILE_SZ; tq++) { - if (skip[tq]) continue; - float * vkq_row = VKQ32 + tq * DV; - for (int tk = 0; tk < KV_TILE_SZ; tk++) { - const float p = KQ[tq * KV_TILE_SZ + tk]; - ggml_vec_mad_f32(DV, vkq_row, V32 + tk * DV, p); - } - } - } else { - for (int tq = 0; tq < Q_TILE_SZ; tq++) { - if (skip[tq]) continue; - float * vkq_row = VKQ32 + tq * DV; - for (int tk = 0; tk < KV_TILE_SZ; tk++) { - const float p = KQ[tq * KV_TILE_SZ + tk]; - const float * v_row = (const float *)((const char *) v->data + ((ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3)); - ggml_vec_mad_f32(DV, vkq_row, v_row, p); - } + // V accumulation: VKQ32 += softmax(KQ) * V + // Pack V tile to contiguous F32, zero-padded + for (int tk = 0; tk < kv_tile; tk++) { + const char * v_data = (const char *)v->data + (ic + tk)*nbv1 + iv2*nbv2 + iv3*nbv3; + if (kv_type == GGML_TYPE_F16) { + ggml_fp16_to_fp32_row((const ggml_fp16_t *)v_data, V32 + tk * DV, DV); + } else { + memcpy(V32 + tk * DV, v_data, DV * sizeof(float)); } } + for (int tq = 0; tq < Q_TILE_SZ; tq++) { + if (skip[tq]) { + memset(KQ + tq * KV_TILE_SZ, 0, KV_TILE_SZ * sizeof(float)); + } + } + simd_gemm(VKQ32, KQ, V32, Q_TILE_SZ, KV_TILE_SZ, DV); } // sinks (apply only to valid rows in the tile) @@ -8730,15 +8809,15 @@ static void ggml_compute_forward_flash_attn_ext_f16( const int64_t dr = (nr + nchunk - 1) / nchunk; - static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV; static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q; - const bool use_tiled = !use_ref && + bool use_tiled = !use_ref && (q->type == GGML_TYPE_F32 && kv_is_f32_or_f16 && k->type == v->type && - nek1 % KV_TILE_SZ == 0 && neq1 >= Q_TILE_SZ); - +#ifdef GGML_SIMD + use_tiled &= (DV % GGML_F32_EPR == 0); +#endif int current_chunk = ith; while (current_chunk < nchunk) { diff --git a/ggml/src/ggml-cpu/repack.cpp b/ggml/src/ggml-cpu/repack.cpp index 4cb7cdeb07..f94426ddd7 100644 --- a/ggml/src/ggml-cpu/repack.cpp +++ b/ggml/src/ggml-cpu/repack.cpp @@ -1916,9 +1916,10 @@ static block_q4_Kx8 make_block_q4_Kx8(block_q4_K * in, unsigned int blck_size_in int src_offset = (i / 8) * blck_size_interleave; int dst_offset = i * blck_size_interleave; + // buffer large enough for the max interleave block size (8 bytes) uint64_t elems; - memcpy(&elems, &in[src_id].qs[src_offset], sizeof(uint64_t)); - memcpy(&out.qs[dst_offset], &elems, sizeof(uint64_t)); + memcpy(&elems, &in[src_id].qs[src_offset], blck_size_interleave); + memcpy(&out.qs[dst_offset], &elems, blck_size_interleave); } // The below logic is designed so as to unpack and rearrange scales and mins values in Q4_K diff --git a/ggml/src/ggml-cpu/simd-gemm.h b/ggml/src/ggml-cpu/simd-gemm.h new file mode 100644 index 0000000000..78d663e593 --- /dev/null +++ b/ggml/src/ggml-cpu/simd-gemm.h @@ -0,0 +1,136 @@ +#pragma once + +// Computes C[M x N] += A[M x K] * B[K x N] + +#include "simd-mappings.h" + +// TODO: add support for sizeless vector types +#if defined(GGML_SIMD) && !defined(__ARM_FEATURE_SVE) && !defined(__riscv_v_intrinsic) + +// TODO: untested on avx512 +// These are in units of GGML_F32_EPR +#if defined(__AVX512F__) || defined (__ARM_NEON__) + static constexpr int GEMM_RM = 4; + static constexpr int GEMM_RN = 4; // 16+4+1 = 25/32 +#elif defined(__AVX2__) || defined(__AVX__) + static constexpr int GEMM_RM = 6; + static constexpr int GEMM_RN = 2; // 12+2+1 = 15/16 +#else + static constexpr int GEMM_RM = 2; + static constexpr int GEMM_RN = 2; +#endif + +template +static inline void simd_gemm_ukernel( + float * GGML_RESTRICT C, + const float * GGML_RESTRICT A, + const float * GGML_RESTRICT B, + int K, int N) +{ + static constexpr int KN = GGML_F32_EPR; + + GGML_F32_VEC acc[RM][RN]; + for (int64_t i = 0; i < RM; i++) { + for (int r = 0; r < RN; r++) { + acc[i][r] = GGML_F32_VEC_LOAD(C + i * N + r * KN); + } + } + + for (int64_t kk = 0; kk < K; kk++) { + GGML_F32_VEC Bv[RN]; + for (int r = 0; r < RN; r++) { + Bv[r] = GGML_F32_VEC_LOAD(B + kk * N + r * KN); + } + for (int64_t i = 0; i < RM; i++) { + GGML_F32_VEC p = GGML_F32_VEC_SET1(A[i * K + kk]); + for (int r = 0; r < RN; r++) { + acc[i][r] = GGML_F32_VEC_FMA(acc[i][r], Bv[r], p); + } + } + } + + for (int64_t i = 0; i < RM; i++) { + for (int r = 0; r < RN; r++) { + GGML_F32_VEC_STORE(C + i * N + r * KN, acc[i][r]); + } + } +} + +// C[M x N] += A[M x K] * B[K x N] +static void simd_gemm( + float * GGML_RESTRICT C, + const float * GGML_RESTRICT A, + const float * GGML_RESTRICT B, + int M, int K, int N) +{ + static constexpr int KN = GGML_F32_EPR; + + int64_t ii = 0; + for (; ii + GEMM_RM <= M; ii += GEMM_RM) { + int64_t jj = 0; + for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) { + simd_gemm_ukernel(C + jj, A, B + jj, K, N); + } + for (; jj + KN <= N; jj += KN) { + simd_gemm_ukernel(C + jj, A, B + jj, K, N); + } + for (; jj < N; jj++) { + for (int64_t i = 0; i < GEMM_RM; i++) { + float a = C[i * N + jj]; + for (int64_t kk = 0; kk < K; kk++) { + a += A[i + kk] * B[kk * N + jj]; + } + C[i * N + jj] = a; + } + } + + A += GEMM_RM * K; + C += GEMM_RM * N; + } + + // Tail rows: one at a time + for (; ii < M; ii++) { + int64_t jj = 0; + for (; jj + GEMM_RN * KN <= N; jj += GEMM_RN * KN) { + simd_gemm_ukernel<1, GEMM_RN>(C + jj, A, B + jj, K, N); + } + for (; jj + KN <= N; jj += KN) { + simd_gemm_ukernel<1, 1>(C + jj, A, B + jj, K, N); + } + for (; jj < N; jj++) { + float a = C[jj]; + for (int64_t kk = 0; kk < K; kk++) { + a += A[kk] * B[kk * N + jj]; + } + C[jj] = a; + } + + A += K; + C += N; + } +} + +#if defined(__GNUC__) && !defined(__clang__) +#pragma GCC diagnostic pop +#endif + +#else // scalar path + +static void simd_gemm( + float * GGML_RESTRICT C, + const float * GGML_RESTRICT A, + const float * GGML_RESTRICT B, + int M, int K, int N) +{ + for (int64_t i = 0; i < M; i++) { + for (int64_t j = 0; j < N; j++) { + float sum = C[i * N + j]; + for (int64_t kk = 0; kk < K; kk++) { + sum += A[i * K + kk] * B[kk * N + j]; + } + C[i * N + j] = sum; + } + } +} + +#endif // GGML_SIMD diff --git a/ggml/src/ggml-cpu/simd-mappings.h b/ggml/src/ggml-cpu/simd-mappings.h index 630e506542..22de55700d 100644 --- a/ggml/src/ggml-cpu/simd-mappings.h +++ b/ggml/src/ggml-cpu/simd-mappings.h @@ -1160,6 +1160,14 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) { float32x4_t tmp = x[0] + vec_reve(x[0]); \ res = tmp[0] + tmp[1]; \ } +#define GGML_F32x4_REDUCE_4(res, s0, s1, s2, s3) \ +{ \ + float32x4_t v = vec_add(vec_add(s0, s1), \ + vec_add(s2, s3)); \ + v = vec_add(v, vec_sld(v, v, 8)); \ + v = vec_add(v, vec_sld(v, v, 4)); \ + res += (ggml_float)vec_extract(v, 0); \ +} #define GGML_F32_VEC GGML_F32x4 #define GGML_F32_VEC_ZERO GGML_F32x4_ZERO @@ -1209,6 +1217,24 @@ static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) { #define GGML_F16_VEC_MUL GGML_F32x4_MUL #define GGML_F16_VEC_REDUCE GGML_F32x4_REDUCE +// BF16 s390x +#define GGML_BF16_STEP 16 +#define GGML_BF16_EPR 8 + +#define GGML_BF16x8 __vector unsigned short +#define GGML_BF16x8_ZERO vec_splats((unsigned short)0) +#define GGML_BF16x8_LOAD(p) vec_xl(0, (const unsigned short *)(p)) + +#define GGML_BF16_VEC GGML_BF16x8 +#define GGML_BF16_VEC_ZERO GGML_BF16x8_ZERO +#define GGML_BF16_VEC_LOAD GGML_BF16x8_LOAD +#define GGML_BF16_TO_F32_LO(v) ((float32x4_t) vec_mergel((v), GGML_BF16_VEC_ZERO)) +#define GGML_BF16_TO_F32_HI(v) ((float32x4_t) vec_mergeh((v), GGML_BF16_VEC_ZERO)) +#define GGML_BF16_FMA_LO(acc, x, y) \ + (acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_LO(x), GGML_BF16_TO_F32_LO(y)) +#define GGML_BF16_FMA_HI(acc, x, y) \ + (acc) = GGML_F32x4_FMA((acc), GGML_BF16_TO_F32_HI(x), GGML_BF16_TO_F32_HI(y)) + #elif defined(__riscv_v_intrinsic) // compatible with vlen >= 128 diff --git a/ggml/src/ggml-cpu/unary-ops.cpp b/ggml/src/ggml-cpu/unary-ops.cpp index 1d9873ad0f..1d8344436f 100644 --- a/ggml/src/ggml-cpu/unary-ops.cpp +++ b/ggml/src/ggml-cpu/unary-ops.cpp @@ -111,7 +111,7 @@ template static void apply_unary_op(const ggml_compute_params * params, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; - GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst)); + GGML_ASSERT(ggml_is_contiguous_rows(src0) && ggml_is_contiguous_rows(dst) && ggml_are_same_shape(src0, dst)); GGML_TENSOR_UNARY_OP_LOCALS diff --git a/ggml/src/ggml-cpu/vec.cpp b/ggml/src/ggml-cpu/vec.cpp index 8708cd4e92..d0e4001338 100644 --- a/ggml/src/ggml-cpu/vec.cpp +++ b/ggml/src/ggml-cpu/vec.cpp @@ -236,8 +236,7 @@ void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * vfloat32m1_t redsum = __riscv_vfredusum_vs_f32m4_f32m1(vsum0, __riscv_vfmv_v_f_f32m1(0.0f, 1), vl); sumf += __riscv_vfmv_f_s_f32m1_f32(redsum); -#endif -#if defined(__POWER9_VECTOR__) +#elif defined(__POWER9_VECTOR__) || defined(__VXE__) || defined(__VXE2__) const int np = (n & ~(GGML_BF16_STEP - 1)); if (np > 0) { GGML_F32_VEC sum[4] = {GGML_F32_VEC_ZERO}; diff --git a/ggml/src/ggml-cuda/convert.cu b/ggml/src/ggml-cuda/convert.cu index ba3d4eeb88..09b6d5db6a 100644 --- a/ggml/src/ggml-cuda/convert.cu +++ b/ggml/src/ggml-cuda/convert.cu @@ -7,7 +7,8 @@ template static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y, - const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t ne00, const int64_t ne01, + const int64_t ne0203, const uint3 ne02, const int64_t s01, const int64_t s02, const int64_t s03) { const int64_t i00 = 2 * (int64_t(blockDim.x)*blockIdx.x + threadIdx.x); @@ -16,23 +17,27 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __ } const int64_t i01 = blockIdx.y; - const int64_t i02 = blockIdx.z % ne02; - const int64_t i03 = blockIdx.z / ne02; - const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01; + for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) { + const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02); + const int64_t i02 = dm.y; + const int64_t i03 = dm.x; - const int64_t ib = ibx0 + i00/qk; // block index - const int64_t iqs = (i00%qk)/qr; // quant index - const int64_t iybs = i00 - i00%qk; // y block start index - const int64_t y_offset = qr == 1 ? 1 : qk/2; + const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01; - // dequantize - float2 v; - dequantize_kernel(vx, ib, iqs, v); + const int64_t ib = ibx0 + i00/qk; // block index + const int64_t iqs = (i00%qk)/qr; // quant index + const int64_t iybs = i00 - i00%qk; // y block start index + const int64_t y_offset = qr == 1 ? 1 : qk/2; - const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs; - y[iy0 + 0] = ggml_cuda_cast(v.x); - y[iy0 + y_offset] = ggml_cuda_cast(v.y); + // dequantize + float2 v; + dequantize_kernel(vx, ib, iqs, v); + + const int64_t iy0 = (i0203*ne01 + i01)*ne00 + iybs + iqs; + y[iy0 + 0] = ggml_cuda_cast(v.x); + y[iy0 + y_offset] = ggml_cuda_cast(v.y); + } } template @@ -485,9 +490,11 @@ template static void dequantize_block_cuda(const void * vx, dst_t * y, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) { - const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, ne02*ne03); + const int64_t ne0203 = ne02*ne03; + const uint3 ne02_fdv = init_fastdiv_values(ne02); + const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, (int)std::min(ne0203, (int64_t)65535)); dequantize_block<<>> - (vx, y, ne00, ne01, ne02, s01, s02, s03); + (vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03); } template @@ -612,7 +619,8 @@ static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t template static __global__ void convert_unary( - const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02, + const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, + const int64_t ne0203, const uint3 ne02, const int64_t s01, const int64_t s02, const int64_t s03) { const int64_t i00 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; @@ -621,23 +629,29 @@ static __global__ void convert_unary( } const int64_t i01 = blockIdx.y; - const int64_t i02 = blockIdx.z % ne02; - const int64_t i03 = blockIdx.z / ne02; const src_t * x = (const src_t *) vx; - const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00; - const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00; - y[iy] = ggml_cuda_cast(x[ix]); + for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) { + const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02); + const int64_t i02 = dm.y; + const int64_t i03 = dm.x; + + const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00; + const int64_t iy = (i0203*ne01 + i01)*ne00 + i00; + y[iy] = ggml_cuda_cast(x[ix]); + } } template static void convert_unary_cuda(const void * vx, dst_t * y, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) { - const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, ne02*ne03); + const int64_t ne0203 = ne02*ne03; + const uint3 ne02_fdv = init_fastdiv_values(ne02); + const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, (int)std::min(ne0203, (int64_t)65535)); convert_unary<<>> - (vx, y, ne00, ne01, ne02, s01, s02, s03); + (vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03); } template diff --git a/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 8694fd06c7..f19defbff9 100644 --- a/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -63,11 +63,19 @@ static __global__ void flash_attn_ext_f16( constexpr int frag_m = ncols == 8 ? 32 : 16; constexpr int frag_n = ncols == 8 ? 8 : 16; static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0."); +#if defined(GGML_USE_HIP) && HIP_VERSION >= 60500000 + typedef wmma::fragment frag_a_K; + typedef wmma::fragment frag_a_V; + typedef wmma::fragment frag_b; + typedef wmma::fragment frag_c_KQ; + typedef wmma::fragment frag_c_VKQ; +#else typedef wmma::fragment frag_a_K; typedef wmma::fragment frag_a_V; typedef wmma::fragment frag_b; typedef wmma::fragment frag_c_KQ; typedef wmma::fragment frag_c_VKQ; +#endif constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel. constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy. @@ -126,6 +134,19 @@ static __global__ void flash_attn_ext_f16( __shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice. half2 * VKQ2 = (half2 *) VKQ; + +#if defined(GGML_USE_HIP) && HIP_VERSION >= 60500000 + const _Float16 * K_h_f16 = reinterpret_cast(K_h); + const _Float16 * V_h_f16 = reinterpret_cast(V_h); + _Float16 * KQ_f16 = reinterpret_cast<_Float16 *>(KQ); + _Float16 * VKQ_f16 = reinterpret_cast<_Float16 *>(VKQ); +#else + const half * K_h_f16 = K_h; + const half * V_h_f16 = V_h; + half * KQ_f16 = KQ; + half * VKQ_f16 = VKQ; +#endif + #pragma unroll for (int j0 = 0; j0 < ncols; j0 += nwarps) { const int j = j0 + threadIdx.y; @@ -160,7 +181,7 @@ static __global__ void flash_attn_ext_f16( for (int i0 = 0; i0 < D; i0 += 16) { #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { - wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded); + wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ_f16 + j0*D_padded + i0, D_padded); } } @@ -180,7 +201,7 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) { frag_a_K K_a; - wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); + wmma::load_matrix_sync(K_a, K_h_f16 + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]); @@ -310,7 +331,7 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + (threadIdx.y % VKQ_ratio)*16; wmma::load_matrix_sync( KQ_b[k0/(VKQ_ratio*16)][j0/frag_n], - KQ + j0*(kqar*kqs_padded) + k, + KQ_f16 + j0*(kqar*kqs_padded) + k, kqar*kqs_padded); } } @@ -328,7 +349,7 @@ static __global__ void flash_attn_ext_f16( const int k = k0 + (threadIdx.y % VKQ_ratio)*16; frag_a_V v_a; - wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); + wmma::load_matrix_sync(v_a, V_h_f16 + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV); #pragma unroll for (int j = 0; j < ncols/frag_n; ++j) { wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]); @@ -344,7 +365,7 @@ static __global__ void flash_attn_ext_f16( #pragma unroll for (int j0 = 0; j0 < ncols; j0 += frag_n) { wmma::store_matrix_sync( - KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), + KQ_f16 + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio), VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n], D_padded, wmma::mem_col_major); } diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index b163468789..bed5c71a1b 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -2872,6 +2872,7 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { const std::string ffn_moe_down_bias_prefix = "ffn_moe_down_biased"; const std::string nemotron_h_block_out_prefix = "nemotron_h_block_out"; const std::string mamba2_y_add_d_prefix = "mamba2_y_add_d"; + const std::string delta_net_prefix = "dnet_add"; for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; @@ -2902,7 +2903,8 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { strncmp(node->name, ffn_moe_up_bias_prefix.c_str(), ffn_moe_up_bias_prefix.size()) != 0 && strncmp(node->name, ffn_moe_down_bias_prefix.c_str(), ffn_moe_down_bias_prefix.size()) != 0 && strncmp(node->name, nemotron_h_block_out_prefix.c_str(), nemotron_h_block_out_prefix.size()) != 0 && - strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0) { + strncmp(node->name, mamba2_y_add_d_prefix.c_str(), mamba2_y_add_d_prefix.size()) != 0 && + strncmp(node->name, delta_net_prefix.c_str(), delta_net_prefix.size()) != 0) { // disable CUDA graphs for batch size > 1 for now while excluding the matrix-matrix addition as part of Gemma3n's `project_per_layer_input` operation // by means of matching node names. See // https://github.com/ggml-org/llama.cpp/blob/f9a31eea06a859e34cecb88b4d020c7f03d86cc4/src/llama-model.cpp#L10199-L10241 and @@ -3640,11 +3642,13 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud n_fuse++; if (n_fuse > 1) { + ggml_tensor fused_add_node; + memcpy(&fused_add_node, node, sizeof(ggml_tensor)); for (int j = 0; j < n_fuse - 1; ++j) { - node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1]; + fused_add_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1]; } - cgraph->nodes[i + n_fuse - 1]->data = node->data; - ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse); + fused_add_node.data = cgraph->nodes[i + n_fuse - 1]->data; + ggml_cuda_op_fused_add(*cuda_ctx, &fused_add_node, n_fuse); i += n_fuse - 1; continue; @@ -4542,6 +4546,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_UNARY_OP_CEIL: case GGML_UNARY_OP_ROUND: case GGML_UNARY_OP_TRUNC: + // TODO: should become: + //return ggml_is_contiguous_rows(op->src[0]); return ggml_is_contiguous(op->src[0]); default: return false; @@ -4820,8 +4826,11 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_TRANSPOSE_2D: case GGML_OP_POOL_2D: - case GGML_OP_ACC: return true; + case GGML_OP_ACC: + // TODO: extend support like so: + //return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]); + return ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); case GGML_OP_SUM: return ggml_is_contiguous_rows(op->src[0]); case GGML_OP_TOP_K: diff --git a/ggml/src/ggml-cuda/mmq.cuh b/ggml/src/ggml-cuda/mmq.cuh index f80f98cda2..255e59f6fc 100644 --- a/ggml/src/ggml-cuda/mmq.cuh +++ b/ggml/src/ggml-cuda/mmq.cuh @@ -2715,14 +2715,14 @@ template static __device__ __forceinline__ void loa #pragma unroll for (int l = 0; l < QR2_XXS; ++l) { - const int * grid_pos = (const int *) (iq2xxs_grid + aux8[l]); - const int signs_packed = ksigns_iq2xs[(aux32 >> (7*l)) & 0x7F]; + const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[l]]; + const uint32_t signs = unpack_ksigns(aux32 >> (7 * l)); - const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000); - const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0); + const int signs0 = __vcmpne4(signs & 0x08040201, 0); + const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0); - const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); - const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); + const int signs1 = __vcmpne4(signs & 0x80402010, 0); + const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid0; @@ -2733,12 +2733,12 @@ template static __device__ __forceinline__ void loa #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } - const int ls = aux32 >> 28; + const int ls = aux32 >> 27 | 1; // (scale * 2 + 1) const float d = bxi->d; #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) - x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = (ls*d + d/2)/4; + x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4 #else - x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = (ls*d + d/2)/4; + x_df[i*(MMQ_TILE_NE_K/4) + i/4 + kqsx] = d * ls / 8; // (d * scale + d / 2) / 4 #endif // defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) } } @@ -2776,11 +2776,14 @@ template static __device__ __forceinline__ void loa #pragma unroll for (int l = 0; l < QR2_XS; ++l) { - const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l] & 0x000001FF)); - const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l] >> 9)); + const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l] & 0x1FF]; + const uint32_t signs = unpack_ksigns(q2[l] >> 9); - const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); - const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); + const int signs0 = __vcmpne4(signs & 0x08040201, 0); + const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); + + const int signs1 = __vcmpne4(signs & 0x80402010, 0); + const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q3_K + 8*kqsx + (2*l + 0)] = grid_l; @@ -2904,11 +2907,13 @@ template static __device__ __forceinline__ void loa #pragma unroll for (int l = 0; l < QR3_XXS; ++l) { const int2 grid_pos = make_int2(iq3xxs_grid[q3[2*l+0]], iq3xxs_grid[q3[2*l+1]]); + const uint32_t signs = unpack_ksigns(aux32 >> (7*l)); - const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l)) & 0x7F)); + const int signs0 = __vcmpne4(signs & 0x08040201, 0); + const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); - const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); - const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); + const int signs1 = __vcmpne4(signs & 0x80402010, 0); + const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); #if defined(AMD_MFMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8*kqsx + (2*l + 0)] = grid_l; diff --git a/ggml/src/ggml-cuda/vecdotq.cuh b/ggml/src/ggml-cuda/vecdotq.cuh index 6baab1176f..ab803aca21 100644 --- a/ggml/src/ggml-cuda/vecdotq.cuh +++ b/ggml/src/ggml-cuda/vecdotq.cuh @@ -94,6 +94,15 @@ static __device__ __forceinline__ int2 get_int_from_table_16(const int & q4, con #endif } +static __device__ __forceinline__ uint32_t unpack_ksigns(const uint8_t v) { + // v is a 7 bit int, with the 8th sign being encodable as popcnt + // with xor we can "correct" the bit instead of having to mask + const uint32_t p = __popc(v) & 1; + const uint32_t s = v ^ p << 7; + // broadcast over uint to allow for 0x08040201 / 0x80402010 as selectors + return s * 0x01010101; +} + // VDR = vec dot ratio, how many contiguous integers each thread processes when the vec dot kernel is called // MMVQ = mul_mat_vec_q, MMQ = mul_mat_q @@ -905,22 +914,22 @@ static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1( int sumi = 0; #pragma unroll for (int k0 = 0; k0 < 8; k0 += 2) { - const int * grid_pos = (const int *) (iq2xxs_grid + aux8[k0/2]); - const int signs_packed = ksigns_iq2xs[(aux32 >> (7*k0/2)) & 0x7F]; + const uint2 grid_pos = ((const uint2*)iq2xxs_grid)[aux8[k0/2]]; + const uint32_t signs = unpack_ksigns(aux32 >> (7 * k0 / 2)); - const int signs0 = __vcmpne4(((signs_packed & 0x03) << 7) | ((signs_packed & 0x0C) << 21), 0x00000000); - const int grid0 = __vsub4(grid_pos[0] ^ signs0, signs0); + const int signs0 = __vcmpne4(signs & 0x08040201, 0); + const int grid0 = __vsub4(grid_pos.x ^ signs0, signs0); const int u0 = get_int_b4(bq8_1[iqs/2].qs, k0 + 0); sumi = ggml_cuda_dp4a(grid0, u0, sumi); - const int signs1 = __vcmpne4(((signs_packed & 0x30) << 3) | ((signs_packed & 0xC0) << 17), 0x00000000); - const int grid1 = __vsub4(grid_pos[1] ^ signs1, signs1); + const int signs1 = __vcmpne4(signs & 0x80402010, 0); + const int grid1 = __vsub4(grid_pos.y ^ signs1, signs1); const int u1 = get_int_b4(bq8_1[iqs/2].qs, k0 + 1); sumi = ggml_cuda_dp4a(grid1, u1, sumi); } - const int ls = aux32 >> 28; - sumi = (ls*sumi + sumi/2)/4; + const int ls = aux32 >> 27 | 1; // (scale * 2 + 1) + sumi = sumi * ls / 8; // (sumi * scale + sumi / 2) / 4 const float d = __half2float(bq2->d) * __low2float(bq8_1[iqs/2].ds); return d * sumi; } @@ -942,13 +951,15 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1( int sumi1 = 0; #pragma unroll for (int l0 = 0; l0 < 8; l0 += 2) { - const uint32_t * grid_pos = (const uint32_t *)(iq2xs_grid + (q2[l0/2] & 0x000001FF)); - const uint32_t * signs = (const uint32_t *)(ksigns64 + (q2[l0/2] >> 9)); - - const int grid_l = __vsub4(grid_pos[0] ^ signs[0], signs[0]); - const int grid_h = __vsub4(grid_pos[1] ^ signs[1], signs[1]); + const uint2 grid_pos = ((const uint2*)iq2xs_grid)[q2[l0/2] & 0x1FF]; + const uint32_t signs = unpack_ksigns(q2[l0/2] >> 9); + const int signs0 = __vcmpne4(signs & 0x08040201, 0); + const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0); + + const int signs1 = __vcmpne4(signs & 0x80402010, 0); + const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1); if (l0 < 4) { @@ -1028,13 +1039,16 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1( #pragma unroll for (int l0 = 0; l0 < 8; l0 += 2) { const int2 grid_pos = make_int2(iq3xxs_grid[q3[l0 + 0]], iq3xxs_grid[q3[l0 + 1]]); + const uint32_t signs = unpack_ksigns(aux32 >> (7*l0/2)); - const int * signs = (const int *)(ksigns64 + ((aux32 >> (7*l0/2)) & 0x7F)); - - const int grid_l = __vsub4(grid_pos.x ^ signs[0], signs[0]); - const int grid_h = __vsub4(grid_pos.y ^ signs[1], signs[1]); + const int signs0 = __vcmpne4(signs & 0x08040201, 0); + const int grid_l = __vsub4(grid_pos.x ^ signs0, signs0); const int u0 = get_int_b4(bq8_1[iqs/2].qs, l0 + 0); + + const int signs1 = __vcmpne4(signs & 0x80402010, 0); + const int grid_h = __vsub4(grid_pos.y ^ signs1, signs1); + const int u1 = get_int_b4(bq8_1[iqs/2].qs, l0 + 1); sumi = ggml_cuda_dp4a(grid_l, u0, sumi); diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c index c184637443..74c777d4c3 100644 --- a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -17,121 +17,6 @@ #include "htp-msg.h" #include "htp-ops.h" -static inline HVX_Vector hvx_load_f32_to_f16(const HVX_Vector * restrict src, const HVX_Vector zero) { - HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(src[0], zero); // 32 elements - HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(src[1], zero); // 32 elements - return Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); -} - -// Dot product of FP32 and FP16 vectors, accumulating to float -static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) { - const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32 - const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16 - - uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors - uint32_t nloe = n % VLEN_FP16; // leftover elements - - const HVX_Vector zero = Q6_V_vsplat_R(0); - HVX_Vector rsum = Q6_V_vsplat_R(0); - - uint32_t i = 0; - - #pragma unroll(4) - for (i = 0; i < nvec; i++) { - // Load y (fp32) and convert into fp16 - HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); - - // Load x (fp16) - HVX_Vector x_hf = vx[i]; - - HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); - - rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum)); - } - - if (nloe) { - // Load y (fp32) and convert into fp16 - HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); - - // Load x (fp16) - HVX_Vector x_hf = vx[i]; - - // Zero-out unused elements - // Note that we need to clear both x and y because they may contain NANs - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); - x_hf = Q6_V_vand_QV(bmask, x_hf); - y_hf = Q6_V_vand_QV(bmask, y_hf); - - HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); - - rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum)); - } - - rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum)); - hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum)); -} - -// Dot product of FP32 and FP16 vectors, accumulating to float -static inline void hvx_dot_f32_f16_aa_rx2(float * restrict r, - const void * restrict y, - const void * restrict x0, - const void * restrict x1, - unsigned int n, - float s) { - const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32 - const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16 - const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16 - - uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors - uint32_t nloe = n % VLEN_FP16; // leftover elements - - const HVX_Vector zero = Q6_V_vsplat_R(0); - HVX_Vector rsum0 = Q6_V_vsplat_R(0); - HVX_Vector rsum1 = Q6_V_vsplat_R(0); - - uint32_t i = 0; - - #pragma unroll(2) - for (i = 0; i < nvec; i++) { - // Load y (fp32) and convert into fp16 - HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); - // Load x (fp16) - HVX_Vector x0_hf = vx0[i]; - HVX_Vector x1_hf = vx1[i]; - - HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); - HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); - - rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0)); - rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1)); - } - - if (nloe) { - // Load y (fp32) and convert into fp16 - HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero); - - // Load x (fp16) - HVX_Vector x0_hf = vx0[i]; - HVX_Vector x1_hf = vx1[i]; - - // Zero-out unused elements - // Note that we need to clear both x and y because they may contain NANs - HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); - x0_hf = Q6_V_vand_QV(bmask, x0_hf); - x1_hf = Q6_V_vand_QV(bmask, x1_hf); - y_hf = Q6_V_vand_QV(bmask, y_hf); - - HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); - HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); - - rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0)); - rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1)); - } - - HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1)); - hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum)); -} - // Dot product of two F16 vectors, accumulating to float static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) { const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16 @@ -140,8 +25,7 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors uint32_t nloe = n % VLEN_FP16; // leftover elements - const HVX_Vector zero = Q6_V_vsplat_R(0); - HVX_Vector rsum = Q6_V_vsplat_R(0); + HVX_Vector rsum = Q6_V_vsplat_R(0); uint32_t i = 0; @@ -156,11 +40,10 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict } if (nloe) { - HVX_Vector y_hf = vy[i]; - // Load x (fp16) and zero-out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); - HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]); + HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]); + HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]); HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); @@ -181,12 +64,11 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r, const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16 const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16 - uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors - uint32_t nloe = n % VLEN_FP16; // leftover elements + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements - const HVX_Vector zero = Q6_V_vsplat_R(0); - HVX_Vector rsum0 = Q6_V_vsplat_R(0); - HVX_Vector rsum1 = Q6_V_vsplat_R(0); + HVX_Vector rsum0 = Q6_V_vsplat_R(0); + HVX_Vector rsum1 = Q6_V_vsplat_R(0); uint32_t i = 0; @@ -204,12 +86,11 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r, } if (nloe) { - HVX_Vector y_hf = vy[i]; - // Load x (fp16) and zero-out unused elements HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); - HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]); - HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]); + HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]); + HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]); + HVX_Vector y_hf = Q6_V_vand_QV(bmask, vy[i]); HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); @@ -222,7 +103,7 @@ static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r, hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum)); } -// MAD: y (F32) += x (F16) * s (float) +// MAD: y (F32) += x (F16) * s (F32) static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) { const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x; HVX_Vector * restrict ptr_y = (HVX_Vector *) y; @@ -259,15 +140,125 @@ static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict } } +// MAD: y (F32) += x0 (F16) * s0 (F32) + x1 (F16) * s1 (F32) +static inline void hvx_mad_f32_f16_aa_rx2(float * restrict y, + const void * restrict x0, + const void * restrict x1, + float s0, + float s1, + int n) { + const HVX_Vector * restrict ptr_x0 = (const HVX_Vector *) x0; + const HVX_Vector * restrict ptr_x1 = (const HVX_Vector *) x1; + HVX_Vector * restrict ptr_y = (HVX_Vector *) y; + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + HVX_Vector S0 = hvx_vec_splat_f16(s0); + HVX_Vector S1 = hvx_vec_splat_f16(s1); + + uint32_t i = 0; + #pragma unroll(2) + for (i = 0; i < nvec; ++i) { + // Multiply x * s -> pair of F32 vectors + HVX_VectorPair xs0_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x0[i]), S0); + HVX_VectorPair xs1_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x1[i]), S1); + + HVX_Vector xs_p_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xs0_p), Q6_V_lo_W(xs1_p)); + HVX_Vector xs_p_hi = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p)); + + ptr_y[i * 2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs_p_lo, ptr_y[i * 2])); + ptr_y[i * 2 + 1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs_p_hi, ptr_y[i * 2 + 1])); + } + + if (nloe) { + HVX_VectorPair xs0_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x0[i]), S0); + HVX_VectorPair xs1_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x1[i]), S1); + + HVX_Vector xs_p_lo = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xs0_p), Q6_V_lo_W(xs1_p)); + HVX_Vector xs = xs_p_lo; + i = 2 * i; // index for ptr_y + + if (nloe >= 32) { + ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i])); + nloe -= 32; ++i; + xs = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_hi_W(xs0_p), Q6_V_hi_W(xs1_p)); + } + + if (nloe) { + HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i])); + hvx_vec_store_a(&ptr_y[i], nloe * 4, xy); + } + } +} + #define FLASH_ATTN_BLOCK_SIZE 128 -static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, int nth) { +struct htp_fa_context { + const struct htp_ops_context * octx; + + struct fastdiv_values src0_div21; + struct fastdiv_values src0_div1; + + struct fastdiv_values broadcast_rk2; + struct fastdiv_values broadcast_rk3; + struct fastdiv_values broadcast_rv2; + struct fastdiv_values broadcast_rv3; + + struct fastdiv_values src3_div2; + struct fastdiv_values src3_div3; + + float scale; + float max_bias; + float logit_softcap; + + uint32_t n_head_log2; + float m0; + float m1; + + uint32_t n_blocks; + + size_t size_q_row_padded; + size_t size_k_row_padded; + size_t size_v_row_padded; + + size_t size_k_block; + size_t size_v_block; + size_t size_m_block; + + bool is_q_fp32; +}; + +static inline void hvx_scale_vec_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, HVX_Vector vs) { + assert((size_t) dst % 128 == 0); + assert((size_t) src % 128 == 0); + + const HVX_Vector * restrict vsrc = (const HVX_Vector * restrict) src; + HVX_Vector * restrict vdst = (HVX_Vector * restrict) dst; + + const uint32_t nvec = n / VLEN_FP32; + const uint32_t nloe = n % VLEN_FP32; + + uint32_t i = 0; + #pragma unroll(4) + for (; i < nvec; ++i) { + vdst[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs)); + } + if (nloe) { + HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); + hvx_vec_store_a(&vdst[i], nloe * sizeof(float), Q6_Vsf_equals_Vqf32(v)); + } +} + +static void flash_attn_ext_f16_thread(unsigned int nth, unsigned int ith, void * data) { + struct htp_fa_context * factx = (struct htp_fa_context *) data; + const struct htp_ops_context * octx = factx->octx; const struct htp_tensor * q = &octx->src0; const struct htp_tensor * k = &octx->src1; const struct htp_tensor * v = &octx->src2; const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL; const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * dst = &octx->dst; const uint32_t neq0 = q->ne[0]; const uint32_t neq1 = q->ne[1]; @@ -304,18 +295,6 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in const uint32_t nb2 = dst->nb[2]; const uint32_t nb3 = dst->nb[3]; - float scale = 1.0f; - float max_bias = 0.0f; - float logit_softcap = 0.0f; - - memcpy(&scale, (float *) octx->op_params + 0, sizeof(float)); - memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float)); - memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float)); - - if (logit_softcap != 0) { - scale /= logit_softcap; - } - // total rows in q const uint32_t nr = neq1*neq2*neq3; @@ -331,18 +310,8 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in const uint32_t DV = nev0; const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2); - const size_t size_q_row_padded = hex_round_up(size_q_row, 128); - const size_t size_k_row = DK * sizeof(__fp16); const size_t size_v_row = DV * sizeof(__fp16); - const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask - - const size_t size_k_row_padded = hex_round_up(size_k_row, 128); - const size_t size_v_row_padded = hex_round_up(size_v_row, 128); - - const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE; - const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE; - const size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128); // Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith; @@ -351,31 +320,28 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith; uint8_t * spad_a = octx->dst_spad.data + octx->dst_spad.size_per_thread * ith; - const uint32_t n_head = neq2; - const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); - const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); - const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + const HVX_Vector logit_cap = hvx_vec_splat_f32(factx->logit_softcap); for (uint32_t ir = ir0; ir < ir1; ++ir) { - const uint32_t iq3 = fastdiv(ir, &octx->src0_div21); - const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1); + const uint32_t iq3 = fastdiv(ir, &factx->src0_div21); + const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &factx->src0_div1); const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1); - const uint32_t ik3 = fastdiv(iq3, &octx->broadcast_rk3); - const uint32_t ik2 = fastdiv(iq2, &octx->broadcast_rk2); + const uint32_t ik3 = fastdiv(iq3, &factx->broadcast_rk3); + const uint32_t ik2 = fastdiv(iq2, &factx->broadcast_rk2); - const uint32_t iv3 = fastdiv(iq3, &octx->broadcast_rv3); - const uint32_t iv2 = fastdiv(iq2, &octx->broadcast_rv2); + const uint32_t iv3 = fastdiv(iq3, &factx->broadcast_rv3); + const uint32_t iv2 = fastdiv(iq2, &factx->broadcast_rv2); // Fetch Q row const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3); - dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), size_q_row_padded, nbq1, size_q_row, 1); + dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), factx->size_q_row_padded, nbq1, size_q_row, 1); const uint32_t h = iq2; // head index - const float slope = (max_bias > 0.0f) ? (h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)) : 1.0f; + const float slope = (factx->max_bias > 0.0f) ? (h < factx->n_head_log2 ? powf(factx->m0, h + 1) : powf(factx->m1, 2*(h - factx->n_head_log2) + 1)) : 1.0f; - float S = 0.0f; // sum - float M = -INFINITY; // maximum KQ value + HVX_Vector S_vec = hvx_vec_splat_f32(0.0f); + HVX_Vector M_vec = hvx_vec_splat_f32(-INFINITY); // Clear accumulator hvx_splat_f32_a(spad_a, 0, DV); @@ -383,40 +349,42 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in const __fp16 * mp_base = NULL; if (mask) { - const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &octx->src3_div2); - const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &octx->src3_div3); + const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &factx->src3_div2); + const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &factx->src3_div3); mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]); } - const uint32_t n_blocks = (nek1 + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE; - // Prefetch first two blocks - for (uint32_t ib = 0; ib < MIN(n_blocks, 2); ++ib) { + for (uint32_t ib = 0; ib < MIN(factx->n_blocks, 2); ++ib) { const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE; const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start); // K const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3); - uint8_t * k_dst = spad_k + (ib % 2) * size_k_block; - dma_queue_push(dma, dma_make_ptr(k_dst, k_src), size_k_row_padded, nbk1, size_k_row, current_block_size); + uint8_t * k_dst = spad_k + (ib % 2) * factx->size_k_block; + dma_queue_push(dma, dma_make_ptr(k_dst, k_src), factx->size_k_row_padded, nbk1, size_k_row, current_block_size); // V const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3); - uint8_t * v_dst = spad_v + (ib % 2) * size_v_block; - dma_queue_push(dma, dma_make_ptr(v_dst, v_src), size_v_row_padded, nbv1, size_v_row, current_block_size); + uint8_t * v_dst = spad_v + (ib % 2) * factx->size_v_block; + dma_queue_push(dma, dma_make_ptr(v_dst, v_src), factx->size_v_row_padded, nbv1, size_v_row, current_block_size); // Mask if (mask) { const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start); - uint8_t * m_dst = spad_m + (ib % 2) * size_m_block; + uint8_t * m_dst = spad_m + (ib % 2) * factx->size_m_block; // Mask is 1D contiguous for this row dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1); } } - const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst; + uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst; + if (factx->is_q_fp32) { + hvx_copy_f16_f32_aa(q_ptr_vtcm, q_ptr_vtcm, DK); // inplace convert f32 to f16 + } - for (uint32_t ib = 0; ib < n_blocks; ++ib) { + const HVX_Vector slope_vec = hvx_vec_splat_f16(slope); + for (uint32_t ib = 0; ib < factx->n_blocks; ++ib) { const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE; const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start); @@ -428,8 +396,6 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in // Inner loop processing the block from VTCM uint32_t ic = 0; - const bool is_q_fp32 = (q->type == HTP_TYPE_F32); - // Process in blocks of 32 (VLEN_FP32) static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 <= 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage"); HVX_Vector_x4 scores_x4; @@ -437,22 +403,18 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) { // 1. Compute scores float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32]; - for (int j = 0; j < VLEN_FP32; j += 2) { + for (uint32_t j = 0; j < VLEN_FP32; j += 2) { const uint32_t cur_ic = ic + j; - const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded; - if (is_q_fp32) { - hvx_dot_f32_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale); - } else { - hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale); - } + const uint8_t * k_ptr = k_base + cur_ic * factx->size_k_row_padded; + hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + factx->size_k_row_padded, DK, factx->scale); } HVX_Vector scores = *(HVX_Vector *) scores_arr; // 2. Softcap - if (logit_softcap != 0.0f) { + if (factx->logit_softcap != 0.0f) { scores = hvx_vec_tanh_f32(scores); - scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_f32(logit_softcap)); + scores = Q6_Vqf32_vmpy_VsfVsf(scores, logit_cap); scores = Q6_Vsf_equals_Vqf32(scores); } @@ -460,70 +422,59 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in if (mask) { const __fp16 * mp = m_base + ic; HVX_Vector m_vals_f16 = *(const HVX_UVector *) mp; - - HVX_Vector one_f16 = Q6_Vh_vsplat_R(0x3c00); - HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), one_f16); - - HVX_Vector m_vals_f32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_f32_pair)); - - HVX_Vector slope_vec = hvx_vec_splat_f32(slope); - HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_f32, slope_vec); - scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val)); + HVX_VectorPair m_vals_f32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_f16), slope_vec); + HVX_Vector add_val = Q6_V_lo_W(m_vals_f32_pair); + scores = Q6_Vqf32_vadd_Vqf32Vsf(add_val, scores); scores = Q6_Vsf_equals_Vqf32(scores); } scores_x4.v[iv] = scores; - v_max = Q6_Vsf_vmax_VsfVsf(scores, v_max); + v_max = hvx_vec_reduce_max2_f32(scores, v_max); // All lanes have block max } { // 4. Online Softmax Update - v_max = hvx_vec_reduce_max_f32(v_max); - float m_block = hvx_vec_get_f32(v_max); - float M_old = M; - float M_new = (m_block > M) ? m_block : M; - M = M_new; + HVX_Vector M_new_vec = Q6_Vsf_vmax_VsfVsf(v_max, M_vec); + HVX_Vector diff_vec = Q6_Vqf32_vsub_VsfVsf(M_vec, M_new_vec); + HVX_Vector ms_vec = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(diff_vec)); + M_vec = M_new_vec; - const float ms = expf(M_old - M_new); - hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms); + hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec); - HVX_Vector M_new_vec = hvx_vec_splat_f32(M_new); HVX_Vector p_sum_vec = hvx_vec_splat_f32(0.0f); for (uint32_t ic2 = 0, iv = 0; ic2 + VLEN_FP32 <= current_block_size; ic2 += VLEN_FP32, ++iv) { HVX_Vector scores = scores_x4.v[iv]; - HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec); + HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_vec); HVX_Vector P = hvx_vec_exp_f32(Q6_Vsf_equals_Vqf32(scores_shifted)); p_sum_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(p_sum_vec, P)); // 5. Accumulate V float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32]; - *(HVX_Vector*)p_arr = P; + *(HVX_Vector *) p_arr = P; - for (int j = 0; j < VLEN_FP32; ++j) { - const uint32_t cur_ic = ic2 + j; - const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded; - hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]); + for (uint32_t j = 0; j < VLEN_FP32; j += 2) { + const uint32_t cur_ic = ic2 + j; + const uint8_t * v_ptr = v_base + cur_ic * factx->size_v_row_padded; + hvx_mad_f32_f16_aa_rx2(VKQ32, v_ptr, v_ptr + factx->size_v_row_padded, p_arr[j], p_arr[j + 1], DV); } } p_sum_vec = hvx_vec_reduce_sum_f32(p_sum_vec); - S = S * ms + hvx_vec_get_f32(p_sum_vec); + S_vec = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(S_vec, ms_vec)), p_sum_vec)); } + // Sync scalars for leftover/next block if needed + float M = hvx_vec_get_f32(M_vec); + float S = hvx_vec_get_f32(S_vec); + // Leftover for (; ic < current_block_size; ++ic) { float s_val; - const uint8_t * k_ptr = k_base + ic * size_k_row_padded; - - if (is_q_fp32) { - hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale); - } else { - hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale); - } - - if (logit_softcap != 0.0f) { - s_val = logit_softcap * tanhf(s_val); + const uint8_t * k_ptr = k_base + ic * factx->size_k_row_padded; + hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, factx->scale); + if (factx->logit_softcap != 0.0f) { + s_val = factx->logit_softcap * tanhf(s_val); } if (mask) { @@ -532,37 +483,42 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in } const float Mold = M; - float ms = 1.0f; float vs = 1.0f; if (s_val > M) { M = s_val; - ms = expf(Mold - M); - hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms); + HVX_Vector diff_vec = hvx_vec_splat_f32(Mold - M); + HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec); + hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec); + + float ms = hvx_vec_get_f32(ms_vec); + S = S * ms + vs; } else { - vs = expf(s_val - M); + HVX_Vector diff_vec = hvx_vec_splat_f32(s_val - M); + vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec)); + S += vs; } - const uint8_t * v_ptr = v_base + ic * size_v_row_padded; + const uint8_t * v_ptr = v_base + ic * factx->size_v_row_padded; hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs); - - S = S * ms + vs; } + M_vec = hvx_vec_splat_f32(M); + S_vec = hvx_vec_splat_f32(S); // Issue DMA for next+1 block (if exists) - if (ib + 2 < n_blocks) { + if (ib + 2 < factx->n_blocks) { const uint32_t next_ib = ib + 2; const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE; const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start); // K const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3); - dma_queue_push(dma, dma_make_ptr(k_base, k_src), size_k_row_padded, nbk1, size_k_row, next_block_size); + dma_queue_push(dma, dma_make_ptr(k_base, k_src), factx->size_k_row_padded, nbk1, size_k_row, next_block_size); // V const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3); - dma_queue_push(dma, dma_make_ptr(v_base, v_src), size_v_row_padded, nbv1, size_v_row, next_block_size); + dma_queue_push(dma, dma_make_ptr(v_base, v_src), factx->size_v_row_padded, nbv1, size_v_row, next_block_size); // Mask if (mask) { @@ -573,20 +529,26 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in } // sinks + float M = hvx_vec_get_f32(M_vec); + float S = hvx_vec_get_f32(S_vec); + if (sinks) { const float s = ((float *)((char *) sinks->data))[h]; - float ms = 1.0f; float vs = 1.0f; if (s > M) { - ms = expf(M - s); - hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms); - } else { - vs = expf(s - M); - } + HVX_Vector diff_vec = hvx_vec_splat_f32(M - s); + HVX_Vector ms_vec = hvx_vec_exp_f32(diff_vec); + hvx_scale_vec_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms_vec); - S = S * ms + vs; + float ms = hvx_vec_get_f32(ms_vec); + S = S * ms + vs; + } else { + HVX_Vector diff_vec = hvx_vec_splat_f32(s - M); + vs = hvx_vec_get_f32(hvx_vec_exp_f32(diff_vec)); + S += vs; + } } const float S_inv = S == 0.0f ? 0.0f : 1.0f/S; @@ -609,53 +571,73 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in } } -static void htp_flash_attn_ext_job(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - flash_attn_ext_f16_thread(octx, i, n); -} - int op_flash_attn_ext(struct htp_ops_context * octx) { const struct htp_tensor * q = &octx->src0; const struct htp_tensor * k = &octx->src1; const struct htp_tensor * v = &octx->src2; - const struct htp_tensor * mask = (octx->src3.type != HTP_TYPE_COUNT) ? &octx->src3 : NULL; - struct htp_tensor * dst = &octx->dst; + const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL; + const struct htp_tensor * dst = &octx->dst; // Check support - if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) || - k->type != HTP_TYPE_F16 || - v->type != HTP_TYPE_F16) { + if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) || k->type != HTP_TYPE_F16 || v->type != HTP_TYPE_F16) { return HTP_STATUS_NO_SUPPORT; } - octx->src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]); - octx->src0_div1 = init_fastdiv_values(q->ne[1]); + struct htp_fa_context factx; + factx.octx = octx; - octx->broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]); - octx->broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]); - octx->broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]); - octx->broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]); + factx.src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]); + factx.src0_div1 = init_fastdiv_values(q->ne[1]); + + factx.broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]); + factx.broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]); + factx.broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]); + factx.broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]); if (mask) { - octx->src3_div2 = init_fastdiv_values(mask->ne[2]); - octx->src3_div3 = init_fastdiv_values(mask->ne[3]); + factx.src3_div2 = init_fastdiv_values(mask->ne[2]); + factx.src3_div3 = init_fastdiv_values(mask->ne[3]); } - size_t size_q_row_padded = hex_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128); - size_t size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128); - size_t size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128); + factx.is_q_fp32 = (q->type == HTP_TYPE_F32); + factx.size_q_row_padded = hex_round_up(q->ne[0] * (factx.is_q_fp32 ? 4 : 2), 128); + factx.size_k_row_padded = hex_round_up(k->ne[0] * sizeof(__fp16), 128); + factx.size_v_row_padded = hex_round_up(v->ne[0] * sizeof(__fp16), 128); - size_t size_q_block = size_q_row_padded * 1; // single row for now - size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE; - size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE; - size_t size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128); + size_t size_q_block = factx.size_q_row_padded * 1; // single row for now + factx.size_k_block = factx.size_k_row_padded * FLASH_ATTN_BLOCK_SIZE; + factx.size_v_block = factx.size_v_row_padded * FLASH_ATTN_BLOCK_SIZE; + factx.size_m_block = hex_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128); + + factx.n_blocks = (k->ne[1] + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE; + + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (float *) octx->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float)); + + if (logit_softcap != 0.0f) { + scale /= logit_softcap; + } + + factx.scale = scale; + factx.max_bias = max_bias; + factx.logit_softcap = logit_softcap; + + uint32_t n_head = q->ne[2]; + factx.n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + factx.m0 = powf(2.0f, -(max_bias ) / factx.n_head_log2); + factx.m1 = powf(2.0f, -(max_bias / 2.0f) / factx.n_head_log2); size_t size_vkq_acc = hex_round_up(v->ne[0] * sizeof(float), 128); // VKQ32 octx->src0_spad.size_per_thread = size_q_block * 1; - octx->src1_spad.size_per_thread = size_k_block * 2; - octx->src2_spad.size_per_thread = size_v_block * 2; - octx->src3_spad.size_per_thread = mask ? size_m_block * 2 : 0; + octx->src1_spad.size_per_thread = factx.size_k_block * 2; + octx->src2_spad.size_per_thread = factx.size_v_block * 2; + octx->src3_spad.size_per_thread = mask ? factx.size_m_block * 2 : 0; octx->dst_spad.size_per_thread = size_vkq_acc; octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; @@ -677,7 +659,7 @@ int op_flash_attn_ext(struct htp_ops_context * octx) { octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size; if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - worker_pool_run_func(octx->ctx->worker_pool, htp_flash_attn_ext_job, octx, octx->n_threads); + worker_pool_run_func(octx->ctx->worker_pool, flash_attn_ext_f16_thread, &factx, octx->n_threads); } return HTP_STATUS_OK; diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index c0d72587ce..f1ad24dbfa 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -64,25 +64,12 @@ struct htp_ops_context { struct fastdiv_values broadcast_rv2; struct fastdiv_values broadcast_rv3; - struct fastdiv_values mm_div_ne12_ne1; // fastdiv values for ne12 * ne1 - struct fastdiv_values mm_div_ne1; // fastdiv values for ne1 - struct fastdiv_values mm_div_r2; // fastdiv values for ne12 / ne02 - struct fastdiv_values mm_div_r3; // fastdiv values for ne13 / ne03 - struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12 struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11 struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10 struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11 - struct fastdiv_values cpy_div_ne01; // fastdiv values for ne01 - struct fastdiv_values cpy_div_ne02; // fastdiv values for ne02 - struct fastdiv_values cpy_div_ne03; // fastdiv values for ne03 - - struct fastdiv_values cpy_rshp_div_n0; // fastdiv values for ne00 - struct fastdiv_values cpy_rshp_div_n1n0; // fastdiv values for ne00*ne01 - struct fastdiv_values cpy_rshp_div_n2n1n0; // fastdiv values for ne00*ne01*ne02 - uint32_t flags; }; diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index 62708eee5c..92a1422896 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -189,7 +189,7 @@ static int vtcm_release_callback(unsigned int rctx, void * state) { // otherwise we'll release it once we're done with the current Op. if (ctx->vtcm_inuse) { - ctx->vtcm_needs_release = false; + ctx->vtcm_needs_release = true; return 0; } diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index d251eeed33..c360abe8da 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -23,10 +23,30 @@ #define MM_SPAD_SRC1_NROWS 16 #define MM_SPAD_DST_NROWS 2 -struct htp_matmul_type { +struct htp_matmul_context { const char * type; - void (*vec_dot)(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); - void (*vec_dot_rx2)(const int n, float * restrict s, const void * restrict vx, uint32_t vx_row_size, const void * restrict vy); + struct htp_ops_context * octx; + + void (*vec_dot_1x1)(const int n, float * restrict s0, + const void * restrict vx0, + const void * restrict vy0); + + void (*vec_dot_2x1)(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0); + + void (*vec_dot_2x2)(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1); + + // Precomputed values + uint32_t src0_nrows_per_thread; + uint32_t src1_nrows_per_thread; + + struct fastdiv_values mm_div_ne12_ne1; + struct fastdiv_values mm_div_ne1; + struct fastdiv_values mm_div_r2; + struct fastdiv_values mm_div_r3; }; // vdelta control to replicate first 4x fp32 values across lanes @@ -122,6 +142,7 @@ static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) { HVX_Vector v6_7 = vptr[3]; // ... const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector i8 = Q6_Vb_vsplat_R(8); HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 @@ -133,15 +154,14 @@ static inline HVX_Vector_x8 hvx_vec_load_q4x4x8(const uint8_t * restrict ptr) { HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 // Convert uint4 to int4 (i.e. x - 8) - const HVX_Vector i8 = Q6_Vb_vsplat_R(8); - v0 = Q6_Vb_vsub_VbVb(v0, i8); - v1 = Q6_Vb_vsub_VbVb(v1, i8); - v2 = Q6_Vb_vsub_VbVb(v2, i8); - v3 = Q6_Vb_vsub_VbVb(v3, i8); - v4 = Q6_Vb_vsub_VbVb(v4, i8); - v5 = Q6_Vb_vsub_VbVb(v5, i8); - v6 = Q6_Vb_vsub_VbVb(v6, i8); - v7 = Q6_Vb_vsub_VbVb(v7, i8); + v0 = Q6_Vb_vsub_VbVb(v0, i8); + v1 = Q6_Vb_vsub_VbVb(v1, i8); + v2 = Q6_Vb_vsub_VbVb(v2, i8); + v3 = Q6_Vb_vsub_VbVb(v3, i8); + v4 = Q6_Vb_vsub_VbVb(v4, i8); + v5 = Q6_Vb_vsub_VbVb(v5, i8); + v6 = Q6_Vb_vsub_VbVb(v6, i8); + v7 = Q6_Vb_vsub_VbVb(v7, i8); HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; return r; @@ -156,6 +176,7 @@ static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr) HVX_Vector v6_7 = vptr[3]; // ... const HVX_Vector mask_h4 = Q6_Vb_vsplat_R(0x0F); + const HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut; HVX_Vector v0 = Q6_V_vand_VV(v0_1, mask_h4); // & 0x0F HVX_Vector v1 = Q6_Vub_vlsr_VubR(v0_1, 4); // >> 4 @@ -166,15 +187,14 @@ static inline HVX_Vector_x8 hvx_vec_load_mxfp4x4x8(const uint8_t * restrict ptr) HVX_Vector v6 = Q6_V_vand_VV(v6_7, mask_h4); // & 0x0F HVX_Vector v7 = Q6_Vub_vlsr_VubR(v6_7, 4); // >> 4 - HVX_Vector lut = *(const HVX_Vector *) kvalues_mxfp4_lut; - v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0); - v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0); - v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0); - v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0); - v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0); - v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0); - v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0); - v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0); + v0 = Q6_Vb_vlut32_VbVbI(v0, lut, 0); + v1 = Q6_Vb_vlut32_VbVbI(v1, lut, 0); + v2 = Q6_Vb_vlut32_VbVbI(v2, lut, 0); + v3 = Q6_Vb_vlut32_VbVbI(v3, lut, 0); + v4 = Q6_Vb_vlut32_VbVbI(v4, lut, 0); + v5 = Q6_Vb_vlut32_VbVbI(v5, lut, 0); + v6 = Q6_Vb_vlut32_VbVbI(v6, lut, 0); + v7 = Q6_Vb_vlut32_VbVbI(v7, lut, 0); HVX_Vector_x8 r = { v0, v1, v2, v3, v4, v5, v6, v7 }; return r; @@ -196,46 +216,6 @@ static inline HVX_Vector_x8 hvx_vec_load_q8x4x8(const uint8_t * restrict ptr) { return r; } -static inline HVX_Vector_x4 hvx_vec_load_x4_f16(const uint8_t * restrict ptr) { - const HVX_Vector * restrict vptr = (const HVX_Vector *) ptr; - - HVX_Vector v0 = vptr[0]; // first 64 vals - HVX_Vector v1 = vptr[1]; // second 64 vals - HVX_Vector v2 = vptr[2]; // third 64 vals - HVX_Vector v3 = vptr[3]; // forth 64 vals - - HVX_Vector_x4 r = { v0, v1, v2, v3 }; - return r; -} - -static inline HVX_Vector_x4 hvx_vec_load_x4_f32_as_f16(const uint8_t * restrict ptr) { - const HVX_VectorPair * restrict vptr = (const HVX_VectorPair *) ptr; - - HVX_VectorPair v0 = vptr[0]; // first 64 vals - HVX_VectorPair v1 = vptr[1]; // second 64 vals - HVX_VectorPair v2 = vptr[2]; // third 64 vals - HVX_VectorPair v3 = vptr[3]; // forth 64 vals - - HVX_Vector vq0_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v0), Q6_V_vzero()); - HVX_Vector vq0_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v0), Q6_V_vzero()); - HVX_Vector vq1_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v1), Q6_V_vzero()); - HVX_Vector vq1_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v1), Q6_V_vzero()); - HVX_Vector vq2_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v2), Q6_V_vzero()); - HVX_Vector vq2_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v2), Q6_V_vzero()); - HVX_Vector vq3_lo = Q6_Vqf32_vsub_VsfVsf(Q6_V_lo_W(v3), Q6_V_vzero()); - HVX_Vector vq3_hi = Q6_Vqf32_vsub_VsfVsf(Q6_V_hi_W(v3), Q6_V_vzero()); - - HVX_Vector vh0 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq0_hi, vq0_lo)); - HVX_Vector vh1 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq1_hi, vq1_lo)); - HVX_Vector vh2 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq2_hi, vq2_lo)); - HVX_Vector vh3 = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(vq3_hi, vq3_lo)); - - // vcombine does a shuffle, use vdeal to undo - - HVX_Vector_x4 r = { Q6_Vh_vdeal_Vh(vh0), Q6_Vh_vdeal_Vh(vh1), Q6_Vh_vdeal_Vh(vh2), Q6_Vh_vdeal_Vh(vh3) }; - return r; -} - // Reduce multiply 1024 x 1024 int8 elements (32x q4/8 blocks in 8x HVX vectors). // Accumulate each block into a single int32 value. // Return a single HVX vector with 32x int32 accumulators. @@ -300,26 +280,26 @@ static inline HVX_Vector hvx_vec_rmpy_x8_nloe(HVX_Vector_x8 x, HVX_Vector_x8 y, return hvx_vec_rmpy_x8_n(x, y, 1024); } -static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +static void vec_dot_q4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx % 128 == 0); - assert((unsigned long) vy % 128 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales - const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); @@ -372,36 +352,34 @@ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void * r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - hvx_vec_store_u(&s[0], 4, r0_sum); + hvx_vec_store_u(s0, 4, r0_sum); } -static void vec_dot_q4x4x2_q8x4x2_rx2(const int n, - float * restrict s, - const void * restrict vx, - uint32_t vx_row_size, - const void * restrict vy) { +static void vec_dot_q4x4x2_q8x4x2_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx % 128 == 0); - assert((unsigned long) vy % 128 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk / 2; // int4 - const uint32_t x_qrow_size = n / 2; // int4 (not padded) + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); @@ -468,13 +446,143 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n, } HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); - hvx_vec_store_u(&s[0], 8, rsum); + hvx_vec_store_u(s0, 8, rsum); } -static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +static void vec_dot_q4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + assert((unsigned long) vy1 % 128 == 0); + + const uint32_t qk = QK_Q4_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk / 2; // int4 + const uint32_t x_qrow_size = n / 2; // int4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales + const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first + const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales + + // Row sums (sf) - 4 accumulators for 2×2 tile + HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0); + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elements + + uint32_t i = 0; + for (; i < nb; i++) { + // Load src1 columns (reused across both src0 rows) + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + + // Load src0 rows (reused across both src1 columns) + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size); + + // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); + + // Load scales + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + // Compute combined scales + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + // Apply scales and accumulate + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q4x4x8(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe)); + + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + // Zero out unused scales + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); + r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); + r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); + r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); + r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); + r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); + r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); + r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + hvx_vec_store_u(s0, 8, r0_r1_c0_sum); // row0,col0 row1,col0 + hvx_vec_store_u(s1, 8, r0_r1_c1_sum); // row0,col1 row1,col1 +} + +static void vec_dot_q8x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx % 128 == 0); - assert((unsigned long) vy % 128 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_Q4_0x4x2 * 4; @@ -486,11 +594,11 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * const uint32_t y_qblk_size = qk; // int8 const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales - const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); @@ -543,36 +651,34 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - hvx_vec_store_u(&s[0], 4, r0_sum); + hvx_vec_store_u(s0, 4, r0_sum); } -static void vec_dot_q8x4x2_q8x4x2_rx2(const int n, - float * restrict s, - const void * restrict vx, - uint32_t vx_row_size, - const void * restrict vy) { +static void vec_dot_q8x4x2_q8x4x2_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx % 128 == 0); - assert((unsigned long) vy % 128 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_Q4_0x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t x_qblk_size = qk; // int8 - const uint32_t x_qrow_size = n; // int8 (not padded) + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk; // int8 + const uint32_t x_qrow_size = n; // int8 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (qf32) HVX_Vector r0_sum = Q6_V_vsplat_R(0); @@ -639,16 +745,143 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n, } HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); - hvx_vec_store_u(&s[0], 8, rsum); + hvx_vec_store_u(s0, 8, rsum); } -static void vec_dot_mxfp4x4x2_q8x4x2(const int n, - float * restrict s, - const void * restrict vx, - const void * restrict vy) { +static void vec_dot_q8x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + assert((unsigned long) vy1 % 128 == 0); + + const uint32_t qk = QK_Q8_0x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t x_qblk_size = qk; // int8 + const uint32_t x_qrow_size = n; // int8 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales + const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first + const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales + + // Row sums (sf) - 4 accumulators for 2×2 tile + HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0); + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elements + + uint32_t i = 0; + for (; i < nb; i++) { + // Load src1 columns (reused across both src0 rows) + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + + // Load src0 rows (reused across both src1 columns) + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size); + + // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); + + // Load scales + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + // Compute combined scales + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + // Apply scales and accumulate + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_q8x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_q8x4x8(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe)); + + HVX_Vector vy0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y0_d + i * y_dblk_size)); + HVX_Vector vy1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (y1_d + i * y_dblk_size)); + HVX_Vector r0_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r0_x_d + i * x_dblk_size)); + HVX_Vector r1_d = Q6_Vh_vshuff_Vh(*(const HVX_UVector *) (r1_x_d + i * x_dblk_size)); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy0_d))); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r0_d, vy1_d))); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy0_d))); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(r1_d, vy1_d))); + + // Zero out unused scales + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); + r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); + r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); + r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); + r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); + r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); + r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); + r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0 + hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 +} + +static void vec_dot_mxfp4x4x2_q8x4x2_1x1(const int n, float * restrict s0, const void * restrict vx0, const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx % 128 == 0); - assert((unsigned long) vy % 128 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_MXFP4x4x2 * 4; @@ -660,11 +893,11 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n, const uint32_t y_qblk_size = qk; // int8 const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) vx + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) vx + x_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0 + 0); // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0 + x_qrow_size); // then scales - const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + const uint8_t * restrict y_q = ((const uint8_t *) vy0 + 0); // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0 + y_qrow_size); // then scales // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); @@ -747,36 +980,34 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n, r0_sum = hvx_vec_reduce_sum_f32(r0_sum); - hvx_vec_store_u(&s[0], 4, r0_sum); + hvx_vec_store_u(s0, 4, r0_sum); } -static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n, - float * restrict s, - const void * restrict vx, - uint32_t vx_row_size, - const void * restrict vy) { +static void vec_dot_mxfp4x4x2_q8x4x2_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { assert(n % 32 == 0); // min sub-block size - assert((unsigned long) vx % 128 == 0); - assert((unsigned long) vy % 128 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); const uint32_t qk = QK_MXFP4x4x2 * 4; - const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 - const uint32_t x_qblk_size = qk / 2; // fp4 - const uint32_t x_qrow_size = n / 2; // fp4 (not padded) + const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 + const uint32_t x_qblk_size = qk / 2; // fp4 + const uint32_t x_qrow_size = n / 2; // fp4 (not padded) - const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 - const uint32_t y_qblk_size = qk; // int8 - const uint32_t y_qrow_size = n; // int8 (not padded) + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) - const uint8_t * restrict r0_x_q = ((const uint8_t *) (vx + (0 * vx_row_size)) + 0); // quants first - const uint8_t * restrict r0_x_d = ((const uint8_t *) (vx + (0 * vx_row_size)) + x_qrow_size); // then scales + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales - const uint8_t * restrict r1_x_q = ((const uint8_t *) (vx + (1 * vx_row_size)) + 0); // quants first - const uint8_t * restrict r1_x_d = ((const uint8_t *) (vx + (1 * vx_row_size)) + x_qrow_size); // then scales - - const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first - const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales + const uint8_t * restrict y_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales // Row sum (sf) HVX_Vector r0_sum = Q6_V_vsplat_R(0); @@ -879,10 +1110,180 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n, } HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum); - hvx_vec_store_u(&s[0], 8, rsum); + hvx_vec_store_u(s0, 8, rsum); } -static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +static void vec_dot_mxfp4x4x2_q8x4x2_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + assert(n % 32 == 0); + assert((unsigned long) vx0 % 128 == 0); + assert((unsigned long) vx1 % 128 == 0); + assert((unsigned long) vy0 % 128 == 0); + assert((unsigned long) vy1 % 128 == 0); + + const uint32_t qk = QK_MXFP4x4x2 * 4; + + const uint32_t x_dblk_size = 8 * 4 * 1; // 32x e8m0 + const uint32_t x_qblk_size = qk / 2; // fp4 + const uint32_t x_qrow_size = n / 2; // fp4 (not padded) + + const uint32_t y_dblk_size = 8 * 4 * 2; // 32x __fp16 + const uint32_t y_qblk_size = qk; // int8 + const uint32_t y_qrow_size = n; // int8 (not padded) + + const uint8_t * restrict r0_x_q = ((const uint8_t *) vx0) + 0; // quants first + const uint8_t * restrict r0_x_d = ((const uint8_t *) vx0) + x_qrow_size; // then scales + const uint8_t * restrict r1_x_q = ((const uint8_t *) vx1) + 0; // quants first + const uint8_t * restrict r1_x_d = ((const uint8_t *) vx1) + x_qrow_size; // then scales + + const uint8_t * restrict y0_q = ((const uint8_t *) vy0) + 0; // quants first + const uint8_t * restrict y0_d = ((const uint8_t *) vy0) + y_qrow_size; // then scales + const uint8_t * restrict y1_q = ((const uint8_t *) vy1) + 0; // quants first + const uint8_t * restrict y1_d = ((const uint8_t *) vy1) + y_qrow_size; // then scales + + // Row sums (sf) - 4 accumulators for 2×2 tile + HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0); + + const uint32_t nb = n / qk; // num full blocks + const uint32_t nloe = n % qk; // num leftover elements + + uint32_t i = 0; + for (; i < nb; i++) { + // Load src1 columns (reused across both src0 rows) + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + + // Load src0 rows (reused across both src1 columns) + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size); + + // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy0_q)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r0_q, vy1_q)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy0_q)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_full(r1_q, vy1_q)); + + // Load scales + HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size); + HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size); + HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + + // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving + HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 + vy0_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half)); + vy0_d = Q6_Vsf_equals_Vqf32(vy0_d); + vy1_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half)); + vy1_d = Q6_Vsf_equals_Vqf32(vy1_d); + + // Convert rX_d scales from e8m0 to fp32 + // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... + // Left shift with zero fill to create FP32 + // FIXME: might need to handle zero as a special case (see ggml-cpu code) + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + r0_d = Q6_V_vdelta_VV(r0_d, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + r1_d = Q6_V_vdelta_VV(r1_d, expand); + r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); + r1_d = Q6_Vw_vasl_VwR(r1_d, 23); + + // Compute combined scales + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d)); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d)); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d)); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d)); + + // Apply scales and accumulate + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Process leftovers + if (nloe) { + HVX_Vector_x8 vy0_q = hvx_vec_load_q8x4x8(y0_q + i * y_qblk_size); + HVX_Vector_x8 vy1_q = hvx_vec_load_q8x4x8(y1_q + i * y_qblk_size); + HVX_Vector_x8 r0_q = hvx_vec_load_mxfp4x4x8(r0_x_q + i * x_qblk_size); + HVX_Vector_x8 r1_q = hvx_vec_load_mxfp4x4x8(r1_x_q + i * x_qblk_size); + + HVX_Vector r0_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy0_q, nloe)); + HVX_Vector r0_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r0_q, vy1_q, nloe)); + HVX_Vector r1_c0_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy0_q, nloe)); + HVX_Vector r1_c1_ia = Q6_Vsf_equals_Vw(hvx_vec_rmpy_x8_nloe(r1_q, vy1_q, nloe)); + + HVX_Vector vy0_d = *(const HVX_UVector *) (y0_d + i * y_dblk_size); + HVX_Vector vy1_d = *(const HVX_UVector *) (y1_d + i * y_dblk_size); + HVX_Vector r0_d = *(const HVX_UVector *) (r0_x_d + i * x_dblk_size); + HVX_Vector r1_d = *(const HVX_UVector *) (r1_x_d + i * x_dblk_size); + + // Convert vy_d from fp16 to fp32 while applying 0.5 scaling which is used for e8m0 halving + HVX_Vector half = Q6_Vh_vsplat_R(0x3800); // 0.5 in fp16 + vy0_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy0_d), half)); + vy0_d = Q6_Vsf_equals_Vqf32(vy0_d); + vy1_d = Q6_V_lo_W(Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(vy1_d), half)); + vy1_d = Q6_Vsf_equals_Vqf32(vy1_d); + + // Convert rX_d scales from e8m0 to fp32 + // Expand and zero-pad 32x uint8 e8m0 values to uint32s : 0 0 0 0, 0 0 0 1, 0 0 0 2, ... + // Left shift with zero fill to create FP32 + // FIXME: might need to handle zero as a special case (see ggml-cpu code) + HVX_Vector expand = *(const HVX_Vector *) expand_x32_e8m0; + HVX_Vector e8m0_mask = Q6_V_vsplat_R(0x000000ff); + r0_d = Q6_V_vdelta_VV(r0_d, expand); + r0_d = Q6_V_vand_VV(r0_d, e8m0_mask); + r0_d = Q6_Vw_vasl_VwR(r0_d, 23); + r1_d = Q6_V_vdelta_VV(r1_d, expand); + r1_d = Q6_V_vand_VV(r1_d, e8m0_mask); + r1_d = Q6_Vw_vasl_VwR(r1_d, 23); + + HVX_Vector r0_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy0_d)); + HVX_Vector r0_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy1_d)); + HVX_Vector r1_c0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy0_d)); + HVX_Vector r1_c1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy1_d)); + + // Zero out unused scales + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); + r0_c0_dd = Q6_V_vand_QV(bmask, r0_c0_dd); + r0_c1_dd = Q6_V_vand_QV(bmask, r0_c1_dd); + r1_c0_dd = Q6_V_vand_QV(bmask, r1_c0_dd); + r1_c1_dd = Q6_V_vand_QV(bmask, r1_c1_dd); + r0_c0_ia = Q6_V_vand_QV(bmask, r0_c0_ia); + r0_c1_ia = Q6_V_vand_QV(bmask, r0_c1_ia); + r1_c0_ia = Q6_V_vand_QV(bmask, r1_c0_ia); + r1_c1_ia = Q6_V_vand_QV(bmask, r1_c1_ia); + + HVX_Vector r0_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c0_ia, r0_c0_dd); + HVX_Vector r0_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r0_c1_ia, r0_c1_dd); + HVX_Vector r1_c0_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c0_ia, r1_c0_dd); + HVX_Vector r1_c1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_c1_ia, r1_c1_dd); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_fa, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_fa, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_fa, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_fa, r1_c1_sum)); + } + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0 + hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 +} + +static void vec_dot_f16_f16_aa_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { const HVX_Vector * restrict x = (const HVX_Vector *) vx; const HVX_Vector * restrict y = (const HVX_Vector *) vy; @@ -912,14 +1313,12 @@ static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * res hvx_vec_store_u(&s[0], 4, rsum); } -static void vec_dot_f16_f16_aa_rx2(const int n, - float * restrict s, - const void * restrict vx, - uint32_t vx_row_size, - const void * restrict vy) { - const HVX_Vector * restrict x0 = (const HVX_Vector *) vx; - const HVX_Vector * restrict x1 = (const HVX_Vector *) ((const uint8_t *) vx + vx_row_size); - const HVX_Vector * restrict y = (const HVX_Vector *) vy; +static void vec_dot_f16_f16_aa_2x1(const int n, float * restrict s0, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0) { + const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; + const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1; + const HVX_Vector * restrict y = (const HVX_Vector *) vy0; uint32_t nvec = n / VLEN_FP16; uint32_t nloe = n % VLEN_FP16; @@ -953,10 +1352,86 @@ static void vec_dot_f16_f16_aa_rx2(const int n, } HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(Q6_Vsf_equals_Vqf32(rsum0), Q6_Vsf_equals_Vqf32(rsum1)); - hvx_vec_store_u(&s[0], 8, rsum); + hvx_vec_store_u(s0, 8, rsum); } -static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { +static void vec_dot_f16_f16_aa_2x2(const int n, float * restrict s0, float * restrict s1, + const void * restrict vx0, const void * restrict vx1, + const void * restrict vy0, const void * restrict vy1) { + const HVX_Vector * restrict x0 = (const HVX_Vector *) vx0; + const HVX_Vector * restrict x1 = (const HVX_Vector *) vx1; + const HVX_Vector * restrict y0 = (const HVX_Vector *) vy0; + const HVX_Vector * restrict y1 = (const HVX_Vector *) vy1; + + uint32_t nvec = n / VLEN_FP16; + uint32_t nloe = n % VLEN_FP16; + + // Row sums (sf) - 4 accumulators for 2×2 tile + HVX_Vector r0_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r0_c1_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c0_sum = Q6_V_vsplat_R(0); + HVX_Vector r1_c1_sum = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + HVX_Vector r0_hf = x0[i]; + HVX_Vector r1_hf = x1[i]; + HVX_Vector c0_hf = y0[i]; + HVX_Vector c1_hf = y1[i]; + + // Compute 4 dot products: r0×c0, r0×c1, r1×c0, r1×c1 + HVX_VectorPair r0_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c0_hf); + HVX_VectorPair r0_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c1_hf); + HVX_VectorPair r1_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c0_hf); + HVX_VectorPair r1_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c1_hf); + + HVX_Vector r0_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c0_qf_p), Q6_V_hi_W(r0_c0_qf_p)); + HVX_Vector r0_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c1_qf_p), Q6_V_hi_W(r0_c1_qf_p)); + HVX_Vector r1_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c0_qf_p), Q6_V_hi_W(r1_c0_qf_p)); + HVX_Vector r1_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c1_qf_p), Q6_V_hi_W(r1_c1_qf_p)); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_qf, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_qf, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_qf, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_qf, r1_c1_sum)); + } + + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + + HVX_Vector r0_hf = Q6_V_vand_QV(bmask, x0[i]); + HVX_Vector r1_hf = Q6_V_vand_QV(bmask, x1[i]); + HVX_Vector c0_hf = Q6_V_vand_QV(bmask, y0[i]); + HVX_Vector c1_hf = Q6_V_vand_QV(bmask, y1[i]); + + HVX_VectorPair r0_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c0_hf); + HVX_VectorPair r0_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r0_hf, c1_hf); + HVX_VectorPair r1_c0_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c0_hf); + HVX_VectorPair r1_c1_qf_p = Q6_Wqf32_vmpy_VhfVhf(r1_hf, c1_hf); + + HVX_Vector r0_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c0_qf_p), Q6_V_hi_W(r0_c0_qf_p)); + HVX_Vector r0_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r0_c1_qf_p), Q6_V_hi_W(r0_c1_qf_p)); + HVX_Vector r1_c0_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c0_qf_p), Q6_V_hi_W(r1_c0_qf_p)); + HVX_Vector r1_c1_qf = Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(r1_c1_qf_p), Q6_V_hi_W(r1_c1_qf_p)); + + r0_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c0_qf, r0_c0_sum)); + r0_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_c1_qf, r0_c1_sum)); + r1_c0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c0_qf, r1_c0_sum)); + r1_c1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_c1_qf, r1_c1_sum)); + + } + + // Reduce and store results + HVX_Vector r0_r1_c0_sum = hvx_vec_reduce_sum_f32x2(r0_c0_sum, r1_c0_sum); + HVX_Vector r0_r1_c1_sum = hvx_vec_reduce_sum_f32x2(r0_c1_sum, r1_c1_sum); + + hvx_vec_store_u(&s0[0], 8, r0_r1_c0_sum); // row0,col0 row1,col0 + hvx_vec_store_u(&s1[0], 8, r0_r1_c1_sum); // row0,col1 row1,col1 +} + +static void vec_dot_f16_f16_uu_1x1(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { const HVX_UVector * restrict x = (const HVX_UVector *) vx; const HVX_UVector * restrict y = (const HVX_UVector *) vy; @@ -986,7 +1461,7 @@ static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * res hvx_vec_store_u(&s[0], 4, rsum); } -static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * restrict x, const void * restrict y) { +static void vec_dot_f16_f32_uu_1x1(const int n, float * restrict s, const void * restrict x, const void * restrict y) { const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x; const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y; @@ -1083,14 +1558,16 @@ static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * res const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; -#define htp_matmul_preamble \ - htp_matmul_tensors_preamble; \ - dma_queue *dma_queue = octx->ctx->dma[ith]; \ - uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; +#define htp_matmul_preamble \ + struct htp_matmul_context * mmctx = data; \ + struct htp_ops_context * octx = mmctx->octx; \ + htp_matmul_tensors_preamble; \ + dma_queue *dma_queue = octx->ctx->dma[ith]; \ + uint32_t src0_nrows_per_thread = mmctx->src0_nrows_per_thread; // *** matmul with support for 4d tensors and full broadcasting -static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { +static void matmul_4d(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; uint64_t t1, t2; @@ -1136,13 +1613,13 @@ static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) { - const uint32_t i13 = fastdiv(ir1, &octx->mm_div_ne12_ne1); - const uint32_t i12 = fastdiv(ir1 - i13 * ne12 * ne1, &octx->mm_div_ne1); + const uint32_t i13 = fastdiv(ir1, &mmctx->mm_div_ne12_ne1); + const uint32_t i12 = fastdiv(ir1 - i13 * ne12 * ne1, &mmctx->mm_div_ne1); const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1); // broadcast src0 into src1 - const uint32_t i03 = fastdiv(i13, &octx->mm_div_r3); - const uint32_t i02 = fastdiv(i12, &octx->mm_div_r2); + const uint32_t i03 = fastdiv(i13, &mmctx->mm_div_r3); + const uint32_t i02 = fastdiv(i12, &mmctx->mm_div_r2); const uint32_t i1 = i11; const uint32_t i2 = i12; @@ -1155,7 +1632,7 @@ static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end); for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) { const uint8_t * restrict src0_row = src0_base + ir0 * nb01; - mt->vec_dot(ne00, &dst_col[ir0], src0_row, src1_col); + mmctx->vec_dot_1x1(ne00, &dst_col[ir0], src0_row, src1_col); } } } @@ -1170,7 +1647,7 @@ static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx } // src1 tensor is already in VTCM spad -static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { +static void matmul_2d(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows @@ -1195,7 +1672,7 @@ static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx // Per-thread VTCM scratchpads for all tensors // Note that the entire src1 tensor is already in VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith; + uint8_t * restrict spad_dst = dst_spad->data + dst_spad->size_per_thread * ith; uint8_t * restrict spad_src0 = src0_spad->data + src0_spad->size_per_thread * ith; uint8_t * restrict src1_data = src1_spad->data; @@ -1219,11 +1696,21 @@ static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - #pragma unroll(2) - for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) { + // Process src1 columns in pairs (2×2 tiling) + uint32_t ir1 = 0; + for (; ir1 + 1 < src1_nrows; ir1 += 2) { + const uint8_t * restrict src1_col0 = (const uint8_t *) (src1_data + (ir1+0) * src1_stride); + const uint8_t * restrict src1_col1 = (const uint8_t *) (src1_data + (ir1+1) * src1_stride); + float * restrict dst_row0 = (float *) (dst->data + ((ir1+0) * dst_row_size)); + float * restrict dst_row1 = (float *) (dst->data + ((ir1+1) * dst_row_size)); + mmctx->vec_dot_2x2(ne00, &dst_row0[ir0], &dst_row1[ir0], ss0, ss0 + src0_stride, src1_col0, src1_col1); + } + + // Handle remaining src1 rows (fallback to 2×1) + for (; ir1 < src1_nrows; ++ir1) { const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size)); - mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_stride, src1_col); + mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_stride, src1_col); } // Prefetch next (n + spad_nrows) row @@ -1247,20 +1734,20 @@ static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) { const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size)); - mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col); + mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col); } } t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "matmul-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mt->type, ith, nth, + FARF(HIGH, "matmul-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } // q8x4x2 src1 tensor is already in VTCM spad -static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { +static void matvec_2d(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; const uint32_t src0_nrows = ne01; @@ -1311,7 +1798,7 @@ static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx // Process src0 rows for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - mt->vec_dot_rx2(ne00, &tmp[ir0 - src0_start_row], ss0, src0_stride, src1_col); + mmctx->vec_dot_2x1(ne00, &tmp[ir0 - src0_start_row], ss0, ss0 + src0_stride, src1_col); // Prefetch next (n + spad_nrows) row const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS); @@ -1329,14 +1816,14 @@ static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), src0_stride, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - mt->vec_dot(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); + mmctx->vec_dot_1x1(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); } hvx_copy_f32_ua((uint8_t *) &dst_col[src0_start_row], (uint8_t *) tmp, src0_end_row - src0_start_row); t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "matvec-%s %u/%u: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mt->type, ith, nth, + FARF(HIGH, "matvec-%s %u/%u: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); @@ -1350,7 +1837,7 @@ struct mmid_row_mapping { }; // src1 tensor is already in VTCM spad -static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { +static void matmul_id(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; struct htp_tensor * restrict ids = &octx->src2; @@ -1423,11 +1910,10 @@ static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx const int rm2 = row_mapping.i2; // token idx const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx - const uint8_t * restrict src1_col = - (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size); + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size); float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0)); - mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col); + mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col); } // Prefetch next (n + spad_nrows) row @@ -1453,25 +1939,24 @@ static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx const int rm2 = row_mapping.i2; // token idx const uint32_t ir1 = src1_nrows == 1 ? 0 : rm1; // src1 row idx - const uint8_t * restrict src1_col = - (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size); + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + (ir1 + rm2 * ne11 + 0) * src1_row_size); float * dst_row = (float *) (dst->data + (rm1 * nb1 + rm2 * nb2 + 0)); - mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col); + mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col); } } } t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "matmul-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mt->type, + FARF(HIGH, "matmul-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } // src1 tensor is already in VTCM spad -static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { +static void matvec_id(unsigned int nth, unsigned int ith, void * data) { htp_matmul_preamble; struct htp_tensor * restrict ids = &octx->src2; @@ -1531,7 +2016,7 @@ static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx // Process src0 rows for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col); + mmctx->vec_dot_2x1(ne00, &dst_row[ir0], ss0, ss0 + src0_row_size_padded, src1_col); // Prefetch next (n + spad_nrows) row const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS); @@ -1549,13 +2034,13 @@ static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), src0_row_size_padded, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col); + mmctx->vec_dot_1x1(ne00, &dst_row[ir0], ss0, src1_col); } } t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "matvec-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mt->type, + FARF(HIGH, "matvec-id-%s %d/%d: %ux%ux%ux%u (%u:%u) * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u usec %u\n", mmctx->type, ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], src2->ne[0], src2->ne[1], src2->ne[2], src2->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); @@ -1754,12 +2239,14 @@ static void quantize_row_f32_q8x4x2(float * restrict x, uint8_t * restrict y, ui hvx_copy_f16_ua(y_d, t_d, nb * 8); } -static void quantize_f32_q8x4x2(const struct htp_tensor * src, - uint8_t * restrict dst, - struct htp_spad * spad, - uint32_t nth, - uint32_t ith, - uint32_t nrows_per_thread) { +static void quantize_f32_q8x4x2(unsigned int nth, unsigned int ith, void * data) { + struct htp_matmul_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * src = &octx->src1; + uint8_t * restrict dst = octx->src1_spad.data; + struct htp_spad * spad = &octx->src0_spad; + uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; uint64_t t1 = HAP_perf_get_qtimer_count(); @@ -1799,8 +2286,14 @@ static void quantize_f32_q8x4x2(const struct htp_tensor * src, ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void quantize_f32_f16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith, - uint32_t nrows_per_thread, uint32_t dst_stride) { +static void quantize_f32_f16(unsigned int nth, unsigned int ith, void * data) { + struct htp_matmul_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * src = &octx->src1; + uint8_t * restrict dst = octx->src1_spad.data; + uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; + uint32_t dst_stride = octx->src1_spad.stride; uint64_t t1 = HAP_perf_get_qtimer_count(); @@ -1835,8 +2328,14 @@ static void quantize_f32_f16(const struct htp_tensor * src, uint8_t * restrict d } // TODO just a plain copy that should be done via the DMA during the Op setup -static void quantize_f16_f16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith, - uint32_t nrows_per_thread, uint32_t dst_stride) { +static void quantize_f16_f16(unsigned int nth, unsigned int ith, void * data) { + struct htp_matmul_context * mmctx = data; + struct htp_ops_context * octx = mmctx->octx; + + const struct htp_tensor * src = &octx->src1; + uint8_t * restrict dst = octx->src1_spad.data; + uint32_t nrows_per_thread = mmctx->src1_nrows_per_thread; + uint32_t dst_stride = octx->src1_spad.stride; uint64_t t1 = HAP_perf_get_qtimer_count(); @@ -1870,213 +2369,76 @@ static void quantize_f16_f16(const struct htp_tensor * src, uint8_t * restrict d ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -static void htp_quantize_f32_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - quantize_f32_q8x4x2(&octx->src1, octx->src1_spad.data, &octx->src0_spad, n, i, octx->src1_nrows_per_thread); -} - -static void htp_quantize_f32_f16(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - quantize_f32_f16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride); -} - -static void htp_quantize_f16_f16(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - quantize_f16_f16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride); -} - -// ** matmul/matvec callbacks for worker_pool - -static void htp_matvec_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - - matvec_2d(&mt, octx, n, i); -} - -static void htp_matmul_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - - matmul_2d(&mt, octx, n, i); -} - -static void htp_matvec_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q8x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q8x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - - matvec_2d(&mt, octx, n, i); -} - -static void htp_matmul_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q8x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q8x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - - matmul_2d(&mt, octx, n, i); -} - -static void htp_matvec_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "mxfp4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - - matvec_2d(&mt, octx, n, i); -} - -static void htp_matmul_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "mxfp4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - - matmul_2d(&mt, octx, n, i); -} - -static void htp_matvec_2d_f16_f16(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "f16-f16"; - mt.vec_dot = vec_dot_f16_f16_aa; - mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2; - - matvec_2d(&mt, octx, n, i); -} - -static void htp_matmul_2d_f16_f16(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "f16-f16"; - mt.vec_dot = vec_dot_f16_f16_aa; - mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2; - - matmul_2d(&mt, octx, n, i); -} - -static void htp_matmul_4d_f16_f32(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "f16-f32"; - mt.vec_dot = vec_dot_f16_f32_uu; - - matmul_4d(&mt, octx, n, i); -} - -static void htp_matmul_4d_f16_f16(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "f16-f16"; - mt.vec_dot = vec_dot_f16_f16_uu; - - matmul_4d(&mt, octx, n, i); -} - -// ** matmul-id callbacks for worker_pool - -static void htp_matvec_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - - matvec_id(&mt, octx, n, i); -} - -static void htp_matmul_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - - matmul_id(&mt, octx, n, i); -} - -static void htp_matvec_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q8x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q8x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - - matvec_id(&mt, octx, n, i); -} - -static void htp_matmul_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "q8x4x2-q8x4x2"; - mt.vec_dot = vec_dot_q8x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - - matmul_id(&mt, octx, n, i); -} - -static void htp_matvec_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "mxfp4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - - matvec_id(&mt, octx, n, i); -} - -static void htp_matmul_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { - struct htp_ops_context * octx = data; - - struct htp_matmul_type mt; - mt.type = "mxfp4x4x2-q8x4x2"; - mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; - mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - - matmul_id(&mt, octx, n, i); -} - -// ** main matmul entry point static inline bool htp_is_permuted(const struct htp_tensor * t) { return t->nb[0] > t->nb[1] || t->nb[1] > t->nb[2] || t->nb[2] > t->nb[3]; } +static int htp_mminit_vec_dot(struct htp_matmul_context * mmctx, enum htp_data_type type) { + switch (type) { + case HTP_TYPE_Q4_0: + mmctx->type = "q4x4x2-f32"; + mmctx->vec_dot_1x1 = vec_dot_q4x4x2_q8x4x2_1x1; + mmctx->vec_dot_2x1 = vec_dot_q4x4x2_q8x4x2_2x1; + mmctx->vec_dot_2x2 = vec_dot_q4x4x2_q8x4x2_2x2; + return 0; + case HTP_TYPE_Q8_0: + mmctx->type = "q8x4x2-f32"; + mmctx->vec_dot_1x1 = vec_dot_q8x4x2_q8x4x2_1x1; + mmctx->vec_dot_2x1 = vec_dot_q8x4x2_q8x4x2_2x1; + mmctx->vec_dot_2x2 = vec_dot_q8x4x2_q8x4x2_2x2; + return 0; + case HTP_TYPE_MXFP4: + mmctx->type = "mxfp4x4x2-f32"; + mmctx->vec_dot_1x1 = vec_dot_mxfp4x4x2_q8x4x2_1x1; + mmctx->vec_dot_2x1 = vec_dot_mxfp4x4x2_q8x4x2_2x1; + mmctx->vec_dot_2x2 = vec_dot_mxfp4x4x2_q8x4x2_2x2; + return 0; + default: + return -1; + } +} + +static void htp_mminit_spad(struct htp_ops_context * octx, + size_t dst_row_size, + size_t src0_row_size_padded, + size_t src1_row_size, + uint32_t src1_nrows, + size_t src2_spad_size_per_thread) { + octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); + + if (src2_spad_size_per_thread > 0) { + octx->src2_spad.size_per_thread = src2_spad_size_per_thread; + octx->src2_spad.size = octx->src2_spad.size_per_thread; + } + + // src0 spad is also used in dynamic quantizer to store padded src1 rows + size_t src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); + if (octx->src0_spad.size_per_thread < src1_row_size_padded) { + octx->src0_spad.size_per_thread = src1_row_size_padded; + } + + octx->src1_spad.size = octx->src1_spad.size_per_thread; + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; +} + int op_matmul(struct htp_ops_context * octx) { htp_matmul_tensors_preamble; - const char * op_type; + struct htp_matmul_context mmctx_struct = {0}; + struct htp_matmul_context * mmctx = &mmctx_struct; + mmctx->octx = octx; const uint32_t src0_nrows = ne01 * ne02 * ne03; const uint32_t src1_nrows = ne11 * ne12 * ne13; + // Compute src0_nrows_per_thread + mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; + mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even + const size_t src0_row_size = nb01; const size_t dst_row_size = nb1; size_t src1_row_size = nb11; @@ -2085,181 +2447,95 @@ int op_matmul(struct htp_ops_context * octx) { size_t src1_row_size_padded; worker_callback_t quant_job_func; - worker_callback_t matmul_job_func; + worker_callback_t matmul_job_func = src1_nrows > 1 ? matmul_2d : matvec_2d; bool need_quant = !(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE); - switch (src0->type) { - case HTP_TYPE_Q4_0: - op_type = "q4x4x2-f32"; - quant_job_func = htp_quantize_f32_q8x4x2; - if (src1_nrows > 1) { - matmul_job_func = htp_matmul_2d_q4x4x2_q8x4x2; - } else { - matmul_job_func = htp_matvec_2d_q4x4x2_q8x4x2; - } + if (src0->type == HTP_TYPE_F16) { + // Try optimized f16-f16 path first (src1 in VTCM) + const size_t f16_src1_row_size = hex_round_up(ne10 * 2, 128); + const size_t f16_src1_spad_size = hex_round_up(f16_src1_row_size * src1_nrows, 256); + const size_t f16_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads; + const size_t f16_dst_spad_size = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads; - src1_row_size = q8x4x2_row_size(ne10); // row size post quantization + const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size; - // Entire src1 tensor is placed into the VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size + // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting). + // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul. + const bool is_batched = (ne02 > 1) || (ne03 > 1); + const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1); + + if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) { + // Optimized path + quant_job_func = (src1->type == HTP_TYPE_F32) ? quantize_f32_f16 : quantize_f16_f16; + mmctx->type = "f16-f16"; + mmctx->vec_dot_1x1 = vec_dot_f16_f16_aa_1x1; + mmctx->vec_dot_2x1 = vec_dot_f16_f16_aa_2x1; + mmctx->vec_dot_2x2 = vec_dot_f16_f16_aa_2x2; + + src1_row_size = f16_src1_row_size; // row size post quantization octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); - // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); - if (octx->src0_spad.size_per_thread < src1_row_size_padded) { - octx->src0_spad.size_per_thread = src1_row_size_padded; - } - octx->src1_spad.size = octx->src1_spad.size_per_thread; octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - break; - - case HTP_TYPE_Q8_0: - op_type = "q8x4x2-f32"; - quant_job_func = htp_quantize_f32_q8x4x2; - if (src1_nrows > 1) { - matmul_job_func = htp_matmul_2d_q8x4x2_q8x4x2; + } else { + // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required + quant_job_func = NULL; + if (src1->type == HTP_TYPE_F32) { + mmctx->type = "f16-f32"; + mmctx->vec_dot_1x1 = vec_dot_f16_f32_uu_1x1; + matmul_job_func = matmul_4d; } else { - matmul_job_func = htp_matvec_2d_q8x4x2_q8x4x2; + mmctx->type = "f16-f16"; + mmctx->vec_dot_1x1 = vec_dot_f16_f16_uu_1x1; + matmul_job_func = matmul_4d; } - src1_row_size = q8x4x2_row_size(ne10); // row size post quantization - - // Entire src1 tensor is placed into the VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size + src1_row_size = nb11; // original row size in DDR octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); + octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256); + octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256); - // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); - if (octx->src0_spad.size_per_thread < src1_row_size_padded) { - octx->src0_spad.size_per_thread = src1_row_size_padded; - } - - octx->src1_spad.size = octx->src1_spad.size_per_thread; octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - break; - case HTP_TYPE_MXFP4: - op_type = "mxfp4x4x2-f32"; - quant_job_func = htp_quantize_f32_q8x4x2; - if (src1_nrows > 1) { - matmul_job_func = htp_matmul_2d_mxfp4x4x2_q8x4x2; - } else { - matmul_job_func = htp_matvec_2d_mxfp4x4x2_q8x4x2; - } + // Init fastdiv for matmul_4d (supports broadcasting) + mmctx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]); + mmctx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]); + mmctx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]); + mmctx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]); - src1_row_size = q8x4x2_row_size(ne10); // row size post quantization - - // Entire src1 tensor is placed into the VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size - - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); - - // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); - if (octx->src0_spad.size_per_thread < src1_row_size_padded) { - octx->src0_spad.size_per_thread = src1_row_size_padded; - } - - octx->src1_spad.size = octx->src1_spad.size_per_thread; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - break; - - case HTP_TYPE_F16: - { - // Try optimized f16-f16 path first (src1 in VTCM) - const size_t f16_src1_row_size = hex_round_up(ne10 * 2, 128); - const size_t f16_src1_spad_size = hex_round_up(f16_src1_row_size * src1_nrows, 256); - const size_t f16_src0_spad_size = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads; - const size_t f16_dst_spad_size = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads; - - const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size; - - // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting). - // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul. - const bool is_batched = (ne02 > 1) || (ne03 > 1); - const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1); - - if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) { - // Optimized path - op_type = "f16-f16"; - quant_job_func = (src1->type == HTP_TYPE_F32) ? htp_quantize_f32_f16 : htp_quantize_f16_f16; - if (src1_nrows > 1) { - matmul_job_func = htp_matmul_2d_f16_f16; - } else { - matmul_job_func = htp_matvec_2d_f16_f16; - } - - src1_row_size = f16_src1_row_size; // row size post quantization - - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); - - octx->src1_spad.size = octx->src1_spad.size_per_thread; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - } else { - // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required - quant_job_func = NULL; - if (src1->type == HTP_TYPE_F32) { - op_type = "f16-f32"; - matmul_job_func = htp_matmul_4d_f16_f32; - } else { - op_type = "f16-f16"; - matmul_job_func = htp_matmul_4d_f16_f16; - } - - src1_row_size = nb11; // original row size in DDR - - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256); - octx->src1_spad.size_per_thread = hex_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256); - - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - - // Init fastdiv for matmul_4d (supports broadcasting) - octx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]); - octx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]); - octx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]); - octx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]); - - need_quant = false; - } - } - break; - - default: + need_quant = false; + } + } else { + if (htp_mminit_vec_dot(mmctx, src0->type) != 0) { return HTP_STATUS_NO_SUPPORT; + } + + quant_job_func = quantize_f32_q8x4x2; + src1_row_size = q8x4x2_row_size(ne10); + htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, 0); } // VTCM scratchpads for all tensors size_t spad_size = octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size; - FARF(HIGH, "matmul-%s : src0-spad-size %u src1-spad-size %u dst-spad-size %u (%zu)\n", op_type, + FARF(HIGH, "matmul-%s : src0-spad-size %u src1-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type, octx->src0_spad.size, octx->src1_spad.size, octx->dst_spad.size, spad_size); - FARF(HIGH, "matmul-%s : %ux%ux%ux%u * %ux%ux%ux%u-> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", op_type, src0->ne[0], + FARF(HIGH, "matmul-%s : %ux%ux%ux%u * %ux%ux%ux%u-> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->data, src1->data, dst->data); // Make sure the reserved vtcm size is sufficient if (octx->ctx->vtcm_size < spad_size) { - FARF(ERROR, "matmul-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, + FARF(ERROR, "matmul-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, octx->ctx->vtcm_size, spad_size); return HTP_STATUS_VTCM_TOO_SMALL; } @@ -2268,40 +2544,32 @@ int op_matmul(struct htp_ops_context * octx) { octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; octx->dst_spad.data = octx->src1_spad.data + octx->src1_spad.size; - octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; - octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1); // round up to even - octx->src0_spad.stride = src0_row_size_padded; octx->src1_spad.stride = src1_row_size; if (need_quant) { - // Run quant jobs - const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); - octx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; - worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, octx, n_quant_jobs); + const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); + mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; + worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); } if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - // Run matmul jobs const uint32_t n_matmul_jobs = octx->n_threads; - worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, octx, n_matmul_jobs); + worker_pool_run_func(octx->ctx->worker_pool, matmul_job_func, mmctx, n_matmul_jobs); } return HTP_STATUS_OK; } -// ** main matmul-id entry point - int op_matmul_id(struct htp_ops_context * octx) { htp_matmul_tensors_preamble; + struct htp_matmul_context mmctx_struct = {0}; + struct htp_matmul_context * mmctx = &mmctx_struct; + mmctx->octx = octx; + struct htp_tensor * restrict ids = &octx->src2; - const char * op_type; - - worker_callback_t quant_job_func; - worker_callback_t matmul_id_job_func; - const size_t src0_row_size = nb01; const size_t dst_row_size = nb1; @@ -2310,6 +2578,13 @@ int op_matmul_id(struct htp_ops_context * octx) { const uint32_t src0_nrows = ne01; // per expert const uint32_t src1_nrows = ne11 * ne12 * ne13; + worker_callback_t quant_job_func; + worker_callback_t matmul_id_job_func = src1_nrows > 1 ? matmul_id : matvec_id; + + // Compute src0_nrows_per_thread + mmctx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; + mmctx->src0_nrows_per_thread += (mmctx->src0_nrows_per_thread & 1); // round up to even + size_t src1_row_size; size_t src1_row_size_padded; @@ -2320,112 +2595,29 @@ int op_matmul_id(struct htp_ops_context * octx) { size_t matrix_row_counts_size = n_as * sizeof(uint32_t); size_t matrix_row_map_size = n_as * ids->ne[0] * ids->ne[1] * sizeof(struct mmid_row_mapping); - switch (src0->type) { - case HTP_TYPE_Q4_0: - op_type = "q4x2x2-f32"; - quant_job_func = htp_quantize_f32_q8x4x2; - src1_row_size = q8x4x2_row_size(ne10); // row size post quantization - if (src1_nrows > 1) { - matmul_id_job_func = htp_matmul_id_q4x4x2_q8x4x2; - } else { - matmul_id_job_func = htp_matvec_id_q4x4x2_q8x4x2; - } - - // Entire src1 tensor is placed into the VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); - octx->src2_spad.size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256); - - // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); - if (octx->src0_spad.size_per_thread < src1_row_size_padded) { - octx->src0_spad.size_per_thread = src1_row_size_padded; - } - - octx->src2_spad.size = octx->src2_spad.size_per_thread; - octx->src1_spad.size = octx->src1_spad.size_per_thread; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - break; - - case HTP_TYPE_Q8_0: - op_type = "q8x2x2-f32"; - quant_job_func = htp_quantize_f32_q8x4x2; - src1_row_size = q8x4x2_row_size(ne10); // row size post quantization - if (src1_nrows > 1) { - matmul_id_job_func = htp_matmul_id_q8x4x2_q8x4x2; - } else { - matmul_id_job_func = htp_matvec_id_q8x4x2_q8x4x2; - } - - // Entire src1 tensor is placed into the VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); - octx->src2_spad.size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256); - - // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); - if (octx->src0_spad.size_per_thread < src1_row_size_padded) { - octx->src0_spad.size_per_thread = src1_row_size_padded; - } - - octx->src2_spad.size = octx->src2_spad.size_per_thread; - octx->src1_spad.size = octx->src1_spad.size_per_thread; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - break; - - case HTP_TYPE_MXFP4: - op_type = "mxfp4x2x2-f32"; - quant_job_func = htp_quantize_f32_q8x4x2; - src1_row_size = q8x4x2_row_size(ne10); // row size post quantization - if (src1_nrows > 1) { - matmul_id_job_func = htp_matmul_id_mxfp4x4x2_q8x4x2; - } else { - matmul_id_job_func = htp_matvec_id_mxfp4x4x2_q8x4x2; - } - - // Entire src1 tensor is placed into the VTCM - // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = hex_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = hex_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); - octx->src1_spad.size_per_thread = hex_round_up(src1_row_size * src1_nrows, 256); - octx->src2_spad.size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256); - - // src0 spad is also used in dynamic quantizer to store padded src1 rows - src1_row_size_padded = hex_round_up(src1_row_size, QK_Q8_0x4x2 * sizeof(float)); - if (octx->src0_spad.size_per_thread < src1_row_size_padded) { - octx->src0_spad.size_per_thread = src1_row_size_padded; - } - - octx->src2_spad.size = octx->src2_spad.size_per_thread; - octx->src1_spad.size = octx->src1_spad.size_per_thread; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; - break; - - default: - return HTP_STATUS_NO_SUPPORT; + if (htp_mminit_vec_dot(mmctx, src0->type) != 0) { + return HTP_STATUS_NO_SUPPORT; } + quant_job_func = quantize_f32_q8x4x2; + src1_row_size = q8x4x2_row_size(ne10); + + const size_t src2_spad_size_per_thread = hex_round_up(matrix_row_counts_size + matrix_row_map_size, 256); + htp_mminit_spad(octx, dst_row_size, src0_row_size_padded, src1_row_size, src1_nrows, src2_spad_size_per_thread); + size_t spad_size = octx->src2_spad.size + octx->src1_spad.size + octx->src0_spad.size + octx->dst_spad.size; - FARF(HIGH, "matmul-id-%s : src0-spad-size %u src1-spad-size %u src2-spad-size %u dst-spad-size %u (%zu)\n", op_type, + FARF(HIGH, "matmul-id-%s : src0-spad-size %u src1-spad-size %u src2-spad-size %u dst-spad-size %u (%zu)\n", mmctx->type, octx->src0_spad.size, octx->src1_spad.size, octx->src2_spad.size, octx->dst_spad.size, spad_size); - FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", op_type, + FARF(HIGH, "matmul-id-%s : %ux%ux%ux%u * %ux%ux%ux%u (%ux%ux%ux%u) -> %ux%ux%ux%u (0x%p, 0x%p, 0x%p)\n", mmctx->type, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], ids->ne[0], ids->ne[1], ids->ne[2], ids->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], src0->data, src1->data, dst->data); // Make sure the reserved vtcm size is sufficient if (octx->ctx->vtcm_size < spad_size) { - FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", op_type, - octx->ctx->vtcm_size, spad_size); + FARF(ERROR, "matmul-id-%s : current VTCM reservation %zu is too small, needed %zu\n", mmctx->type, octx->ctx->vtcm_size, spad_size); return HTP_STATUS_VTCM_TOO_SMALL; } @@ -2434,8 +2626,8 @@ int op_matmul_id(struct htp_ops_context * octx) { octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size; octx->dst_spad.data = octx->src2_spad.data + octx->src2_spad.size; - octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; - octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1); // round up to even + octx->src0_spad.stride = src0_row_size_padded; + octx->src1_spad.stride = src1_row_size; if (src1_nrows > 1) { // initialize matrix_row_counts and map @@ -2447,8 +2639,7 @@ int op_matmul_id(struct htp_ops_context * octx) { // group rows by src0 matrix for (uint32_t iid1 = 0; iid1 < ids->ne[1]; ++iid1) { // token idx for (uint32_t id = 0; id < n_ids; ++id) { // expert idx - const uint32_t i02 = - *(const uint32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]); + const uint32_t i02 = *(const uint32_t *) ((const uint8_t *) ids->data + iid1 * ids->nb[1] + id * ids->nb[0]); assert(i02 >= 0 && i02 < n_as); @@ -2460,16 +2651,14 @@ int op_matmul_id(struct htp_ops_context * octx) { // Setup worker pool callbacks if (!(octx->flags & HTP_OPFLAGS_SKIP_QUANTIZE)) { - // Run quant jobs const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); - octx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; - worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, octx, n_quant_jobs); + mmctx->src1_nrows_per_thread = (src1_nrows + n_quant_jobs - 1) / n_quant_jobs; + worker_pool_run_func(octx->ctx->worker_pool, quant_job_func, mmctx, n_quant_jobs); } if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { - // Run matmul-id jobs const uint32_t n_matmul_jobs = octx->n_threads; - worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, octx, n_matmul_jobs); + worker_pool_run_func(octx->ctx->worker_pool, matmul_id_job_func, mmctx, n_matmul_jobs); } return HTP_STATUS_OK; diff --git a/ggml/src/ggml-metal/ggml-metal-common.cpp b/ggml/src/ggml-metal/ggml-metal-common.cpp index 95627d3866..2eb9820bff 100644 --- a/ggml/src/ggml-metal/ggml-metal-common.cpp +++ b/ggml/src/ggml-metal/ggml-metal-common.cpp @@ -264,15 +264,26 @@ static std::vector ggml_metal_graph_optimize_reorder(const std::vector ggml_metal_graph_optimize_reorder(const std::vectorsrc[0]->nb[0] == ggml_type_size(op->src[0]->type)); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); char base[256]; char name[256]; - const char * op_str = "undefined"; + int op_num = -1; + switch (op->op) { - case GGML_OP_SUM_ROWS: - op_str = "sum_rows"; break; - case GGML_OP_MEAN: - op_str = "mean"; break; + case GGML_OP_SUM_ROWS: op_num = OP_SUM_ROWS_NUM_SUM_ROWS; break; + case GGML_OP_MEAN: op_num = OP_SUM_ROWS_NUM_MEAN; break; default: GGML_ABORT("fatal error"); }; - snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type)); + const char * t0_str = ggml_type_name(op->src[0]->type); + const char * t_str = ggml_type_name(op->type); - snprintf(name, 256, "%s", base); + const bool is_c4 = op->src[0]->ne[0] % 4 == 0; + + snprintf(base, 256, "kernel_sum_rows_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : ""); + snprintf(name, 256, "%s_op=%d", base, op_num); ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name); if (!res.pipeline) { - res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_int16(cv, op_num, FC_SUM_ROWS + 0); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); } res.smem = 32*sizeof(float); + if (is_c4) { + res.smem *= 4; + } + + res.c4 = is_c4; + return res; } diff --git a/ggml/src/ggml-metal/ggml-metal-device.m b/ggml/src/ggml-metal/ggml-metal-device.m index c714ef3add..3db7f12629 100644 --- a/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ggml/src/ggml-metal/ggml-metal-device.m @@ -1019,7 +1019,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_SIN: case GGML_OP_COS: case GGML_OP_LOG: - return ggml_is_contiguous_rows(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16); case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { case GGML_UNARY_OP_TANH: @@ -1039,7 +1039,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_SOFTPLUS: case GGML_UNARY_OP_EXPM1: - return ggml_is_contiguous_rows(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16); default: return false; } @@ -1067,8 +1067,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_MUL: case GGML_OP_DIV: case GGML_OP_ADD_ID: - return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ACC: + return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32; case GGML_OP_REPEAT: case GGML_OP_CONV_TRANSPOSE_1D: return true; @@ -1159,6 +1159,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_MUL_MAT: case GGML_OP_MUL_MAT_ID: return has_simdgroup_reduction; + case GGML_OP_SET: case GGML_OP_CPY: case GGML_OP_DUP: case GGML_OP_CONT: diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h index 952e1be076..383e0d6e93 100644 --- a/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ggml/src/ggml-metal/ggml-metal-impl.h @@ -82,6 +82,7 @@ #define FC_COUNT_EQUAL 1100 #define FC_UNARY 1200 #define FC_BIN 1300 +#define FC_SUM_ROWS 1400 // op-specific constants #define OP_FLASH_ATTN_EXT_NQPSG 8 @@ -118,6 +119,8 @@ #define OP_UNARY_NUM_SOFTPLUS 115 #define OP_UNARY_NUM_EXPM1 116 +#define OP_SUM_ROWS_NUM_SUM_ROWS 10 +#define OP_SUM_ROWS_NUM_MEAN 11 // kernel argument structs // diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index 7db95d1c84..3d5db0b79f 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -426,6 +426,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx); } break; + case GGML_OP_SET: + { + n_fuse = ggml_metal_op_set(ctx, idx); + } break; case GGML_OP_DUP: case GGML_OP_CPY: case GGML_OP_CONT: @@ -616,8 +620,8 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32); GGML_ASSERT(op->type == GGML_TYPE_F32); - GGML_ASSERT(ggml_is_contiguous(op->src[0])); - GGML_ASSERT(ggml_is_contiguous(op->src[1])); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[1])); const size_t pnb1 = ((const int32_t *) op->op_params)[0]; const size_t pnb2 = ((const int32_t *) op->op_params)[1]; @@ -667,10 +671,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { } ggml_metal_kargs_bin args = { - /*.ne00 =*/ ne00, - /*.ne01 =*/ ne01, - /*.ne02 =*/ ne02, - /*.ne03 =*/ ne03, + /*.ne00 =*/ ne10, + /*.ne01 =*/ ne11, + /*.ne02 =*/ ne12, + /*.ne03 =*/ ne13, /*.nb00 =*/ nb00, /*.nb01 =*/ pnb1, /*.nb02 =*/ pnb2, @@ -683,10 +687,10 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { /*.nb11 =*/ nb11, /*.nb12 =*/ nb12, /*.nb13 =*/ nb13, - /*.ne0 =*/ ne0, - /*.ne1 =*/ ne1, - /*.ne2 =*/ ne2, - /*.ne3 =*/ ne3, + /*.ne0 =*/ ne10, + /*.ne1 =*/ ne11, + /*.ne2 =*/ ne12, + /*.ne3 =*/ ne13, /*.nb0 =*/ nb0, /*.nb1 =*/ pnb1, /*.nb2 =*/ pnb2, @@ -703,7 +707,13 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); - const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00); + const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + int nth = 1; + + while (2*nth < args.ne0 && nth < nth_max) { + nth *= 2; + } ggml_metal_encoder_dispatch_threadgroups(enc, ne11, ne12, ne13, nth, 1, 1); @@ -904,6 +914,11 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + GGML_ASSERT(ggml_is_contiguous_rows(op->src[0])); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + ggml_metal_kargs_sum_rows args = { /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, @@ -925,21 +940,26 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op); + if (pipeline.c4) { + args.ne00 = ne00/4; + args.ne0 = ne0/4; + } + int nth = 32; // SIMD width - while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + while (nth < args.ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { nth *= 2; } nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); - nth = std::min(nth, ne00); + nth = std::min(nth, (int) args.ne00); const size_t smem = pipeline.smem; ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); @@ -1599,6 +1619,134 @@ int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint64_t, nb, op, nb); + + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]); + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + const size_t pnb1 = ((const int32_t *) op->op_params)[0]; + const size_t pnb2 = ((const int32_t *) op->op_params)[1]; + const size_t pnb3 = ((const int32_t *) op->op_params)[2]; + const size_t offs = ((const int32_t *) op->op_params)[3]; + + const bool inplace = (bool) ((const int32_t *) op->op_params)[4]; + + if (!inplace) { + // run a separete kernel to cpy src->dst + // not sure how to avoid this + // TODO: make a simpler cpy_bytes kernel + + //const id pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj; + auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); + + ggml_metal_kargs_cpy args = { + /*.nk0 =*/ ne00, + /*.ne00 =*/ ne00, + /*.ne01 =*/ ne01, + /*.ne02 =*/ ne02, + /*.ne03 =*/ ne03, + /*.nb00 =*/ nb00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne0 =*/ ne0, + /*.ne1 =*/ ne1, + /*.ne2 =*/ ne2, + /*.ne3 =*/ ne3, + /*.nb0 =*/ nb0, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, + }; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00); + + ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1); + + ggml_metal_op_concurrency_reset(ctx); + } + + auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[1]->type, op->type); + + GGML_ASSERT(ne10 % ggml_blck_size(op->src[1]->type) == 0); + + int64_t nk0 = ne10; + if (ggml_is_quantized(op->src[1]->type)) { + nk0 = ne10/16; + } else if (ggml_is_quantized(op->type)) { + nk0 = ne10/ggml_blck_size(op->type); + } + + int nth = std::min(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + // when rows are small, we can batch them together in a single threadgroup + int nrptg = 1; + + // TODO: relax this constraint in the future + if (ggml_blck_size(op->src[1]->type) == 1 && ggml_blck_size(op->type) == 1) { + if (nth > nk0) { + nrptg = (nth + nk0 - 1)/nk0; + nth = nk0; + + if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { + nrptg--; + } + } + } + + nth = std::min(nth, nk0); + + ggml_metal_kargs_cpy args = { + /*.nk0 =*/ nk0, + /*.ne00 =*/ ne10, + /*.ne01 =*/ ne11, + /*.ne02 =*/ ne12, + /*.ne03 =*/ ne13, + /*.nb00 =*/ nb10, + /*.nb01 =*/ nb11, + /*.nb02 =*/ nb12, + /*.nb03 =*/ nb13, + /*.ne0 =*/ ne10, + /*.ne1 =*/ ne11, + /*.ne2 =*/ ne12, + /*.ne3 =*/ ne13, + /*.nb0 =*/ ggml_element_size(op), + /*.nb1 =*/ pnb1, + /*.nb2 =*/ pnb2, + /*.nb3 =*/ pnb3, + }; + + const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1; + + bid_dst.offs += offs; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, bid_src1, 1); + ggml_metal_encoder_set_buffer (enc, bid_dst, 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne11 + nrptg - 1)/nrptg, ne12, ne13, nth, nrptg, 1); + + return 1; +} + int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.h b/ggml/src/ggml-metal/ggml-metal-ops.h index 29456d70d5..f3e38c7aa9 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ggml/src/ggml-metal/ggml-metal-ops.h @@ -59,6 +59,7 @@ int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx); int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx); int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx); int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_set (ggml_metal_op_t ctx, int idx); int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pool_1d (ggml_metal_op_t ctx, int idx); int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx); diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index a385a50b94..6c349aa0c9 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -77,6 +77,14 @@ static inline float dot(float x, float y) { return x*y; } +static inline float sum(float x) { + return x; +} + +static inline float sum(float4 x) { + return x[0] + x[1] + x[2] + x[3]; +} + // NOTE: this is not dequantizing - we are simply fitting the template template void dequantize_f32(device const float4x4 * src, short il, thread type4x4 & reg) { @@ -910,7 +918,7 @@ constant float a4_erf = -1.453152027f; constant float a5_erf = 1.061405429f; template -T erf_approx(T x) { +inline T erf_approx(T x) { T sign_x = sign(x); x = fabs(x); T t = 1.0f / (1.0f + p_erf * x); @@ -918,10 +926,27 @@ T erf_approx(T x) { return sign_x * y; } +template T elu_approx(T x); + +template<> inline float elu_approx(float x) { + return (x > 0.f) ? x : (exp(x) - 1); +} + +template<> inline float4 elu_approx(float4 x) { + float4 res; + + res[0] = (x[0] > 0.0f) ? x[0] : (exp(x[0]) - 1.0f); + res[1] = (x[1] > 0.0f) ? x[1] : (exp(x[1]) - 1.0f); + res[2] = (x[2] > 0.0f) ? x[2] : (exp(x[2]) - 1.0f); + res[3] = (x[3] > 0.0f) ? x[3] : (exp(x[3]) - 1.0f); + + return res; +} + constant short FC_unary_op [[function_constant(FC_UNARY + 0)]]; constant bool FC_unary_cnt[[function_constant(FC_UNARY + 1)]]; -template +template kernel void kernel_unary_impl( constant ggml_metal_kargs_unary & args, device const char * src0, @@ -963,111 +988,111 @@ kernel void kernel_unary_impl( } } - device const T0 & x = src0_ptr[i0]; + const TC x = (TC) src0_ptr[i0]; if (FC_OP == OP_UNARY_NUM_SCALE) { - dst_ptr[i0] = args.scale * x + args.bias; + dst_ptr[i0] = (T) (args.scale * x + args.bias); } if (FC_OP == OP_UNARY_NUM_FILL) { - dst_ptr[i0] = args.val; + dst_ptr[i0] = (T) args.val; } if (FC_OP == OP_UNARY_NUM_CLAMP) { - dst_ptr[i0] = clamp(x, args.min, args.max); + dst_ptr[i0] = (T) clamp(x, args.min, args.max); } if (FC_OP == OP_UNARY_NUM_SQR) { - dst_ptr[i0] = x * x; + dst_ptr[i0] = (T) (x * x); } if (FC_OP == OP_UNARY_NUM_SQRT) { - dst_ptr[i0] = sqrt(x); + dst_ptr[i0] = (T) sqrt(x); } if (FC_OP == OP_UNARY_NUM_SIN) { - dst_ptr[i0] = sin(x); + dst_ptr[i0] = (T) sin(x); } if (FC_OP == OP_UNARY_NUM_COS) { - dst_ptr[i0] = cos(x); + dst_ptr[i0] = (T) cos(x); } if (FC_OP == OP_UNARY_NUM_LOG) { - dst_ptr[i0] = log(x); + dst_ptr[i0] = (T) log(x); } if (FC_OP == OP_UNARY_NUM_LEAKY_RELU) { - dst_ptr[i0] = T(x > 0.0f)*x + T(x <= 0.0f)*(x * args.slope); + dst_ptr[i0] = (T) (TC(x > 0)*x + TC(x <= 0)*(x * args.slope)); } if (FC_OP == OP_UNARY_NUM_TANH) { - dst_ptr[i0] = precise::tanh(x); + dst_ptr[i0] = (T) precise::tanh(x); } if (FC_OP == OP_UNARY_NUM_RELU) { - dst_ptr[i0] = fmax(0.0f, x); + dst_ptr[i0] = (T) fmax(0, x); } if (FC_OP == OP_UNARY_NUM_SIGMOID) { - dst_ptr[i0] = 1.0f / (1.0f + exp(-x)); + dst_ptr[i0] = (T) (1 / (1 + exp(-x))); } if (FC_OP == OP_UNARY_NUM_GELU) { - dst_ptr[i0] = 0.5f*x*(1.0f + precise::tanh(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x))); + dst_ptr[i0] = (T) (0.5*x*(1 + precise::tanh(SQRT_2_OVER_PI*x*(1 + GELU_COEF_A*x*x)))); } if (FC_OP == OP_UNARY_NUM_GELU_ERF) { - dst_ptr[i0] = 0.5f*x*(1.0f + erf_approx(SQRT_2_INV*x)); + dst_ptr[i0] = (T) (0.5*x*(1 + erf_approx(SQRT_2_INV*x))); } if (FC_OP == OP_UNARY_NUM_GELU_QUICK) { - dst_ptr[i0] = x * (1.0f/(1.0f + exp(GELU_QUICK_COEF*x))); + dst_ptr[i0] = (T) (x * (1/(1 + exp(GELU_QUICK_COEF*x)))); } if (FC_OP == OP_UNARY_NUM_SILU) { - dst_ptr[i0] = x / (1.0f + exp(-x)); + dst_ptr[i0] = (T) (x / (1 + exp(-x))); } if (FC_OP == OP_UNARY_NUM_ELU) { - dst_ptr[i0] = T(x > 0.0f)*x + T(x <= 0.0f)*(exp(x) - 1.0f); + dst_ptr[i0] = (T) elu_approx(x); } if (FC_OP == OP_UNARY_NUM_NEG) { - dst_ptr[i0] = -x; + dst_ptr[i0] = (T) -x; } if (FC_OP == OP_UNARY_NUM_ABS) { - dst_ptr[i0] = fabs(x); + dst_ptr[i0] = (T) fabs(x); } if (FC_OP == OP_UNARY_NUM_SGN) { - dst_ptr[i0] = T(x > 0.0f) - T(x < 0.0f); + dst_ptr[i0] = T(x > 0) - T(x < 0); } if (FC_OP == OP_UNARY_NUM_STEP) { - dst_ptr[i0] = T(x > 0.0f); + dst_ptr[i0] = T(x > 0); } if (FC_OP == OP_UNARY_NUM_HARDSWISH) { - dst_ptr[i0] = x * fmax(0.0f, fmin(1.0f, x/6.0f + 0.5f)); + dst_ptr[i0] = (T) (x * fmax(0, fmin(1, x/6 + 0.5))); } if (FC_OP == OP_UNARY_NUM_HARDSIGMOID) { - dst_ptr[i0] = fmax(0.0f, fmin(1.0f, x/6.0f + 0.5f)); + dst_ptr[i0] = (T) fmax(0, fmin(1, x/6 + 0.5)); } if (FC_OP == OP_UNARY_NUM_EXP) { - dst_ptr[i0] = exp(x); + dst_ptr[i0] = (T) exp(x); } if (FC_OP == OP_UNARY_NUM_SOFTPLUS) { - dst_ptr[i0] = select(log(1.0f + exp(x)), x, x > 20.0f); + dst_ptr[i0] = (T) select(log(1 + exp(x)), x, x > 20); } if (FC_OP == OP_UNARY_NUM_EXPM1) { // TODO: precise implementation - dst_ptr[i0] = exp(x) - 1.0f; + dst_ptr[i0] = (T) (exp(x) - 1); } } @@ -1075,11 +1100,12 @@ kernel void kernel_unary_impl( #undef FC_CNT } -typedef decltype(kernel_unary_impl) kernel_unary_t; - -template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl; -template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl; +typedef decltype(kernel_unary_impl) kernel_unary_t; +template [[host_name("kernel_unary_f32_f32")]] kernel kernel_unary_t kernel_unary_impl; +template [[host_name("kernel_unary_f32_f32_4")]] kernel kernel_unary_t kernel_unary_impl; +template [[host_name("kernel_unary_f16_f16")]] kernel kernel_unary_t kernel_unary_impl; +template [[host_name("kernel_unary_f16_f16_4")]] kernel kernel_unary_t kernel_unary_impl; // OP: 0 - add, 1 - sub, 2 - mul, 3 - div constant short FC_bin_op [[function_constant(FC_BIN + 0)]]; @@ -1483,33 +1509,35 @@ kernel void kernel_op_sum_f32( } } -template -kernel void kernel_sum_rows( +constant short FC_sum_rows_op [[function_constant(FC_SUM_ROWS + 0)]]; + +template +kernel void kernel_sum_rows_impl( constant ggml_metal_kargs_sum_rows & args, - device const float * src0, - device float * dst, - threadgroup float * shmem_f32 [[threadgroup(0)]], + device const char * src0, + device char * dst, + threadgroup char * shmem [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort3 tpitg[[thread_position_in_threadgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]], ushort tiisg[[thread_index_in_simdgroup]], ushort3 ntg[[threads_per_threadgroup]]) { - int64_t i3 = tgpig.z; - int64_t i2 = tgpig.y; - int64_t i1 = tgpig.x; +#define FC_OP FC_sum_rows_op - if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) { - return; - } + const int i3 = tgpig.z; + const int i2 = tgpig.y; + const int i1 = tgpig.x; + + threadgroup T0 * shmem_t = (threadgroup T0 *) shmem; if (sgitg == 0) { - shmem_f32[tiisg] = 0.0f; + shmem_t[tiisg] = 0.0f; } - device const float * src_row = (device const float *) ((device const char *) src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); - device float * dst_row = (device float *) ((device char *) dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); + device const T0 * src_row = (device const T0 *) (src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03); + device T * dst_row = (device T *) (dst + i1*args.nb1 + i2*args.nb2 + i3*args.nb3); - float sumf = 0; + T0 sumf = T0(0.0f); for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) { sumf += src_row[i0]; @@ -1520,23 +1548,33 @@ kernel void kernel_sum_rows( threadgroup_barrier(mem_flags::mem_threadgroup); if (tiisg == 0) { - shmem_f32[sgitg] = sumf; + shmem_t[sgitg] = sumf; } threadgroup_barrier(mem_flags::mem_threadgroup); - sumf = shmem_f32[tiisg]; + sumf = shmem_t[tiisg]; sumf = simd_sum(sumf); if (tpitg.x == 0) { - dst_row[0] = norm ? sumf / args.ne00 : sumf; + if (FC_OP == OP_SUM_ROWS_NUM_MEAN) { + if (is_same::value) { + dst_row[0] = sum(sumf) / (4*args.ne00); + } else { + dst_row[0] = sum(sumf) / args.ne00; + } + } else { + dst_row[0] = sum(sumf); + } } + +#undef FC_OP } -typedef decltype(kernel_sum_rows) kernel_sum_rows_t; +typedef decltype(kernel_sum_rows_impl) kernel_sum_rows_t; -template [[host_name("kernel_sum_rows_f32")]] kernel kernel_sum_rows_t kernel_sum_rows; -template [[host_name("kernel_mean_f32")]] kernel kernel_sum_rows_t kernel_sum_rows; +template [[host_name("kernel_sum_rows_f32_f32")]] kernel kernel_sum_rows_t kernel_sum_rows_impl; +template [[host_name("kernel_sum_rows_f32_f32_4")]] kernel kernel_sum_rows_t kernel_sum_rows_impl; template kernel void kernel_cumsum_blk( @@ -2417,9 +2455,6 @@ kernel void kernel_solve_tri_f32( const short K = FC_solve_tri_k; const short NP = PAD2(N, NW); - const int32_t ne02 = args.ne02; - const int32_t ne03 = args.ne03; - const int32_t i03 = tgpig.z; const int32_t i02 = tgpig.y; const int32_t i01 = tgpig.x*NSG + sgitg; @@ -5931,7 +5966,7 @@ kernel void kernel_flash_attn_ext_vec( static_assert(DK4 % NL == 0, "DK4 must be divisible by NL"); static_assert(DV4 % NL == 0, "DV4 must be divisible by NL"); - const short T = PK + NSG*SH; // shared memory size per query in (half) + //const short T = PK + NSG*SH; // shared memory size per query in (half) //threadgroup q_t * sq = (threadgroup q_t *) (shmem_f16 + 0*PK); // holds the query data threadgroup q4_t * sq4 = (threadgroup q4_t *) (shmem_f16 + 0*PK); // same as above but in q4_t @@ -8519,7 +8554,9 @@ kernel void kernel_mul_mm( threadgroup S0 * sa = (threadgroup S0 *)(shmem); threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); +#ifdef GGML_METAL_HAS_TENSOR threadgroup float * sc = (threadgroup float *)(shmem); +#endif constexpr int NR0 = 64; constexpr int NR1 = 32; @@ -8642,8 +8679,8 @@ kernel void kernel_mul_mm( const short sx = (tiitg%NL1); const short sy = (tiitg/NL1)/8; - const short dx = sx; - const short dy = sy; + //const short dx = sx; + //const short dy = sy; const short ly = (tiitg/NL1)%8; @@ -8892,7 +8929,9 @@ kernel void kernel_mul_mm_id( threadgroup S0 * sa = (threadgroup S0 *)(shmem); threadgroup S1 * sb = (threadgroup S1 *)(shmem + 4096); +#ifdef GGML_METAL_HAS_TENSOR threadgroup float * sc = (threadgroup float *)(shmem); +#endif constexpr int NR0 = 64; constexpr int NR1 = 32; @@ -9027,8 +9066,8 @@ kernel void kernel_mul_mm_id( const short sx = (tiitg%NL1); const short sy = (tiitg/NL1)/8; - const short dx = sx; - const short dy = sy; + //const short dx = sx; + //const short dy = sy; const short ly = (tiitg/NL1)%8; diff --git a/ggml/src/ggml-opencl/CMakeLists.txt b/ggml/src/ggml-opencl/CMakeLists.txt index fa5fadd112..f389193691 100644 --- a/ggml/src/ggml-opencl/CMakeLists.txt +++ b/ggml/src/ggml-opencl/CMakeLists.txt @@ -85,6 +85,9 @@ set(GGML_OPENCL_KERNELS mul_mv_q4_0_f32_8x_flat mul_mv_q4_0_f32_1d_8x_flat mul_mv_q4_0_f32_1d_16x_flat + mul_mv_q4_1_f32 + mul_mv_q4_1_f32_flat + mul_mv_q4_k_f32 mul_mv_q6_k_f32 mul_mv_q6_k_f32_flat mul_mv_q8_0_f32 @@ -100,7 +103,10 @@ set(GGML_OPENCL_KERNELS gemv_moe_mxfp4_f32 mul_mm_f32_f32_l4_lm mul_mm_f16_f32_l4_lm + mul_mm_q4_0_f32_l4_lm + mul_mm_q4_1_f32_l4_lm mul_mm_q8_0_f32_l4_lm + mul_mm_q6_k_f32_l4_lm mul_mm_q8_0_f32_8x4 gemv_noshuffle_general_q8_0_f32 mul diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp index 508b2b8f03..ae3f79fd0d 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp @@ -525,6 +525,7 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mm_f16_f32_kq; cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v; cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0; + cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1; cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans; cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans; cl_kernel kernel_mul_mat_q4_0_f32_8x_flat; @@ -532,6 +533,9 @@ struct ggml_backend_opencl_context { cl_kernel kernel_restore_block_q4_0_noshuffle; cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K; cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat; + cl_kernel kernel_mul_mv_q4_1_f32; + cl_kernel kernel_mul_mv_q4_1_f32_flat; + cl_kernel kernel_mul_mv_q4_K_f32; cl_kernel kernel_mul_mv_q6_K_f32; cl_kernel kernel_mul_mv_q6_K_f32_flat; cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat; @@ -563,7 +567,10 @@ struct ggml_backend_opencl_context { cl_kernel kernel_mul_mv_id_mxfp4_f32_flat; cl_kernel kernel_mul_mm_f32_f32_l4_lm; cl_kernel kernel_mul_mm_f16_f32_l4_lm; + cl_kernel kernel_mul_mm_q4_0_f32_l4_lm; + cl_kernel kernel_mul_mm_q4_1_f32_l4_lm; cl_kernel kernel_mul_mm_q8_0_f32_l4_lm; + cl_kernel kernel_mul_mm_q6_k_f32_l4_lm; std::vector profiling_info; @@ -886,6 +893,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve CL_CHECK((backend_ctx->kernel_restore_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0_noshuffle", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0", &err), err)); + CL_CHECK((backend_ctx->kernel_convert_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1", &err), err)); + CL_CHECK((backend_ctx->kernel_restore_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err)); CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err)); CL_CHECK((backend_ctx->kernel_restore_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4_trans", &err), err)); @@ -1117,6 +1126,57 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mv_q4_1_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_1_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_1_f32.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q4_1_f32 = clCreateKernel(prog, "kernel_mul_mv_q4_1_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // mul_mv_q4_1_f32_flat + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_1_f32_flat.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_1_f32_flat.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q4_1_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q4_1_f32_flat", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + + // mul_mv_q4_k_f32 + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mv_q4_k_f32.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mv_q4_k_f32.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mv_q4_K_f32 = clCreateKernel(prog, "kernel_mul_mv_q4_K_f32", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mv_q6_k_f32 { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -1342,6 +1402,38 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mm_q4_0_f32_l4_lm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_q4_0_f32_l4_lm.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_q4_0_f32_l4_lm.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mm_q4_0_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q4_0_f32_l4_lm", &err), err)); + GGML_LOG_CONT("."); + } + + // mul_mm_q4_1_f32_l4_lm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_q4_1_f32_l4_lm.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_q4_1_f32_l4_lm.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mm_q4_1_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q4_1_f32_l4_lm", &err), err)); + GGML_LOG_CONT("."); + } + // mul_mm_q8_0_f32_l4_lm { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -1358,6 +1450,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve GGML_LOG_CONT("."); } + // mul_mm_q6_k_f32_l4_lm + { +#ifdef GGML_OPENCL_EMBED_KERNELS + const std::string kernel_src { + #include "mul_mm_q6_k_f32_l4_lm.cl.h" + }; +#else + const std::string kernel_src = read_file("mul_mm_q6_k_f32_l4_lm.cl"); +#endif + cl_program prog = + build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts); + + CL_CHECK((backend_ctx->kernel_mul_mm_q6_k_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q6_k_f32_l4_lm", &err), err)); + CL_CHECK(clReleaseProgram(prog)); + GGML_LOG_CONT("."); + } + // mul_mm_f16_f32_kq_kqv { #ifdef GGML_OPENCL_EMBED_KERNELS @@ -2887,6 +2996,59 @@ struct ggml_tensor_extra_cl_q4_0 { } }; +struct ggml_tensor_extra_cl_q4_1 { + // Quantized values. + cl_mem q = nullptr; + // Quantized values in image1d_buffer_t. + cl_mem q_img = nullptr; + // Scales. + cl_mem d = nullptr; + // Scales in image1d_buffer_t. + cl_mem d_img = nullptr; + // Min + cl_mem m = nullptr; + // Min in image1d_buffer_t. + cl_mem m_img = nullptr; + // Size of quantized values. + size_t size_q = 0; + // Size of scales. + size_t size_d = 0; + // Size of min values. + size_t size_m = 0; + + ~ggml_tensor_extra_cl_q4_1() { + reset(); + } + + void reset() { + // q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer. + // They must be properly released so that the original buffer can be + // properly released to avoid memory leak. + if (q != nullptr) { + CL_CHECK(clReleaseMemObject(q)); + q = nullptr; + } + if (d != nullptr) { + CL_CHECK(clReleaseMemObject(d)); + d = nullptr; + } + if (m != nullptr) { + CL_CHECK(clReleaseMemObject(m)); + m = nullptr; + } + // Currently, q_img and d_img are only initialized when SMALL_ALLOC is + // enabled. They point to the images in ggml_backend_opencl_buffer_context. + // So, there is no need to release them here. + // TODO: initialize them for non SMALL_PATH path, or remove them. + q_img = nullptr; + d_img = nullptr; + m_img = nullptr; + size_q = 0; + size_d = 0; + size_m = 0; + } +}; + struct ggml_tensor_extra_cl_mxfp4 { // Quantized values. cl_mem q = nullptr; @@ -3363,7 +3525,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te return true; } else if (op->src[0]->type == GGML_TYPE_F32) { return op->src[1]->type == GGML_TYPE_F32; - } else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_MXFP4 || + } else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 || + op->src[0]->type == GGML_TYPE_MXFP4 || + op->src[0]->type == GGML_TYPE_Q4_K || op->src[0]->type == GGML_TYPE_Q6_K) { return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]); } else if (op->src[0]->type == GGML_TYPE_Q8_0) { @@ -3592,6 +3756,21 @@ struct ggml_backend_opencl_buffer_context { return extra; } + ggml_tensor_extra_cl_q4_1 * ggml_opencl_alloc_temp_tensor_extra_q4_1() { + ggml_tensor_extra_cl_q4_1 * extra; + if (temp_tensor_extras_q4_1.empty()) { + extra = new ggml_tensor_extra_cl_q4_1(); + } else { + extra = temp_tensor_extras_q4_1.back(); + temp_tensor_extras_q4_1.pop_back(); + } + + temp_tensor_extras_q4_1_in_use.push_back(extra); + + extra->reset(); + return extra; + } + ggml_tensor_extra_cl_mxfp4 * ggml_opencl_alloc_temp_tensor_extra_mxfp4() { ggml_tensor_extra_cl_mxfp4 * extra; if (temp_tensor_extras_mxfp4.empty()) { @@ -3648,6 +3827,11 @@ struct ggml_backend_opencl_buffer_context { } temp_tensor_extras_q4_0_in_use.clear(); + for (ggml_tensor_extra_cl_q4_1 * e : temp_tensor_extras_q4_1_in_use) { + temp_tensor_extras_q4_1.push_back(e); + } + temp_tensor_extras_q4_1_in_use.clear(); + for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4_in_use) { temp_tensor_extras_mxfp4.push_back(e); } @@ -3673,6 +3857,8 @@ struct ggml_backend_opencl_buffer_context { std::vector temp_tensor_extras_in_use; std::vector temp_tensor_extras_q4_0; std::vector temp_tensor_extras_q4_0_in_use; + std::vector temp_tensor_extras_q4_1; + std::vector temp_tensor_extras_q4_1_in_use; std::vector temp_tensor_extras_mxfp4; std::vector temp_tensor_extras_mxfp4_in_use; std::vector temp_tensor_extras_q8_0; @@ -4042,6 +4228,75 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer, return; } + if (tensor->type == GGML_TYPE_Q4_1) { + ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; + GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); + + // Allocate the new extra and create aliases from the original. + ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; + ggml_tensor_extra_cl_q4_1 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q4_1(); + + size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_m = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t); + size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2; + GGML_ASSERT(size_d + size_m + size_q == ggml_nbytes(tensor) && "Incorrect tensor size"); + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + CL_CHECK(clEnqueueWriteBuffer( + queue, data_device, CL_TRUE, 0, + ggml_nbytes(tensor), data, 0, NULL, NULL)); + + cl_buffer_region region; + + // The original tensor memory is divided into scales and quants, i.e., + // we first store scales, mins, then quants. + // Create subbuffer for scales. + region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment); + region.size = size_d; + extra->d = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + auto previous_origin = region.origin; + + // Create subbuffer for mins. + region.origin = align_to(previous_origin + size_d, backend_ctx->alignment); + region.size = size_m; + extra->m = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + previous_origin = region.origin; + + // Create subbuffer for quants. + region.origin = align_to(previous_origin + size_m, backend_ctx->alignment); + region.size = size_q; + extra->q = clCreateSubBuffer( + extra_orig->data_device, CL_MEM_READ_WRITE, + CL_BUFFER_CREATE_TYPE_REGION, ®ion, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->m)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {64, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clReleaseMemObject(data_device)); + + tensor->extra = extra; + + return; + } if (tensor->type == GGML_TYPE_MXFP4) { ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra; GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized"); @@ -4544,7 +4799,35 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer, size, data, 0, NULL, NULL)); CL_CHECK(clReleaseMemObject(data_device)); return; - } else if (tensor->type == GGML_TYPE_MXFP4) { + } + if (tensor->type == GGML_TYPE_Q4_1) { + ggml_tensor_extra_cl_q4_1 * extra = (ggml_tensor_extra_cl_q4_1 *)tensor->extra; + + cl_int err; + cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE, + ggml_nbytes(tensor), NULL, &err); + CL_CHECK(err); + + cl_kernel kernel = backend_ctx->kernel_restore_block_q4_1; + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->m)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &data_device)); + + size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1}; + size_t local_work_size[] = {1, 1, 1}; + + cl_event evt; + CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, + global_work_size, local_work_size, 0, NULL, &evt)); + CL_CHECK(clWaitForEvents(1, &evt)); + CL_CHECK(clEnqueueReadBuffer( + queue, data_device, CL_TRUE, offset, + size, data, 0, NULL, NULL)); + CL_CHECK(clReleaseMemObject(data_device)); + return; + } + if (tensor->type == GGML_TYPE_MXFP4) { ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *)tensor->extra; cl_int err; @@ -8372,6 +8655,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co #ifdef GGML_OPENCL_SOA_Q ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra; + ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra; ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra; ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra; ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra; @@ -8885,6 +9169,91 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); return; } + case GGML_TYPE_Q4_0: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_q4_0_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_0->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } + case GGML_TYPE_Q4_1: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_q4_1_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_1->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_1->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_1->m)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } case GGML_TYPE_Q8_0: { if (ne11 < 32) { break; @@ -8927,6 +9296,50 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); return; } + case GGML_TYPE_Q6_K: { + if (ne11 < 32) { + break; + } + if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) { + break; + } + + kernel = backend_ctx->kernel_mul_mm_q6_k_f32_l4_lm; + nth0 = 128; // calculated as (BM*BN)/(TM*TN) + + int batch_stride_a = ne00*ne01; + int batch_stride_b = ne10*ne11; + int batch_stride_d = ne0*ne1; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q6_K->ql)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q6_K->qh)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q6_K->s)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q6_K->d)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); // stride_a + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); // stride_b + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne01)); // stride_d + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_a)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_b)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &batch_stride_d)); + CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r3)); + + // 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed. + size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, 1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); + return; + } default: break; } @@ -9181,7 +9594,71 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); #endif // GGML_OPENCL_SOA_Q break; - case GGML_TYPE_Q4_1: + case GGML_TYPE_Q4_1: { +#ifdef GGML_OPENCL_SOA_Q + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + kernel = backend_ctx->kernel_mul_mv_q4_1_f32_flat; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_1->q)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_1->d)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_1->m)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &r3)); +#else + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + kernel = backend_ctx->kernel_mul_mv_q4_1_f32; + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3)); +#endif // GGML_OPENCL_SOA_Q + break; + } case GGML_TYPE_Q8_0: { #ifdef GGML_OPENCL_SOA_Q kernel = backend_ctx->kernel_mul_mv_q8_0_f32_flat; @@ -9262,7 +9739,42 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co } case GGML_TYPE_Q2_K: case GGML_TYPE_Q3_K: - case GGML_TYPE_Q4_K: + case GGML_TYPE_Q4_K: { + kernel = backend_ctx->kernel_mul_mv_q4_K_f32; + + if (backend_ctx->gpu_family == INTEL) { + nth0 = 16; + nth1 = 1; + ndst = 4; + } else if (backend_ctx->gpu_family == ADRENO) { + nth0 = 64; + nth1 = 1; + ndst = 4; + } else { + GGML_ASSERT(false && "TODO: Unknown GPU"); + } + + CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device)); + CL_CHECK(clSetKernelArg(kernel, 1, sizeof(int), &offset0)); + CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device)); + CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &offset1)); + CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device)); + CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &offsetd)); + CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00)); + CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01)); + CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01)); + CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02)); + CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03)); + CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12)); + CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11)); + CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12)); + CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13)); + CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne0)); + CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne1)); + CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2)); + CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3)); + break; + } case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: #ifdef GGML_OPENCL_SOA_Q @@ -9424,7 +9936,10 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } else if (src0t == GGML_TYPE_Q4_K) { - GGML_ASSERT(false && "not implemented"); + size_t global_work_size[] = {(size_t)(ne01+ndst*nth1-1)/(ndst*nth1)*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13}; + size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1}; + + backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst); } else if (src0t == GGML_TYPE_Q3_K) { GGML_ASSERT(false && "not implemented"); } else if (src0t == GGML_TYPE_Q5_K) { diff --git a/ggml/src/ggml-opencl/kernels/cvt.cl b/ggml/src/ggml-opencl/kernels/cvt.cl index 9fb434713d..2c244ce321 100644 --- a/ggml/src/ggml-opencl/kernels/cvt.cl +++ b/ggml/src/ggml-opencl/kernels/cvt.cl @@ -46,6 +46,15 @@ struct block_q4_0 uint8_t qs[QK4_0 / 2]; }; +//------------------------------------------------------------------------------ +// block_q4_1 +//------------------------------------------------------------------------------ +struct block_q4_1 { + half d; // delta + half m; // min + uchar qs[QK4_1 / 2]; // nibbles / quants +}; + //------------------------------------------------------------------------------ // block_q6_K //------------------------------------------------------------------------------ @@ -148,6 +157,48 @@ kernel void kernel_restore_block_q4_0_noshuffle( } } +//------------------------------------------------------------------------------ +// kernel_convert_block_q4_1 +// Convert the block_q4_1 format to 2 separate arrays (AOS -> SOA). +// This kernel does not deshuffle the bits. +//------------------------------------------------------------------------------ +kernel void kernel_convert_block_q4_1( + global struct block_q4_1 * src0, + global uchar * dst_q, + global half * dst_d, + global half * dst_m +) { + global struct block_q4_1 * b = (global struct block_q4_1 *) src0 + get_global_id(0); + global uchar * q = (global uchar *) dst_q + QK4_1/2*get_global_id(0); + global half * d = (global half *) dst_d + get_global_id(0); + global half * m = (global half *) dst_m + get_global_id(0); + + *d = b->d; + *m = b->m; + + for (int i = 0; i < QK4_1/2; ++i) { + q[i] = b->qs[i]; + } +} + +kernel void kernel_restore_block_q4_1( + global uchar * src_q, + global half * src_d, + global half * src_m, + global struct block_q4_1 * dst +) { + global struct block_q4_1 * b = (global struct block_q4_1 *) dst + get_global_id(0); + global uchar * q = (global uchar *) src_q + QK4_1/2*get_global_id(0); + global half * d = (global half *) src_d + get_global_id(0); + global half * m = (global half *) src_m + get_global_id(0); + + b->d = *d; + b->m = *m; + for (int i = 0; i < QK4_1/2; ++i) { + b->qs[i] = q[i]; + } +} + //------------------------------------------------------------------------------ // block_mxfp4 //------------------------------------------------------------------------------ diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl new file mode 100644 index 0000000000..4100e3080a --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q4_0_f32_l4_lm.cl @@ -0,0 +1,163 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 8 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q4_0_f32_l4_lm( + global uchar4 * src0_q, + global half * src0_d, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + int ib = idx / 4; + int iqs = idx % 4; + + float d = (float)src0_d[ib]; + global uchar4 * qs = src0_q + ib*4 + iqs; + uchar4 q = *qs; + float4 v1 = (convert_float4((uchar4)((q.s0 )&0x0F, (q.s1 )&0x0F, (q.s2 )&0x0F, (q.s3 )&0x0F)) - 8.0f)*d; + float4 v2 = (convert_float4((uchar4)((q.s0>>4)&0x0F, (q.s1>>4)&0x0F, (q.s2>>4)&0x0F, (q.s3>>4)&0x0F)) - 8.0f)*d; + + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3; + } else { + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl new file mode 100644 index 0000000000..d0d2f08361 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q4_1_f32_l4_lm.cl @@ -0,0 +1,165 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 8 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q4_1_f32_l4_lm( + global uchar4 * src0_q, + global half * src0_d, + global half * src0_m, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + int ib = idx / 4; + int iqs = idx % 4; + + float d = (float)src0_d[ib]; + float m = (float)src0_m[ib]; + global uchar4 * qs = src0_q + ib*4 + iqs; + uchar4 q = *qs; + float4 v1 = (convert_float4((uchar4)((q.s0 )&0x0F, (q.s1 )&0x0F, (q.s2 )&0x0F, (q.s3 )&0x0F)))*d + m; + float4 v2 = (convert_float4((uchar4)((q.s0>>4)&0x0F, (q.s1>>4)&0x0F, (q.s2>>4)&0x0F, (q.s3>>4)&0x0F)))*d + m; + + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3; + } else { + buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl b/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl new file mode 100644 index 0000000000..3602c92fef --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mm_q6_k_f32_l4_lm.cl @@ -0,0 +1,158 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#define LOAD_VEC_A 2 +#define LOAD_VEC_B 4 + +#define BM 64 +#define BN 64 +#define BK 32 +#define TM 4 +#define TN 8 + +kernel void kernel_mul_mm_q6_k_f32_l4_lm( + global uchar * src0_ql, + global uchar * src0_qh, + global char * src0_s, + global half * src0_d, + global float4 * src1, + ulong offset1, + global float * dst, + ulong offsetd, + + int ne00, + int ne01, + int ne02, + int ne11, + int ne12, + + int stride_a, + int stride_b, + int stride_d, + + int batch_stride_a, + int batch_stride_b, + int batch_stride_d, + + int r2, + int r3 +) { + src1 = (global float4*)((global char*)src1 + offset1); + dst = (global float *)((global char*)dst + offsetd); + + local float buf_a[BM * BK]; + local float buf_b[BN * BK]; + + const int batch_idx = get_global_id(2); + + const int i13 = batch_idx / ne12; + const int i12 = batch_idx % ne12; + + const int i03 = i13 / r3; + const int i02 = i12 / r2; + + const int batch_idx_a = i03 * ne02 + i02; + + const int ir = get_group_id(0); + const int ic = get_group_id(1); + + const int tid = get_local_id(0); + const int th_r = tid % (BM / TM); + const int th_c = tid / (BM / TM); + + const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A); + const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A); + const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B); + const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B); + + const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK; + const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK; + + int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A; + int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B; + + float sums[TM * TN]; + float cache_a[TM]; + float cache_b[TN]; + + for (int i = 0; i < TM * TN; i++) { + sums[i] = 0.0f; + } + + for (int block = 0; block < ne00; block += BK) { + for (int l = 0; l < BM; l += loadstride_a) { + if (ir*BM + loadc_a + l < ne01) { + int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a; + + int ib = idx / 128; // 2 values per idx + int iqs = idx % 128; // 0..127 + + int n = iqs / 64; // 0,1 + int b = (iqs % 64) / 32; // 0,1 + int is_b = (iqs % 16) / 8; // 0,1 + int qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 + int is = 8 * n + qhshift + is_b; // 0..15 + int qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126 + int qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 + + float dscale = (float)src0_d[ib] * (float)src0_s[ib*16 + is]; + + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 0] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 0] >> qhshift) & 3) << 4)) - 32); + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 1] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 1] >> qhshift) & 3) << 4)) - 32); + } else { + buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f; + buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f; + } + } + + for (int l = 0; l < BN; l += loadstride_b) { + if (ic*BN + loadc_b + l < ne11) { + int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b; + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3; + } else { + buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f; + buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f; + } + } + + barrier(CLK_LOCAL_MEM_FENCE); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + + for (int i = 0; i < BK; i++) { + for (int j = 0; j < TM; j++) { + cache_a[j] = buf_a[(i) * BM + th_r * TM + j]; + } + + for (int j = 0; j < TN; j++) { + cache_b[j] = buf_b[(i) * BN + th_c * TN + j]; + } + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + const int sums_idx = cc*TM + cr; + sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]); + } + } + } + barrier(CLK_LOCAL_MEM_FENCE); + } + + const int dr = ir * BM + th_r * TM; + const int dc = ic * BN + th_c * TN; + + const int offsets = batch_idx * batch_stride_d; + + for (int cc = 0; cc < TN; cc++) { + for (int cr = 0; cr < TM; cr++) { + if (dr + cr < ne01 && dc + cc < ne11) { + dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr]; + } + } + } +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl new file mode 100644 index 0000000000..6fe828f20e --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32.cl @@ -0,0 +1,219 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_1 32 + +struct block_q4_1 { + half d; // delta + half m; // min + uchar qs[QK4_1 / 2]; // nibbles / quants +}; + +inline float block_q4_1_dot_y( + global const struct block_q4_1 * qb_curr, + float sumy, + float16 yl, + int il +) { + float d = qb_curr->d; + float m = qb_curr->m; + + float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f); + + global const ushort * qs = ((global const ushort *) qb_curr + 2 + il/2); + + acc.s0 += yl.s0 * (qs[0] & 0x000F); + acc.s0 += yl.s1 * (qs[0] & 0x0F00); + acc.s0 += yl.s8 * (qs[0] & 0x00F0); + acc.s3 += yl.s9 * (qs[0] & 0xF000); + + acc.s0 += yl.s2 * (qs[1] & 0x000F); + acc.s1 += yl.s3 * (qs[1] & 0x0F00); + acc.s2 += yl.sa * (qs[1] & 0x00F0); + acc.s3 += yl.sb * (qs[1] & 0xF000); + + acc.s0 += yl.s4 * (qs[2] & 0x000F); + acc.s1 += yl.s5 * (qs[2] & 0x0F00); + acc.s2 += yl.sc * (qs[2] & 0x00F0); + acc.s3 += yl.sd * (qs[2] & 0xF000); + + acc.s0 += yl.s6 * (qs[3] & 0x000F); + acc.s1 += yl.s7 * (qs[3] & 0x0F00); + acc.s2 += yl.se * (qs[3] & 0x00F0); + acc.s3 += yl.sf * (qs[3] & 0xF000); + + return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3) + sumy * m; +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each subgroup works on 4 rows +#define N_SIMDGROUP 1 // number of subgroups in a thread group +#define N_SIMDWIDTH 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32( + global void * src0, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK4_1; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + global struct block_q4_1 * x = (global struct block_q4_1 *) src0 + offset0; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK4_1 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q4_1_dot_y(x+ib+0*nb, sumy, yl, il); + sumf.s1 += block_q4_1_dot_y(x+ib+1*nb, sumy, yl, il); + sumf.s2 += block_q4_1_dot_y(x+ib+2*nb, sumy, yl, il); + sumf.s3 += block_q4_1_dot_y(x+ib+3*nb, sumy, yl, il); + + yb += QK4_1 * (N_SIMDWIDTH/2); + } + + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q4_1_f32( + global void * src0, + ulong offset0, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = (global void*)((global char*)src0 + offset0); + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl new file mode 100644 index 0000000000..d7c4645d67 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_1_f32_flat.cl @@ -0,0 +1,229 @@ +#pragma OPENCL EXTENSION cl_khr_fp16 : enable + +#ifdef cl_intel_subgroups +#pragma OPENCL EXTENSION cl_intel_subgroups : enable +#else +#pragma OPENCL EXTENSION cl_khr_subgroups : enable +#endif + +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +#define QK4_1 32 + +struct block_q4_1 { + half d; // delta + half m; // min + uchar qs[QK4_1 / 2]; // nibbles / quants +}; + +inline float block_q4_1_dot_y_flat( + global const uchar * x, + global const half * dh, + global const half * mh, + float sumy, + float16 yl, + int il +) { + float d = *dh; + float m = *mh; + global const ushort * qs = ((global const ushort *) x + il/2); + + float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f); + + acc.s0 += yl.s0 * (qs[0] & 0x000F); + acc.s0 += yl.s1 * (qs[0] & 0x0F00); + acc.s0 += yl.s8 * (qs[0] & 0x00F0); + acc.s3 += yl.s9 * (qs[0] & 0xF000); + + acc.s0 += yl.s2 * (qs[1] & 0x000F); + acc.s1 += yl.s3 * (qs[1] & 0x0F00); + acc.s2 += yl.sa * (qs[1] & 0x00F0); + acc.s3 += yl.sb * (qs[1] & 0xF000); + + acc.s0 += yl.s4 * (qs[2] & 0x000F); + acc.s1 += yl.s5 * (qs[2] & 0x0F00); + acc.s2 += yl.sc * (qs[2] & 0x00F0); + acc.s3 += yl.sd * (qs[2] & 0xF000); + + acc.s0 += yl.s6 * (qs[3] & 0x000F); + acc.s1 += yl.s7 * (qs[3] & 0x0F00); + acc.s2 += yl.se * (qs[3] & 0x00F0); + acc.s3 += yl.sf * (qs[3] & 0xF000); + + return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3) + sumy * m; +} + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // each subgroup works on 4 rows +#define N_SIMDGROUP 1 // number of subgroups in a thread group +#define N_SIMDWIDTH 16 // assuming subgroup size is 16 +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +inline void mul_vec_q_n_f32_flat( + global void * src0_q, + global void * src0_d, + global void * src0_m, + global float * src1, + global float * dst, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + const ulong nb = ne00/QK4_1; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02); + + // The number of scales/mins is the same as the number of blocks. + ulong offset0_dm = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)); + // Each block contains QK4_1/2 uchars, hence offset for qs is as follows. + ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_1/2; + + global uchar * x = (global uchar *) src0_q + offset0_q; + global half * d = (global half *) src0_d + offset0_dm; + global half * m = (global half *) src0_m + offset0_dm; + global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1; + + float16 yl; + float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f); + + int ix = get_sub_group_local_id()/2; + int il = 8*(get_sub_group_local_id()%2); + + global float * yb = y + ix * QK4_1 + il; + + for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) { + float sumy = 0; + + sumy += yb[0]; + sumy += yb[1]; + sumy += yb[2]; + sumy += yb[3]; + sumy += yb[4]; + sumy += yb[5]; + sumy += yb[6]; + sumy += yb[7]; + + sumy += yb[16]; + sumy += yb[17]; + sumy += yb[18]; + sumy += yb[19]; + sumy += yb[20]; + sumy += yb[21]; + sumy += yb[22]; + sumy += yb[23]; + + + yl.s0 = yb[0]; + yl.s1 = yb[1]/256.f; + + yl.s2 = yb[2]; + yl.s3 = yb[3]/256.f; + + yl.s4 = yb[4]; + yl.s5 = yb[5]/256.f; + + yl.s6 = yb[6]; + yl.s7 = yb[7]/256.f; + + yl.s8 = yb[16]/16.f; + yl.s9 = yb[17]/4096.f; + + yl.sa = yb[18]/16.f; + yl.sb = yb[19]/4096.f; + + yl.sc = yb[20]/16.f; + yl.sd = yb[21]/4096.f; + + yl.se = yb[22]/16.f; + yl.sf = yb[23]/4096.f; + + sumf.s0 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 0*nb*QK4_1/2, d + ib + 0*nb, m + ib + 0*nb, sumy, yl, il); + sumf.s1 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 1*nb*QK4_1/2, d + ib + 1*nb, m + ib + 1*nb, sumy, yl, il); + sumf.s2 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 2*nb*QK4_1/2, d + ib + 2*nb, m + ib + 2*nb, sumy, yl, il); + sumf.s3 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 3*nb*QK4_1/2, d + ib + 3*nb, m + ib + 3*nb, sumy, yl, il); + + yb += QK4_1 * (N_SIMDWIDTH/2); + } + + float4 tot = (float4)( + sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1), + sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3) + ); + + if (get_sub_group_local_id() == 0) { + if (first_row + 0 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0; + } + if (first_row + 1 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1; + } + if (first_row + 2 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2; + } + if (first_row + 3 < ne01) { + dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3; + } + } +} + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q4_1_f32_flat( + global void * src0_q, + global void * src0_d, + global void * src0_m, + global float * src1, + ulong offset1, + global float * dst, + ulong offsetd, + int ne00, + int ne01, + int ne02, + int ne10, + int ne12, + int ne0, + int ne1, + int r2, + int r3 +) { + src1 = (global float*)((global char*)src1 + offset1); + dst = (global float*)((global char*)dst + offsetd); + + mul_vec_q_n_f32_flat(src0_q, src0_d, src0_m, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3); +} diff --git a/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl b/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl new file mode 100644 index 0000000000..71ab989821 --- /dev/null +++ b/ggml/src/ggml-opencl/kernels/mul_mv_q4_k_f32.cl @@ -0,0 +1,180 @@ +#ifdef cl_intel_required_subgroup_size +#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable +#define INTEL_GPU 1 +#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16))) +#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32))) +#elif defined(cl_qcom_reqd_sub_group_size) +#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable +#define ADRENO_GPU 1 +#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half"))) +#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full"))) +#endif + +//------------------------------------------------------------------------------ +// block_q4_K +//------------------------------------------------------------------------------ +#define QK_K 256 +#define K_SCALE_SIZE 12 + +// 8 blocks of 32 elements each +// weight is represented as x = a * q + b +typedef struct { + half d; // super-block scale for quantized scales + half dmin; // super-block scale for quantized mins + + uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uchar qs[QK_K/2]; // 4-bit quants +} block_q4_K; + +#undef N_DST +#undef N_SIMDGROUP +#undef N_SIMDWIDTH + +#ifdef INTEL_GPU +#define N_DST 4 // number of rows each SIMD group works on +#define N_SIMDGROUP 1 // number of SIMD groups in a thread group +#define N_SIMDWIDTH 16 // SIMD group size +#elif defined (ADRENO_GPU) +#define N_DST 4 +#define N_SIMDGROUP 1 +#define N_SIMDWIDTH 64 +#endif + +#undef BLOCK_STRIDE +// number of (super) blocks each subgroup processes +// each thread in a subgroup processes a block (32 weights) +#define BLOCK_STRIDE (N_SIMDWIDTH/8) + +#ifdef INTEL_GPU +REQD_SUBGROUP_SIZE_16 +#elif defined (ADRENO_GPU) +REQD_SUBGROUP_SIZE_64 +#endif +kernel void kernel_mul_mv_q4_K_f32( + global char * src0, + int offset0, + global char * src1, + int offset1, + global char * dst, + int offsetd, + int ne00, + int ne01, + ulong nb01, + ulong nb02, + ulong nb03, + int ne12, + ulong nb11, + ulong nb12, + ulong nb13, + int ne0, + int ne1, + int r2, + int r3 +) { + src0 = src0 + offset0; + src1 = src1 + offset1; + dst = dst + offsetd; + + ushort kmask1 = 0x3f3f; + ushort kmask2 = 0x0f0f; + ushort kmask3 = 0xc0c0; + + int ix = get_sub_group_local_id()/8; // super block index + int it = get_sub_group_local_id()%8; // block index (inside super block) + int iq = it/4; // 0 or 1 - first or second half of the super block + int ir = it%4; // 0...3 - block index in the half super block + + int nb = ne00/QK_K; + + int r0 = get_group_id(0); + int r1 = get_group_id(1); + int im = get_group_id(2); + int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST; + + int i12 = im%ne12; + int i13 = im/ne12; + + int offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03; + int offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13; + + global block_q4_K * x = (global block_q4_K *) (src0 + offset_src0); + global float * y = (global float *) (src1 + offset_src1); + + float yl[16]; + float yh[16]; + float sumf[N_DST] = {0.f}; + float all_sum; + + global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir; + + ushort sc16[4]; + uchar * sc8 = (uchar *)sc16; + + for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) { + float4 sumy = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; ++i) { + yl[i+0] = y4[i+0]; + sumy.s0 += yl[i+0]; + + yl[i+8] = y4[i+32]; + sumy.s1 += yl[i+8]; + + yh[i+0] = y4[i+128]; + sumy.s2 += yh[i+0]; + + yh[i+8] = y4[i+160]; + sumy.s3 += yh[i+8]; + } + + global ushort * sc = (global ushort *)x[ib].scales + iq; + global ushort * q1 = (global ushort *)x[ib].qs + 16 * iq + 4 * ir; + global half * dh = &x[ib].d; + + for (int row = 0; row < N_DST; row++) { + sc16[0] = sc[0] & kmask1; + sc16[1] = sc[2] & kmask1; + sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2); + sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2); + + global ushort * q2 = q1 + 32; + + float4 acc1 = {0.f, 0.f, 0.f, 0.f}; + float4 acc2 = {0.f, 0.f, 0.f, 0.f}; + for (int i = 0; i < 8; i += 2) { + acc1.s0 += yl[i+0] * (q1[i/2] & 0x000F); + acc1.s1 += yl[i+1] * (q1[i/2] & 0x0F00); + acc1.s2 += yl[i+8] * (q1[i/2] & 0x00F0); + acc1.s3 += yl[i+9] * (q1[i/2] & 0xF000); + acc2.s0 += yh[i+0] * (q2[i/2] & 0x000F); + acc2.s1 += yh[i+1] * (q2[i/2] & 0x0F00); + acc2.s2 += yh[i+8] * (q2[i/2] & 0x00F0); + acc2.s3 += yh[i+9] * (q2[i/2] & 0xF000); + } + + float dall = dh[0]; + float dmin = dh[1]; + sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] + + (acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f + + (acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] + + (acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) - + dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]); + + q1 += nb01/2; + sc += nb01/2; + dh += nb01/2; + } + + y4 += BLOCK_STRIDE * QK_K; + } + + global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0; + + for (int row = 0; row < N_DST; ++row) { + all_sum = sub_group_reduce_add(sumf[row]); + if (first_row + row < ne01) { + if (get_sub_group_local_id() == 0) { + dst_f32[first_row + row] = all_sum; + } + } + } +} diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 72097ffd0f..114992da08 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -92,6 +92,7 @@ static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } #define VK_VENDOR_ID_APPLE 0x106b #define VK_VENDOR_ID_INTEL 0x8086 #define VK_VENDOR_ID_NVIDIA 0x10de +#define VK_VENDOR_ID_QUALCOMM 0x5143 #define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256 @@ -687,6 +688,7 @@ struct vk_device_struct { vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT]; vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_acc_f32; + vk_pipeline pipeline_set_f32; // [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16] vk_pipeline pipeline_add[2][2][2]; @@ -4080,7 +4082,7 @@ static void ggml_vk_load_shaders(vk_device& device) { } ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {1, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); @@ -4181,7 +4183,8 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1); - ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 1}, 1); + ggml_vk_create_pipeline(device, device->pipeline_set_f32, "set_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {0, 0}, 1); ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); @@ -5641,6 +5644,10 @@ static void ggml_vk_instance_init() { driver_priorities[vk::DriverId::eMesaNvk] = 2; #endif break; + case VK_VENDOR_ID_QUALCOMM: + driver_priorities[vk::DriverId::eQualcommProprietary] = 1; + driver_priorities[vk::DriverId::eMesaTurnip] = 2; + break; } driver_priorities[vk::DriverId::eMesaDozen] = 100; @@ -8422,6 +8429,8 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co const uint32_t acctype = f32acc ? 4 : 2; const uint32_t f16vec4 = 8; + const uint32_t tmpsh = (Bc / MatBc) * sizeof(float); + const uint32_t qstride = hsk_pad / 4 + 2; const uint32_t Qf = Br * qstride * f16vec4; @@ -8438,7 +8447,7 @@ static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, co const uint32_t slope = Br * acctype; - const uint32_t total_size = Qf + Psh + sfsh + ksh + slope; + const uint32_t total_size = tmpsh + Qf + Psh + sfsh + ksh + slope; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", kv_type=" << kv_type << ", total_size=" << total_size << ", supported=" << supported); @@ -8815,6 +8824,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return ctx->device->pipeline_acc_f32; } return nullptr; + case GGML_OP_SET: + if (src0->type == src1->type && src0->type == dst->type && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32)) { + return ctx->device->pipeline_set_f32; + } + return nullptr; case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_MUL: @@ -9801,16 +9816,16 @@ static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const const uint32_t src1_type_size = ggml_type_size(src1->type); const uint32_t dst_type_size = ggml_type_size(dst->type); - int nb1 = dst->op_params[0] / 4; // 4 bytes of float32 - int nb2 = dst->op_params[1] / 4; // 4 bytes of float32 - // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused - int offset = dst->op_params[3] / 4; // offset in bytes + int nb1 = dst->op_params[0] / src0_type_size; // 4 bytes of float32 + int nb2 = dst->op_params[1] / src0_type_size; // 4 bytes of float32 + int nb3 = dst->op_params[2] / src0_type_size; // 4 bytes of float32 + int offset = dst->op_params[3] / src0_type_size; // offset in bytes - ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_ACC, { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, nullptr, dst, dst->op, { (uint32_t)ggml_nelements(src0), - (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3, (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, - (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] / dst_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)nb3, 0, 0.0f, 0.0f, offset, }); @@ -10624,8 +10639,10 @@ static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& sub } static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { - float * op_params = (float *)dst->op_params; - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f, 0.0f, 0.0f }); + const float * op_params = (const float *)dst->op_params; + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); + p.param1 = op_params[0]; + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_L2_NORM, std::move(p)); } static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { @@ -12500,6 +12517,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr break; case GGML_OP_ACC: + case GGML_OP_SET: ggml_vk_acc(ctx, compute_ctx, src0, src1, node); break; @@ -14896,8 +14914,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return true; case GGML_OP_NORM: case GGML_OP_GROUP_NORM: - case GGML_OP_L2_NORM: return ggml_is_contiguous(op->src[0]); + case GGML_OP_L2_NORM: + return ggml_is_contiguous_rows(op->src[0]) && + op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; case GGML_OP_ADD: case GGML_OP_SUB: case GGML_OP_MUL: @@ -14960,7 +14980,10 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } return op->src[0]->type == GGML_TYPE_F32; case GGML_OP_ACC: - return op->src[0]->type == GGML_TYPE_F32; + return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; + case GGML_OP_SET: + return op->src[0]->type == op->src[1]->type && op->src[0]->type == op->type && + (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_I32); case GGML_OP_CONCAT: return ggml_type_size(op->src[0]->type) == ggml_type_size(GGML_TYPE_F32); case GGML_OP_ADD1: @@ -15611,6 +15634,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]); } else if (tensor->op == GGML_OP_ACC) { tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]); + } else if (tensor->op == GGML_OP_SET) { + tensor_clone = ggml_set(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]); } else if (tensor->op == GGML_OP_NORM) { tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params); } else if (tensor->op == GGML_OP_GROUP_NORM) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp b/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp index 5084a70ed4..6ba3d1d89e 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp @@ -3,6 +3,9 @@ #include "types.glsl" #include "generic_binary_head.glsl" +// false for SET, true for ACC +layout(constant_id = 1) const bool ACC = true; + layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; void main() { @@ -13,17 +16,22 @@ void main() { const uint offset = p.param3; const uint src1_i = idx - offset; - const uint oz = src1_i / p.nb02; - const uint oy = (src1_i - (oz * p.nb02)) / p.nb01; - const uint ox = src1_i % p.nb01; + const uint i3 = src1_i / p.nb03; + const uint rem2 = src1_i - i3 * p.nb03; + const uint i2 = rem2 / p.nb02; + const uint rem1 = rem2 - i2 * p.nb02; + const uint i1 = rem1 / p.nb01; + const uint i0 = rem1 % p.nb01; uint i00, i01, i02, i03; - get_indices(idx, i00, i01, i02, i03); - if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) { - data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + ox + oy * p.ne10 + oz * p.ne10 * p.ne11])); + if (i0 < p.ne10 && i1 < p.ne11 && i2 < p.ne12 && i3 < p.ne13) { + if (ACC) { + data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)])); + } else { + data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_b[get_boffset() + src1_idx(i0, i1, i2, i3)])); + } } else { - data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)])); + data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx])); } } - diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp index 914f131c96..0735f67854 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -130,6 +130,7 @@ void main() { if (MASK_ENABLE && mask_opt_bits != MASK_OPT_ALL_ZERO) { bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; + float max_mask = NEG_FLT_MAX_OVER_2; [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { uint32_t c = (idx + tid) % Bc; uint32_t r = (idx + tid) / Bc; @@ -137,12 +138,25 @@ void main() { if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { float m = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); masksh[c][r] = m; + max_mask = max(max_mask, m); } else { masksh[c][r] = float(0); } } } + // skip the block if the mask is entirely -inf + bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); barrier(); + if (gl_SubgroupInvocationID == 0) { + tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; + } + barrier(); + [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { + max_mask = max(max_mask, tmpsh[s]); + } + if (max_mask <= NEG_FLT_MAX_OVER_2) { + continue; + } } float Sf[Br][cols_per_thread]; @@ -260,6 +274,9 @@ void main() { barrier(); } + // prevent race on tmpsh + barrier(); + // reduce across threads [[unroll]] for (uint32_t r = 0; r < Br; ++r) { diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp index b317773823..19630972da 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -42,6 +42,8 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY return elem; } +shared float tmpsh[row_split]; + const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4 shared f16vec4 Qf[Br * qstride]; @@ -213,6 +215,19 @@ void main() { } } } + // skip the block if the mask is entirely -inf + bool all_less = subgroupAll(max_mask <= NEG_FLT_MAX_OVER_2); + barrier(); + if (gl_SubgroupInvocationID == 0) { + tmpsh[gl_SubgroupID] = all_less ? NEG_FLT_MAX_OVER_2 : 0.0f; + } + barrier(); + [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { + max_mask = max(max_mask, tmpsh[s]); + } + if (max_mask <= NEG_FLT_MAX_OVER_2) { + continue; + } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp index 39f0c4d23b..853f17fa16 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -176,7 +176,14 @@ void main() { tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); tensorLayoutM = setTensorLayoutClampValueNV(tensorLayoutM, 0xfc00); // -inf in float16_t + coopmat mvmax; + coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + // skip the block if the mask is entirely -inf + coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16); + if (mvmax[0] <= NEG_FLT_MAX_OVER_2) { + continue; + } } else { tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); // Don't clamp against nem1 when GQA is enabled @@ -184,7 +191,14 @@ void main() { tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV); tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); + coopmat mvmax; + coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + // skip the block if the mask is entirely -inf + coopMatReduceNV(mvmax, mv, gl_CooperativeMatrixReduceRowAndColumnNV, maxReduceFp16); + if (mvmax[0] <= NEG_FLT_MAX_OVER_2) { + continue; + } } } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp b/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp index 83ef2f8795..7d0a1de0df 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp @@ -1,6 +1,6 @@ #version 450 -#include "generic_head.glsl" +#include "generic_unary_head.glsl" #include "types.glsl" #extension GL_EXT_control_flow_attributes : enable @@ -8,19 +8,22 @@ layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; -layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; -layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; - shared FLOAT_TYPE sum[BLOCK_SIZE]; void main() { const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; const uint tid = gl_LocalInvocationID.x; + const uint i3 = row / (p.ne11 * p.ne12); + const uint i3_offset = i3 * p.ne12 * p.ne11; + const uint i2 = (row - i3_offset) / p.ne11; + const uint i2_offset = i2 * p.ne11; + const uint i1 = row - i3_offset - i2_offset; + sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp - [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { - const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]); + [[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) { + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0]); sum[tid] += xi * xi; } @@ -35,7 +38,7 @@ void main() { const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1))); - [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { - data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col])); + [[unroll]] for (uint i0 = tid; i0 < p.ne00; i0 += BLOCK_SIZE) { + data_d[i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0] = D_TYPE(scale * FLOAT_TYPE(data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0])); } } diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index f39dd8da3a..5e60d8b180 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -5751,7 +5751,7 @@ static struct ggml_tensor * ggml_unary_impl( struct ggml_tensor * a, enum ggml_unary_op op, bool inplace) { - GGML_ASSERT(ggml_is_contiguous_1(a)); + GGML_ASSERT(ggml_is_contiguous_rows(a)); struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a); diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index f685b2a000..d0761961f6 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -181,6 +181,11 @@ class Keys: SLIDING_WINDOW_PATTERN = "{arch}.attention.sliding_window_pattern" TEMPERATURE_SCALE = "{arch}.attention.temperature_scale" + class Indexer: + HEAD_COUNT = "{arch}.attention.indexer.head_count" + KEY_LENGTH = "{arch}.attention.indexer.key_length" + TOP_K = "{arch}.attention.indexer.top_k" + class Rope: DIMENSION_COUNT = "{arch}.rope.dimension_count" DIMENSION_SECTIONS = "{arch}.rope.dimension_sections" @@ -431,6 +436,7 @@ class MODEL_ARCH(IntEnum): CHATGLM = auto() GLM4 = auto() GLM4_MOE = auto() + GLM_DSA = auto() BITNET = auto() T5 = auto() T5ENCODER = auto() @@ -676,6 +682,10 @@ class MODEL_TENSOR(IntEnum): VISEXP_GATE = auto() VISEXP_DOWN = auto() VISEXP_UP = auto() + INDEXER_K_NORM = auto() + INDEXER_PROJ = auto() + INDEXER_ATTN_K = auto() + INDEXER_ATTN_Q_B = auto() # vision V_MMPROJ = auto() V_MMPROJ_FC = auto() @@ -881,6 +891,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.CHATGLM: "chatglm", MODEL_ARCH.GLM4: "glm4", MODEL_ARCH.GLM4_MOE: "glm4moe", + MODEL_ARCH.GLM_DSA: "glm-dsa", MODEL_ARCH.BITNET: "bitnet", MODEL_ARCH.T5: "t5", MODEL_ARCH.T5ENCODER: "t5encoder", @@ -1124,6 +1135,10 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = { MODEL_TENSOR.VISEXP_GATE: "blk.{bid}.vis_gate", MODEL_TENSOR.VISEXP_DOWN: "blk.{bid}.vis_down", MODEL_TENSOR.VISEXP_UP: "blk.{bid}.vis_up", + MODEL_TENSOR.INDEXER_K_NORM: "blk.{bid}.indexer.k_norm", + MODEL_TENSOR.INDEXER_PROJ: "blk.{bid}.indexer.proj", + MODEL_TENSOR.INDEXER_ATTN_K: "blk.{bid}.indexer.attn_k", + MODEL_TENSOR.INDEXER_ATTN_Q_B: "blk.{bid}.indexer.attn_q_b", # vision MODEL_TENSOR.V_MMPROJ: "mm.{bid}", MODEL_TENSOR.V_MMPROJ_FC: "mm.model.fc", @@ -2765,6 +2780,47 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, ], + MODEL_ARCH.GLM_DSA: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ROPE_FREQS, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_A, + MODEL_TENSOR.ATTN_Q_B, + MODEL_TENSOR.ATTN_KV_A_MQA, + MODEL_TENSOR.ATTN_KV_B, + MODEL_TENSOR.ATTN_K_B, + MODEL_TENSOR.ATTN_V_B, + MODEL_TENSOR.ATTN_Q_A_NORM, + MODEL_TENSOR.ATTN_KV_A_NORM, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.ATTN_ROT_EMBD, + MODEL_TENSOR.FFN_GATE_INP, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + MODEL_TENSOR.FFN_GATE_EXP, + MODEL_TENSOR.FFN_DOWN_EXP, + MODEL_TENSOR.FFN_UP_EXP, + MODEL_TENSOR.FFN_GATE_SHEXP, + MODEL_TENSOR.FFN_DOWN_SHEXP, + MODEL_TENSOR.FFN_UP_SHEXP, + MODEL_TENSOR.FFN_EXP_PROBS_B, + MODEL_TENSOR.INDEXER_K_NORM, + MODEL_TENSOR.INDEXER_PROJ, + MODEL_TENSOR.INDEXER_ATTN_K, + MODEL_TENSOR.INDEXER_ATTN_Q_B, + # NextN/MTP tensors - preserved but unused + MODEL_TENSOR.NEXTN_EH_PROJ, + MODEL_TENSOR.NEXTN_EMBED_TOKENS, + MODEL_TENSOR.NEXTN_ENORM, + MODEL_TENSOR.NEXTN_HNORM, + MODEL_TENSOR.NEXTN_SHARED_HEAD_HEAD, + MODEL_TENSOR.NEXTN_SHARED_HEAD_NORM, + ], MODEL_ARCH.BITNET: [ MODEL_TENSOR.ATTN_Q, MODEL_TENSOR.ATTN_K, @@ -3867,6 +3923,7 @@ class VisionProjectorType: MUSIC_FLAMINGO = "musicflamingo" # audio GLM4V = "glm4v" YOUTUVL = "youtuvl" + NEMOTRON_V2_VL = "nemotron_v2_vl" # Items here are (block size, type size) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 7ca1322dc6..de610ad692 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -771,6 +771,15 @@ class GGUFWriter: def add_value_length_mla(self, length: int) -> None: self.add_uint32(Keys.Attention.VALUE_LENGTH_MLA.format(arch=self.arch), length) + def add_indexer_head_count(self, count: int) -> None: + self.add_uint32(Keys.Attention.Indexer.HEAD_COUNT.format(arch=self.arch), count) + + def add_indexer_key_length(self, length: int) -> None: + self.add_uint32(Keys.Attention.Indexer.KEY_LENGTH.format(arch=self.arch), length) + + def add_indexer_top_k(self, top_k: int) -> None: + self.add_uint32(Keys.Attention.Indexer.TOP_K.format(arch=self.arch), top_k) + def add_max_alibi_bias(self, bias: float) -> None: self.add_float32(Keys.Attention.MAX_ALIBI_BIAS.format(arch=self.arch), bias) diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 40c15be10b..579bf17ccb 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1206,6 +1206,22 @@ class TensorNameMap: "model.layers.{bid}.self_attn.vision_expert_query_key_value", # cogvlm ), + MODEL_TENSOR.INDEXER_K_NORM: ( + "model.layers.{bid}.self_attn.indexer.k_norm", # DSA + ), + + MODEL_TENSOR.INDEXER_PROJ: ( + "model.layers.{bid}.self_attn.indexer.weights_proj", # DSA + ), + + MODEL_TENSOR.INDEXER_ATTN_K: ( + "model.layers.{bid}.self_attn.indexer.wk", # DSA + ), + + MODEL_TENSOR.INDEXER_ATTN_Q_B: ( + "model.layers.{bid}.self_attn.indexer.wq_b", # DSA + ), + ############################################################################ # TODO: these do not belong to block_mappings_cfg - move them to mappings_cfg MODEL_TENSOR.ENC_OUTPUT_NORM: ( @@ -1331,6 +1347,7 @@ class TensorNameMap: "model.vision_tower.embeddings.cls_token", # Intern-S1 "vision_model.class_embedding", # llama 4 "model.vision.patch_embedding.cls_embedding", # cogvlm + "vision_model.radio_model.model.patch_generator.cls_token.token", # Nemotron Nano v2 VL "model.vision_model.embeddings.class_embedding", # Deepseek-OCR ), @@ -1347,6 +1364,7 @@ class TensorNameMap: "model.vision.patch_embedding.proj", # cogvlm "model.vision_model.embeddings.patch_embedding", # Deepseek-OCR CLIP "siglip2.vision_model.embeddings.patch_embedding", + "vision_model.radio_model.model.patch_generator.embedder", # Nemotron Nano v2 VL ), MODEL_TENSOR.V_ENC_EMBD_NORM: ( @@ -1363,6 +1381,7 @@ class TensorNameMap: "visual.pos_embed", # qwen3vl "model.vision.patch_embedding.position_embedding", # cogvlm "visual.embeddings.position_embedding", # glm4v + "vision_model.radio_model.model.patch_generator.pos_embed", # Nemotron Nano v2 VL ), MODEL_TENSOR.V_ENC_EMBD_IMGNL: ( @@ -1378,6 +1397,7 @@ class TensorNameMap: "model.vision.transformer.layers.{bid}.attention.query_key_value", # cogvlm "model.vision_model.transformer.layers.{bid}.self_attn.qkv_proj", # Deepseek-OCR CLIP "vision_tower.encoder.blocks.{bid}.wqkv" # Kimi-K2.5 + "vision_model.radio_model.model.blocks.{bid}.attn.qkv", # Nemotron Nano v2 VL ), MODEL_TENSOR.V_ENC_ATTN_Q: ( @@ -1446,6 +1466,7 @@ class TensorNameMap: "model.vision.transformer.layers.{bid}.input_layernorm", # cogvlm "model.vision_model.transformer.layers.{bid}.layer_norm1", # Deepseek-OCR CLIP "siglip2.vision_model.encoder.layers.{bid}.layer_norm1", + "vision_model.radio_model.model.blocks.{bid}.norm1", # Nemotron Nano v2 VL ), MODEL_TENSOR.V_ENC_ATTN_O: ( @@ -1463,6 +1484,7 @@ class TensorNameMap: "model.vision.transformer.layers.{bid}.attention.dense", # cogvlm "model.vision_model.transformer.layers.{bid}.self_attn.out_proj", # Deepseek-OCR CLIP "siglip2.vision_model.encoder.layers.{bid}.self_attn.out_proj", # youtuvl + "vision_model.radio_model.model.blocks.{bid}.attn.proj", # Nemotron Nano v2 VL ), MODEL_TENSOR.V_ENC_POST_ATTN_NORM: ( @@ -1479,6 +1501,7 @@ class TensorNameMap: "model.vision.transformer.layers.{bid}.post_attention_layernorm", # cogvlm "model.vision_model.transformer.layers.{bid}.layer_norm2", # Deepseek-OCR CLIP "siglip2.vision_model.encoder.layers.{bid}.layer_norm2", + "vision_model.radio_model.model.blocks.{bid}.norm2", # Nemotron Nano v2 VL ), MODEL_TENSOR.V_ENC_FFN_UP: ( @@ -1496,6 +1519,7 @@ class TensorNameMap: "model.vision_model.transformer.layers.{bid}.mlp.fc1", # Deepseek-OCR CLIP "model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm "siglip2.vision_model.encoder.layers.{bid}.mlp.fc1", + "vision_model.radio_model.model.blocks.{bid}.mlp.fc1", # Nemotron Nano v2 VL ), MODEL_TENSOR.V_ENC_FFN_GATE: ( @@ -1519,6 +1543,7 @@ class TensorNameMap: "model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm "model.vision_model.transformer.layers.{bid}.mlp.fc2", # Deepseek-OCR CLIP "siglip2.vision_model.encoder.layers.{bid}.mlp.fc2", + "vision_model.radio_model.model.blocks.{bid}.mlp.fc2", # Nemotron Nano v2 VL ), MODEL_TENSOR.V_LAYER_SCALE_1: ( diff --git a/include/llama.h b/include/llama.h index 46c3672e98..d2d7f59ebc 100644 --- a/include/llama.h +++ b/include/llama.h @@ -656,21 +656,12 @@ extern "C" { // The following functions operate on a llama_context, hence the naming: llama_verb_... - // Add a loaded LoRA adapter to given context - // This will not modify model's weight - LLAMA_API int32_t llama_set_adapter_lora( + // Set LoRa adapters on the context. Will only modify if the adapters currently in context are different. + LLAMA_API int32_t llama_set_adapters_lora( struct llama_context * ctx, - struct llama_adapter_lora * adapter, - float scale); - - // Remove a specific LoRA adapter from given context - // Return -1 if the adapter is not present in the context - LLAMA_API int32_t llama_rm_adapter_lora( - struct llama_context * ctx, - struct llama_adapter_lora * adapter); - - // Remove all LoRA adapters from given context - LLAMA_API void llama_clear_adapter_lora(struct llama_context * ctx); + struct llama_adapter_lora ** adapters, + size_t n_adapters, + float * scales); // Apply a loaded control vector to a llama_context, or if data is NULL, clear // the currently loaded vector. @@ -678,7 +669,7 @@ extern "C" { // to an n_embd x n_layers buffer starting from layer 1. // il_start and il_end are the layer range the vector should apply to (both inclusive) // See llama_control_vector_load in common to load a control vector. - LLAMA_API int32_t llama_apply_adapter_cvec( + LLAMA_API int32_t llama_set_adapter_cvec( struct llama_context * ctx, const float * data, size_t len, @@ -1150,9 +1141,9 @@ extern "C" { // /// Apply chat template. Inspired by hf apply_chat_template() on python. - /// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model" + /// /// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template - /// @param tmpl A Jinja template to use for this chat. If this is nullptr, the model’s default chat template will be used instead. + /// @param tmpl A Jinja template to use for this chat. /// @param chat Pointer to a list of multiple llama_chat_message /// @param n_msg Number of llama_chat_message in this chat /// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message. diff --git a/scripts/pr2wt.sh b/scripts/pr2wt.sh index bd635f3b9d..067f5d466b 100755 --- a/scripts/pr2wt.sh +++ b/scripts/pr2wt.sh @@ -30,12 +30,18 @@ fi PR=$1 [[ "$PR" =~ ^[0-9]+$ ]] || { echo "error: PR number must be numeric"; exit 1; } +url_origin=$(git config --get remote.upstream.url 2>/dev/null) || \ url_origin=$(git config --get remote.origin.url) || { - echo "error: no remote named 'origin' in this repository" + echo "error: no remote named 'upstream' or 'origin' in this repository" exit 1 } -org_repo=$(echo $url_origin | cut -d/ -f4-) +# Extract org/repo from either https or ssh format. +if [[ $url_origin =~ ^git@ ]]; then + org_repo=$(echo $url_origin | cut -d: -f2) +else + org_repo=$(echo $url_origin | cut -d/ -f4-) +fi org_repo=${org_repo%.git} echo "org/repo: $org_repo" diff --git a/scripts/sync-ggml.last b/scripts/sync-ggml.last index 81e79a9470..02a096882e 100644 --- a/scripts/sync-ggml.last +++ b/scripts/sync-ggml.last @@ -1 +1 @@ -a8db410a252c8c8f2d120c6f2e7133ebe032f35d +d6754f3d0e6d0acd21c12442353c9fd2f94188e7 diff --git a/scripts/sync_vendor.py b/scripts/sync_vendor.py index 1ff6a9a40f..fe1286d009 100755 --- a/scripts/sync_vendor.py +++ b/scripts/sync_vendor.py @@ -1,6 +1,11 @@ #!/usr/bin/env python3 import urllib.request +import os +import sys +import subprocess + +HTTPLIB_VERSION = "d4180e923f846b44a3d30acd938438d6e64fc9f6" vendor = { "https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp", @@ -12,8 +17,9 @@ vendor = { # "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.23/miniaudio.h": "vendor/miniaudio/miniaudio.h", "https://github.com/mackron/miniaudio/raw/669ed3e844524fcd883231b13095baee9f6de304/miniaudio.h": "vendor/miniaudio/miniaudio.h", - "https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.30.2/httplib.h": "vendor/cpp-httplib/httplib.h", - "https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.30.2/LICENSE": "vendor/cpp-httplib/LICENSE", + f"https://raw.githubusercontent.com/yhirose/cpp-httplib/{HTTPLIB_VERSION}/httplib.h": "httplib.h", + f"https://raw.githubusercontent.com/yhirose/cpp-httplib/{HTTPLIB_VERSION}/split.py": "split.py", + f"https://raw.githubusercontent.com/yhirose/cpp-httplib/{HTTPLIB_VERSION}/LICENSE": "vendor/cpp-httplib/LICENSE", "https://raw.githubusercontent.com/sheredom/subprocess.h/b49c56e9fe214488493021017bf3954b91c7c1f5/subprocess.h": "vendor/sheredom/subprocess.h", } @@ -22,19 +28,16 @@ for url, filename in vendor.items(): print(f"downloading {url} to {filename}") # noqa: NP100 urllib.request.urlretrieve(url, filename) - # split cpp/h files for httplib - # see: https://github.com/yhirose/cpp-httplib/blob/master/split.py - if 'httplib.h' in filename: - border = '// ----------------------------------------------------------------------------' - with open(filename, 'r') as f: - content = f.read() - header, implementation, footer = content.split(border, 2) - fname_cpp = filename.replace('.h', '.cpp') - with open(filename, 'w') as fh: - fh.write(header) - fh.write(footer) - with open(fname_cpp, 'w') as fc: - fc.write('#include "httplib.h"\n') - fc.write('namespace httplib {\n') - fc.write(implementation.replace('\ninline ', '\n')) - fc.write('} // namespace httplib\n') +print("Splitting httplib.h...") # noqa: NP100 +try: + subprocess.check_call([ + sys.executable, "split.py", + "--extension", "cpp", + "--out", "vendor/cpp-httplib" + ]) +except Exception as e: + print(f"Error: {e}") # noqa: NP100 + sys.exit(1) +finally: + os.remove("split.py") + os.remove("httplib.h") diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index fdda05d3ea..daf249422a 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -57,13 +57,14 @@ add_library(llama models/deci.cpp models/deepseek.cpp models/deepseek2.cpp + models/delta-net-base.cpp models/dots1.cpp models/dream.cpp models/ernie4-5-moe.cpp models/ernie4-5.cpp + models/exaone-moe.cpp models/exaone.cpp models/exaone4.cpp - models/exaone-moe.cpp models/falcon-h1.cpp models/falcon.cpp models/gemma-embedding.cpp @@ -91,10 +92,12 @@ add_library(llama models/llama-iswa.cpp models/llama.cpp models/maincoder.cpp + models/mamba-base.cpp models/mamba.cpp models/mimo2-iswa.cpp models/minicpm3.cpp models/minimax-m2.cpp + models/mistral3.cpp models/modern-bert.cpp models/mpt.cpp models/nemotron-h.cpp @@ -118,12 +121,12 @@ add_library(llama models/qwen2moe.cpp models/qwen2vl.cpp models/qwen3.cpp - models/qwen3vl.cpp - models/qwen3vl-moe.cpp - models/qwen3moe.cpp - models/qwen3next.cpp models/qwen35.cpp models/qwen35moe.cpp + models/qwen3moe.cpp + models/qwen3next.cpp + models/qwen3vl-moe.cpp + models/qwen3vl.cpp models/refact.cpp models/rnd1.cpp models/rwkv6-base.cpp @@ -142,8 +145,6 @@ add_library(llama models/t5-enc.cpp models/wavtokenizer-dec.cpp models/xverse.cpp - models/mistral3.cpp - models/graph-context-mamba.cpp ) set_target_properties(llama PROPERTIES diff --git a/src/llama-adapter.h b/src/llama-adapter.h index d275d25425..aa3ab63ad7 100644 --- a/src/llama-adapter.h +++ b/src/llama-adapter.h @@ -39,6 +39,8 @@ private: std::vector tensors; // per layer }; +using llama_adapter_cvec_ptr = std::shared_ptr; + // // llama_adapter_lora // @@ -84,3 +86,4 @@ struct llama_adapter_lora { }; using llama_adapter_loras = std::unordered_map; +using llama_adapter_loras_ptr = std::unique_ptr; diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index fa7f5e20a3..96edd0b116 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -75,6 +75,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_CHATGLM, "chatglm" }, { LLM_ARCH_GLM4, "glm4" }, { LLM_ARCH_GLM4_MOE, "glm4moe" }, + { LLM_ARCH_GLM_DSA, "glm-dsa" }, { LLM_ARCH_BITNET, "bitnet" }, { LLM_ARCH_T5, "t5" }, { LLM_ARCH_T5ENCODER, "t5encoder" }, @@ -226,6 +227,9 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_TEMPERATURE_SCALE, "%s.attention.temperature_scale" }, { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, + { LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, "%s.attention.indexer.head_count" }, + { LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, "%s.attention.indexer.key_length" }, + { LLM_KV_ATTENTION_INDEXER_TOP_K, "%s.attention.indexer.top_k" }, { LLM_KV_ROPE_DIMENSION_COUNT, "%s.rope.dimension_count" }, { LLM_KV_ROPE_DIMENSION_SECTIONS, "%s.rope.dimension_sections" }, @@ -517,6 +521,10 @@ static const std::map LLM_TENSOR_NAMES = { { LLM_TENSOR_VISEXP_FFN_GATE, "blk.%d.vis_gate" }, { LLM_TENSOR_VISEXP_FFN_DOWN, "blk.%d.vis_down" }, { LLM_TENSOR_VISEXP_FFN_UP, "blk.%d.vis_up" }, + { LLM_TENSOR_INDEXER_K_NORM, "blk.%d.indexer.k_norm" }, + { LLM_TENSOR_INDEXER_PROJ, "blk.%d.indexer.proj" }, + { LLM_TENSOR_INDEXER_ATTN_K, "blk.%d.indexer.attn_k" }, + { LLM_TENSOR_INDEXER_ATTN_Q_B, "blk.%d.indexer.attn_q_b" }, }; static std::set llm_get_tensor_names(llm_arch arch) { @@ -1690,6 +1698,46 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, }; + case LLM_ARCH_GLM_DSA: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q_A_NORM, + LLM_TENSOR_ATTN_KV_A_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_A, + LLM_TENSOR_ATTN_Q_B, + LLM_TENSOR_ATTN_KV_A_MQA, + LLM_TENSOR_ATTN_KV_B, + LLM_TENSOR_ATTN_K_B, + LLM_TENSOR_ATTN_V_B, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_UP, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_GATE_INP, + LLM_TENSOR_FFN_GATE_EXPS, + LLM_TENSOR_FFN_DOWN_EXPS, + LLM_TENSOR_FFN_UP_EXPS, + LLM_TENSOR_FFN_GATE_INP_SHEXP, + LLM_TENSOR_FFN_GATE_SHEXP, + LLM_TENSOR_FFN_DOWN_SHEXP, + LLM_TENSOR_FFN_UP_SHEXP, + LLM_TENSOR_FFN_EXP_PROBS_B, + LLM_TENSOR_INDEXER_K_NORM, + LLM_TENSOR_INDEXER_PROJ, + LLM_TENSOR_INDEXER_ATTN_K, + LLM_TENSOR_INDEXER_ATTN_Q_B, + LLM_TENSOR_NEXTN_EH_PROJ, + LLM_TENSOR_NEXTN_EMBED_TOKENS, + LLM_TENSOR_NEXTN_ENORM, + LLM_TENSOR_NEXTN_HNORM, + LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, + LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, + }; case LLM_ARCH_BITNET: return { LLM_TENSOR_TOKEN_EMBD, @@ -2676,6 +2724,10 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_VISEXP_FFN_GATE, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_VISEXP_FFN_DOWN, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, {LLM_TENSOR_VISEXP_FFN_UP, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_INDEXER_K_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, + {LLM_TENSOR_INDEXER_PROJ, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_INDEXER_ATTN_K, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_INDEXER_ATTN_Q_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}}, // NextN/MTP tensors are currently ignored (reserved for future MTP support) // These tensors only exist in the last layer(s) and are treated as output tensors {LLM_TENSOR_NEXTN_EH_PROJ, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, diff --git a/src/llama-arch.h b/src/llama-arch.h index a1aaf77069..7f96bf6fff 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -79,6 +79,7 @@ enum llm_arch { LLM_ARCH_CHATGLM, LLM_ARCH_GLM4, LLM_ARCH_GLM4_MOE, + LLM_ARCH_GLM_DSA, LLM_ARCH_BITNET, LLM_ARCH_T5, LLM_ARCH_T5ENCODER, @@ -230,6 +231,9 @@ enum llm_kv { LLM_KV_ATTENTION_TEMPERATURE_SCALE, LLM_KV_ATTENTION_KEY_LENGTH_MLA, LLM_KV_ATTENTION_VALUE_LENGTH_MLA, + LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, + LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, + LLM_KV_ATTENTION_INDEXER_TOP_K, LLM_KV_ROPE_DIMENSION_COUNT, LLM_KV_ROPE_DIMENSION_SECTIONS, @@ -518,6 +522,10 @@ enum llm_tensor { LLM_TENSOR_VISEXP_FFN_GATE, LLM_TENSOR_VISEXP_FFN_DOWN, LLM_TENSOR_VISEXP_FFN_UP, + LLM_TENSOR_INDEXER_K_NORM, + LLM_TENSOR_INDEXER_PROJ, + LLM_TENSOR_INDEXER_ATTN_K, + LLM_TENSOR_INDEXER_ATTN_Q_B, LLM_TENSOR_NEXTN_EH_PROJ, LLM_TENSOR_NEXTN_EMBED_TOKENS, LLM_TENSOR_NEXTN_ENORM, diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 6b43ca1926..fc05989aa5 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -22,6 +22,8 @@ llama_context::llama_context( const llama_model & model, llama_context_params params) : model(model), + cvec(std::make_unique()), + loras(std::make_unique()), balloc(std::make_unique(model.hparams.n_pos_per_embd())) { // TODO warning when creating llama_context with awkward ctx size that is not a power of 2, // may need to be backend-dependent @@ -878,6 +880,7 @@ const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) { } } catch (const std::exception & err) { // fallback to full vocab list + GGML_UNUSED(err); } return sampling.token_ids_full_vocab.data(); @@ -1057,51 +1060,43 @@ bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { return true; } -void llama_context::set_adapter_lora( - llama_adapter_lora * adapter, - float scale) { - LLAMA_LOG_DEBUG("%s: adapter = %p, scale = %f\n", __func__, (void *) adapter, scale); +void llama_context::set_adapters_lora(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) { + LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters); - if (auto it = loras.find(adapter); it != loras.end()) { - if (it->second == scale) { - return; - } - } - - loras[adapter] = scale; - - sched_need_reserve = true; -} - -bool llama_context::rm_adapter_lora( - llama_adapter_lora * adapter) { - LLAMA_LOG_DEBUG("%s: adapter = %p\n", __func__, (void *) adapter); - - auto it = loras.find(adapter); - if (it != loras.end()) { - loras.erase(it); - - sched_need_reserve = true; - - return true; - } - - return false; -} - -void llama_context::clear_adapter_lora() { - LLAMA_LOG_DEBUG("%s: call\n", __func__); - - if (loras.empty()) { + if (adapters_lora_are_same(adapters, n_adapters, scales)) { return; } - loras.clear(); + loras.reset(new llama_adapter_loras()); + + for (size_t i = 0; i < n_adapters; i ++) { + if (scales[i] != 0.0f) { + loras->insert({adapters[i], scales[i]}); + } + } sched_need_reserve = true; } -bool llama_context::apply_adapter_cvec( +bool llama_context::adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales) { + LLAMA_LOG_DEBUG("%s: adapters = %p\n", __func__, (void *) adapters); + + if (n_adapters != loras->size()) { + return false; + } + + for (size_t i = 0; i < n_adapters; i ++) { + auto it = loras->find(adapters[i]); + + if (it == loras->end() || it->second != scales[i]) { + return false; + } + } + + return true; +} + +bool llama_context::set_adapter_cvec( const float * data, size_t len, int32_t n_embd, @@ -1111,7 +1106,7 @@ bool llama_context::apply_adapter_cvec( // TODO: should we reserve? - return cvec.apply(model, data, len, n_embd, il_start, il_end); + return cvec->apply(model, data, len, n_embd, il_start, il_end); } llm_graph_result * llama_context::process_ubatch(const llama_ubatch & ubatch, llm_graph_type gtype, llama_memory_context_i * mctx, ggml_status & ret) { @@ -1817,7 +1812,6 @@ int llama_context::decode(const llama_batch & batch_inp) { // uint32_t llama_context::output_reserve(int32_t n_outputs) { - const auto & hparams = model.hparams; const auto & vocab = model.vocab; @@ -1901,11 +1895,6 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { embd = has_embd ? buffer_view{(float *) (base + offset), embd.size} : buffer_view{nullptr, 0}; offset += embd.size * sizeof(float); - sampling.logits = {nullptr, 0}; - sampling.probs = {nullptr, 0}; - sampling.sampled = {nullptr, 0}; - sampling.candidates = {nullptr, 0}; - if (has_sampling) { sampling.logits = {(float *) (base + offset), (size_t)(n_vocab*n_outputs_max)}; offset += sampling.logits.size * sizeof(float); @@ -1931,6 +1920,15 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0); std::fill_n(sampling.sampled.data, sampling.sampled.size, LLAMA_TOKEN_NULL); + } else { + sampling.logits = {nullptr, 0}; + sampling.probs = {nullptr, 0}; + sampling.sampled = {nullptr, 0}; + sampling.candidates = {nullptr, 0}; + + sampling.logits_count.clear(); + sampling.probs_count.clear(); + sampling.candidates_count.clear(); } // set all ids as invalid (negative) @@ -1961,37 +1959,30 @@ void llama_context::output_reorder() { } } - if (sampling.logits.has_data()) { + if (!sampling.samplers.empty()) { + assert(sampling.logits.size > 0); + assert(sampling.probs.size > 0); + assert(sampling.candidates.size > 0); + assert(sampling.sampled.size > 0); + assert(sampling.logits_count.size() > 0); + assert(sampling.probs_count.size() > 0); + assert(sampling.candidates_count.size() > 0); + for (uint64_t k = 0; k < n_vocab; ++k) { std::swap(sampling.logits.data[i0*n_vocab + k], sampling.logits.data[i1*n_vocab + k]); } - } - if (sampling.probs.has_data()) { for (uint64_t k = 0; k < n_vocab; ++k) { std::swap(sampling.probs.data[i0*n_vocab + k], sampling.probs.data[i1*n_vocab + k]); } - } - if (sampling.candidates.has_data()) { for (uint64_t k = 0; k < n_vocab; ++k) { std::swap(sampling.candidates.data[i0*n_vocab + k], sampling.candidates.data[i1*n_vocab + k]); } - } - if (sampling.sampled.has_data()) { - std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]); - } - - if (!sampling.logits_count.empty()) { - std::swap(sampling.logits_count[i0], sampling.logits_count[i1]); - } - - if (!sampling.probs_count.empty()) { - std::swap(sampling.probs_count[i0], sampling.probs_count[i1]); - } - - if (!sampling.candidates_count.empty()) { + std::swap(sampling.sampled.data[i0], sampling.sampled.data[i1]); + std::swap(sampling.logits_count[i0], sampling.logits_count[i1]); + std::swap(sampling.probs_count[i0], sampling.probs_count[i1]); std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]); } } @@ -2092,8 +2083,8 @@ llm_graph_params llama_context::graph_params( /*.gtype =*/ gtype, /*.sched =*/ sched.get(), /*.backend_cpu =*/ backend_cpu, - /*.cvec =*/ &cvec, - /*.loras =*/ &loras, + /*.cvec =*/ cvec.get(), + /*.loras =*/ loras.get(), /*.mctx =*/ mctx, /*.cross =*/ &cross, /*.samplers =*/ sampling.samplers, @@ -3209,35 +3200,28 @@ uint32_t llama_get_sampled_probs_count_ith(llama_context * ctx, int32_t i) { // llama adapter API -int32_t llama_set_adapter_lora( +int32_t llama_set_adapters_lora( llama_context * ctx, - llama_adapter_lora * adapter, - float scale) { - ctx->set_adapter_lora(adapter, scale); + llama_adapter_lora ** adapters, + size_t n_adapters, + float * scales) { + if (adapters == nullptr || scales == nullptr) { + GGML_ASSERT(n_adapters == 0 && "invalid llama_set_adapters_lora call"); + } + + ctx->set_adapters_lora(adapters, n_adapters, scales); return 0; } -int32_t llama_rm_adapter_lora( - llama_context * ctx, - llama_adapter_lora * adapter) { - bool res = ctx->rm_adapter_lora(adapter); - - return res ? 0 : -1; -} - -void llama_clear_adapter_lora(llama_context * ctx) { - ctx->clear_adapter_lora(); -} - -int32_t llama_apply_adapter_cvec( +int32_t llama_set_adapter_cvec( llama_context * ctx, - const float * data, - size_t len, - int32_t n_embd, - int32_t il_start, - int32_t il_end) { - bool res = ctx->apply_adapter_cvec(data, len, n_embd, il_start, il_end); + const float * data, + size_t len, + int32_t n_embd, + int32_t il_start, + int32_t il_end) { + bool res = ctx->set_adapter_cvec(data, len, n_embd, il_start, il_end); return res ? 0 : -1; } diff --git a/src/llama-context.h b/src/llama-context.h index d995117574..e0d0085c1c 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -105,16 +105,11 @@ struct llama_context { void set_causal_attn(bool value); void set_warmup(bool value); - void set_adapter_lora( - llama_adapter_lora * adapter, - float scale); + void set_adapters_lora(llama_adapter_lora ** adapters, size_t n_adapters, float * scales); - bool rm_adapter_lora( - llama_adapter_lora * adapter); + bool adapters_lora_are_same(llama_adapter_lora ** adapters, size_t n_adapters, float * scales); - void clear_adapter_lora(); - - bool apply_adapter_cvec( + bool set_adapter_cvec( const float * data, size_t len, int32_t n_embd, @@ -261,33 +256,36 @@ private: const llama_model & model; - llama_cparams cparams; - llama_adapter_cvec cvec; - llama_adapter_loras loras; + llama_cparams cparams; + + llama_adapter_cvec_ptr cvec; + llama_adapter_loras_ptr loras; llama_cross cross; // TODO: tmp for handling cross-attention - need something better probably std::unique_ptr memory; // decode output (2-dimensional array: [n_outputs][n_vocab]) - struct buffer_view logits = {nullptr, 0}; + buffer_view logits = {nullptr, 0}; // embeddings output (2-dimensional array: [n_outputs][n_embd]) // populated only when pooling_type == LLAMA_POOLING_TYPE_NONE - struct buffer_view embd = {nullptr, 0}; + buffer_view embd = {nullptr, 0}; struct sampling_info { + // !samplers.empty() to check if any samplers are active std::map samplers; - struct buffer_view logits = {nullptr, 0}; - struct buffer_view sampled = {nullptr, 0}; - struct buffer_view probs = {nullptr, 0}; - struct buffer_view candidates = {nullptr, 0}; + buffer_view logits = {nullptr, 0}; + buffer_view sampled = {nullptr, 0}; + buffer_view probs = {nullptr, 0}; + buffer_view candidates = {nullptr, 0}; std::vector logits_count; std::vector probs_count; std::vector candidates_count; + // optimization std::vector token_ids_full_vocab; }; diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index ff59eb0a92..fe0a57f778 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -17,6 +17,41 @@ #include #include +// dedup helpers + +static ggml_tensor * build_kq_mask( + ggml_context * ctx, + const llama_kv_cache_context * mctx, + const llama_ubatch & ubatch, + const llama_cparams & cparams) { + const auto n_kv = mctx->get_n_kv(); + const auto n_tokens = ubatch.n_tokens; + const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; + + return ggml_new_tensor_4d(ctx, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); +} + +static bool can_reuse_kq_mask( + ggml_tensor * kq_mask, + const llama_kv_cache_context * mctx, + const llama_ubatch & ubatch, + const llama_cparams & cparams) { + const auto n_kv = mctx->get_n_kv(); + const auto n_tokens = ubatch.n_tokens; + const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; + + bool res = true; + + res &= (kq_mask->ne[0] == n_kv); + res &= (kq_mask->ne[1] == n_tokens/n_stream); + res &= (kq_mask->ne[2] == 1); + res &= (kq_mask->ne[3] == n_stream); + + return res; +} + +// impl + void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { if (ubatch->token) { const int64_t n_tokens = ubatch->n_tokens; @@ -403,8 +438,7 @@ bool llm_graph_input_attn_kv::can_reuse(const llm_graph_params & params) { res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; //res &= self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - res &= self_kq_mask->ne[0] == mctx->get_n_kv(); - res &= self_kq_mask->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams); return res; } @@ -424,8 +458,7 @@ bool llm_graph_input_attn_k::can_reuse(const llm_graph_params & params) { res &= self_k_idxs->ne[0] == params.ubatch.n_tokens; - res &= self_kq_mask->ne[0] == mctx->get_n_kv(); - res &= self_kq_mask->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(self_kq_mask, mctx, params.ubatch, params.cparams); return res; } @@ -455,11 +488,8 @@ bool llm_graph_input_attn_kv_iswa::can_reuse(const llm_graph_params & params) { res &= self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; //res &= self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - res &= self_kq_mask->ne[0] == mctx->get_base()->get_n_kv(); - res &= self_kq_mask->ne[1] == params.ubatch.n_tokens; - - res &= self_kq_mask_swa->ne[0] == mctx->get_swa()->get_n_kv(); - res &= self_kq_mask_swa->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(self_kq_mask, mctx->get_base(), params.ubatch, params.cparams); + res &= can_reuse_kq_mask(self_kq_mask_swa, mctx->get_swa(), params.ubatch, params.cparams); return res; } @@ -521,8 +551,7 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv(); - res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams); res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); @@ -565,8 +594,7 @@ bool llm_graph_input_mem_hybrid_k::can_reuse(const llm_graph_params & params) { res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; - res &= inp_attn->self_kq_mask->ne[0] == mctx->get_attn()->get_n_kv(); - res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(inp_attn->self_kq_mask, mctx->get_attn(), params.ubatch, params.cparams); res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); @@ -625,8 +653,7 @@ bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) res &= inp_attn->self_k_idxs->ne[0] == params.ubatch.n_tokens; //res &= inp_attn->self_v_idxs->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - res &= inp_attn->self_kq_mask->ne[0] == attn_ctx->get_base()->get_n_kv(); - res &= inp_attn->self_kq_mask->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(inp_attn->self_kq_mask, attn_ctx->get_base(), params.ubatch, params.cparams); } // swa tensors may not be allocated if there are no SWA attention layers @@ -634,8 +661,7 @@ bool llm_graph_input_mem_hybrid_iswa::can_reuse(const llm_graph_params & params) res &= inp_attn->self_k_idxs_swa->ne[0] == params.ubatch.n_tokens; //res &= inp_attn->self_v_idxs_swa->ne[0] == params.ubatch.n_tokens; // TODO: need to move this to the unified cache and check there - res &= inp_attn->self_kq_mask_swa->ne[0] == attn_ctx->get_swa()->get_n_kv(); - res &= inp_attn->self_kq_mask_swa->ne[1] == params.ubatch.n_tokens; + res &= can_reuse_kq_mask(inp_attn->self_kq_mask_swa, attn_ctx->get_swa(), params.ubatch, params.cparams); } res &= inp_rs->s_copy->ne[0] == mctx->get_recr()->get_n_rs(); @@ -1891,14 +1917,11 @@ static std::unique_ptr build_attn_inp_kv_impl( { GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA"); - const auto n_kv = mctx_cur->get_n_kv(); - const auto n_tokens = ubatch.n_tokens; - const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; - inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams); + ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -1983,13 +2006,9 @@ static std::unique_ptr build_attn_inp_k_impl( { GGML_ASSERT(hparams.swa_type == LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache_iswa for SWA"); - const auto n_kv = mctx_cur->get_n_kv(); - const auto n_tokens = ubatch.n_tokens; - const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; - inp->self_k_idxs = mctx_cur->build_input_k_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur, ubatch, cparams); ggml_set_input(inp->self_kq_mask); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; @@ -2188,15 +2207,11 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const auto inp = std::make_unique(hparams, cparams, mctx_cur); - const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; - { - const auto n_kv = mctx_cur->get_base()->get_n_kv(); - inp->self_k_idxs = mctx_cur->get_base()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs = mctx_cur->get_base()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + inp->self_kq_mask = build_kq_mask(ctx0, mctx_cur->get_base(), ubatch, cparams); ggml_set_input(inp->self_kq_mask); ggml_set_name(inp->self_kq_mask, "self_kq_mask"); @@ -2207,12 +2222,10 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const { GGML_ASSERT(hparams.swa_type != LLAMA_SWA_TYPE_NONE && "Use llama_kv_cache for non-SWA"); - const auto n_kv = mctx_cur->get_swa()->get_n_kv(); - inp->self_k_idxs_swa = mctx_cur->get_swa()->build_input_k_idxs(ctx0, ubatch); inp->self_v_idxs_swa = mctx_cur->get_swa()->build_input_v_idxs(ctx0, ubatch); - inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + inp->self_kq_mask_swa = build_kq_mask(ctx0, mctx_cur->get_swa(), ubatch, cparams); ggml_set_input(inp->self_kq_mask_swa); ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa"); @@ -2374,27 +2387,21 @@ llm_graph_input_mem_hybrid_iswa * llm_graph_context::build_inp_mem_hybrid_iswa() auto inp_attn = std::make_unique(hparams, cparams, attn_ctx); - const auto n_stream = cparams.kv_unified ? 1 : ubatch.n_seqs_unq; - { - const auto n_kv = attn_ctx->get_base()->get_n_kv(); - inp_attn->self_k_idxs = attn_ctx->get_base()->build_input_k_idxs(ctx0, ubatch); inp_attn->self_v_idxs = attn_ctx->get_base()->build_input_v_idxs(ctx0, ubatch); - inp_attn->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + inp_attn->self_kq_mask = build_kq_mask(ctx0, attn_ctx->get_base(), ubatch, cparams); ggml_set_input(inp_attn->self_kq_mask); inp_attn->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask, GGML_TYPE_F16) : inp_attn->self_kq_mask; } { - const auto n_kv = attn_ctx->get_swa()->get_n_kv(); - inp_attn->self_k_idxs_swa = attn_ctx->get_swa()->build_input_k_idxs(ctx0, ubatch); inp_attn->self_v_idxs_swa = attn_ctx->get_swa()->build_input_v_idxs(ctx0, ubatch); - inp_attn->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); + inp_attn->self_kq_mask_swa = build_kq_mask(ctx0, attn_ctx->get_swa(), ubatch, cparams); ggml_set_input(inp_attn->self_kq_mask_swa); inp_attn->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp_attn->self_kq_mask_swa, GGML_TYPE_F16) : inp_attn->self_kq_mask_swa; diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 706eda8441..c4b2a99da5 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -193,6 +193,11 @@ struct llama_hparams { std::array xielu_beta; std::array xielu_eps; + // DSA (deepseek sparse attention) + uint32_t indexer_n_head = 0; + uint32_t indexer_head_size = 0; + uint32_t indexer_top_k = 0; + // qwen3vl deepstack uint32_t n_deepstack_layers = 0; diff --git a/src/llama-mmap.cpp b/src/llama-mmap.cpp index 0261e4c72c..c03228e9ce 100644 --- a/src/llama-mmap.cpp +++ b/src/llama-mmap.cpp @@ -504,6 +504,8 @@ struct llama_mmap::impl { } } #elif defined(_WIN32) + HANDLE hMapping = nullptr; + impl(struct llama_file * file, size_t prefetch, bool numa) { GGML_UNUSED(numa); @@ -511,7 +513,7 @@ struct llama_mmap::impl { HANDLE hFile = (HANDLE) _get_osfhandle(file->file_id()); - HANDLE hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL); + hMapping = CreateFileMappingA(hFile, NULL, PAGE_READONLY, 0, 0, NULL); if (hMapping == NULL) { DWORD error = GetLastError(); @@ -520,9 +522,9 @@ struct llama_mmap::impl { addr = MapViewOfFile(hMapping, FILE_MAP_READ, 0, 0, 0); DWORD error = GetLastError(); - CloseHandle(hMapping); if (addr == NULL) { + CloseHandle(hMapping); throw std::runtime_error(format("MapViewOfFile failed: %s", llama_format_win_err(error).c_str())); } @@ -554,9 +556,17 @@ struct llama_mmap::impl { } ~impl() { - if (!UnmapViewOfFile(addr)) { - LLAMA_LOG_WARN("warning: UnmapViewOfFile failed: %s\n", - llama_format_win_err(GetLastError()).c_str()); + if (hMapping) { + if (addr) { + if (!UnmapViewOfFile(addr)) { + LLAMA_LOG_WARN("warning: UnmapViewOfFile failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } + } + if (!CloseHandle(hMapping)) { + LLAMA_LOG_WARN("warning: CloseHandle failed: %s\n", + llama_format_win_err(GetLastError()).c_str()); + } } } #else diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 393c515f97..08502c82a2 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -137,6 +137,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_300B_A47B: return "300B.A47B"; case LLM_TYPE_310B_A15B: return "310B.A15B"; case LLM_TYPE_355B_A32B: return "355B.A32B"; + case LLM_TYPE_744B_A40B: return "744B.A40B"; case LLM_TYPE_E2B: return "E2B"; case LLM_TYPE_E4B: return "E4B"; default: return "?B"; @@ -1826,6 +1827,50 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_GLM_DSA: + { + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_ROPE_DIMENSION_SECTIONS, hparams.rope_sections, 4, false); + + // MoE parameters + ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert); + ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead, false); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); + + // deepseek MLA parameters + ml.get_key(LLM_KV_ATTENTION_Q_LORA_RANK, hparams.n_lora_q); + ml.get_key(LLM_KV_ATTENTION_KV_LORA_RANK, hparams.n_lora_kv); + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH_MLA, hparams.n_embd_head_k_mla_impl, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla_impl, false); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); + + // DSA parameters + ml.get_key(LLM_KV_ATTENTION_INDEXER_HEAD_COUNT, hparams.indexer_n_head); + ml.get_key(LLM_KV_ATTENTION_INDEXER_KEY_LENGTH, hparams.indexer_head_size); + ml.get_key(LLM_KV_ATTENTION_INDEXER_TOP_K, hparams.indexer_top_k); + + // Expert gating function (GLM-4.5 uses sigmoid) + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); + if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { + hparams.expert_gating_func = LLAMA_EXPERT_GATING_FUNC_TYPE_SIGMOID; + } + + // NextN/MTP parameters + ml.get_key(LLM_KV_NEXTN_PREDICT_LAYERS, hparams.nextn_predict_layers, false); + + // TODO: when MTP is implemented, this should probably be updated if needed + hparams.n_layer_kv_from_start = hparams.n_layer - hparams.nextn_predict_layers; + + switch (hparams.n_layer) { + case 79: type = LLM_TYPE_744B_A40B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_BITNET: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -5529,6 +5574,108 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; + case LLM_ARCH_GLM_DSA: + { + const bool is_mla = hparams.is_mla(); + if (!is_mla) { + throw std::runtime_error("GLM_DSA architecture requires MLA"); + } + + // note: these are the actual head sizes you get when treating as MHA or after "decompression" using wv_b for MLA + const int64_t n_embd_head_k_mla = hparams.n_embd_head_k_mla(); + const int64_t n_embd_head_v_mla = hparams.n_embd_head_v_mla(); + + const int64_t n_embd_head_qk_rope = hparams.n_rot; + const int64_t n_embd_head_qk_nope = n_embd_head_k_mla - n_embd_head_qk_rope; + + const int64_t q_lora_rank = hparams.n_lora_q; + const int64_t kv_lora_rank = hparams.n_lora_kv; + + const int64_t n_ff_exp = hparams.n_ff_exp; + const int64_t n_expert_shared = hparams.n_expert_shared; + + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + int flags = 0; + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + // skip all tensors in the NextN layers + // TODO @ngxson : TENSOR_NOT_REQUIRED was a hack, need to remove it later + flags |= TENSOR_SKIP | TENSOR_NOT_REQUIRED; + } + + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, flags); + layer.attn_q_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_A_NORM, "weight", i), {q_lora_rank}, flags); + layer.attn_kv_a_norm = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_NORM, "weight", i), {kv_lora_rank}, flags); + + layer.wq_a = create_tensor(tn(LLM_TENSOR_ATTN_Q_A, "weight", i), {n_embd, q_lora_rank}, flags); + layer.wq_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_B, "weight", i), {q_lora_rank, n_head * n_embd_head_k_mla}, flags); + + layer.wkv_a_mqa = create_tensor(tn(LLM_TENSOR_ATTN_KV_A_MQA, "weight", i), {n_embd, kv_lora_rank + n_embd_head_qk_rope}, flags); + + // note: only old legacy GGUF files will have the unsplit wkv_b tensor in + layer.wk_b = create_tensor(tn(LLM_TENSOR_ATTN_K_B, "weight", i), {n_embd_head_qk_nope, kv_lora_rank, n_head}, flags); + layer.wv_b = create_tensor(tn(LLM_TENSOR_ATTN_V_B, "weight", i), {kv_lora_rank, n_embd_head_v_mla, n_head}, flags); + + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_head * n_embd_head_v_mla, n_embd}, flags); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, flags); + + // DSA indexer + layer.indexer_k_norm = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "weight", i), {hparams.indexer_head_size}, flags); + layer.indexer_k_norm_b = create_tensor(tn(LLM_TENSOR_INDEXER_K_NORM, "bias", i), {hparams.indexer_head_size}, flags); + layer.indexer_proj = create_tensor(tn(LLM_TENSOR_INDEXER_PROJ, "weight", i), {n_embd, hparams.indexer_n_head}, flags); + layer.indexer_attn_k = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_K, "weight", i), {n_embd, hparams.indexer_head_size}, flags); + layer.indexer_attn_q_b = create_tensor(tn(LLM_TENSOR_INDEXER_ATTN_Q_B, "weight", i), {q_lora_rank, hparams.indexer_n_head * hparams.indexer_head_size}, flags); + if (i < (int) hparams.n_layer_dense_lead) { + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, flags); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, flags); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, flags); + } else { + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, flags); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); + + if (n_expert == 0) { + throw std::runtime_error("n_expert must be > 0"); + } + if (n_expert_used == 0) { + throw std::runtime_error("n_expert_used must be > 0"); + } + + // MoE branch + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {n_ff_exp, n_embd, n_expert}, flags); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert}, flags); + + // Shared expert branch + layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); + layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { n_ff_exp * n_expert_shared, n_embd}, flags); + layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), {n_embd, n_ff_exp * n_expert_shared}, flags); + } + + // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers + if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { + layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); + layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); + layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); + + // Optional tensors + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); + } + } + } break; case LLM_ARCH_NEMOTRON: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); @@ -7802,7 +7949,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: expert_weights_scale = %.1f\n", __func__, hparams.expert_weights_scale); } - if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_DEEPSEEK2OCR) { + if (arch == LLM_ARCH_DEEPSEEK2 || arch == LLM_ARCH_DEEPSEEK2OCR || arch == LLM_ARCH_GLM_DSA) { LLAMA_LOG_INFO("%s: n_layer_dense_lead = %d\n", __func__, hparams.n_layer_dense_lead); LLAMA_LOG_INFO("%s: n_lora_q = %d\n", __func__, hparams.n_lora_q); LLAMA_LOG_INFO("%s: n_lora_kv = %d\n", __func__, hparams.n_lora_kv); @@ -8002,7 +8149,6 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, cparams.n_seq_max, nullptr); } else if (llm_arch_is_hybrid(arch)) { - // The main difference between hybrid architectures is the // layer filters, so pick the right one here llama_memory_hybrid::layer_filter_cb filter_attn = nullptr; @@ -8027,7 +8173,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* attn_type_v */ params.type_v, /* attn_v_trans */ !cparams.flash_attn, /* attn_swa_full */ params.swa_full, - /* attn_kv_size */ cparams.n_ctx, + /* attn_kv_size */ cparams.n_ctx_seq, /* attn_n_ubatch */ cparams.n_ubatch, /* attn_n_pad */ 1, /* recurrent_type_r */ GGML_TYPE_F32, @@ -8044,7 +8190,7 @@ llama_memory_i * llama_model::create_memory(const llama_memory_params & params, /* attn_type_k */ params.type_k, /* attn_type_v */ params.type_v, /* attn_v_trans */ !cparams.flash_attn, - /* attn_kv_size */ cparams.n_ctx, + /* attn_kv_size */ cparams.n_ctx_seq, /* attn_n_pad */ 1, /* attn_n_swa */ hparams.n_swa, /* attn_swa_type */ hparams.swa_type, @@ -8375,6 +8521,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique(*this, params); } break; case LLM_ARCH_DEEPSEEK2: + case LLM_ARCH_GLM_DSA: case LLM_ARCH_DEEPSEEK2OCR: { llm = std::make_unique(*this, params); @@ -8778,6 +8925,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_MISTRAL3: case LLM_ARCH_LLAMA_EMBED: case LLM_ARCH_MAINCODER: + case LLM_ARCH_GLM_DSA: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 diff --git a/src/llama-model.h b/src/llama-model.h index adc8ff6479..b350591429 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -130,6 +130,7 @@ enum llm_type { LLM_TYPE_300B_A47B, // Ernie MoE big LLM_TYPE_310B_A15B, // /MiMo-V2-Flash LLM_TYPE_355B_A32B, // GLM-4.5 + LLM_TYPE_744B_A40B, // GLM-5 LLM_TYPE_E2B, LLM_TYPE_E4B, }; @@ -429,6 +430,13 @@ struct llama_layer { struct ggml_tensor * ssm_g_b = nullptr; struct ggml_tensor * ssm_o_norm = nullptr; + // DSA (deepseek sparse attention) + struct ggml_tensor * indexer_k_norm = nullptr; + struct ggml_tensor * indexer_k_norm_b = nullptr; + struct ggml_tensor * indexer_proj = nullptr; + struct ggml_tensor * indexer_attn_k = nullptr; + struct ggml_tensor * indexer_attn_q_b = nullptr; // note: for lora a/b, not bias + struct llama_layer_posnet posnet; struct llama_layer_convnext convnext; diff --git a/src/models/deepseek2.cpp b/src/models/deepseek2.cpp index 5b1234a1d6..45cb33deb9 100644 --- a/src/models/deepseek2.cpp +++ b/src/models/deepseek2.cpp @@ -48,7 +48,8 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr ggml_tensor * inp_out_ids = build_inp_out_ids(); - for (int il = 0; il < n_layer; ++il) { + int effective_n_layers = hparams.n_layer - hparams.nextn_predict_layers; + for (int il = 0; il < effective_n_layers; ++il) { ggml_tensor * inpSA = inpL; // norm @@ -222,7 +223,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); } } - if (il == n_layer - 1 && inp_out_ids) { + if (il == effective_n_layers - 1 && inp_out_ids) { cur = ggml_get_rows(ctx0, cur, inp_out_ids); inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); } diff --git a/src/models/delta-net-base.cpp b/src/models/delta-net-base.cpp new file mode 100644 index 0000000000..0cdf9c324b --- /dev/null +++ b/src/models/delta-net-base.cpp @@ -0,0 +1,333 @@ +#include "models.h" + +#define CHUNK_SIZE 64 + +// utility to get one slice from the third dimension +// input dim: [x, y, c, b] +// output dim: [x, y, 1, b] +static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t c) { + return ggml_view_4d(ctx0, t, t->ne[0], t->ne[1], 1, t->ne[3], + t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c); +} + +llm_build_delta_net_base::llm_build_delta_net_base(const llm_graph_params & params) : llm_graph_context(params) {} + +std::pair llm_build_delta_net_base::build_delta_net_chunking( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il) { + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(S_k == S_v); + GGML_ASSERT(H_v % H_k == 0); + + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs); + + GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); + GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); + GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); + + const float scale = 1.0f / sqrtf(S_k); + + q = ggml_scale(ctx0, q, scale); + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(b, "b_in", il); + cb(g, "g_in", il); + + q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] + k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] + v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs] + g = ggml_permute(ctx0, g, 2, 1, 3, 0); // [ 1, n_tokens, H_v, n_seqs] + b = ggml_permute(ctx0, b, 2, 0, 1, 3); // [ 1, n_tokens, H_v, n_seqs] + + const int CS = CHUNK_SIZE; + + const int pad = (CS - n_tokens % CS) % CS; + const int n_chunks = (n_tokens + pad) / CS; + + q = ggml_pad(ctx0, q, 0, pad, 0, 0); + k = ggml_pad(ctx0, k, 0, pad, 0, 0); + v = ggml_pad(ctx0, v, 0, pad, 0, 0); + g = ggml_pad(ctx0, g, 0, pad, 0, 0); + b = ggml_pad(ctx0, b, 0, pad, 0, 0); + + ggml_tensor * v_b = ggml_mul(ctx0, v, b); + ggml_tensor * k_b = ggml_mul(ctx0, k, b); + + cb(v_b, "v_b", il); + cb(k_b, "k_b", il); + + q = ggml_reshape_4d(ctx0, q, S_k, CS, n_chunks, H_k * n_seqs); + k = ggml_reshape_4d(ctx0, k, S_k, CS, n_chunks, H_k * n_seqs); + k_b = ggml_reshape_4d(ctx0, k_b, S_k, CS, n_chunks, H_v * n_seqs); + v = ggml_reshape_4d(ctx0, v, S_v, CS, n_chunks, H_v * n_seqs); + v_b = ggml_reshape_4d(ctx0, v_b, S_v, CS, n_chunks, H_v * n_seqs); + + g = ggml_reshape_4d(ctx0, g, CS, 1, n_chunks, H_v * n_seqs); + b = ggml_reshape_4d(ctx0, b, 1, CS, n_chunks, H_v * n_seqs); + + // [CS, 1, n_chunks, H_v * n_seqs] + ggml_tensor * g_cs = ggml_cumsum(ctx0, g); + cb(g_cs, "g_cs", il); + + ggml_tensor * g_cs_i = g_cs; + ggml_tensor * g_cs_j = ggml_reshape_4d(ctx0, g_cs, 1, CS, n_chunks, H_v * n_seqs); + + g_cs_j = ggml_repeat_4d(ctx0, g_cs_j, CS, CS, n_chunks, H_v * n_seqs); + + // [CS, CS, n_chunks, H_v * n_seqs] + ggml_tensor * decay_mask; + decay_mask = ggml_sub(ctx0, g_cs_j, g_cs_i); + decay_mask = ggml_tri(ctx0, decay_mask, GGML_TRI_TYPE_LOWER_DIAG); + decay_mask = ggml_exp(ctx0, decay_mask); + cb(decay_mask, "decay_mask", il); + + // [CS, CS, n_chunks, H_k * n_seqs] + ggml_tensor * kb; + kb = ggml_mul_mat(ctx0, k, k_b); + kb = ggml_mul (ctx0, kb, decay_mask); + + // [CS, CS, n_chunks, H_k * n_seqs] + ggml_tensor * attn; + attn = ggml_tri(ctx0, kb, GGML_TRI_TYPE_LOWER); + + ggml_tensor * identity; + identity = ggml_view_1d(ctx0, attn, CS, 0); + identity = ggml_fill (ctx0, identity, 1.0f); + identity = ggml_diag (ctx0, identity); + + ggml_tensor * lhs = ggml_add(ctx0, attn, identity); + cb(lhs, "dnet_add_ch_lhs", il); + + attn = ggml_neg(ctx0, attn); + + ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); + attn = ggml_add(ctx0, lin_solve, identity); + cb(attn, "dnet_add_ch_attn_solved", il); // [CS, CS, n_chunks, H_k * n_seqs] + + // [S_v, CS, n_chunks, H_v * n_seqs] + v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_b)), attn); + + // [CS, 1, n_chunks, H_v * n_seqs] + ggml_tensor * g_exp = ggml_exp(ctx0, g_cs); + + k_b = ggml_cont(ctx0, ggml_transpose(ctx0, k_b)); + + // [CS, S_k, n_chunks, H_k * n_seqs] + ggml_tensor * kbg = ggml_mul(ctx0, k_b, g_exp); + cb(kbg, "k_beta_g_exp", il); + + // [S_k, CS, n_chunks, H_k * n_seqs] + ggml_tensor * k_cd = ggml_mul_mat(ctx0, kbg, attn); + cb(k_cd, "k_cumdecay", il); + + // [S_k, CS, n_chunks, H_k * n_seqs] + ggml_tensor * g_exp_t = ggml_transpose(ctx0, g_exp); + ggml_tensor * q_g_exp = ggml_mul(ctx0, q, g_exp_t); + + // [CS, CS, n_chunks, H_k * n_seqs] + ggml_tensor * kq = ggml_mul_mat(ctx0, k, q); + kq = ggml_mul(ctx0, kq, decay_mask); + kq = ggml_tri(ctx0, kq, GGML_TRI_TYPE_LOWER_DIAG); + cb(kq, "kq", il); + + // vectorized calculation of key_gdiff + // improved from the chunked version: + // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) + // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() + // key_gdiff = key * g_diff.unsqueeze(-1) + // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new + // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew + + // get last element in g_cumsum along CS dimension (ne0) + // example: [[x, y, z, ..., last], ...] -> [[last], ...] + // [1, 1, n_chunks, H_v * n_seqs] + ggml_tensor * g_last = ggml_view_4d(ctx0, g_cs, 1, 1, g_cs->ne[2], g_cs->ne[3], + g_cs->nb[1], + g_cs->nb[2], + g_cs->nb[3], + ggml_row_size(g_cs->type, g_cs->ne[0] - 1)); + cb(g_last, "g_last", il); + + // TODO: remove this cont when CUDA supports non-cont unary ops + g_last = ggml_cont(ctx0, g_last); + + // [1, 1, n_chunks, H_v * n_seqs] + ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last); + cb(g_last_exp, "g_last_exp", il); + + // [CS, 1, n_chunks, H_v * n_seqs] + ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cs, g_last)); + cb(g_diff, "g_diff", il); + + ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); + ggml_tensor * g_diff_exp_t = ggml_transpose(ctx0, g_diff_exp); + + // [S_k, CS, n_chunks, H_v * n_seqs] + ggml_tensor * kg = ggml_mul(ctx0, k, g_diff_exp_t); + cb(kg, "key_gdiff", il); + + // [CS, S_k, n_chunks, H_v * n_seqs] + ggml_tensor * kg_t = ggml_cont(ctx0, ggml_transpose(ctx0, kg)); + cb(kg_t, "key_gdiff_t", il); + + ggml_tensor * s_t = ggml_transpose(ctx0, s); + s_t = ggml_cont_4d(ctx0, s_t, S_v, S_v, 1, H_v * n_seqs); + cb(s_t, "dnet_add_ch_state", il); + + // [CS, S_v, n_chunks, H_v * n_seqs] + ggml_tensor * v_t = ggml_cont(ctx0, ggml_transpose(ctx0, v)); + + for (int64_t chunk = 0; chunk < n_chunks; chunk++) { + ggml_tensor * ch_k_cd = get_slice_2d(ctx0, k_cd, chunk); // [S_k, CS, 1, H_k * n_seqs] + ggml_tensor * ch_v_t = get_slice_2d(ctx0, v_t, chunk); // [ CS, S_v, 1, H_v * n_seqs] + ggml_tensor * ch_kq = get_slice_2d(ctx0, kq, chunk); // [ CS, CS, 1, H_k * n_seqs] + ggml_tensor * ch_q_g_exp = get_slice_2d(ctx0, q_g_exp, chunk); // [S_k, CS, 1, H_k * n_seqs] + ggml_tensor * ch_kg_t = get_slice_2d(ctx0, kg_t, chunk); // [ CS, S_k, 1, H_v * n_seqs] + + // [CS, S_v, 1, H_v * n_seqs] + ggml_tensor * v_t_p = ggml_mul_mat(ctx0, ch_k_cd, s_t); + cb(v_t_p, "v_prime", il); + + // [CS, S_v, 1, H_v * n_seqs] + ggml_tensor * v_t_new = ggml_sub(ctx0, ch_v_t, v_t_p); + cb(v_t_new, "v_t_new", il); + + // [S_v, CS, 1, H_v * n_seqs] + ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_t_new, ch_kq); + cb(v_attn, "v_attn", il); + + // [S_v, CS, 1, H_v * n_seqs] + ggml_tensor * attn_inter = ggml_mul_mat(ctx0, s_t, ch_q_g_exp); + cb(attn_inter, "attn_inter", il); + + // [S_v, CS, 1, H_v * n_seqs] + ggml_tensor * o_ch = ggml_add(ctx0, attn_inter, v_attn); + cb(o_ch, "dnet_add_ch_attn_out", il); + + v = ggml_set_inplace(ctx0, v, o_ch, v->nb[1], v->nb[2], v->nb[3], chunk * v->nb[2]); + + // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new + // TODO: head broadcast might not work here - probably will need a transpose + ggml_tensor * kgv = ggml_mul_mat(ctx0, ch_kg_t, v_t_new); // [S_k, S_v, 1, H_k * n_seqs] + + // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew + ggml_tensor * ch_g_last_exp = get_slice_2d(ctx0, g_last_exp, chunk); + s_t = ggml_mul(ctx0, s_t, ch_g_last_exp); + s_t = ggml_add(ctx0, s_t, kgv); + cb(s_t, "dnet_add_ch_state", il); + } + + s_t = ggml_reshape_4d(ctx0, s_t, S_v, S_v, H_v, n_seqs); + + // truncate padded tokens + ggml_tensor * o = ggml_view_4d(ctx0, v, + S_v, n_tokens, H_v, n_seqs, + ggml_row_size(v->type, S_v), + ggml_row_size(v->type, S_v * CS * n_chunks), + ggml_row_size(v->type, S_v * CS * n_chunks * H_v), 0); + + o = ggml_permute (ctx0, o, 0, 2, 1, 3); // [S_v, H_v, n_tokens, n_seqs] + s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs] + + return {o, s}; +} + +std::pair llm_build_delta_net_base::build_delta_net_autoregressive( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, // beta + ggml_tensor * s, // state + int il) { + const int64_t S_k = q->ne[0]; + const int64_t H_k = q->ne[1]; + const int64_t n_tokens = q->ne[2]; + const int64_t n_seqs = q->ne[3]; + + const int64_t S_v = v->ne[0]; + const int64_t H_v = v->ne[1]; + + GGML_ASSERT(n_tokens == 1); + + GGML_ASSERT(S_k == S_v); + GGML_ASSERT(H_v % H_k == 0); + + GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); + GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); + GGML_ASSERT(v->ne[0] == S_v && v->ne[1] == H_v && v->ne[2] == n_tokens && v->ne[3] == n_seqs); + + GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); + GGML_ASSERT(b->ne[0] == H_v && b->ne[2] == n_tokens && b->ne[3] == n_seqs); + GGML_ASSERT(s->ne[0] == S_v && s->ne[1] == S_v && s->ne[2] == H_v && s->ne[3] == n_seqs); + + const float scale = 1.0f / sqrtf(S_k); + + q = ggml_scale(ctx0, q, scale); + + q = ggml_permute(ctx0, q, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] + k = ggml_permute(ctx0, k, 0, 2, 1, 3); // [S_k, n_tokens, H_k, n_seqs] + v = ggml_permute(ctx0, v, 0, 2, 1, 3); // [S_v, n_tokens, H_v, n_seqs] + + cb(q, "q_in", il); + cb(k, "k_in", il); + cb(v, "v_in", il); + cb(b, "b_in", il); + cb(g, "g_in", il); + + g = ggml_reshape_4d(ctx0, g, 1, 1, H_v, n_seqs); + b = ggml_reshape_4d(ctx0, b, 1, 1, H_v, n_seqs); + + // [S_v, S_v, H_v, n_seqs] + g = ggml_exp(ctx0, g); + s = ggml_mul(ctx0, s, g); + + ggml_tensor * s_t = ggml_cont(ctx0, ggml_transpose(ctx0, s)); + + // [1, S_v, H_v, n_seqs] + ggml_tensor * sk; + sk = ggml_mul (ctx0, s_t, k); + sk = ggml_sum_rows(ctx0, sk); + + // [S_v, 1, H_v, n_seqs] + ggml_tensor * d; + d = ggml_sub(ctx0, v, ggml_transpose(ctx0, sk)); + d = ggml_mul(ctx0, d, b); + + // [1, S_v, H_v, n_seqs] + ggml_tensor * d_t; + d_t = ggml_transpose(ctx0, d); + + // [S_v, S_v, H_v, n_seqs] + ggml_tensor * kd; + k = ggml_repeat(ctx0, k, s); + kd = ggml_mul (ctx0, k, d_t); + + s_t = ggml_add(ctx0, s_t, kd); + + cb(s_t, "dnet_add_ar_state", il); + + ggml_tensor * s_q = ggml_mul (ctx0, s_t, q); + ggml_tensor * o = ggml_sum_rows(ctx0, s_q); + + o = ggml_permute (ctx0, o, 2, 0, 1, 3); // [S_v, H_v, n_tokens, n_seqs] + s = ggml_transpose(ctx0, s_t); // [S_v, S_v, H_v, n_seqs] + + return {o, s}; +} diff --git a/src/models/falcon-h1.cpp b/src/models/falcon-h1.cpp index b641a09407..785a7e5e66 100644 --- a/src/models/falcon-h1.cpp +++ b/src/models/falcon-h1.cpp @@ -1,9 +1,7 @@ #include "models.h" - - llm_build_falcon_h1::llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params) { + llm_build_mamba_base(params) { const int64_t n_embd_head = hparams.n_embd_head_v; ggml_tensor * cur; diff --git a/src/models/granite-hybrid.cpp b/src/models/granite-hybrid.cpp index f6ca4c17a2..726ecdcca7 100644 --- a/src/models/granite-hybrid.cpp +++ b/src/models/granite-hybrid.cpp @@ -2,7 +2,7 @@ llm_build_granite_hybrid::llm_build_granite_hybrid(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params) { + llm_build_mamba_base(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); diff --git a/src/models/jamba.cpp b/src/models/jamba.cpp index a0187772cc..ceab581740 100644 --- a/src/models/jamba.cpp +++ b/src/models/jamba.cpp @@ -1,6 +1,6 @@ #include "models.h" -llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) { +llm_build_jamba::llm_build_jamba(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) { const int64_t n_embd_head = hparams.n_embd_head_v; ggml_tensor * cur; diff --git a/src/models/kimi-linear.cpp b/src/models/kimi-linear.cpp index 0f037d1a39..133834021d 100644 --- a/src/models/kimi-linear.cpp +++ b/src/models/kimi-linear.cpp @@ -1,6 +1,8 @@ #include "models.h" #include "ggml.h" +#include "llama-memory-recurrent.h" + #define CHUNK_SIZE 64 // Causal Conv1d function for Q,K,V @@ -41,8 +43,11 @@ static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_t conv_x->nb[1], conv_x->nb[2], n_seq_tokens * conv_x->nb[0]); ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_x, - ggml_view_1d(ctx0, conv_states_all, conv_state_size * n_seqs, - (kv_head * n_embd_r_total + qkv * conv_state_size) * ggml_element_size(conv_states_all)))); + ggml_view_3d(ctx0, conv_states_all, + d_conv - 1, d_inner, n_seqs, + (d_conv - 1) * ggml_element_size(conv_states_all), // nb1: contiguous within one channel's conv taps + n_embd_r_total * ggml_element_size(conv_states_all), // nb2: stride between sequences (skip over K,V states) + (kv_head * n_embd_r_total + qkv * conv_state_size) * ggml_element_size(conv_states_all)))); // offset to first seq's Q/K/V state // Reshape conv weight: GGUF [d_conv, 1, d_inner, 1] -> ggml_ssm_conv expects [d_conv, d_inner] // GGUF stores as [d_conv, 1, d_inner, 1] with memory layout w[conv_step + channel * d_conv] // vLLM stores as [d_inner, d_conv] with memory layout w[channel * d_conv + conv_step] @@ -62,7 +67,7 @@ static ggml_tensor * causal_conv1d(ggml_cgraph * gf, ggml_context * ctx0, ggml_t } llm_build_kimi_linear::llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params), model(model) { + llm_build_mamba_base(params), model(model) { ggml_tensor * cur; ggml_tensor * inpL; diff --git a/src/models/graph-context-mamba.cpp b/src/models/mamba-base.cpp similarity index 97% rename from src/models/graph-context-mamba.cpp rename to src/models/mamba-base.cpp index b9a363b32b..aaac9487df 100644 --- a/src/models/graph-context-mamba.cpp +++ b/src/models/mamba-base.cpp @@ -1,8 +1,10 @@ #include "models.h" -llm_graph_context_mamba::llm_graph_context_mamba(const llm_graph_params & params) : llm_graph_context(params) {} +#include "llama-memory-recurrent.h" -ggml_tensor * llm_graph_context_mamba::build_mamba_layer(llm_graph_input_rs * inp, +llm_build_mamba_base::llm_build_mamba_base(const llm_graph_params & params) : llm_graph_context(params) {} + +ggml_tensor * llm_build_mamba_base::build_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, @@ -143,7 +145,7 @@ ggml_tensor * llm_graph_context_mamba::build_mamba_layer(llm_graph_input_rs * in return cur; } -ggml_tensor * llm_graph_context_mamba::build_mamba2_layer(llm_graph_input_rs * inp, +ggml_tensor * llm_build_mamba_base::build_mamba2_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, diff --git a/src/models/mamba.cpp b/src/models/mamba.cpp index 46819613c2..55fd2e055c 100644 --- a/src/models/mamba.cpp +++ b/src/models/mamba.cpp @@ -1,7 +1,6 @@ #include "models.h" - -llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_params & params) : llm_graph_context_mamba(params) { +llm_build_mamba::llm_build_mamba(const llama_model & model, const llm_graph_params & params) : llm_build_mamba_base(params) { ggml_tensor * cur; ggml_tensor * inpL; diff --git a/src/models/models.h b/src/models/models.h index 3c66d32531..920a8e5798 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -1,23 +1,51 @@ #pragma once -#include "../llama-model.h" -#include "../llama-graph.h" +#include "llama-model.h" +#include "llama-graph.h" -// TODO: remove in follow-up PR - move to .cpp files -#include "../llama-memory-recurrent.h" +// note: almost all graphs require atleast sqrtf, so include cmath globally #include -struct llm_graph_context_mamba : public llm_graph_context { - llm_graph_context_mamba(const llm_graph_params & params); +// +// base classes +// - virtual ~llm_graph_context_mamba() = default; +struct llm_build_mamba_base : public llm_graph_context { + llm_build_mamba_base(const llm_graph_params & params); + + virtual ~llm_build_mamba_base() = default; ggml_tensor * build_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il); ggml_tensor * build_mamba2_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il) const; }; -// Base class for RWKV-related models +struct llm_build_delta_net_base : public llm_graph_context { + llm_build_delta_net_base(const llm_graph_params & params); + + virtual ~llm_build_delta_net_base() = default; + + // returns pair of output and new state + std::pair build_delta_net_chunking( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il); + + // returns pair of output and new state + std::pair build_delta_net_autoregressive( + ggml_tensor * q, + ggml_tensor * k, + ggml_tensor * v, + ggml_tensor * g, + ggml_tensor * b, + ggml_tensor * s, + int il); +}; + struct llm_build_rwkv6_base : public llm_graph_context { const llama_model & model; @@ -58,6 +86,10 @@ struct llm_build_rwkv7_base : public llm_graph_context { int il) const; }; +// +// models +// + struct llm_build_afmoe : public llm_graph_context { llm_build_afmoe(const llama_model & model, const llm_graph_params & params); }; @@ -175,7 +207,7 @@ struct llm_build_falcon : public llm_graph_context { llm_build_falcon(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_falcon_h1 : public llm_graph_context_mamba { +struct llm_build_falcon_h1 : public llm_build_mamba_base { llm_build_falcon_h1(const llama_model & model, const llm_graph_params & params); }; @@ -253,7 +285,7 @@ private: const int il); }; -struct llm_build_granite_hybrid : public llm_graph_context_mamba { +struct llm_build_granite_hybrid : public llm_build_mamba_base { llm_build_granite_hybrid(const llama_model & model, const llm_graph_params & params); ggml_tensor * build_layer_ffn(ggml_tensor * cur, ggml_tensor * inpSA, const llama_model & model, const int il); ggml_tensor * build_attention_layer(ggml_tensor * cur, ggml_tensor * inp_pos, llm_graph_input_attn_kv * inp_attn, @@ -284,11 +316,12 @@ struct llm_build_jais : public llm_graph_context { llm_build_jais(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_jamba : public llm_graph_context_mamba { +struct llm_build_jamba : public llm_build_mamba_base { llm_build_jamba(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_kimi_linear : public llm_graph_context_mamba { +// TODO: derive llm_build_delta_net_base instead +struct llm_build_kimi_linear : public llm_build_mamba_base { llm_build_kimi_linear(const llama_model & model, const llm_graph_params & params); std::pair build_kda_autoregressive( @@ -347,7 +380,7 @@ struct llm_build_maincoder : public llm_graph_context { llm_build_maincoder(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_mamba : public llm_graph_context_mamba { +struct llm_build_mamba : public llm_build_mamba_base { llm_build_mamba(const llama_model & model, const llm_graph_params & params); }; @@ -379,11 +412,11 @@ struct llm_build_nemotron : public llm_graph_context { llm_build_nemotron(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_nemotron_h : public llm_graph_context_mamba { +struct llm_build_nemotron_h : public llm_build_mamba_base { llm_build_nemotron_h(const llama_model & model, const llm_graph_params & params); - ggml_tensor * build_ffn_layer(ggml_tensor * cur, const llama_model & model, const int il); + ggml_tensor * build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il); ggml_tensor * build_attention_layer(ggml_tensor * cur, llm_graph_input_attn_kv * inp_attn, - const llama_model & model, const int64_t n_embd_head, const int il); + const llama_model & model, int64_t n_embd_head, int il); }; struct llm_build_neo_bert : public llm_graph_context { @@ -428,7 +461,7 @@ struct llm_build_phi3 : public llm_graph_context { llm_build_phi3(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_plamo2 : public llm_graph_context_mamba { +struct llm_build_plamo2 : public llm_build_mamba_base { llm_build_plamo2(const llama_model & model, const llm_graph_params & params); private: ggml_tensor * build_plamo2_mamba_layer(llm_graph_input_rs * inp, ggml_tensor * cur, const llama_model & model, const llama_ubatch & ubatch, int il); @@ -477,7 +510,7 @@ struct llm_build_qwen3vlmoe : public llm_graph_context { llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params); }; -struct llm_build_qwen3next : public llm_graph_context_mamba { +struct llm_build_qwen3next : public llm_build_delta_net_base { llm_build_qwen3next(const llama_model & model, const llm_graph_params & params); private: ggml_tensor * build_layer_attn( @@ -489,38 +522,12 @@ private: ggml_tensor * build_layer_attn_linear( llm_graph_input_rs * inp, ggml_tensor * cur, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, int il); ggml_tensor * build_layer_ffn( ggml_tensor * cur, int il); - // returns pair of output and new state - std::pair build_delta_net_chunking( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, - int il); - - // returns pair of output and new state - std::pair build_delta_net_autoregressive( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - int il); - ggml_tensor * build_norm_gated( ggml_tensor * input, ggml_tensor * weights, @@ -535,7 +542,8 @@ private: const llama_model & model; }; -struct llm_build_qwen35 : public llm_graph_context_mamba { +// TODO: derive llm_build_delta_net_base instead +struct llm_build_qwen35 : public llm_graph_context { llm_build_qwen35(const llama_model & model, const llm_graph_params & params); private: ggml_tensor * build_layer_attn( @@ -553,6 +561,7 @@ private: ggml_tensor * diag_mask, int il); + ggml_tensor * build_layer_ffn( ggml_tensor * cur, int il); @@ -594,7 +603,8 @@ private: const llama_model & model; }; -struct llm_build_qwen35moe : public llm_graph_context_mamba { +// TODO: derive llm_build_delta_net_base instead +struct llm_build_qwen35moe : public llm_graph_context { llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params); private: ggml_tensor * build_layer_attn( diff --git a/src/models/nemotron-h.cpp b/src/models/nemotron-h.cpp index 079c730ac2..d61d62a8c9 100644 --- a/src/models/nemotron-h.cpp +++ b/src/models/nemotron-h.cpp @@ -1,9 +1,7 @@ #include "models.h" - - llm_build_nemotron_h::llm_build_nemotron_h(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params) { + llm_build_mamba_base(params) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); @@ -65,8 +63,8 @@ llm_build_nemotron_h::llm_build_nemotron_h(const llama_model & model, const llm_ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor * cur, llm_graph_input_attn_kv * inp_attn, const llama_model & model, - const int64_t n_embd_head, - const int il) { + int64_t n_embd_head, + int il) { // compute Q and K ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); cb(Qcur, "Qcur", il); @@ -106,7 +104,7 @@ ggml_tensor * llm_build_nemotron_h::build_attention_layer(ggml_tensor * return cur; } -ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, const int il) { +ggml_tensor * llm_build_nemotron_h::build_ffn_layer(ggml_tensor * cur, const llama_model & model, int il) { if (model.layers[il].ffn_gate_inp == nullptr) { cur = build_ffn(cur, model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, diff --git a/src/models/plamo2.cpp b/src/models/plamo2.cpp index 31115a08f9..3af236843b 100644 --- a/src/models/plamo2.cpp +++ b/src/models/plamo2.cpp @@ -1,7 +1,9 @@ #include "models.h" +#include "llama-memory-recurrent.h" + llm_build_plamo2::llm_build_plamo2(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params) { + llm_build_mamba_base(params) { ggml_tensor * cur; ggml_tensor * inpL; diff --git a/src/models/qwen35.cpp b/src/models/qwen35.cpp index 592c170457..94c68dbb26 100644 --- a/src/models/qwen35.cpp +++ b/src/models/qwen35.cpp @@ -1,10 +1,11 @@ -#include "ggml.h" #include "models.h" +#include "llama-memory-recurrent.h" + #define CHUNK_SIZE 64 llm_build_qwen35::llm_build_qwen35(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params), model(model) { + llm_graph_context(params), model(model) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); diff --git a/src/models/qwen35moe.cpp b/src/models/qwen35moe.cpp index 0db8f825c6..93da7ea628 100644 --- a/src/models/qwen35moe.cpp +++ b/src/models/qwen35moe.cpp @@ -1,10 +1,11 @@ -#include "ggml.h" #include "models.h" +#include "llama-memory-recurrent.h" + #define CHUNK_SIZE 64 llm_build_qwen35moe::llm_build_qwen35moe(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params), model(model) { + llm_graph_context(params), model(model) { const int64_t n_embd_head = hparams.n_embd_head_v; GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); diff --git a/src/models/qwen3next.cpp b/src/models/qwen3next.cpp index 99b1a76a48..0fdf2d42c2 100644 --- a/src/models/qwen3next.cpp +++ b/src/models/qwen3next.cpp @@ -1,10 +1,9 @@ -#include "ggml.h" #include "models.h" -#define CHUNK_SIZE 64 +#include "llama-memory-recurrent.h" llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_graph_params & params) : - llm_graph_context_mamba(params), model(model) { + llm_build_delta_net_base(params), model(model) { ggml_tensor * cur; ggml_tensor * inpL; @@ -16,17 +15,6 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr ggml_tensor * inp_pos = build_inp_pos(); ggml_tensor * inp_out_ids = build_inp_out_ids(); - ggml_tensor * causal_mask = - ggml_tri(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, CHUNK_SIZE, CHUNK_SIZE), 1.0f), - GGML_TRI_TYPE_LOWER); - - ggml_tensor * identity = ggml_diag(ctx0, ggml_fill_inplace(ctx0, ggml_new_tensor_1d(ctx0, GGML_TYPE_F32, CHUNK_SIZE), 1.0f)); - ggml_tensor * diag_mask = ggml_add(ctx0, causal_mask, identity); - - ggml_build_forward_expand(gf, causal_mask); - ggml_build_forward_expand(gf, identity); - ggml_build_forward_expand(gf, diag_mask); - for (int il = 0; il < n_layer; ++il) { ggml_tensor * inpSA = inpL; @@ -36,7 +24,7 @@ llm_build_qwen3next::llm_build_qwen3next(const llama_model & model, const llm_gr // Determine layer type and build appropriate attention mechanism if (hparams.is_recurrent(il)) { // Linear attention layer (gated delta net) - cur = build_layer_attn_linear(inp->get_recr(), cur, causal_mask, identity, diag_mask, il); + cur = build_layer_attn_linear(inp->get_recr(), cur, il); } else { // Full attention layer cur = build_layer_attn(inp->get_attn(), cur, inp_pos, il); @@ -94,354 +82,6 @@ static ggml_tensor * get_slice_2d(ggml_context * ctx0, ggml_tensor * t, int64_t t->nb[1], t->nb[2], t->nb[3], t->nb[2] * c); } -std::pair llm_build_qwen3next::build_delta_net_chunking( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, - int il) { - const int64_t S_k = q->ne[0]; - const int64_t H_k = q->ne[1]; - const int64_t n_tokens = q->ne[2]; - const int64_t n_seqs = q->ne[3]; - - const int64_t S_v = v->ne[0]; - const int64_t H_v = v->ne[1]; - - GGML_ASSERT(v->ne[2] == n_tokens); - GGML_ASSERT(k->ne[2] == n_tokens); - GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); - GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); - GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); - - GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); - - GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case - - const float eps_norm = hparams.f_norm_rms_eps; - - q = ggml_l2_norm(ctx0, q, eps_norm); - k = ggml_l2_norm(ctx0, k, eps_norm); - - const float scale = 1.0f / sqrtf(S_v); - - q = ggml_scale(ctx0, q, scale); - - beta = ggml_sigmoid(ctx0, beta); - - cb(q, "q_in", il); - cb(k, "k_in", il); - cb(v, "v_in", il); - cb(beta, "beta_in", il); - cb(g, "g_in", il); - - q = ggml_cont_4d(ctx0, ggml_permute(ctx0, q, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - k = ggml_cont_4d(ctx0, ggml_permute(ctx0, k, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - v = ggml_cont_4d(ctx0, ggml_permute(ctx0, v, 0, 2, 1, 3), S_v, n_tokens, H_v, n_seqs); - g = ggml_cont_4d(ctx0, ggml_permute(ctx0, g, 2, 0, 3, 1), n_tokens, 1, H_k, n_seqs); - - beta = ggml_cont(ctx0, ggml_permute(ctx0, beta, 2, 0, 1, 3)); - state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); - - cb(q, "q_perm", il); - cb(k, "k_perm", il); - cb(v, "v_perm", il); - cb(beta, "beta_perm", il); - cb(g, "g_perm", il); - cb(state, "state_in", il); - - GGML_ASSERT(q->ne[1] == n_tokens && q->ne[0] == S_k && q->ne[2] == H_k && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[1] == n_tokens && k->ne[0] == S_k && k->ne[2] == H_k && k->ne[3] == n_seqs); - GGML_ASSERT(v->ne[1] == n_tokens && v->ne[0] == S_v && v->ne[2] == H_k && v->ne[3] == n_seqs); - GGML_ASSERT(beta->ne[1] == n_tokens && beta->ne[2] == H_k && beta->ne[0] == 1 && beta->ne[3] == n_seqs); - - // Do padding - const int64_t chunk_size = CHUNK_SIZE; - - const int64_t pad = (chunk_size - n_tokens % chunk_size) % chunk_size; - const int64_t n_chunks = (n_tokens + pad) / chunk_size; - - q = ggml_pad(ctx0, q, 0, pad, 0, 0); - k = ggml_pad(ctx0, k, 0, pad, 0, 0); - v = ggml_pad(ctx0, v, 0, pad, 0, 0); - g = ggml_pad(ctx0, g, pad, 0, 0, 0); - beta = ggml_pad(ctx0, beta, 0, pad, 0, 0); - - cb(q, "q_pad", il); - cb(k, "k_pad", il); - cb(v, "v_pad", il); - cb(beta, "beta_pad", il); - cb(g, "g_pad", il); - - ggml_tensor * v_beta = ggml_mul(ctx0, v, beta); - ggml_tensor * k_beta = ggml_mul(ctx0, k, beta); - - cb(v_beta, "v_beta", il); - cb(k_beta, "k_beta", il); - - q = ggml_reshape_4d(ctx0, q, S_k, chunk_size, n_chunks, H_k * n_seqs); - k = ggml_reshape_4d(ctx0, k, S_k, chunk_size, n_chunks, H_k * n_seqs); - k_beta = ggml_reshape_4d(ctx0, k_beta, S_k, chunk_size, n_chunks, H_k * n_seqs); - v = ggml_reshape_4d(ctx0, v, S_v, chunk_size, n_chunks, H_v * n_seqs); - v_beta = ggml_reshape_4d(ctx0, v_beta, S_v, chunk_size, n_chunks, H_v * n_seqs); - - g = ggml_reshape_4d(ctx0, g, chunk_size, 1, n_chunks, H_k * n_seqs); - beta = ggml_reshape_4d(ctx0, beta, 1, chunk_size, n_chunks, H_k * n_seqs); - - ggml_tensor * g_cumsum = ggml_cumsum(ctx0, g); - cb(g_cumsum, "g_cumsum", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * gcs_i = g_cumsum; // ggml_reshape_4d(ctx0, g_cumsum, chunk_size, 1, n_chunks, H_v * n_seqs); - ggml_tensor * gcs_j = ggml_reshape_4d(ctx0, g_cumsum, 1, chunk_size, n_chunks, H_v * n_seqs); - - ggml_tensor * gcs_j_broadcast = - ggml_repeat_4d(ctx0, gcs_j, chunk_size, chunk_size, n_chunks, H_v * n_seqs); - - ggml_tensor * decay_mask = ggml_sub(ctx0, gcs_j_broadcast, gcs_i); - cb(decay_mask, "decay_mask", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); - decay_mask = ggml_exp(ctx0, decay_mask); - decay_mask = ggml_mul(ctx0, decay_mask, diag_mask); - - ggml_tensor * kmulkbeta = ggml_mul_mat(ctx0, k, k_beta); - - ggml_tensor * k_decay = ggml_mul(ctx0, kmulkbeta, decay_mask); - ggml_tensor * attn = ggml_neg(ctx0, ggml_mul(ctx0, k_decay, causal_mask)); - cb(attn, "attn_pre_solve", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * attn_lower = ggml_mul(ctx0, attn, causal_mask); - ggml_tensor * lhs = ggml_sub(ctx0, ggml_repeat(ctx0, identity, attn_lower), attn_lower); - - ggml_tensor * lin_solve = ggml_solve_tri(ctx0, lhs, attn, true, true, false); - attn = ggml_mul(ctx0, lin_solve, causal_mask); - attn = ggml_add(ctx0, attn, identity); - cb(attn, "attn_solved", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - v = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, v_beta)), attn); - - ggml_tensor * g_cumsum_t = ggml_cont(ctx0, ggml_transpose(ctx0, g_cumsum)); - ggml_tensor * gexp = ggml_exp(ctx0, g_cumsum_t); - - ggml_tensor * kbeta_gexp = ggml_mul(ctx0, k_beta, gexp); - cb(kbeta_gexp, "kbeta_gexp", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * k_cumdecay = - ggml_cont(ctx0, ggml_transpose(ctx0, ggml_mul_mat(ctx0, attn, ggml_cont(ctx0, ggml_transpose(ctx0, kbeta_gexp))))); - cb(k_cumdecay, "k_cumdecay", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * attn_kq = ggml_mul_mat(ctx0, k, q); - attn_kq = ggml_mul(ctx0, attn_kq, decay_mask); - attn_kq = ggml_mul(ctx0, attn_kq, diag_mask); - cb(attn_kq, "attn_kq", il); // shape: (chunk_size, chunk_size, n_chunks, H_v * n_seqs) - - - // vectorized calculation of key_gdiff - // improved from the chunked version: - // g_last = torch.clamp(g_cum[:, :, -1], max=50.0).exp().unsqueeze(-1).unsqueeze(-1) - // g_diff = torch.clamp(g_cum[:, :, -1:] - g_cum, max=50.0).exp() - // key_gdiff = key * g_diff.unsqueeze(-1) - // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new - // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew - - // get last element in g_cumsum along chunk_size dimension (ne0) - // example: [[x, y, z, ..., last], ...] -> [[last], ...] - ggml_tensor * g_last = ggml_view_4d(ctx0, g_cumsum, 1, 1, g_cumsum->ne[2], g_cumsum->ne[3], - g_cumsum->nb[1], g_cumsum->nb[2], g_cumsum->nb[3], - (g_cumsum->ne[0] - 1) * ggml_element_size(g_cumsum)); - g_last = ggml_cont(ctx0, g_last); - cb(g_last, "g_last", il); // shape: (1, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * g_last_exp = ggml_exp(ctx0, g_last); - cb(g_last_exp, "g_last_exp", il); // shape: (1, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * g_diff = ggml_neg(ctx0, ggml_sub(ctx0, g_cumsum, g_last)); - cb(g_diff, "g_diff", il); // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) - - ggml_tensor * g_diff_exp = ggml_exp(ctx0, g_diff); - ggml_tensor * g_diff_exp_t = ggml_reshape_4d(ctx0, g_diff_exp, - 1, chunk_size, n_chunks, g_diff_exp->ne[3]); - - ggml_tensor * key_gdiff = ggml_mul(ctx0, k, g_diff_exp_t); - cb(key_gdiff, "key_gdiff", il); // shape: (S_k, chunk_size, n_chunks, H_v * n_seqs) - - ggml_tensor * key_gdiff_t = ggml_cont(ctx0, ggml_transpose(ctx0, key_gdiff)); - cb(key_gdiff_t, "key_gdiff_t", il); // shape: (chunk_size, S_k, n_chunks, H_v * n_seqs) - - - // state to be updated per chunk - ggml_tensor * new_state = state; // ggml_dup(ctx0, state); - cb(new_state, "new_state", il); // shape: (S_v, S_v, H_v, n_seqs) - - // shape after loop of chunks: (S_v, chunk_size, n_chunks, H_v * n_seqs) - ggml_tensor * core_attn_out = nullptr; - - for (int64_t chunk = 0; chunk < n_chunks; chunk++) { - // shape: (S_k, chunk_size, 1, H_k * n_seqs) - ggml_tensor * q_chunk = get_slice_2d(ctx0, q, chunk); // (no cont), next op: ggml_mul - - // shape: (S_v, chunk_size, 1, H_v * n_seqs) - ggml_tensor * v_chunk = get_slice_2d(ctx0, v, chunk); // (no cont), next op: ggml_repeat - - // shape: (chunk_size, 1, n_chunks, H_v * n_seqs) - ggml_tensor * gexp_chunk = get_slice_2d(ctx0, gexp, chunk); // (no cont), next op: ggml_mul - - // shape: (chunk_size, 1, H_v * n_seqs) - ggml_tensor * k_cumdecay_chunk = get_slice_2d(ctx0, k_cumdecay, chunk); // (no cont), next op: ggml_mul_mat - - // attn = (q_i @ k_i.transpose(-1, -2) * decay_mask[:, :, i]).masked_fill_(mask, 0) - // replaced by precomputed attn_kq - ggml_tensor * attn_chunk = get_slice_2d(ctx0, attn_kq, chunk); - cb(attn_chunk, "attn_chunk", il); - - ggml_tensor * state_t = ggml_cont_4d(ctx0, ggml_permute(ctx0, new_state, 1, 0, 2, 3), S_v, S_v, 1, H_v * n_seqs); - - // v_prime = (k_cumdecay[:, :, i]) @ last_recurrent_state - ggml_tensor * v_prime = ggml_mul_mat(ctx0, state_t, k_cumdecay_chunk); - cb(v_prime, "v_prime_chunk", il); // shape: (S_v, 1, H_v * n_seqs) - - // v_new = v_i - v_prime - ggml_tensor * v_new = ggml_sub(ctx0, ggml_repeat(ctx0, v_chunk, v_prime), v_prime); - ggml_tensor * v_new_t = ggml_cont(ctx0, ggml_transpose(ctx0, v_new)); - cb(v_new, "v_new_chunk", il); - - // attn_inter = (q_i * g[:, :, i, :, None].exp()) @ last_recurrent_state - ggml_tensor * q_g_exp = ggml_mul(ctx0, q_chunk, gexp_chunk); - ggml_tensor * attn_inter = ggml_mul_mat(ctx0, state_t, q_g_exp); - cb(attn_inter, "attn_inter_chunk", il); - - // core_attn_out[:, :, i] = attn_inter + attn @ v_new - ggml_tensor * v_attn = ggml_mul_mat(ctx0, v_new_t, attn_chunk); - cb(v_attn, "v_attn_chunk", il); - - ggml_tensor * core_attn_out_chunk = ggml_add(ctx0, attn_inter, v_attn); - cb(core_attn_out_chunk, "core_attn_out_chunk", il); // shape: (S_v, chunk_size, 1, H_v * n_seqs) - - core_attn_out = core_attn_out == nullptr - ? core_attn_out_chunk - : ggml_concat(ctx0, core_attn_out, core_attn_out_chunk, 2); - - // kgdmulvnew = (key_gdiff).transpose(-1, -2) @ v_new - ggml_tensor * k_gdiff_t = get_slice_2d(ctx0, key_gdiff_t, chunk); - //ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, k_gdiff, v_new); // this is slower on metal, why? - ggml_tensor * kgdmulvnew = ggml_mul_mat(ctx0, v_new_t, k_gdiff_t); - - // last_recurrent_state = last_recurrent_state * g_last + kgdmulvnew - ggml_tensor * gexp_last_chunk = ggml_cont(ctx0, get_slice_2d(ctx0, g_last_exp, chunk)); - new_state = ggml_add(ctx0, - ggml_mul(ctx0, new_state, ggml_reshape_4d(ctx0, gexp_last_chunk, gexp_last_chunk->ne[0], gexp_last_chunk->ne[1], H_v, n_seqs)), - ggml_reshape_4d(ctx0, kgdmulvnew, kgdmulvnew->ne[0], kgdmulvnew->ne[1], H_v, n_seqs)); - } - - // truncate padded tokens - ggml_tensor * output_tokens = ggml_view_4d(ctx0, core_attn_out, - S_v, n_tokens, H_v, n_seqs, - ggml_row_size(core_attn_out->type, S_v), - ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks), - ggml_row_size(core_attn_out->type, S_v * chunk_size * n_chunks * H_v), 0); - output_tokens = ggml_cont(ctx0, output_tokens); - cb(output_tokens, "output_tokens", il); - - // permute back to (S_v, H_v, n_tokens, n_seqs) - output_tokens = ggml_permute(ctx0, output_tokens, 0, 2, 1, 3); - output_tokens = ggml_cont(ctx0, output_tokens); - - return {output_tokens, new_state}; -} - -std::pair llm_build_qwen3next::build_delta_net_autoregressive( - ggml_tensor * q, - ggml_tensor * k, - ggml_tensor * v, - ggml_tensor * g, - ggml_tensor * beta, - ggml_tensor * state, - int il) { - const int64_t S_k = q->ne[0]; - const int64_t H_k = q->ne[1]; - const int64_t n_tokens = q->ne[2]; - const int64_t n_seqs = q->ne[3]; - - const int64_t S_v = v->ne[0]; - const int64_t H_v = v->ne[1]; - - GGML_ASSERT(n_tokens == 1); // This function is optimized for single token processing - GGML_ASSERT(v->ne[2] == n_tokens); - GGML_ASSERT(k->ne[2] == n_tokens); - GGML_ASSERT(g->ne[0] == H_v && g->ne[1] == n_tokens && g->ne[2] == n_seqs); - GGML_ASSERT(beta->ne[0] == H_v && beta->ne[2] == n_tokens && beta->ne[3] == n_seqs); - GGML_ASSERT(state->ne[0] == S_v && state->ne[1] == S_v * H_v && state->ne[2] == 1 && state->ne[3] == n_seqs); - - GGML_ASSERT(q->ne[0] == S_k && q->ne[1] == H_k && q->ne[2] == n_tokens && q->ne[3] == n_seqs); - GGML_ASSERT(k->ne[0] == S_k && k->ne[1] == H_k && k->ne[2] == n_tokens && k->ne[3] == n_seqs); - - GGML_ASSERT(H_k == H_v); // we did a repeat to make sure this is the case - - const float eps_norm = hparams.f_norm_rms_eps; - - q = ggml_l2_norm(ctx0, q, eps_norm); - k = ggml_l2_norm(ctx0, k, eps_norm); - - const float scale = 1.0f / sqrtf(S_v); - - q = ggml_scale(ctx0, q, scale); - beta = ggml_sigmoid(ctx0, beta); - - cb(q, "q_in", il); - cb(k, "k_in", il); - cb(v, "v_in", il); - cb(beta, "beta_in", il); - cb(g, "g_in", il); - - state = ggml_reshape_4d(ctx0, state, S_v, S_v, H_v, n_seqs); - - ggml_tensor * g_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, g), 1, 1, H_k, n_seqs); - ggml_tensor * beta_t = ggml_reshape_4d(ctx0, ggml_transpose(ctx0, beta), 1, 1, H_k, n_seqs); - - // Apply exponential to g_t - g_t = ggml_exp(ctx0, g_t); - - // Apply the gated delta rule for the single timestep - // last_recurrent_state = last_recurrent_state * g_t - state = ggml_mul(ctx0, state, g_t); - - // kv_mem = (last_recurrent_state * k_t.unsqueeze(-1)).sum(dim=-2) - ggml_tensor * k_t_unsqueezed = ggml_reshape_4d(ctx0, k, 1, S_v, H_v, n_seqs); - ggml_tensor * kv_mem = ggml_mul(ctx0, state, k_t_unsqueezed); - // we need to sum over dim=-2, so we transpose, sum, then transpose again - kv_mem = ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, kv_mem)))); - - // v_t = v.unsqueeze(2) (we insert the singleton dimension after n_seqs and H_v) - ggml_tensor * v_t = ggml_reshape_4d(ctx0, v, S_v, 1, H_v, n_seqs); - // delta = (v_t - kv_mem) * beta_t - ggml_tensor * v_diff = ggml_sub(ctx0, v_t, kv_mem); // both should be [S_v, 1, H_v, n_seqs] - ggml_tensor * delta = ggml_mul(ctx0, v_diff, beta_t); - - // last_recurrent_state = last_recurrent_state + k_t.unsqueeze(-1) * delta - ggml_tensor * k_t_delta = ggml_mul(ctx0, ggml_repeat_4d(ctx0, k_t_unsqueezed, S_v, S_v, H_v, n_seqs), delta); - state = ggml_add(ctx0, state, k_t_delta); - - // Compute the attention output - // core_attn_out = (last_recurrent_state * q_t.unsqueeze(-1)).sum(dim=-2) - ggml_tensor * q_t_unsqueezed = ggml_reshape_4d(ctx0, q, 1, S_v, H_v, n_seqs); // unsqueeze q_t - ggml_tensor * state_q = ggml_mul(ctx0, state, q_t_unsqueezed); - // again, since it's over dim = -2, transpose, sum, transpose back - ggml_tensor * core_attn_out = - ggml_transpose(ctx0, ggml_sum_rows(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, state_q)))); - - // core_attn_out should be [S_v, 1, H_v, n_seqs] after this - cb(core_attn_out, "output_tokens", il); - cb(state, "new_state", il); - - return {core_attn_out, state}; -} - ggml_tensor * llm_build_qwen3next::build_norm_gated( ggml_tensor * input, ggml_tensor * weights, @@ -472,39 +112,29 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn( // Split Q projection into query and gate // The split should be along dimension 0 (the feature dimension) ggml_tensor * Qcur = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1, - Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0); + Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], 0); + cb(Qcur, "Qcur_view", il); + ggml_tensor * gate = ggml_view_4d(ctx0, Qcur_full, n_embd_head, n_head, n_tokens, 1, Qcur_full->nb[1], Qcur_full->nb[2], Qcur_full->nb[3], n_embd_head * ggml_element_size(Qcur_full)); - cb(Qcur, "Qcur", il); cb(gate, "gate", il); - // Now reshape Qcur to [n_embd_head, n_head, n_tokens] for multi-head attention - Qcur = ggml_cont_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); - cb(Qcur, "Qcur_reshaped", il); - - // Apply Q normalization - Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); - cb(Qcur, "Qcur_normed", il); - ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); cb(Kcur, "Kcur", il); ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); cb(Vcur, "Vcur", il); - // Apply K normalization Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, nullptr, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, nullptr, LLM_NORM_RMS, il); cb(Kcur, "Kcur_normed", il); - // Reshape gate to [n_embd, n_tokens] for the sigmoid gating (flatten the heads) - gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); - cb(gate, "gate_reshaped", il); - - Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - - // Apply RoPE Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, @@ -519,7 +149,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn( cb(Kcur, "Kcur", il); cb(Vcur, "Vcur", il); - // Attention computation const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; cur = build_attn(inp, @@ -527,10 +156,15 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn( Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); cb(cur, "attn_pregate", il); - ggml_tensor * gate_sigmoid = ggml_sigmoid(ctx0, gate); - cb(gate_sigmoid, "gate_sigmoid", il); + // TODO: CUDA is missing non-contiguous unary ops. when implemented: remove this cont + gate = ggml_cont_2d(ctx0, gate, n_embd_head * n_head, n_tokens); - cur = ggml_mul(ctx0, cur, gate_sigmoid); + gate = ggml_sigmoid(ctx0, gate); + cb(gate, "gate_sigmoid", il); + + gate = ggml_reshape_2d(ctx0, gate, n_embd_head * n_head, n_tokens); + + cur = ggml_mul(ctx0, cur, gate); cb(cur, "attn_gated", il); cur = build_lora_mm(model.layers[il].wo, cur); @@ -560,7 +194,6 @@ std::pair llm_build_qwen3next::build_qkvz( cb(z, "z", il); return { qkv_mixed, z }; - } else { // legacy (slower) path ggml_tensor * mixed_qkvz = build_lora_mm(model.layers[il].ssm_in, input); @@ -624,9 +257,6 @@ std::pair llm_build_qwen3next::build_qkvz( ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( llm_graph_input_rs * inp, ggml_tensor * cur, - ggml_tensor * causal_mask, - ggml_tensor * identity, - ggml_tensor * diag_mask, int il) { const auto * mctx_cur = inp->mctx; @@ -671,7 +301,12 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( split_sizes_ba[0] * ggml_element_size(mixed_ba_reshaped)); cb(a, "a", il); - ggml_tensor * beta = ggml_cont_4d(ctx0, b, num_v_heads, 1, n_seq_tokens, n_seqs); + // TODO: CUDA is missing non-contiguous unary ops. when implemented: remove this cont + b = ggml_cont(ctx0, b); + + ggml_tensor * beta = ggml_sigmoid(ctx0, b); + + beta = ggml_reshape_4d(ctx0, beta, num_v_heads, 1, n_seq_tokens, n_seqs); // Reshape a to merge head dimensions: [batch, seq_len, num_k_heads, num_v_heads/num_k_heads] -> [batch, seq_len, num_v_heads] ggml_tensor * alpha = ggml_cont_3d(ctx0, a, num_v_heads, n_seq_tokens, n_seqs); @@ -679,6 +314,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( ggml_tensor * alpha_biased = ggml_add(ctx0, alpha, model.layers[il].ssm_dt); ggml_tensor * alpha_softplus = ggml_softplus(ctx0, alpha_biased); cb(alpha_softplus, "a_softplus", il); + ggml_tensor * gate = ggml_mul(ctx0, alpha_softplus, model.layers[il].ssm_a); // -A_log.exp() * softplus cb(gate, "gate", il); @@ -686,8 +322,6 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( ggml_tensor * conv_states_all = mctx_cur->get_r_l(il); ggml_tensor * ssm_states_all = mctx_cur->get_s_l(il); - // bool use_precomputed_states = n_seq_tokens == 1 && mctx_cur->has_previous_state(); - // Build the convolution states tensor ggml_tensor * conv_states = build_rs(inp, conv_states_all, hparams.n_embd_r(), n_seqs); cb(conv_states, "conv_states", il); @@ -696,11 +330,12 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( ggml_tensor * conv_kernel = model.layers[il].ssm_conv1d; const int64_t conv_kernel_size = conv_kernel->ne[0]; const int64_t conv_channels = d_inner + 2 * hparams.ssm_n_group * hparams.ssm_d_state; - conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); + + conv_states = ggml_reshape_3d(ctx0, conv_states, conv_kernel_size - 1, conv_channels, n_seqs); cb(conv_states, "conv_states_reshaped", il); - qkv_mixed = ggml_permute(ctx0, qkv_mixed, 1, 0, 2, 3); - cb(qkv_mixed, "qkv_mixed_permuted", il); + qkv_mixed = ggml_transpose(ctx0, qkv_mixed); + cb(qkv_mixed, "qkv_mixed_transposed", il); ggml_tensor * conv_input = ggml_concat(ctx0, conv_states, qkv_mixed, 0); cb(conv_input, "conv_input", il); @@ -720,7 +355,10 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( ggml_build_forward_expand(gf, ggml_cpy(ctx0, last_conv_states, state_update_target)); cb(conv_states_all, "conv_states_updated", il); - // Apply SSM convolution + ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); + state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim, num_v_heads, n_seqs); + cb(state, "state_predelta", il); + ggml_tensor * conv_output_proper = ggml_ssm_conv(ctx0, conv_input, conv_kernel); cb(conv_output_proper, "conv_output_raw", il); @@ -734,26 +372,36 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( int64_t nb1_qkv = ggml_row_size(conv_qkv_mix->type, qkv_dim); // Extract the convolved Q, K, V from conv_output - ggml_tensor * q_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, 0); + ggml_tensor * q_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_k_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + 0); + + ggml_tensor * k_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_k_dim, num_k_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_k_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); + + ggml_tensor * v_conv = ggml_view_4d(ctx0, conv_qkv_mix, head_v_dim, num_v_heads, n_seq_tokens, n_seqs, + ggml_row_size(conv_qkv_mix->type, head_v_dim), + nb1_qkv, + nb1_qkv * n_seq_tokens, + ggml_row_size(conv_qkv_mix->type, 2 * head_k_dim * num_k_heads)); + cb(q_conv, "q_conv", il); - ggml_tensor * k_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_k_dim * num_k_heads, n_seq_tokens * n_seqs, nb1_qkv, - head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); cb(k_conv, "k_conv", il); - ggml_tensor * v_conv = - ggml_view_2d(ctx0, conv_qkv_mix, head_v_dim * num_v_heads, n_seq_tokens * n_seqs, nb1_qkv, - 2 * head_k_dim * num_k_heads * ggml_element_size(conv_qkv_mix)); cb(v_conv, "v_conv", il); - // Unsqueeze them - q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); - k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); - v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); + const float eps_norm = hparams.f_norm_rms_eps; - ggml_tensor * state = build_rs(inp, ssm_states_all, hparams.n_embd_s(), n_seqs); - state = ggml_reshape_4d(ctx0, state, head_v_dim, head_v_dim * num_v_heads, 1, n_seqs); - cb(state, "state_predelta", il); + q_conv = ggml_l2_norm(ctx0, q_conv, eps_norm); + k_conv = ggml_l2_norm(ctx0, k_conv, eps_norm); + + //q_conv = ggml_cont_4d(ctx0, q_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + //k_conv = ggml_cont_4d(ctx0, k_conv, head_k_dim, num_k_heads, n_seq_tokens, n_seqs); + //v_conv = ggml_cont_4d(ctx0, v_conv, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // if head keys and value keys are different, repeat to force tensors into matching shapes if (num_k_heads != num_v_heads) { @@ -786,7 +434,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( if (n_seq_tokens == 1) { attn_out = build_delta_net_autoregressive(q_conv, k_conv, v_conv, gate, beta, state, il); } else { - attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, causal_mask, identity, diag_mask, il); + attn_out = build_delta_net_chunking(q_conv, k_conv, v_conv, gate, beta, state, il); } ggml_tensor * output = attn_out.first; ggml_tensor * new_state = attn_out.second; @@ -795,19 +443,15 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( // Update the recurrent states ggml_build_forward_expand(gf, - ggml_cpy(ctx0, new_state, - ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, - kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); - - // Reshape both attn_out_final and z to 2D tensors for normalization - // attn_out_final: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] - ggml_tensor * attn_out_2d_final = ggml_reshape_2d(ctx0, output, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + ggml_cpy(ctx0, new_state, + ggml_view_1d(ctx0, ssm_states_all, hparams.n_embd_s() * n_seqs, + kv_head * hparams.n_embd_s() * ggml_element_size(ssm_states_all)))); // z: [head_dim, n_heads, n_tokens, n_seqs] -> [n_heads * n_tokens * n_seqs, head_dim] - ggml_tensor * z_2d = ggml_reshape_2d(ctx0, z, head_v_dim, num_v_heads * n_seq_tokens * n_seqs); + ggml_tensor * z_2d = ggml_reshape_4d(ctx0, z, head_v_dim, num_v_heads, n_seq_tokens, n_seqs); // Apply gated normalization: self.norm(core_attn_out, z) - ggml_tensor * attn_out_norm = build_norm_gated(attn_out_2d_final, model.layers[il].ssm_norm, z_2d, il); + ggml_tensor * attn_out_norm = build_norm_gated(output, model.layers[il].ssm_norm, z_2d, il); // Final reshape: [head_dim, n_heads, n_tokens, n_seqs] -> [n_tokens, n_seqs, n_heads * head_dim] ggml_tensor * final_output = ggml_reshape_3d(ctx0, attn_out_norm, head_v_dim * num_v_heads, n_seq_tokens, n_seqs); @@ -818,7 +462,8 @@ ggml_tensor * llm_build_qwen3next::build_layer_attn_linear( cb(cur, "linear_attn_out", il); // Reshape back to original dimensions - cur = ggml_cont_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs); + cur = ggml_reshape_2d(ctx0, cur, n_embd, n_seq_tokens * n_seqs); + return cur; } @@ -839,7 +484,7 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int if (model.layers[il].ffn_up_shexp != nullptr) { ggml_tensor * ffn_shexp = build_ffn(cur, - model.layers[il].ffn_up_shexp, NULL, NULL, + model.layers[il].ffn_up_shexp, NULL, NULL, model.layers[il].ffn_gate_shexp, NULL, NULL, model.layers[il].ffn_down_shexp, NULL, NULL, NULL, @@ -852,11 +497,9 @@ ggml_tensor * llm_build_qwen3next::build_layer_ffn(ggml_tensor * cur, const int ggml_tensor * shared_gate = build_lora_mm(model.layers[il].ffn_gate_inp_shexp, cur); cb(shared_gate, "shared_expert_gate", il); - // Apply sigmoid to the gate shared_gate = ggml_sigmoid(ctx0, shared_gate); cb(shared_gate, "shared_expert_gate_sigmoid", il); - // Apply the gate to the shared expert output ffn_shexp = ggml_mul(ctx0, ffn_shexp, shared_gate); cb(ffn_shexp, "ffn_shexp_gated", il); diff --git a/src/models/rwkv6-base.cpp b/src/models/rwkv6-base.cpp index 7beed2daff..83aeab7280 100644 --- a/src/models/rwkv6-base.cpp +++ b/src/models/rwkv6-base.cpp @@ -1,5 +1,7 @@ #include "models.h" +#include "llama-memory-recurrent.h" + llm_build_rwkv6_base::llm_build_rwkv6_base(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model) {} diff --git a/src/models/rwkv7-base.cpp b/src/models/rwkv7-base.cpp index cda4465384..7fcab77745 100644 --- a/src/models/rwkv7-base.cpp +++ b/src/models/rwkv7-base.cpp @@ -1,5 +1,7 @@ #include "models.h" +#include "llama-memory-recurrent.h" + llm_build_rwkv7_base::llm_build_rwkv7_base(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model) {} diff --git a/src/unicode.cpp b/src/unicode.cpp index adfc489d1f..b88d953bd2 100644 --- a/src/unicode.cpp +++ b/src/unicode.cpp @@ -1,16 +1,10 @@ -#if defined(_MSC_VER) -#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING -#endif - #include "unicode.h" #include "unicode-data.h" #include #include -#include #include #include -#include #include #include #include @@ -199,27 +193,6 @@ static std::unordered_map unicode_utf8_to_byte_map() { return map; } -static inline std::wstring unicode_wstring_from_utf8(const std::string & s) { -#if defined(__clang__) - // disable C++17 deprecation warning for std::codecvt_utf8 -# pragma clang diagnostic push -# pragma clang diagnostic ignored "-Wdeprecated-declarations" -#elif defined(__GNUC__) -# pragma GCC diagnostic push -# pragma GCC diagnostic ignored "-Wdeprecated-declarations" -#endif - - std::wstring_convert> conv; - -#if defined(__clang__) -# pragma clang diagnostic pop -#elif defined(__GNUC__) -# pragma GCC diagnostic pop -#endif - - return conv.from_bytes(s); -} - static std::vector unicode_byte_encoding_process(const std::vector & bpe_words) { std::vector bpe_encoded_words; for (const auto & word : bpe_words) { @@ -1028,10 +1001,10 @@ std::vector unicode_regex_split(const std::string & text, const std break; } } + const auto cpts_regex = unicode_cpts_from_utf8(regex_expr); if (use_collapsed) { // sanity-check that the original regex does not contain any non-ASCII characters - const auto cpts_regex = unicode_cpts_from_utf8(regex_expr); for (size_t i = 0; i < cpts_regex.size(); ++i) { if (cpts_regex[i] >= 128) { throw std::runtime_error("Regex includes both unicode categories and non-ASCII characters - not supported"); @@ -1087,7 +1060,7 @@ std::vector unicode_regex_split(const std::string & text, const std bpe_offsets = unicode_regex_split_stl(text_collapsed, regex_expr_collapsed, bpe_offsets); } else { // no unicode category used, we can use std::wregex directly - const std::wstring wregex_expr = unicode_wstring_from_utf8(regex_expr); + std::wstring wregex_expr(cpts_regex.begin(), cpts_regex.end()); // std::wregex \s does not mach non-ASCII whitespaces, using 0x0B as fallback std::wstring wtext(cpts.begin(), cpts.end()); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 5d5e44a0c7..746648a064 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -1943,7 +1943,11 @@ struct test_unary : public test_case { ggml_tensor * a; if (v & 1) { - auto ne = ne_a; ne[0] *= 3; + auto ne = ne_a; + ne[0] *= 3; + ne[1] *= 2; + ne[2] *= 5; + ne[3] *= 4; a = ggml_new_tensor(ctx, type, 4, ne.data()); if (grad_supported) { ggml_set_param(a); @@ -2782,9 +2786,10 @@ struct test_set : public test_case { const ggml_type type_dst; const std::array ne; const int dim; + const bool inplace; std::string vars() override { - return VARS_TO_STR4(type_src, type_dst, ne, dim); + return VARS_TO_STR5(type_src, type_dst, ne, dim, inplace); } size_t op_size(ggml_tensor * t) override { @@ -2792,8 +2797,8 @@ struct test_set : public test_case { } test_set(ggml_type type_src = GGML_TYPE_F32, ggml_type type_dst = GGML_TYPE_F32, - std::array ne = {6, 5, 4, 3}, int dim = 1) - : type_src(type_src), type_dst(type_dst), ne(ne), dim(dim) {} + std::array ne = {6, 5, 4, 3}, int dim = 1, bool inplace = false) + : type_src(type_src), type_dst(type_dst), ne(ne), dim(dim), inplace(inplace) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * src = ggml_new_tensor(ctx, type_src, 4, ne.data()); @@ -2804,7 +2809,7 @@ struct test_set : public test_case { for (int i = 0; i < dim; ++i) { ne_dst[i] *= 2; } - ggml_tensor* dst = ggml_new_tensor(ctx, type_dst, 4, ne_dst.data()); + ggml_tensor * dst = ggml_new_tensor(ctx, type_dst, 4, ne_dst.data()); ggml_set_param(dst); ggml_set_name(dst, "dst"); @@ -2812,9 +2817,16 @@ struct test_set : public test_case { for (int i = 0; i < dim; ++i) { offset += ((ne_dst[i] - ne[i])/2)*dst->nb[i]; } - ggml_tensor * out = ggml_set(ctx, dst, src, - // The backward pass requires setting a contiguous region: - src->nb[1], src->nb[2], src->nb[3], offset); + ggml_tensor * out; + if (inplace) { + out = ggml_set_inplace(ctx, dst, src, + // The backward pass requires setting a contiguous region: + src->nb[1], src->nb[2], src->nb[3], offset); + } else { + out = ggml_set(ctx, dst, src, + // The backward pass requires setting a contiguous region: + src->nb[1], src->nb[2], src->nb[3], offset); + } ggml_set_name(out, "out"); return out; @@ -5809,20 +5821,27 @@ struct test_l2_norm : public test_case { const ggml_type type; const std::array ne; const float eps; + bool v; std::string vars() override { - return VARS_TO_STR2(type, ne); + return VARS_TO_STR4(type, ne, eps, v); } test_l2_norm(ggml_type type = GGML_TYPE_F32, std::array ne = {64, 64, 320, 1}, - float eps = 1e-12f) - : type(type), ne(ne), eps(eps) {} + float eps = 1e-12f, + bool v = false) + : type(type), ne(ne), eps(eps), v(v) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne.data()); ggml_set_name(a, "a"); + if (v) { + a = ggml_view_4d(ctx, a, a->ne[0]/2, a->ne[1]/2, a->ne[2]/2, a->ne[3]/2, a->nb[1], a->nb[2], a->nb[3], 0); + ggml_set_name(a, "view of a"); + } + ggml_tensor * out = ggml_l2_norm(ctx, a, eps); ggml_set_name(out, "out"); @@ -5835,26 +5854,46 @@ struct test_acc : public test_case { const ggml_type type; const std::array ne_a; const std::array ne_b; + const int64_t stride_dim; std::string vars() override { - return VARS_TO_STR3(type, ne_a, ne_b); + return VARS_TO_STR4(type, ne_a, ne_b, stride_dim); } test_acc(ggml_type type = GGML_TYPE_F32, - std::array ne_a = {256, 17, 1, 1}, - std::array ne_b = {256, 16, 1, 1}) - : type(type), ne_a(ne_a), ne_b(ne_b) {} + std::array ne_a = {256, 17, 2, 3}, + std::array ne_b = {256, 16, 2, 3}, + uint64_t stride_dim = -1) + : type(type), ne_a(ne_a), ne_b(ne_b), stride_dim(stride_dim) {} ggml_tensor * build_graph(ggml_context * ctx) override { ggml_tensor * a = ggml_new_tensor(ctx, type, 4, ne_a.data()); ggml_set_param(a); ggml_set_name(a, "a"); - ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne_b.data()); - ggml_set_param(b); + ggml_tensor * b; + if (stride_dim == 1 || stride_dim == 2 || stride_dim == 3) { + // Create a larger tensor and take a view at a non-zero offset. + // This tests that the backend correctly handles b's data offset + std::array ne_b_pad = {ne_b[0], ne_b[1], ne_b[2], ne_b[3]}; + ne_b_pad[stride_dim] += 1; + ggml_tensor * b_pad = ggml_new_tensor(ctx, type, 4, ne_b_pad.data()); + ggml_set_param(b_pad); + ggml_set_name(b_pad, "b_pad"); + // View that skips the first row, so b has a non-zero byte offset + b = ggml_view_4d(ctx, b_pad, + ne_b[0], ne_b[1], ne_b[2], ne_b[3], + b_pad->nb[1], b_pad->nb[2], b_pad->nb[3], + b_pad->nb[1]); + } else { + b = ggml_new_tensor(ctx, type, 4, ne_b.data()); + ggml_set_param(b); + } ggml_set_name(b, "b"); - ggml_tensor * out = ggml_acc(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], b->nb[1]); + // When ne_b[0] < ne_a[0], a->nb[1] != b->nb[1], so the stride + // parameters to ggml_acc don't match b's natural stride. + ggml_tensor * out = ggml_acc(ctx, a, b, a->nb[1], a->nb[2], a->nb[3], 0); ggml_set_name(out, "out"); return out; @@ -7424,11 +7463,13 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_dup(GGML_TYPE_I16, {10, 8, 3, 1}, {1, 2, 0, 3})); for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) { - test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim)); + test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim, false)); + test_cases.emplace_back(new test_set(GGML_TYPE_F32, GGML_TYPE_F32, {6, 5, 4, 3}, dim, true)); } for (int dim = 1; dim < GGML_MAX_DIMS; ++dim) { - test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim)); + test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim, false)); + test_cases.emplace_back(new test_set(GGML_TYPE_I32, GGML_TYPE_I32, {6, 5, 4, 3}, dim, true)); } // same-type copy @@ -7562,7 +7603,8 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_rms_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, v, eps)); } test_cases.emplace_back(new test_rms_norm_back(GGML_TYPE_F32, { n, 5, 4, 3 }, eps)); - test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps)); + test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, false)); + test_cases.emplace_back(new test_l2_norm(GGML_TYPE_F32, { n, 5, 4, 3 }, eps, true)); } } @@ -8128,29 +8170,40 @@ static std::vector> make_test_cases_eval() { } test_cases.emplace_back(new test_sum()); - test_cases.emplace_back(new test_sum_rows()); test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 2, 1, 3})); // row-contiguous but non-contiguous test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 3, 2, 1})); test_cases.emplace_back(new test_sum(GGML_TYPE_F32, {11, 5, 6, 3}, {0, 1, 3, 2})); + test_cases.emplace_back(new test_mean()); + test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 1, 1, 1 })); + test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 })); + test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 })); + test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32, 1, 1, 1 })); + test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32, 256, 1, 1 })); + test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32768, 1, 1, 1 })); + test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1, 1, 1 })); + test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 })); + test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 })); + test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }, { 1, 0, 2, 3 })); // sum dst not-contiguous + test_cases.emplace_back(new test_sum_rows()); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, false)); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, false, true)); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 11, 5, 6, 3 }, true, true)); - test_cases.emplace_back(new test_mean()); - test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1, 1, 1 })); + test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 16, 5, 6, 3 }, true, false)); + test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 16, 5, 6, 3 }, false, true)); + test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 16, 5, 6, 3 }, true, true)); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1, 1, 1 })); - test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 1, 1, 1 })); - test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 1024, 1, 1 })); test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 1024, 1, 1 })); - test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 })); - test_cases.emplace_back(new test_sum(GGML_TYPE_F32, { 33, 256, 1, 1 }, { 1, 0, 2, 3 })); // sum dst not-contiguous test_cases.emplace_back(new test_sum_rows(GGML_TYPE_F32, { 33, 256, 1, 1 })); - test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 33, 256, 1, 1 })); - test_cases.emplace_back(new test_mean(GGML_TYPE_F32, { 32769, 1, 1, 1 })); test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {64, 64, 320, 1})); test_cases.emplace_back(new test_group_norm(GGML_TYPE_F32, {9, 9, 1280, 1})); test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {64, 64, 320, 1})); test_cases.emplace_back(new test_group_norm_mul_add(GGML_TYPE_F32, {9, 9, 1280, 1})); - test_cases.emplace_back(new test_acc()); + test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 1, 1}, {256, 16, 1, 1}, -1)); + test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {256, 16, 2, 3}, -1)); + test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {128, 16, 2, 3}, -1)); + test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {256, 16, 2, 3}, 1)); + test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {128, 16, 2, 3}, 2)); + test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {64, 16, 2, 3}, 3)); test_cases.emplace_back(new test_pad()); test_cases.emplace_back(new test_pad(GGML_TYPE_F32, {33, 17, 2, 1}, 4, 3, true)); // circular test_cases.emplace_back(new test_pad_ext()); @@ -8248,7 +8301,7 @@ static std::vector> make_test_cases_eval() { //for (int kv : { 1, 17, 31, 33, 61, 113, 65, 127, 129, 130, 255, 260, 371, 380, 407, 512, 1024, }) { for (int kv : { 113, 512, 1024, }) { if (nr2 != 1 && kv != 512) continue; - for (int nb : { 1, 3, 32, 35, }) { + for (int nb : { 1, 3, 32, 75, }) { for (ggml_prec prec : {GGML_PREC_F32, GGML_PREC_DEFAULT}) { if (hsk != 128 && prec == GGML_PREC_DEFAULT) continue; for (ggml_type type_KV : {GGML_TYPE_F32, GGML_TYPE_F16, GGML_TYPE_BF16, GGML_TYPE_Q8_0, GGML_TYPE_Q4_0}) { @@ -8585,6 +8638,14 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 512, 1)); // prefill test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 1, 1)); // generate + // acc + test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 1, 1}, {256, 16, 1, 1}, -1)); + test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {256, 16, 2, 3}, -1)); + test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {128, 16, 2, 3}, -1)); + test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {256, 16, 2, 3}, 1)); + test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {128, 16, 2, 3}, 2)); + test_cases.emplace_back(new test_acc(GGML_TYPE_F32, {256, 17, 2, 3}, {64, 16, 2, 3}, 3)); + return test_cases; } diff --git a/tools/cli/cli.cpp b/tools/cli/cli.cpp index 02ccb72598..ad421e6326 100644 --- a/tools/cli/cli.cpp +++ b/tools/cli/cli.cpp @@ -52,6 +52,7 @@ struct cli_context { json messages = json::array(); std::vector input_files; task_params defaults; + bool verbose_prompt; // thread for showing "loading" animation std::atomic loading_show; @@ -66,6 +67,8 @@ struct cli_context { defaults.stream = true; // make sure we always use streaming mode defaults.timings_per_token = true; // in order to get timings even when we cancel mid-way // defaults.return_progress = true; // TODO: show progress + + verbose_prompt = params.verbose_prompt; } std::string generate_completion(result_timings & out_timings) { @@ -91,6 +94,12 @@ struct cli_context { rd.post_task({std::move(task)}); } + if (verbose_prompt) { + console::set_display(DISPLAY_TYPE_PROMPT); + console::log("%s\n\n", chat_params.prompt.c_str()); + console::set_display(DISPLAY_TYPE_RESET); + } + // wait for first result console::spinner::start(); server_task_result_ptr result = rd.next(should_stop); diff --git a/tools/mtmd/CMakeLists.txt b/tools/mtmd/CMakeLists.txt index 14c7acbdf9..262f4fca1f 100644 --- a/tools/mtmd/CMakeLists.txt +++ b/tools/mtmd/CMakeLists.txt @@ -20,6 +20,7 @@ add_library(mtmd models/internvl.cpp models/kimivl.cpp models/kimik25.cpp + models/nemotron-v2-vl.cpp models/llama4.cpp models/llava.cpp models/minicpmv.cpp diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 74f7092b57..b1931f42b3 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -254,6 +254,7 @@ enum projector_type { PROJECTOR_TYPE_GLM4V, PROJECTOR_TYPE_YOUTUVL, PROJECTOR_TYPE_KIMIK25, + PROJECTOR_TYPE_NEMOTRON_V2_VL, PROJECTOR_TYPE_UNKNOWN, }; @@ -289,6 +290,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_GLM4V, "glm4v"}, { PROJECTOR_TYPE_YOUTUVL, "youtuvl"}, { PROJECTOR_TYPE_KIMIK25, "kimik25"}, + { PROJECTOR_TYPE_NEMOTRON_V2_VL, "nemotron_v2_vl"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index ea1b9ce785..b17292eefb 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -15,6 +15,7 @@ enum ffn_op_type { FFN_GELU_ERF, FFN_SILU, FFN_GELU_QUICK, + FFN_RELU_SQR, }; enum norm_type { diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index e90ef35331..c552b40538 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -559,6 +559,12 @@ ggml_tensor * clip_graph::build_ffn( cur = ggml_gelu_quick(ctx0, cur); cb(cur, "ffn_gelu_quick", il); } break; + case FFN_RELU_SQR: + { + cur = ggml_relu(ctx0, cur); + cur = ggml_sqr(ctx0, cur); + cb(cur, "ffn_relu_sqr", il); + } break; } if (down) { @@ -807,6 +813,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { builder = std::make_unique(ctx, img); } break; + case PROJECTOR_TYPE_NEMOTRON_V2_VL: + { + builder = std::make_unique(ctx, img); + } break; case PROJECTOR_TYPE_LLAMA4: { builder = std::make_unique(ctx, img); @@ -1111,6 +1121,7 @@ struct clip_model_loader { } } break; case PROJECTOR_TYPE_INTERNVL: + case PROJECTOR_TYPE_NEMOTRON_V2_VL: { get_u32(KEY_PROJ_SCALE_FACTOR, hparams.n_merge, false); } break; @@ -1779,6 +1790,12 @@ struct clip_model_loader { model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight")); model.mm_3_b = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "bias")); } break; + case PROJECTOR_TYPE_NEMOTRON_V2_VL: + { + model.mm_0_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 0, "weight")); + model.mm_1_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 1, "weight")); + model.mm_3_w = get_tensor(string_format(TN_MVLM_PROJ_MLP, 3, "weight")); + } break; case PROJECTOR_TYPE_GLMA: { model.conv1d_1_w = get_tensor(string_format(TN_CONV1D, 1, "weight")); @@ -3445,6 +3462,7 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str case PROJECTOR_TYPE_GLM_EDGE: case PROJECTOR_TYPE_GEMMA3: case PROJECTOR_TYPE_INTERNVL: // TODO @ngxson : support dynamic resolution + case PROJECTOR_TYPE_NEMOTRON_V2_VL: { clip_image_u8 resized_image; int sz = params.image_size; @@ -3837,6 +3855,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im case PROJECTOR_TYPE_GEMMA3: case PROJECTOR_TYPE_IDEFICS3: case PROJECTOR_TYPE_INTERNVL: + case PROJECTOR_TYPE_NEMOTRON_V2_VL: case PROJECTOR_TYPE_LLAMA4: { // both X and Y are downscaled by the scale factor @@ -4281,6 +4300,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima case PROJECTOR_TYPE_GEMMA3NV: case PROJECTOR_TYPE_IDEFICS3: case PROJECTOR_TYPE_INTERNVL: + case PROJECTOR_TYPE_NEMOTRON_V2_VL: case PROJECTOR_TYPE_QWEN2A: case PROJECTOR_TYPE_GLMA: case PROJECTOR_TYPE_ULTRAVOX: @@ -4444,6 +4464,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { case PROJECTOR_TYPE_MUSIC_FLAMINGO: return ctx->model.mm_2_w->ne[1]; case PROJECTOR_TYPE_INTERNVL: + case PROJECTOR_TYPE_NEMOTRON_V2_VL: return ctx->model.mm_3_w->ne[1]; case PROJECTOR_TYPE_LLAMA4: return ctx->model.mm_model_proj->ne[1]; diff --git a/tools/mtmd/models/models.h b/tools/mtmd/models/models.h index 6182161cfd..98c2db1b6f 100644 --- a/tools/mtmd/models/models.h +++ b/tools/mtmd/models/models.h @@ -42,6 +42,11 @@ struct clip_graph_internvl : clip_graph { ggml_cgraph * build() override; }; +struct clip_graph_nemotron_v2_vl : clip_graph { + clip_graph_nemotron_v2_vl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; +}; + struct clip_graph_llama4 : clip_graph { clip_graph_llama4(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} ggml_cgraph * build() override; diff --git a/tools/mtmd/models/nemotron-v2-vl.cpp b/tools/mtmd/models/nemotron-v2-vl.cpp new file mode 100644 index 0000000000..03094be1b2 --- /dev/null +++ b/tools/mtmd/models/nemotron-v2-vl.cpp @@ -0,0 +1,35 @@ +#include "models.h" + +ggml_cgraph * clip_graph_nemotron_v2_vl::build() { + GGML_ASSERT(model.class_embedding != nullptr); + GGML_ASSERT(model.position_embeddings != nullptr); + + const int n_registers = model.class_embedding->ne[1]; + const int n_pos = n_patches + n_registers; + + ggml_tensor * inp = build_inp(); + + // add position embeddings (pre-downsampled during GGUF conversion for fixed 512x512 input) + inp = ggml_add(ctx0, inp, model.position_embeddings); + cb(inp, "inp_pos", -1); + + inp = ggml_concat(ctx0, model.class_embedding, inp, 1); + + ggml_tensor * cur = build_vit(inp, n_pos, NORM_TYPE_NORMAL, hparams.ffn_op, nullptr, nullptr); + + cur = ggml_view_2d(ctx0, cur, + n_embd, n_patches, + ggml_row_size(cur->type, n_embd), + n_registers * ggml_row_size(cur->type, n_embd)); + + cur = build_patch_merge_permute(cur, model.hparams.n_merge); + + { + cur = build_norm(cur, model.mm_0_w, nullptr, NORM_TYPE_RMS, 1e-6, -1); + cur = build_ffn(cur, model.mm_1_w, nullptr, nullptr, nullptr, model.mm_3_w, nullptr, FFN_RELU_SQR, -1); + } + + ggml_build_forward_expand(gf, cur); + + return gf; +} diff --git a/tools/rpc/rpc-server.cpp b/tools/rpc/rpc-server.cpp index 521f79622d..6feb0e91f3 100644 --- a/tools/rpc/rpc-server.cpp +++ b/tools/rpc/rpc-server.cpp @@ -132,7 +132,8 @@ static std::string fs_get_cache_directory() { if (getenv("LLAMA_CACHE")) { cache_directory = std::getenv("LLAMA_CACHE"); } else { -#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || defined(__OpenBSD__) +#if defined(__linux__) || defined(__FreeBSD__) || defined(_AIX) || \ + defined(__OpenBSD__) || defined(__NetBSD__) if (std::getenv("XDG_CACHE_HOME")) { cache_directory = std::getenv("XDG_CACHE_HOME"); } else if (std::getenv("HOME")) { diff --git a/tools/server/CMakeLists.txt b/tools/server/CMakeLists.txt index a39b4c5b35..8c8ec18831 100644 --- a/tools/server/CMakeLists.txt +++ b/tools/server/CMakeLists.txt @@ -28,10 +28,6 @@ target_link_libraries(${TARGET} PUBLIC common mtmd ${CMAKE_THREAD_LIBS_INIT}) set(TARGET llama-server) -if (NOT LLAMA_HTTPLIB) - message(FATAL_ERROR "LLAMA_HTTPLIB is OFF, cannot build llama-server. Hint: to skip building server, set -DLLAMA_BUILD_SERVER=OFF") -endif() - set(TARGET_SRCS server.cpp server-http.cpp diff --git a/tools/server/README.md b/tools/server/README.md index d132830171..0b56ca1e27 100644 --- a/tools/server/README.md +++ b/tools/server/README.md @@ -19,7 +19,7 @@ Set of LLM REST APIs and a web UI to interact with llama.cpp. * Speculative decoding * Easy-to-use web UI -For the ful list of features, please refer to [server's changelog](https://github.com/ggml-org/llama.cpp/issues/9291) +For the full list of features, please refer to [server's changelog](https://github.com/ggml-org/llama.cpp/issues/9291) ## Usage diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index e3b06f4901..75fc856f54 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/webui/.storybook/main.ts b/tools/server/webui/.storybook/main.ts index bfd16fa224..4f6945f210 100644 --- a/tools/server/webui/.storybook/main.ts +++ b/tools/server/webui/.storybook/main.ts @@ -1,17 +1,24 @@ import type { StorybookConfig } from '@storybook/sveltekit'; +import { dirname, resolve } from 'path'; +import { fileURLToPath } from 'url'; + +const __dirname = dirname(fileURLToPath(import.meta.url)); const config: StorybookConfig = { stories: ['../tests/stories/**/*.mdx', '../tests/stories/**/*.stories.@(js|ts|svelte)'], addons: [ '@storybook/addon-svelte-csf', '@chromatic-com/storybook', - '@storybook/addon-docs', + '@storybook/addon-vitest', '@storybook/addon-a11y', - '@storybook/addon-vitest' + '@storybook/addon-docs' ], - framework: { - name: '@storybook/sveltekit', - options: {} + framework: '@storybook/sveltekit', + viteFinal: async (config) => { + config.server = config.server || {}; + config.server.fs = config.server.fs || {}; + config.server.fs.allow = [...(config.server.fs.allow || []), resolve(__dirname, '../tests')]; + return config; } }; export default config; diff --git a/tools/server/webui/.storybook/preview.ts b/tools/server/webui/.storybook/preview.ts index 8d530e43e3..566dbfd289 100644 --- a/tools/server/webui/.storybook/preview.ts +++ b/tools/server/webui/.storybook/preview.ts @@ -13,7 +13,7 @@ const preview: Preview = { }, backgrounds: { - disable: true + disabled: true }, a11y: { diff --git a/tools/server/webui/docs/flows/settings-flow.md b/tools/server/webui/docs/flows/settings-flow.md index 578e01e6e1..40ad3bd94d 100644 --- a/tools/server/webui/docs/flows/settings-flow.md +++ b/tools/server/webui/docs/flows/settings-flow.md @@ -49,14 +49,20 @@ sequenceDiagram settingsStore->>serverStore: defaultParams serverStore-->>settingsStore: {temperature, top_p, top_k, ...} - settingsStore->>ParamSvc: extractServerDefaults(defaultParams) - ParamSvc-->>settingsStore: Record + loop each SYNCABLE_PARAMETER + alt key NOT in userOverrides + settingsStore->>settingsStore: config[key] = serverDefault[key] + Note right of settingsStore: Non-overridden params adopt server default + else key in userOverrides + Note right of settingsStore: Keep user value, skip server default + end + end - settingsStore->>ParamSvc: mergeWithServerDefaults(config, serverDefaults) - Note right of ParamSvc: For each syncable parameter:
- If NOT in userOverrides → use server default
- If in userOverrides → keep user value - ParamSvc-->>settingsStore: mergedConfig + alt serverStore.props has webuiSettings + settingsStore->>settingsStore: Apply webuiSettings from server + Note right of settingsStore: Server-provided UI settings
(e.g. showRawOutputSwitch) + end - settingsStore->>settingsStore: config = mergedConfig settingsStore->>settingsStore: saveConfig() deactivate settingsStore @@ -67,11 +73,18 @@ sequenceDiagram UI->>settingsStore: updateConfig(key, value) activate settingsStore settingsStore->>settingsStore: config[key] = value - settingsStore->>settingsStore: userOverrides.add(key) - Note right of settingsStore: Mark as user-modified (won't be overwritten by server) + + alt value matches server default for key + settingsStore->>settingsStore: userOverrides.delete(key) + Note right of settingsStore: Matches server default, remove override + else value differs from server default + settingsStore->>settingsStore: userOverrides.add(key) + Note right of settingsStore: Mark as user-modified (won't be overwritten) + end + settingsStore->>settingsStore: saveConfig() - settingsStore->>LS: set("llama-config", config) - settingsStore->>LS: set("llama-userOverrides", [...userOverrides]) + settingsStore->>LS: set(CONFIG_LOCALSTORAGE_KEY, config) + settingsStore->>LS: set(USER_OVERRIDES_LOCALSTORAGE_KEY, [...userOverrides]) deactivate settingsStore UI->>settingsStore: updateMultipleConfig({key1: val1, key2: val2}) @@ -88,10 +101,9 @@ sequenceDiagram UI->>settingsStore: resetConfig() activate settingsStore - settingsStore->>settingsStore: config = SETTING_CONFIG_DEFAULT + settingsStore->>settingsStore: config = {...SETTING_CONFIG_DEFAULT} settingsStore->>settingsStore: userOverrides.clear() - settingsStore->>settingsStore: syncWithServerDefaults() - Note right of settingsStore: Apply server defaults for syncable params + Note right of settingsStore: All params reset to defaults
Next syncWithServerDefaults will adopt server values settingsStore->>settingsStore: saveConfig() deactivate settingsStore @@ -139,6 +151,6 @@ sequenceDiagram Note over settingsStore: UI-only (not synced): rect rgb(255, 240, 240) - Note over settingsStore: systemMessage, custom (JSON)
showStatistics, enableContinueGeneration
autoMicOnEmpty, disableAutoScroll
apiKey, pdfAsImage, disableReasoningFormat + Note over settingsStore: systemMessage, custom (JSON)
showStatistics, enableContinueGeneration
autoMicOnEmpty, disableAutoScroll
apiKey, pdfAsImage, disableReasoningParsing, showRawOutputSwitch end ``` diff --git a/tools/server/webui/src/app.css b/tools/server/webui/src/app.css index 9705040a4d..3ab21f0cc7 100644 --- a/tools/server/webui/src/app.css +++ b/tools/server/webui/src/app.css @@ -14,11 +14,11 @@ --popover-foreground: oklch(0.145 0 0); --primary: oklch(0.205 0 0); --primary-foreground: oklch(0.985 0 0); - --secondary: oklch(0.97 0 0); + --secondary: oklch(0.95 0 0); --secondary-foreground: oklch(0.205 0 0); --muted: oklch(0.97 0 0); --muted-foreground: oklch(0.556 0 0); - --accent: oklch(0.97 0 0); + --accent: oklch(0.95 0 0); --accent-foreground: oklch(0.205 0 0); --destructive: oklch(0.577 0.245 27.325); --border: oklch(0.875 0 0); @@ -37,7 +37,7 @@ --sidebar-accent-foreground: oklch(0.205 0 0); --sidebar-border: oklch(0.922 0 0); --sidebar-ring: oklch(0.708 0 0); - --code-background: oklch(0.975 0 0); + --code-background: oklch(0.985 0 0); --code-foreground: oklch(0.145 0 0); --layer-popover: 1000000; } @@ -51,7 +51,7 @@ --popover-foreground: oklch(0.985 0 0); --primary: oklch(0.922 0 0); --primary-foreground: oklch(0.205 0 0); - --secondary: oklch(0.269 0 0); + --secondary: oklch(0.29 0 0); --secondary-foreground: oklch(0.985 0 0); --muted: oklch(0.269 0 0); --muted-foreground: oklch(0.708 0 0); @@ -116,12 +116,62 @@ --color-sidebar-ring: var(--sidebar-ring); } +:root { + --chat-form-area-height: 8rem; + --chat-form-area-offset: 2rem; + --max-message-height: max(24rem, min(80dvh, calc(100dvh - var(--chat-form-area-height) - 12rem))); +} + +@media (min-width: 640px) { + :root { + --chat-form-area-height: 24rem; + --chat-form-area-offset: 12rem; + } +} + @layer base { * { @apply border-border outline-ring/50; } + body { @apply bg-background text-foreground; + scrollbar-width: thin; + scrollbar-gutter: stable; + } + + /* Global scrollbar styling - visible only on hover */ + * { + scrollbar-width: thin; + scrollbar-color: transparent transparent; + transition: scrollbar-color 0.2s ease; + } + + *:hover { + scrollbar-color: hsl(var(--muted-foreground) / 0.3) transparent; + } + + *::-webkit-scrollbar { + width: 6px; + height: 6px; + } + + *::-webkit-scrollbar-track { + background: transparent; + } + + *::-webkit-scrollbar-thumb { + background: transparent; + border-radius: 3px; + transition: background 0.2s ease; + } + + *:hover::-webkit-scrollbar-thumb { + background: hsl(var(--muted-foreground) / 0.3); + } + + *::-webkit-scrollbar-thumb:hover { + background: hsl(var(--muted-foreground) / 0.5); } } diff --git a/tools/server/webui/src/lib/components/app/misc/ActionButton.svelte b/tools/server/webui/src/lib/components/app/actions/ActionIcon.svelte similarity index 99% rename from tools/server/webui/src/lib/components/app/misc/ActionButton.svelte rename to tools/server/webui/src/lib/components/app/actions/ActionIcon.svelte index 411a8b6094..4494ea880b 100644 --- a/tools/server/webui/src/lib/components/app/misc/ActionButton.svelte +++ b/tools/server/webui/src/lib/components/app/actions/ActionIcon.svelte @@ -37,6 +37,7 @@ aria-label={ariaLabel || tooltip} > {@const IconComponent = icon} + diff --git a/tools/server/webui/src/lib/components/app/misc/CopyToClipboardIcon.svelte b/tools/server/webui/src/lib/components/app/actions/ActionIconCopyToClipboard.svelte similarity index 100% rename from tools/server/webui/src/lib/components/app/misc/CopyToClipboardIcon.svelte rename to tools/server/webui/src/lib/components/app/actions/ActionIconCopyToClipboard.svelte diff --git a/tools/server/webui/src/lib/components/app/misc/RemoveButton.svelte b/tools/server/webui/src/lib/components/app/actions/ActionIconRemove.svelte similarity index 94% rename from tools/server/webui/src/lib/components/app/misc/RemoveButton.svelte rename to tools/server/webui/src/lib/components/app/actions/ActionIconRemove.svelte index 173685510f..1ae3d21774 100644 --- a/tools/server/webui/src/lib/components/app/misc/RemoveButton.svelte +++ b/tools/server/webui/src/lib/components/app/actions/ActionIconRemove.svelte @@ -16,7 +16,7 @@ variant="ghost" size="sm" class="h-6 w-6 bg-white/20 p-0 hover:bg-white/30 {className}" - onclick={(e) => { + onclick={(e: MouseEvent) => { e.stopPropagation(); onRemove?.(id); }} diff --git a/tools/server/webui/src/lib/components/app/actions/ActionIconsCodeBlock.svelte b/tools/server/webui/src/lib/components/app/actions/ActionIconsCodeBlock.svelte new file mode 100644 index 0000000000..b20e79b5e0 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/actions/ActionIconsCodeBlock.svelte @@ -0,0 +1,46 @@ + + +
+
+ +
+ + {#if showPreview} + + {/if} +
diff --git a/tools/server/webui/src/lib/components/app/actions/index.ts b/tools/server/webui/src/lib/components/app/actions/index.ts new file mode 100644 index 0000000000..43485c7b7e --- /dev/null +++ b/tools/server/webui/src/lib/components/app/actions/index.ts @@ -0,0 +1,19 @@ +/** + * + * ACTIONS + * + * Small interactive components for user actions. + * + */ + +/** Styled icon button for action triggers with tooltip. */ +export { default as ActionIcon } from './ActionIcon.svelte'; + +/** Code block actions component (copy, preview). */ +export { default as ActionIconsCodeBlock } from './ActionIconsCodeBlock.svelte'; + +/** Copy-to-clipboard icon button with click handler. */ +export { default as ActionIconCopyToClipboard } from './ActionIconCopyToClipboard.svelte'; + +/** Remove/delete icon button with X icon. */ +export { default as ActionIconRemove } from './ActionIconRemove.svelte'; diff --git a/tools/server/webui/src/lib/components/app/misc/BadgeChatStatistic.svelte b/tools/server/webui/src/lib/components/app/badges/BadgeChatStatistic.svelte similarity index 100% rename from tools/server/webui/src/lib/components/app/misc/BadgeChatStatistic.svelte rename to tools/server/webui/src/lib/components/app/badges/BadgeChatStatistic.svelte diff --git a/tools/server/webui/src/lib/components/app/misc/BadgeInfo.svelte b/tools/server/webui/src/lib/components/app/badges/BadgeInfo.svelte similarity index 100% rename from tools/server/webui/src/lib/components/app/misc/BadgeInfo.svelte rename to tools/server/webui/src/lib/components/app/badges/BadgeInfo.svelte diff --git a/tools/server/webui/src/lib/components/app/misc/BadgeModality.svelte b/tools/server/webui/src/lib/components/app/badges/BadgeModality.svelte similarity index 100% rename from tools/server/webui/src/lib/components/app/misc/BadgeModality.svelte rename to tools/server/webui/src/lib/components/app/badges/BadgeModality.svelte diff --git a/tools/server/webui/src/lib/components/app/badges/index.ts b/tools/server/webui/src/lib/components/app/badges/index.ts new file mode 100644 index 0000000000..860afe3084 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/badges/index.ts @@ -0,0 +1,16 @@ +/** + * + * BADGES & INDICATORS + * + * Small visual indicators for status and metadata. + * + */ + +/** Badge displaying chat statistics (tokens, timing). */ +export { default as BadgeChatStatistic } from './BadgeChatStatistic.svelte'; + +/** Generic info badge with optional tooltip and click handler. */ +export { default as BadgeInfo } from './BadgeInfo.svelte'; + +/** Badge indicating model modality (vision, audio, tools). */ +export { default as BadgeModality } from './BadgeModality.svelte'; diff --git a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatForm.svelte b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatForm.svelte index 27ab975cbd..e335f6c546 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatForm.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatForm.svelte @@ -27,11 +27,13 @@ interface Props { class?: string; disabled?: boolean; + initialMessage?: string; isLoading?: boolean; onFileRemove?: (fileId: string) => void; onFileUpload?: (files: File[]) => void; onSend?: (message: string, files?: ChatUploadedFile[]) => Promise; onStop?: () => void; + onSystemPromptAdd?: (draft: { message: string; files: ChatUploadedFile[] }) => void; showHelperText?: boolean; uploadedFiles?: ChatUploadedFile[]; } @@ -39,11 +41,13 @@ let { class: className, disabled = false, + initialMessage = '', isLoading = false, onFileRemove, onFileUpload, onSend, onStop, + onSystemPromptAdd, showHelperText = true, uploadedFiles = $bindable([]) }: Props = $props(); @@ -53,15 +57,28 @@ let currentConfig = $derived(config()); let fileInputRef: ChatFormFileInputInvisible | undefined = $state(undefined); let isRecording = $state(false); - let message = $state(''); + let message = $derived(initialMessage); let pasteLongTextToFileLength = $derived.by(() => { const n = Number(currentConfig.pasteLongTextToFileLen); return Number.isNaN(n) ? Number(SETTING_CONFIG_DEFAULT.pasteLongTextToFileLen) : n; }); - let previousIsLoading = $state(isLoading); + let previousIsLoading = $derived(isLoading); + let previousInitialMessage = $derived(initialMessage); let recordingSupported = $state(false); let textareaRef: ChatFormTextarea | undefined = $state(undefined); + // Sync message when initialMessage prop changes (e.g., after draft restoration) + $effect(() => { + if (initialMessage !== previousInitialMessage) { + message = initialMessage; + previousInitialMessage = initialMessage; + } + }); + + function handleSystemPromptClick() { + onSystemPromptAdd?.({ message, files: uploadedFiles }); + } + // Check if model is selected (in ROUTER mode) let conversationModel = $derived( chatStore.getConversationModel(activeMessages() as DatabaseMessage[]) @@ -272,7 +289,7 @@
0 || uploadedFiles.length > 0} hasText={message.trim().length > 0} @@ -308,6 +327,7 @@ onFileUpload={handleFileUpload} onMicClick={handleMicClick} onStop={handleStop} + onSystemPromptClick={handleSystemPromptClick} />
diff --git a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionAttachmentsDropdown.svelte b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionAttachmentsDropdown.svelte new file mode 100644 index 0000000000..f8c1b23b06 --- /dev/null +++ b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionAttachmentsDropdown.svelte @@ -0,0 +1,189 @@ + + +
+ + + + + + + + +

{triggerTooltipText}

+
+
+
+ + + {#each actions as item (item.id)} + {@const hasDisabledTooltip = !!item.disabled && !!item.disabledReason} + {@const hasEnabledTooltip = !item.disabled && !!item.tooltip} + + {#if hasDisabledTooltip} + + + + {#if item.id === 'images'} + + {:else if item.id === 'audio'} + + {:else if item.id === 'text'} + + {:else if item.id === 'pdf'} + + {:else} + + {/if} + + {item.label} + + + + +

{item.disabledReason}

+
+
+ {:else if hasEnabledTooltip} + + + handleActionClick(item.id)}> + {#if item.id === 'images'} + + {:else if item.id === 'audio'} + + {:else if item.id === 'text'} + + {:else if item.id === 'pdf'} + + {:else} + + {/if} + + {item.label} + + + + +

{item.tooltip}

+
+
+ {:else} + handleActionClick(item.id)}> + {#if item.id === 'images'} + + {:else if item.id === 'audio'} + + {:else if item.id === 'text'} + + {:else if item.id === 'pdf'} + + {:else} + + {/if} + + {item.label} + + {/if} + {/each} +
+
+
diff --git a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionFileAttachments.svelte b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionFileAttachments.svelte index dd37268096..3545b4aebf 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionFileAttachments.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatForm/ChatFormActions/ChatFormActionFileAttachments.svelte @@ -1,5 +1,6 @@
- +
+ +
- +
+ +
{#if isLoading} {:else if shouldShowRecordButton} diff --git a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessage.svelte b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessage.svelte index 220276fc9e..25895c83b7 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessage.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessage.svelte @@ -1,6 +1,16 @@ -
+
@@ -81,6 +88,16 @@
+ + {#if showRawOutputSwitch} +
+ Show raw output + onRawOutputToggle?.(checked)} + /> +
+ {/if}
{ @@ -102,7 +105,7 @@ const { handleModelChange } = useModelChangeValidation({ getRequiredModalities: () => conversationsStore.getModalitiesUpToMessage(message.id), - onSuccess: (modelName) => onRegenerate(modelName) + onSuccess: (modelName: string) => onRegenerate(modelName) }); function handleCopyModel() { @@ -238,7 +241,7 @@
{:else if message.role === 'assistant'} - {#if config().disableReasoningFormat} + {#if showRawOutput}
{messageContent || ''}
{:else} @@ -352,6 +355,9 @@ {onConfirmDelete} {onNavigateToSibling} {onShowDeleteDialogChange} + showRawOutputSwitch={currentConfig.showRawOutputSwitch} + rawOutputEnabled={showRawOutput} + onRawOutputToggle={(enabled) => (showRawOutput = enabled)} /> {/if} diff --git a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageEditForm.svelte b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageEditForm.svelte index f812ea2fd9..c216ea690b 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageEditForm.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageEditForm.svelte @@ -133,7 +133,7 @@ const { handleModelChange } = useModelChangeValidation({ getRequiredModalities, - onValidationFailure: async (previousModelId) => { + onValidationFailure: async (previousModelId: string | null) => { if (previousModelId) { await modelsStore.selectModelById(previousModelId); } diff --git a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageStatistics.svelte b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageStatistics.svelte index 24fe5926ba..b53e82aaf9 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageStatistics.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatMessages/ChatMessageStatistics.svelte @@ -3,6 +3,7 @@ import { BadgeChatStatistic } from '$lib/components/app'; import * as Tooltip from '$lib/components/ui/tooltip'; import { ChatMessageStatsView } from '$lib/enums'; + import { formatPerformanceTime } from '$lib/utils/formatters'; interface Props { predictedTokens?: number; @@ -27,7 +28,7 @@ initialView = ChatMessageStatsView.GENERATION }: Props = $props(); - let activeView: ChatMessageStatsView = $state(initialView); + let activeView: ChatMessageStatsView = $derived(initialView); let hasAutoSwitchedToGeneration = $state(false); // In live mode: auto-switch to GENERATION tab when prompt processing completes @@ -57,8 +58,8 @@ ); let tokensPerSecond = $derived(hasGenerationStats ? (predictedTokens! / predictedMs!) * 1000 : 0); - let timeInSeconds = $derived( - predictedMs !== undefined ? (predictedMs / 1000).toFixed(2) : '0.00' + let formattedTime = $derived( + predictedMs !== undefined ? formatPerformanceTime(predictedMs) : '0s' ); let promptTokensPerSecond = $derived( @@ -67,15 +68,15 @@ : undefined ); - let promptTimeInSeconds = $derived( - promptMs !== undefined ? (promptMs / 1000).toFixed(2) : undefined + let formattedPromptTime = $derived( + promptMs !== undefined ? formatPerformanceTime(promptMs) : undefined ); let hasPromptStats = $derived( promptTokens !== undefined && promptMs !== undefined && promptTokensPerSecond !== undefined && - promptTimeInSeconds !== undefined + formattedPromptTime !== undefined ); // In live mode, generation tab is disabled until we have generation stats @@ -142,7 +143,7 @@ - Send + Save diff --git a/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte b/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte index 27439551a1..3d432e26bc 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatScreen/ChatScreen.svelte @@ -21,6 +21,7 @@ chatStore, errorDialog, isLoading, + isChatStreaming, isEditing, getAddFilesHandler } from '$lib/stores/chat.svelte'; @@ -34,6 +35,7 @@ import { modelsStore, modelOptions, selectedModelId } from '$lib/stores/models.svelte'; import { isFileTypeSupported, filterFilesByModalities } from '$lib/utils'; import { parseFilesToMessageExtras, processFilesToChatUploaded } from '$lib/utils/browser-only'; + import { ErrorDialogType } from '$lib/enums'; import { onMount } from 'svelte'; import { fade, fly, slide } from 'svelte/transition'; import { Trash2, AlertTriangle, RefreshCw } from '@lucide/svelte'; @@ -71,6 +73,8 @@ let emptyFileNames = $state([]); + let initialMessage = $state(''); + let isEmpty = $derived( showCenteredEmpty && !activeConversation() && activeMessages().length === 0 && !isLoading() ); @@ -79,7 +83,7 @@ let isServerLoading = $derived(serverLoading()); let hasPropsError = $derived(!!serverError()); - let isCurrentConversationLoading = $derived(isLoading()); + let isCurrentConversationLoading = $derived(isLoading() || isChatStreaming()); let isRouter = $derived(isRouterMode()); @@ -221,6 +225,14 @@ } } + async function handleSystemPromptAdd(draft: { message: string; files: ChatUploadedFile[] }) { + if (draft.message || draft.files.length > 0) { + chatStore.savePendingDraft(draft.message, draft.files); + } + + await chatStore.addSystemPrompt(); + } + function handleScroll() { if (disableAutoScroll || !chatScrollContainer) return; @@ -343,6 +355,12 @@ if (!disableAutoScroll) { setTimeout(() => scrollChatToBottom('instant'), INITIAL_SCROLL_DELAY); } + + const pendingDraft = chatStore.consumePendingDraft(); + if (pendingDraft) { + initialMessage = pendingDraft.message; + uploadedFiles = pendingDraft.files; + } }); $effect(() => { @@ -428,11 +446,13 @@
chatStore.stopGeneration()} + onSystemPromptAdd={handleSystemPromptAdd} showHelperText={false} bind:uploadedFiles /> @@ -486,11 +506,13 @@
chatStore.stopGeneration()} + onSystemPromptAdd={handleSystemPromptAdd} showHelperText={true} bind:uploadedFiles /> @@ -595,7 +617,7 @@ contextInfo={activeErrorDialog?.contextInfo} onOpenChange={handleErrorDialogOpenChange} open={Boolean(activeErrorDialog)} - type={activeErrorDialog?.type ?? 'server'} + type={(activeErrorDialog?.type as ErrorDialogType) ?? ErrorDialogType.SERVER} /> diff --git a/tools/server/webui/src/lib/components/app/misc/SyntaxHighlightedCode.svelte b/tools/server/webui/src/lib/components/app/content/SyntaxHighlightedCode.svelte similarity index 90% rename from tools/server/webui/src/lib/components/app/misc/SyntaxHighlightedCode.svelte rename to tools/server/webui/src/lib/components/app/content/SyntaxHighlightedCode.svelte index bc42f9dd1e..625fdc7b1b 100644 --- a/tools/server/webui/src/lib/components/app/misc/SyntaxHighlightedCode.svelte +++ b/tools/server/webui/src/lib/components/app/content/SyntaxHighlightedCode.svelte @@ -71,13 +71,11 @@
-
{@html highlightedHtml}
+
{@html highlightedHtml}