diff --git a/.devops/vulkan.Dockerfile b/.devops/vulkan.Dockerfile index b37b4f277d..89831ed5c2 100644 --- a/.devops/vulkan.Dockerfile +++ b/.devops/vulkan.Dockerfile @@ -33,6 +33,7 @@ FROM ubuntu:$UBUNTU_VERSION AS base RUN apt-get update \ && apt-get install -y libgomp1 curl libvulkan1 mesa-vulkan-drivers \ + libglvnd0 libgl1 libglx0 libegl1 libgles2 \ && apt autoremove -y \ && apt clean -y \ && rm -rf /tmp/* /var/tmp/* \ diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index de3ad06065..85601b3712 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -1098,6 +1098,7 @@ jobs: save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} - name: Build with CMake + # TODO: Remove GGML_CUDA_CUB_3DOT2 flag once CCCL 3.2 is bundled within CTK and that CTK version is used in this project run: | cmake -S . -B build -G Ninja \ -DLLAMA_CURL=OFF \ @@ -1107,7 +1108,8 @@ jobs: -DCMAKE_CUDA_ARCHITECTURES=89-real \ -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined \ -DGGML_NATIVE=OFF \ - -DGGML_CUDA=ON + -DGGML_CUDA=ON \ + -DGGML_CUDA_CUB_3DOT2=ON cmake --build build windows-2022-cmake-cuda: @@ -1143,6 +1145,7 @@ jobs: - name: Build id: cmake_build shell: cmd + # TODO: Remove GGML_CUDA_CUB_3DOT2 flag once CCCL 3.2 is bundled within CTK and that CTK version is used in this project run: | call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" x64 cmake -S . -B build -G "Ninja Multi-Config" ^ @@ -1153,7 +1156,8 @@ jobs: -DGGML_BACKEND_DL=ON ^ -DGGML_CPU_ALL_VARIANTS=ON ^ -DGGML_CUDA=ON ^ - -DGGML_RPC=ON + -DGGML_RPC=ON ^ + -DGGML_CUDA_CUB_3DOT2=ON set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1 cmake --build build --config Release -j %NINJA_JOBS% -t ggml cmake --build build --config Release @@ -1414,7 +1418,6 @@ jobs: echo "FIXME: test on devices" openEuler-latest-cmake-cann: - if: ${{ github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'Ascend NPU') }} defaults: run: shell: bash -el {0} @@ -1750,7 +1753,7 @@ jobs: sudo apt-get update # Install necessary packages - sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential libssl-dev wget ccache + sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential libssl-dev wget ccache git-lfs # Set gcc-14 and g++-14 as the default compilers sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100 @@ -1762,6 +1765,8 @@ jobs: rustup install stable rustup default stable + git lfs install + - name: Clone id: checkout uses: actions/checkout@v4 @@ -1847,7 +1852,7 @@ jobs: sudo apt-get update # Install necessary packages - sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential wget ccache + sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential wget ccache git-lfs # Set gcc-14 and g++-14 as the default compilers sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100 @@ -1859,6 +1864,8 @@ jobs: rustup install stable rustup default stable + git lfs install + - name: GCC version check run: | gcc --version @@ -1939,7 +1946,7 @@ jobs: sudo apt-get update # Install necessary packages - sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential wget ccache + sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential wget ccache git-lfs # Set gcc-14 and g++-14 as the default compilers sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100 @@ -1951,6 +1958,8 @@ jobs: rustup install stable rustup default stable + git lfs install + - name: GCC version check run: | gcc --version @@ -2011,7 +2020,7 @@ jobs: sudo apt-get update # Install necessary packages - sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential libssl-dev wget ccache + sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential libssl-dev wget ccache git-lfs # Set gcc-14 and g++-14 as the default compilers sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100 @@ -2023,6 +2032,8 @@ jobs: rustup install stable rustup default stable + git lfs install + - name: GCC version check run: | gcc --version diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 4cc2f4665c..bf5ebb7559 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -420,6 +420,7 @@ jobs: - name: Build id: cmake_build shell: cmd + # TODO: Remove GGML_CUDA_CUB_3DOT2 flag once CCCL 3.2 is bundled within CTK and that CTK version is used in this project run: | call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" x64 cmake -S . -B build -G "Ninja Multi-Config" ^ @@ -427,7 +428,8 @@ jobs: -DGGML_NATIVE=OFF ^ -DGGML_CPU=OFF ^ -DGGML_CUDA=ON ^ - -DLLAMA_CURL=OFF + -DLLAMA_CURL=OFF ^ + -DGGML_CUDA_CUB_3DOT2=ON set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1 cmake --build build --config Release -j %NINJA_JOBS% --target ggml-cuda diff --git a/.github/workflows/server.yml b/.github/workflows/server.yml index f9e2a79af7..5694feb2c9 100644 --- a/.github/workflows/server.yml +++ b/.github/workflows/server.yml @@ -41,6 +41,10 @@ jobs: include: - build_type: Release sanitizer: "" + extra_args: "" + - build_type: Release + sanitizer: "" + extra_args: "LLAMA_ARG_BACKEND_SAMPLING=1" fail-fast: false # While -DLLAMA_SANITIZE_THREAD=ON is broken steps: @@ -65,6 +69,12 @@ jobs: fetch-depth: 0 ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }} + - name: Build + id: cmake_build + run: | + cmake -B build -DLLAMA_CURL=OFF -DLLAMA_BUILD_BORINGSSL=ON + cmake --build build --config ${{ matrix.build_type }} -j ${env:NUMBER_OF_PROCESSORS} --target llama-server + - name: Python setup id: setup_python uses: actions/setup-python@v5 @@ -76,6 +86,14 @@ jobs: run: | pip install -r tools/server/tests/requirements.txt + - name: Tests + id: server_integration_tests + if: ${{ (!matrix.disabled_on_pr || !github.event.pull_request) && matrix.build_type == 'Release' }} + run: | + cd tools/server/tests + export ${{ matrix.extra_args }} + pytest -v -x -m "not slow" + server-windows: runs-on: windows-2022 diff --git a/.gitignore b/.gitignore index 05eb578a82..bb122d6924 100644 --- a/.gitignore +++ b/.gitignore @@ -130,6 +130,7 @@ poetry.toml # Local scripts /run-vim.sh /run-chat.sh +/run-spec.sh /.ccache/ # IDE diff --git a/README.md b/README.md index ed956bb02e..e59612f7ae 100644 --- a/README.md +++ b/README.md @@ -482,21 +482,6 @@ To learn more about model quantization, [read this documentation](tools/quantize -## [`llama-run`](tools/run) - -#### A comprehensive example for running `llama.cpp` models. Useful for inferencing. Used with RamaLama [^3]. - --
- Run a model with a specific prompt (by default it's pulled from Ollama registry) - - ```bash - llama-run granite-code - ``` - -
- -[^3]: [RamaLama](https://github.com/containers/ramalama) - ## [`llama-simple`](examples/simple) #### A minimal example for implementing apps with `llama.cpp`. Useful for developers. @@ -600,7 +585,6 @@ $ echo "source ~/.llama-completion.bash" >> ~/.bashrc - [stb-image](https://github.com/nothings/stb) - Single-header image format decoder, used by multimodal subsystem - Public domain - [nlohmann/json](https://github.com/nlohmann/json) - Single-header JSON library, used by various tools/examples - MIT License - [minja](https://github.com/google/minja) - Minimal Jinja parser in C++, used by various tools/examples - MIT License -- [linenoise.cpp](./tools/run/linenoise.cpp/linenoise.cpp) - C++ library that provides readline-like line editing capabilities, used by `llama-run` - BSD 2-Clause License - [curl](https://curl.se/) - Client-side URL transfer library, used by various tools/examples - [CURL License](https://curl.se/docs/copyright.html) - [miniaudio.h](https://github.com/mackron/miniaudio) - Single-header audio format decoder, used by multimodal subsystem - Public domain - [subprocess.h](https://github.com/sheredom/subprocess.h) - Single-header process launching solution for C and C++ - Public domain diff --git a/ci/run.sh b/ci/run.sh index 0a4a0e41eb..5c2d325a56 100755 --- a/ci/run.sh +++ b/ci/run.sh @@ -52,7 +52,8 @@ if [ ! -z ${GG_BUILD_METAL} ]; then fi if [ ! -z ${GG_BUILD_CUDA} ]; then - CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_CUDA=ON" + # TODO: Remove GGML_CUDA_CUB_3DOT2 flag once CCCL 3.2 is bundled within CTK and that CTK version is used in this project + CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_CUDA=ON -DGGML_CUDA_CUB_3DOT2=ON" if command -v nvidia-smi >/dev/null 2>&1; then CUDA_ARCH=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader,nounits 2>/dev/null | head -1 | tr -d '.') diff --git a/common/arg.cpp b/common/arg.cpp index 62d31393c4..e7966d9d5c 100644 --- a/common/arg.cpp +++ b/common/arg.cpp @@ -679,7 +679,6 @@ static void common_params_print_completion(common_params_context & ctx_arg) { "llama-quantize", "llama-qwen2vl-cli", "llama-retrieval", - "llama-run", "llama-save-load-state", "llama-server", "llama-simple", @@ -854,6 +853,54 @@ bool common_arg_utils::is_autoy(const std::string & value) { return value == "auto" || value == "-1"; } +// Simple CSV parser that handles quoted fields and escaped quotes +// example: +// input: value1,"value, with, commas","value with ""escaped"" quotes",value4 +// output: [value1] [value, with, commas] [value with "escaped" quotes] [value4] +static std::vector parse_csv_row(const std::string& input) { + std::vector fields; + std::string field; + bool in_quotes = false; + + for (size_t i = 0; i < input.length(); ++i) { + char ch = input[i]; + + if (ch == '"') { + if (!in_quotes) { + // start of quoted field (only valid if at beginning of field) + if (!field.empty()) { + // quote appeared in middle of unquoted field, treat as literal + field += '"'; + } else { + in_quotes = true; // start + } + } else { + if (i + 1 < input.length() && input[i + 1] == '"') { + // escaped quote: "" + field += '"'; + ++i; // skip the next quote + } else { + in_quotes = false; // end + } + } + } else if (ch == ',') { + if (in_quotes) { + field += ','; + } else { + fields.push_back(std::move(field)); + field.clear(); + } + } else { + field += ch; + } + } + + // Add the last field + fields.push_back(std::move(field)); + + return fields; +} + common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) { // per-example default params // we define here to make sure it's included in llama-gen-docs @@ -1250,7 +1297,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--in-file"}, "FNAME", "an input file (use comma-separated values to specify multiple files)", [](common_params & params, const std::string & value) { - for (const auto & item : string_split(value, ',')) { + for (const auto & item : parse_csv_row(value)) { std::ifstream file(item); if (!file) { throw std::runtime_error(string_format("error: failed to open file '%s'\n", item.c_str())); @@ -1397,7 +1444,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, bool value) { params.warmup = value; } - ).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_PERPLEXITY})); + ).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_DEBUG})); add_opt(common_arg( {"--spm-infill"}, string_format( @@ -1695,6 +1742,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex params.sampling.grammar = json_schema_to_grammar(json::parse(schema)); } ).set_sparam()); + add_opt(common_arg( + {"-bs", "--backend-sampling"}, + "enable backend sampling (experimental) (default: disabled)", + [](common_params & params) { + params.sampling.backend_sampling = true; + } + ).set_sparam().set_env("LLAMA_ARG_BACKEND_SAMPLING")); add_opt(common_arg( {"--pooling"}, "{none,mean,cls,last,rank}", "pooling type for embeddings, use model default if unspecified", @@ -1706,7 +1760,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex else if (value == "rank") { params.pooling_type = LLAMA_POOLING_TYPE_RANK; } else { throw std::invalid_argument("invalid value"); } } - ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING")); + ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_DEBUG}).set_env("LLAMA_ARG_POOLING")); add_opt(common_arg( {"--attention"}, "{causal,non-causal}", "attention type for embeddings, use model default if unspecified", @@ -1995,7 +2049,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--image", "--audio"}, "FILE", "path to an image or audio file. use with multimodal models, use comma-separated values for multiple files\n", [](common_params & params, const std::string & value) { - for (const auto & item : string_split(value, ',')) { + for (const auto & item : parse_csv_row(value)) { params.image.emplace_back(item); } } @@ -2252,37 +2306,12 @@ common_params_context common_params_parser_init(common_params & params, llama_ex )); add_opt(common_arg( {"--override-kv"}, "KEY=TYPE:VALUE,...", - "advanced option to override model metadata by key. to specify multiple overrides, either use comma-separated or repeat this argument.\n" + "advanced option to override model metadata by key. to specify multiple overrides, either use comma-separated values.\n" "types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false,tokenizer.ggml.add_eos_token=bool:false", [](common_params & params, const std::string & value) { - std::vector kv_overrides; - - std::string current; - bool escaping = false; - - for (const char c : value) { - if (escaping) { - current.push_back(c); - escaping = false; - } else if (c == '\\') { - escaping = true; - } else if (c == ',') { - kv_overrides.push_back(current); - current.clear(); - } else { - current.push_back(c); - } - } - - if (escaping) { - current.push_back('\\'); - } - - kv_overrides.push_back(current); - - for (const auto & kv_override : kv_overrides) { - if (!string_parse_kv_override(kv_override.c_str(), params.kv_overrides)) { - throw std::runtime_error(string_format("error: Invalid type for KV override: %s\n", kv_override.c_str())); + for (const auto & item : parse_csv_row(value)) { + if (!string_parse_kv_override(item.c_str(), params.kv_overrides)) { + throw std::runtime_error(string_format("error: Invalid type for KV override: %s\n", item.c_str())); } } } @@ -2299,7 +2328,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--lora"}, "FNAME", "path to LoRA adapter (use comma-separated values to load multiple adapters)", [](common_params & params, const std::string & value) { - for (const auto & item : string_split(value, ',')) { + for (const auto & item : parse_csv_row(value)) { params.lora_adapters.push_back({ item, 1.0, "", "", nullptr }); } } @@ -2310,7 +2339,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "path to LoRA adapter with user defined scaling (format: FNAME:SCALE,...)\n" "note: use comma-separated values", [](common_params & params, const std::string & value) { - for (const auto & item : string_split(value, ',')) { + for (const auto & item : parse_csv_row(value)) { auto parts = string_split(item, ':'); if (parts.size() != 2) { throw std::invalid_argument("lora-scaled format: FNAME:SCALE"); @@ -2324,7 +2353,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--control-vector"}, "FNAME", "add a control vector\nnote: use comma-separated values to add multiple control vectors", [](common_params & params, const std::string & value) { - for (const auto & item : string_split(value, ',')) { + for (const auto & item : parse_csv_row(value)) { params.control_vectors.push_back({ 1.0f, item, }); } } @@ -2334,7 +2363,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex "add a control vector with user defined scaling SCALE\n" "note: use comma-separated values (format: FNAME:SCALE,...)", [](common_params & params, const std::string & value) { - for (const auto & item : string_split(value, ',')) { + for (const auto & item : parse_csv_row(value)) { auto parts = string_split(item, ':'); if (parts.size() != 2) { throw std::invalid_argument("control-vector-scaled format: FNAME:SCALE"); @@ -2432,7 +2461,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex {"--context-file"}, "FNAME", "file to load context from (use comma-separated values to specify multiple files)", [](common_params & params, const std::string & value) { - for (const auto & item : string_split(value, ',')) { + for (const auto & item : parse_csv_row(value)) { std::ifstream file(item, std::ios::binary); if (!file) { throw std::runtime_error(string_format("error: failed to open file '%s'\n", item.c_str())); @@ -2579,7 +2608,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params, int value) { params.embd_normalize = value; } - ).set_examples({LLAMA_EXAMPLE_EMBEDDING})); + ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_DEBUG})); add_opt(common_arg( {"--embd-output-format"}, "FORMAT", "empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix, \"raw\" = plain whitespace-delimited output (one embedding per line)", @@ -2657,7 +2686,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex [](common_params & params) { params.embedding = true; } - ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS")); + ).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_DEBUG}).set_env("LLAMA_ARG_EMBEDDINGS")); add_opt(common_arg( {"--rerank", "--reranking"}, string_format("enable reranking endpoint on server (default: %s)", "disabled"), @@ -2668,9 +2697,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_RERANKING")); add_opt(common_arg( {"--api-key"}, "KEY", - "API key to use for authentication (default: none)", + "API key to use for authentication, multiple keys can be provided as a comma-separated list (default: none)", [](common_params & params, const std::string & value) { - params.api_keys.push_back(value); + for (const auto & key : parse_csv_row(value)) { + if (!key.empty()) { + params.api_keys.push_back(key); + } + } } ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_API_KEY")); add_opt(common_arg( @@ -2684,7 +2717,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex std::string key; while (std::getline(key_file, key)) { if (!key.empty()) { - params.api_keys.push_back(key); + params.api_keys.push_back(key); } } key_file.close(); @@ -2706,7 +2739,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SSL_CERT_FILE")); add_opt(common_arg( {"--chat-template-kwargs"}, "STRING", - string_format("sets additional params for the json template parser"), + "sets additional params for the json template parser, must be a valid json object string, e.g. '{\"key1\":\"value1\",\"key2\":\"value2\"}'", [](common_params & params, const std::string & value) { auto parsed = json::parse(value); for (const auto & item : parsed.items()) { @@ -3344,6 +3377,27 @@ common_params_context common_params_parser_init(common_params & params, llama_ex } } ).set_examples({ LLAMA_EXAMPLE_FINETUNE })); + add_opt(common_arg( + {"--save-logits"}, + string_format("save final logits to files for verification (default: %s)", params.save_logits ? "true" : "false"), + [](common_params & params) { + params.save_logits = true; + } + ).set_examples({LLAMA_EXAMPLE_DEBUG})); + add_opt(common_arg( + {"--logits-output-dir"}, "PATH", + string_format("directory for saving logits output files (default: %s)", params.logits_output_dir.c_str()), + [](common_params & params, const std::string & value) { + params.logits_output_dir = value; + } + ).set_examples({LLAMA_EXAMPLE_DEBUG})); + add_opt(common_arg( + {"--tensor-filter"}, "REGEX", + "filter tensor names for debug output (regex pattern, can be specified multiple times)", + [](common_params & params, const std::string & value) { + params.tensor_filter.push_back(value); + } + ).set_examples({LLAMA_EXAMPLE_DEBUG})); // presets add_opt(common_arg( diff --git a/common/chat-parser.cpp b/common/chat-parser.cpp index d740dac065..23e23ca8c7 100644 --- a/common/chat-parser.cpp +++ b/common/chat-parser.cpp @@ -1395,6 +1395,14 @@ static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) { builder.consume_reasoning_with_xml_tool_calls(form, "", ""); } +static void common_chat_parse_solar_open(common_chat_msg_parser & builder) { + builder.try_parse_reasoning("<|think|>", "<|end|><|begin|>assistant<|content|>"); + + // TODO: Tool calling + + builder.add_content(builder.consume_rest()); +} + static void common_chat_parse_content_only(common_chat_msg_parser & builder) { builder.try_parse_reasoning("", ""); builder.add_content(builder.consume_rest()); @@ -1479,6 +1487,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) { case COMMON_CHAT_FORMAT_XIAOMI_MIMO: common_chat_parse_xiaomi_mimo(builder); break; + case COMMON_CHAT_FORMAT_SOLAR_OPEN: + common_chat_parse_solar_open(builder); + break; default: throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); } diff --git a/common/chat.cpp b/common/chat.cpp index 7e940695bd..22e527bab8 100644 --- a/common/chat.cpp +++ b/common/chat.cpp @@ -669,6 +669,7 @@ const char * common_chat_format_name(common_chat_format format) { case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: return "Qwen3 Coder"; case COMMON_CHAT_FORMAT_APRIEL_1_5: return "Apriel 1.5"; case COMMON_CHAT_FORMAT_XIAOMI_MIMO: return "Xiaomi MiMo"; + case COMMON_CHAT_FORMAT_SOLAR_OPEN: return "Solar Open"; case COMMON_CHAT_FORMAT_PEG_SIMPLE: return "peg-simple"; case COMMON_CHAT_FORMAT_PEG_NATIVE: return "peg-native"; case COMMON_CHAT_FORMAT_PEG_CONSTRUCTED: return "peg-constructed"; @@ -2064,7 +2065,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp // Trigger on tool calls that appear in the commentary channel data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, - "<\\|channel\\|>(commentary|analysis) to" + "<\\|channel\\|>(?:commentary|analysis) to" }); // Trigger tool calls that appear in the role section, either at the @@ -2397,17 +2398,17 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat (inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call)); // Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives) data.grammar_triggers.push_back({ - COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL, + COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, // If thinking_forced_open, then we capture the tag in the grammar, // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) - std::string(data.thinking_forced_open ? "[\\s\\S]*?(\\s*)" : "(?:[\\s\\S]*?\\s*)?") + ( + std::string(data.thinking_forced_open ? "(\\s*)" : "") + ( "\\s*(" "(?:" "||||)?" "\\s*\\{\\s*\"name\"\\s*:\\s*\"(?:" + string_join(escaped_names, "|") + ")\"" ")" - ")[\\s\\S]*" + ")" ), }); data.preserved_tokens = { @@ -2517,6 +2518,27 @@ static common_chat_params common_chat_params_init_granite(const common_chat_temp return data; } +static common_chat_params common_chat_params_init_solar_open(const common_chat_template & tmpl, const struct templates_params & inputs) { + common_chat_params data; + + // TODO: Reasoning effort + json additional_context = {}; + + data.prompt = apply(tmpl, inputs, std::nullopt, std::nullopt, additional_context); + data.format = COMMON_CHAT_FORMAT_SOLAR_OPEN; + + data.preserved_tokens = { + "<|think|>", + "<|content|>", + "<|begin|>", + "<|end|>", + }; + + // TODO: Tool calling + + return data; +} + static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) { common_chat_params data; data.prompt = apply(tmpl, inputs); @@ -2780,6 +2802,13 @@ static common_chat_params common_chat_templates_apply_jinja( return common_chat_params_init_magistral(tmpl, params); } + // Solar Open + if (src.find("<|tool_response:begin|>") != std::string::npos && + src.find("<|tool_response:name|>") != std::string::npos && + src.find("<|tool_response:result|>") != std::string::npos) { + return common_chat_params_init_solar_open(tmpl, params); + } + // Plain handler (no tools) if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { return common_chat_params_init_without_tools(tmpl, params); diff --git a/common/chat.h b/common/chat.h index 6085510a40..8bd4a325ff 100644 --- a/common/chat.h +++ b/common/chat.h @@ -124,6 +124,7 @@ enum common_chat_format { COMMON_CHAT_FORMAT_QWEN3_CODER_XML, COMMON_CHAT_FORMAT_APRIEL_1_5, COMMON_CHAT_FORMAT_XIAOMI_MIMO, + COMMON_CHAT_FORMAT_SOLAR_OPEN, // These are intended to be parsed by the PEG parser COMMON_CHAT_FORMAT_PEG_SIMPLE, diff --git a/common/common.cpp b/common/common.cpp index 79c4756125..41b2b6833e 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -1086,6 +1086,7 @@ struct common_init_result::impl { std::vector lora; std::vector samplers; + std::vector samplers_seq_config; }; common_init_result::common_init_result(common_params & params) : @@ -1162,10 +1163,19 @@ common_init_result::common_init_result(common_params & params) : // params.sampling.dry_penalty_last_n = llama_n_ctx(lctx); //} + // init the backend samplers as part of the context creation pimpl->samplers.resize(cparams.n_seq_max); + pimpl->samplers_seq_config.resize(cparams.n_seq_max); for (int i = 0; i < (int) cparams.n_seq_max; ++i) { pimpl->samplers[i].reset(common_sampler_init(model, params.sampling)); + pimpl->samplers_seq_config[i] = { i, common_sampler_get(pimpl->samplers[i].get()) }; + } + + // TODO: temporarily gated behind a flag + if (params.sampling.backend_sampling) { + cparams.samplers = pimpl->samplers_seq_config.data(); + cparams.n_samplers = pimpl->samplers_seq_config.size(); } llama_context * lctx = llama_init_from_model(model, cparams); @@ -1189,6 +1199,12 @@ common_sampler * common_init_result::sampler(llama_seq_id seq_id) { return pimpl->samplers[seq_id].get(); } +void common_init_result::reset_samplers() { + for (int i = 0; i < (int) pimpl->samplers.size(); ++i) { + llama_sampler_reset(common_sampler_get(pimpl->samplers[i].get())); + } +} + std::vector & common_init_result::lora() { return pimpl->lora; } @@ -1304,6 +1320,9 @@ common_init_result_ptr common_init_from_params(common_params & params) { llama_synchronize(lctx); llama_perf_context_reset(lctx); llama_set_warmup(lctx, false); + + // reset samplers to reset RNG state after warmup to the seeded state + res->reset_samplers(); } return res; diff --git a/common/common.h b/common/common.h index 55749dd8c7..33b7849a8a 100644 --- a/common/common.h +++ b/common/common.h @@ -80,6 +80,7 @@ int32_t cpu_get_num_math(); // enum llama_example { + LLAMA_EXAMPLE_DEBUG, LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_COMPLETION, @@ -216,6 +217,8 @@ struct common_params_sampling { std::vector logit_bias; // logit biases to apply std::vector logit_bias_eog; // pre-calculated logit biases for EOG tokens + bool backend_sampling = false; + bool has_logit_bias() const { return !logit_bias.empty(); } @@ -370,6 +373,11 @@ struct common_params { std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT std::string logits_file = ""; // file for saving *all* logits // NOLINT + // llama-debug specific options + std::string logits_output_dir = "data"; // directory for saving logits output files // NOLINT + bool save_logits = false; // whether to save logits to files // NOLINT + std::vector tensor_filter; // filter tensor names for debug output (regex) // NOLINT + std::vector in_files; // all input files std::vector antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts) std::vector kv_overrides; @@ -690,7 +698,9 @@ struct common_init_result { llama_model * model(); llama_context * context(); + common_sampler * sampler(llama_seq_id seq_id); + void reset_samplers(); std::vector & lora(); diff --git a/common/llguidance.cpp b/common/llguidance.cpp index adce620e4d..d58f147a76 100644 --- a/common/llguidance.cpp +++ b/common/llguidance.cpp @@ -106,12 +106,16 @@ static void llama_sampler_llg_free(llama_sampler * smpl) { } static llama_sampler_i llama_sampler_llg_i = { - /* .name = */ llama_sampler_llg_name, - /* .accept = */ llama_sampler_llg_accept_impl, - /* .apply = */ llama_sampler_llg_apply, - /* .reset = */ llama_sampler_llg_reset, - /* .clone = */ llama_sampler_llg_clone, - /* .free = */ llama_sampler_llg_free, + /* .name = */ llama_sampler_llg_name, + /* .accept = */ llama_sampler_llg_accept_impl, + /* .apply = */ llama_sampler_llg_apply, + /* .reset = */ llama_sampler_llg_reset, + /* .clone = */ llama_sampler_llg_clone, + /* .free = */ llama_sampler_llg_free, + /* .backend_init = */ NULL, + /* .backend_accept = */ NULL, + /* .backend_apply = */ NULL, + /* .backend_set_input = */ NULL, }; static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len, diff --git a/common/regex-partial.cpp b/common/regex-partial.cpp index 4bff6b6633..e667a209e9 100644 --- a/common/regex-partial.cpp +++ b/common/regex-partial.cpp @@ -27,7 +27,7 @@ common_regex_match common_regex::search(const std::string & input, size_t pos, b return res; } std::match_results srmatch; - if (std::regex_match(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial)) { + if (std::regex_search(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial, std::regex_constants::match_continuous)) { auto group = srmatch[1].str(); if (group.length() != 0) { auto it = srmatch[1].second.base(); @@ -55,18 +55,18 @@ common_regex_match common_regex::search(const std::string & input, size_t pos, b to see if a string ends with a partial regex match, but but it's not in std::regex yet. Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input. - - /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:(?:d)?c)?b)?a).* - - /a|b/ -> (a|b).* + - /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:(?:d)?c)?b)?a) + - /a|b/ -> ^(a|b) - /a*?/ -> error, could match "" - - /a*b/ -> ((?:b)?a*+).* (final repetitions become eager) - - /.*?ab/ -> ((?:b)?a).* (merge .*) - - /a.*?b/ -> ((?:b)?.*?a).* (keep reluctant matches) - - /a(bc)d/ -> ((?:(?:d)?(?:(?:c)?b))?a).* - - /a(bc|de)/ -> ((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a).* - - /ab{2,4}c/ -> abbb?b?c -> ((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a).* + - /a*b/ -> ^((?:b)?a*+) (final repetitions become eager) + - /.*?ab/ -> ^((?:b)?a) (omit .*) + - /a.*?b/ -> ^((?:b)?.*?a) (keep reluctant matches) + - /a(bc)d/ -> ^((?:(?:d)?(?:(?:c)?b))?a) + - /a(bc|de)/ -> ^((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a) + - /ab{2,4}c/ -> ^cbbb?b?a -> ^((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a) - The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern - (i.e. just where the final .* starts in the inverted pattern; all other groups are turned into non-capturing groups, and reluctant quantifiers are ignored) + The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern. + All other groups are turned into non-capturing groups, and reluctant quantifiers are ignored. */ std::string regex_to_reversed_partial_regex(const std::string & pattern) { auto it = pattern.begin(); @@ -177,7 +177,7 @@ std::string regex_to_reversed_partial_regex(const std::string & pattern) { } } - // /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:d)?c)?b)?a).* + // /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:d)?c)?b)?a) // if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group // We'll do the outermost capturing group and final .* in the enclosing function. std::vector res_alts; @@ -200,5 +200,5 @@ std::string regex_to_reversed_partial_regex(const std::string & pattern) { throw std::runtime_error("Unmatched '(' in pattern"); } - return "(" + res + ")[\\s\\S]*"; + return "^(" + res + ")"; } diff --git a/common/sampling.cpp b/common/sampling.cpp index c66f935c65..8a931d51fc 100644 --- a/common/sampling.cpp +++ b/common/sampling.cpp @@ -120,17 +120,34 @@ struct common_sampler { } void set_logits(struct llama_context * ctx, int idx) { - const auto * logits = llama_get_logits_ith(ctx, idx); + const float * sampled_probs = llama_get_sampled_probs_ith (ctx, idx); + const float * sampled_logits = llama_get_sampled_logits_ith (ctx, idx); + const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx); const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); const int n_vocab = llama_vocab_n_tokens(vocab); - cur.resize(n_vocab); - - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + if (sampled_probs) { + const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith(ctx, idx); + cur.resize(sampled_probs_count); + for (uint32_t i = 0; i < sampled_probs_count; ++i) { + cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]}; + } + } else if (sampled_logits) { + const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith(ctx, idx); + cur.resize(sampled_logits_count); + for (uint32_t i = 0; i < sampled_logits_count; i++) { + cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f}; + } + } else { + const auto * logits = llama_get_logits_ith(ctx, idx); + GGML_ASSERT(logits != nullptr); + cur.resize(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } } cur_p = { cur.data(), cur.size(), -1, false }; @@ -159,7 +176,7 @@ std::string common_params_sampling::print() const { return std::string(result); } -struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) { +struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params) { const llama_vocab * vocab = llama_model_get_vocab(model); llama_sampler_chain_params lparams = llama_sampler_chain_default_params(); @@ -179,24 +196,30 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co #endif // LLAMA_USE_LLGUIDANCE } else { std::vector trigger_patterns; - std::vector patterns_anywhere; std::vector trigger_tokens; for (const auto & trigger : params.grammar_triggers) { switch (trigger.type) { case COMMON_GRAMMAR_TRIGGER_TYPE_WORD: { const auto & word = trigger.value; - patterns_anywhere.push_back(regex_escape(word)); + trigger_patterns.push_back(regex_escape(word)); break; } case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN: { - patterns_anywhere.push_back(trigger.value); + trigger_patterns.push_back(trigger.value); break; } case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL: { - trigger_patterns.push_back(trigger.value); + const auto & pattern = trigger.value; + std::string anchored = "^$"; + if (!pattern.empty()) { + anchored = (pattern.front() != '^' ? "^" : "") + + pattern + + (pattern.back() != '$' ? "$" : ""); + } + trigger_patterns.push_back(anchored); break; } case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN: @@ -210,10 +233,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co } } - if (!patterns_anywhere.empty()) { - trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*"); - } - std::vector trigger_patterns_c; trigger_patterns_c.reserve(trigger_patterns.size()); for (const auto & regex : trigger_patterns) { @@ -296,6 +315,12 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, co llama_sampler_chain_add(chain, smpl); } + if (grmr && params.backend_sampling) { + LOG_WRN("%s: backend sampling is not compatible with grammar, disabling\n", __func__); + + params.backend_sampling = false; + } + auto * result = new common_sampler { /* .params = */ params, /* .grmr = */ grmr, @@ -405,6 +430,25 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co auto & chain = gsmpl->chain; auto & cur_p = gsmpl->cur_p; // initialized by set_logits + // Check if a backend sampler has already sampled a token in which case we + // return that token id directly. + { + id = llama_get_sampled_token_ith(ctx, idx); + + if (id != LLAMA_TOKEN_NULL) { + LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id); + + GGML_ASSERT(!gsmpl->grmr && "using grammar in combination with backend sampling is not supported"); + + // TODO: simplify + gsmpl->cur.resize(1); + gsmpl->cur[0] = { id, 0.0f, 1.0f }; + cur_p = { gsmpl->cur.data(), gsmpl->cur.size(), 0, true }; + + return id; + } + } + gsmpl->set_logits(ctx, idx); if (grammar_first) { diff --git a/common/sampling.h b/common/sampling.h index c7101032f2..5b57ad6581 100644 --- a/common/sampling.h +++ b/common/sampling.h @@ -36,7 +36,8 @@ struct common_sampler; // llama_sampler API overloads -struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params); +// note: can mutate params in some cases +struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params); void common_sampler_free(struct common_sampler * gsmpl); @@ -48,6 +49,7 @@ struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl); // arguments can be nullptr to skip printing void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl); +// get the underlying llama_sampler_chain struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl); // extended sampling implementation: diff --git a/convert_hf_to_gguf.py b/convert_hf_to_gguf.py index edc0ed539d..0a8bac0e2d 100755 --- a/convert_hf_to_gguf.py +++ b/convert_hf_to_gguf.py @@ -771,9 +771,14 @@ class TextModel(ModelBase): self.rope_parameters = self.hparams.get("rope_parameters", self.hparams.get("rope_scaling")) or {} + rope_theta = self.find_hparam(["rope_theta", "global_rope_theta", "rotary_emb_base"], optional=True) + local_rope_theta = self.find_hparam(["local_rope_theta", "rope_local_theta", "swa_rope_theta", "rope_local_base_freq"], optional=True) + # Ensure "rope_theta" and "rope_type" is mirrored in rope_parameters if "full_attention" not in self.rope_parameters and "sliding_attention" not in self.rope_parameters: - if "rope_theta" not in self.rope_parameters and (rope_theta := self.find_hparam(["rope_theta", "global_rope_theta", "rotary_emb_base"], optional=True)) is not None: + if local_rope_theta is not None: + self.rope_parameters["sliding_attention"] = {"rope_theta": local_rope_theta} + if "rope_theta" not in self.rope_parameters and rope_theta is not None: self.rope_parameters["rope_theta"] = rope_theta if "rope_type" not in self.rope_parameters and (rope_type := self.rope_parameters.get("type")) is not None: self.rope_parameters["rope_type"] = rope_type @@ -839,6 +844,7 @@ class TextModel(ModelBase): self.gguf_writer.add_head_count_kv(n_head_kv) logger.info(f"gguf: key-value head count = {n_head_kv}") + # TODO: Handle "sliding_attention" similarly when models start implementing it rope_params = self.rope_parameters.get("full_attention", self.rope_parameters) if (rope_type := rope_params.get("rope_type")) is not None: rope_factor = rope_params.get("factor") @@ -885,6 +891,9 @@ class TextModel(ModelBase): if (rope_theta := rope_params.get("rope_theta")) is not None: self.gguf_writer.add_rope_freq_base(rope_theta) logger.info(f"gguf: rope theta = {rope_theta}") + if (local_rope_theta := self.rope_parameters.get("sliding_attention", {}).get("rope_theta")) is not None: + self.gguf_writer.add_rope_freq_base_swa(local_rope_theta) + logger.info(f"gguf: rope theta swa = {local_rope_theta}") if (f_rms_eps := self.find_hparam(["rms_norm_eps", "norm_eps"], optional=True)) is not None: self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps) logger.info(f"gguf: rms norm epsilon = {f_rms_eps}") @@ -1062,6 +1071,9 @@ class TextModel(ModelBase): if chkhsh == "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273": # ref: https://huggingface.co/alvarobartt/grok-2-tokenizer res = "grok-2" + if chkhsh == "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df": + # ref: https://huggingface.co/aari1995/German_Semantic_V3 + res = "jina-v2-de" if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5": # ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B res = "llama-bpe" @@ -1230,6 +1242,12 @@ class TextModel(ModelBase): if chkhsh == "4a2e2abae11ca2b86d570fc5b44be4d5eb5e72cc8f22dd136a94b37da83ab665": # ref: https://huggingface.co/KORMo-Team/KORMo-tokenizer res = "kormo" + if chkhsh == "9d70134b369a70e5735009b6de918f7581b5211f7c074d1f89f753aea8248af1": + # ref: https://huggingface.co/tencent/Youtu-LLM-2B + res = "youtu" + if chkhsh == "16389f0a1f51ee53e562ffd51c371dc508639ab0e4261502071836e50e223e91": + # ref: https://huggingface.co/upstage/Solar-Open-100B + res = "solar-open" if res is None: logger.warning("\n") @@ -2486,6 +2504,7 @@ class StableLMModel(TextModel): "VLlama3ForCausalLM", "LlavaForConditionalGeneration", "VoxtralForConditionalGeneration", + "IQuestCoderForCausalLM", "LlamaModel") class LlamaModel(TextModel): model_arch = gguf.MODEL_ARCH.LLAMA @@ -4994,7 +5013,6 @@ class Plamo3Model(TextModel): if (sliding_window := self.find_hparam(["window_size", "sliding_window"], optional=True)) is not None: self.gguf_writer.add_sliding_window(sliding_window) self.gguf_writer.add_sliding_window_pattern(self.hparams["sliding_window_pattern"]) - self.gguf_writer.add_rope_freq_base_swa(self.rope_parameters.get("sliding_attention", {"rope_theta": self.hparams.get("rope_local_theta")})["rope_theta"]) def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: @@ -5284,13 +5302,14 @@ class BertModel(TextModel): self.gguf_writer.add_token_type_count(self.hparams.get("type_vocab_size", 1)) # convert to phantom space vocab - def phantom(tok): - if tok.startswith("[") and tok.endswith("]"): + def phantom(tok, toktype): + if toktype == gguf.TokenType.CONTROL: return tok if tok.startswith("##"): return tok[2:] return "\u2581" + tok - tokens = list(map(phantom, tokens)) + assert len(tokens) == len(toktypes) + tokens = list(map(phantom, tokens, toktypes)) # add vocab to gguf self.gguf_writer.add_tokenizer_model("bert") @@ -6404,6 +6423,17 @@ class ARwkv7Model(Rwkv7Model): self.gguf_writer.add_head_count(0) +@ModelBase.register("MaincoderForCausalLM") +class MaincoderModel(TextModel): + model_arch = gguf.MODEL_ARCH.MAINCODER + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + if (head_dim := self.hparams.get("head_dim")) is not None: + self.gguf_writer.add_rope_dimension_count(head_dim) + + @ModelBase.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM") class MambaModel(TextModel): model_arch = gguf.MODEL_ARCH.MAMBA @@ -7181,6 +7211,8 @@ class DeepseekModel(TextModel): "DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM", "KimiVLForConditionalGeneration", + "YoutuForCausalLM", + "YoutuVLForConditionalGeneration" ) class DeepseekV2Model(TextModel): model_arch = gguf.MODEL_ARCH.DEEPSEEK2 @@ -7247,7 +7279,15 @@ class DeepseekV2Model(TextModel): super().set_gguf_parameters() hparams = self.hparams - self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"]) + # first_k_dense_replace: number of leading layers using dense FFN instead of MoE + # For non-MoE models (like Youtu), set to n_layer to use dense FFN for all layers + # For MoE models (like DeepSeek-V2), this is the number of leading non-MoE layers + has_moe = hparams.get("n_routed_experts") is not None + first_k_dense_replace = hparams.get("first_k_dense_replace") + if first_k_dense_replace is None: + # Default: if no MoE, all layers are dense; if MoE, none are dense + first_k_dense_replace = hparams["num_hidden_layers"] if not has_moe else 0 + self.gguf_writer.add_leading_dense_block_count(first_k_dense_replace) self.gguf_writer.add_vocab_size(hparams["vocab_size"]) if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None: self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"]) @@ -7259,11 +7299,24 @@ class DeepseekV2Model(TextModel): self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) self.gguf_writer.add_value_length_mla(hparams["v_head_dim"]) - self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"]) - self.gguf_writer.add_expert_count(hparams["n_routed_experts"]) - self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"]) - self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"]) - self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"]) + # MoE parameters (required by C++ code for DEEPSEEK2 arch) + # For non-MoE models like Youtu, use intermediate_size as expert_feed_forward_length + moe_intermediate_size = self.find_hparam(["moe_intermediate_size", "intermediate_size"], optional=False) + self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) + + if (n_routed_experts := hparams.get("n_routed_experts")) is not None: + self.gguf_writer.add_expert_count(n_routed_experts) + + # expert_shared_count is required by C++ code, default to 0 for non-MoE models + n_shared_experts = hparams.get("n_shared_experts", 0) + self.gguf_writer.add_expert_shared_count(n_shared_experts) + + # When not set, C++ code will use scale_w = false to skip the no-op scaling + if (routed_scaling_factor := hparams.get("routed_scaling_factor")) is not None: + self.gguf_writer.add_expert_weights_scale(routed_scaling_factor) + + if (norm_topk_prob := hparams.get("norm_topk_prob")) is not None and norm_topk_prob: + self.gguf_writer.add_expert_weights_norm(norm_topk_prob) self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) @@ -7279,10 +7332,17 @@ class DeepseekV2Model(TextModel): # skip vision tensors and remove "language_model." for Kimi-VL if "vision_tower" in name or "multi_modal_projector" in name: return [] - + if name.startswith("siglip2.") or name.startswith("merger."): + return [] if name.startswith("language_model."): name = name.replace("language_model.", "") + # skip lm_head.weight if tie_word_embeddings is True + if self.hparams.get("tie_word_embeddings", False): + if name == "lm_head.weight" or name == "model.lm_head.weight": + logger.info("Skipping tied output layer 'lm_head.weight' (will use token_embd.weight)") + return [] + # rename e_score_correction_bias tensors if name.endswith("e_score_correction_bias"): name = name.replace("e_score_correction_bias", "e_score_correction.bias") @@ -7429,7 +7489,6 @@ class MimoV2Model(TextModel): self.gguf_writer.add_sliding_window(self.hparams["sliding_window"]) self.gguf_writer.add_sliding_window_pattern(self.hparams["hybrid_layer_pattern"]) - self.gguf_writer.add_rope_freq_base_swa(self.hparams["swa_rope_theta"]) self.gguf_writer.add_value_length(self.hparams["v_head_dim"]) self.gguf_writer.add_expert_count(self.hparams["n_routed_experts"]) self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"]) @@ -9897,6 +9956,27 @@ class LFM2Model(TextModel): return any(p in name for p in ["audio", "codebook", "conformer", "depth_embedding", "depthformer", "depth_linear"]) +@ModelBase.register("Lfm2Model") +class LFM2ColBertModel(LFM2Model): + model_arch = gguf.MODEL_ARCH.LFM2 + dense_tensor_name = "dense_2" + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + if not name.startswith(self.dense_tensor_name): + name = "model." + name + + return super().modify_tensors(data_torch, name, bid) + + def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: + # dense tensor is stored in a separate safetensors file + from safetensors.torch import load_file + tensors_file = self.dir_model / "1_Dense" / "model.safetensors" + assert tensors_file.is_file() + tensor = load_file(tensors_file)["linear.weight"] + self.gguf_writer.add_embedding_length_out(tensor.shape[0]) + yield f"{self.dense_tensor_name}.weight", tensor.clone() + + @ModelBase.register("Lfm2MoeForCausalLM") class LFM2MoeModel(TextModel): model_arch = gguf.MODEL_ARCH.LFM2MOE @@ -10167,7 +10247,6 @@ class ModernBertModel(BertModel): self.gguf_writer.add_sliding_window(self.hparams["local_attention"]) if (sliding_window_pattern := self.hparams.get("global_attn_every_n_layers")) is not None: self.gguf_writer.add_sliding_window_pattern(sliding_window_pattern) - self.gguf_writer.add_rope_freq_base_swa(self.rope_parameters.get("sliding_attention", {"rope_theta": self.hparams.get("local_rope_theta")})["rope_theta"]) self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) @@ -10617,6 +10696,79 @@ class JanusProVisionModel(MmprojModel): return [] +@ModelBase.register("YoutuVLForConditionalGeneration") +class YoutuVLVisionModel(MmprojModel): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + assert self.hparams_vision is not None + self.hparams_vision["image_size"] = self.hparams_vision.get("image_size", 560) + + def set_gguf_parameters(self): + super().set_gguf_parameters() + + self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.YOUTUVL) + self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-6)) + + # Handle activation function + hidden_act = str(self.hparams.get("hidden_act", "gelu_pytorch_tanh")).lower() + if hidden_act in ("gelu", "gelu_pytorch_tanh", "gelu_fast", "gelu_new", "gelu_accurate"): + self.gguf_writer.add_vision_use_gelu(True) + elif hidden_act == "silu": + self.gguf_writer.add_vision_use_silu(True) + else: + raise ValueError(f"Unsupported activation function for YOUTUVL: {hidden_act}") + + self.gguf_writer.add_vision_spatial_merge_size(self.hparams.get("spatial_merge_size", 2)) + + window_size = self.hparams.get("window_size") + if window_size is not None: + self.gguf_writer.add_vision_window_size(window_size) + # fullatt_block_indexes contains explicit layer indices that use full attention + # e.g., [2, 5, 8, 11] means layers 2, 5, 8, 11 use full attention + # All other layers use window attention + fullatt_block_indexes = self.hparams.get("fullatt_block_indexes") + assert fullatt_block_indexes is not None, "fullatt_block_indexes is required for youtuvl" + # Store the explicit layer indices for YoutuVL (irregular pattern approach) + self.gguf_writer.add_vision_wa_layer_indexes(layers=fullatt_block_indexes) + + def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: + del bid # unused + + # Skip language model tensors + skip_prefixes = ('lm_head.', 'model.layers.', 'model.embed_tokens.', 'model.norm.') + if name.startswith(skip_prefixes): + return [] + + # Try to map the tensor using TensorNameMap (handles vision encoder and projector) + try: + new_name = self.map_tensor_name(name) + return [(new_name, data_torch)] + except ValueError: + # If mapping fails, log warning and skip + logger.warning(f"Cannot map tensor: {name}") + return [] + + +@ModelBase.register("SolarOpenForCausalLM") +class SolarOpenModel(Glm4MoeModel): + model_arch = gguf.MODEL_ARCH.GLM4_MOE + + def set_vocab(self): + from transformers import AutoTokenizer + tokenizer = AutoTokenizer.from_pretrained(self.dir_model) + special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True) + tokens, toktypes, tokpre = self.get_vocab_base() + self.gguf_writer.add_tokenizer_model("gpt2") + self.gguf_writer.add_tokenizer_pre(tokpre) + self.gguf_writer.add_token_list(tokens) + self.gguf_writer.add_token_types(toktypes) + special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|endoftext|>"]) + special_vocab._set_special_token("unk", tokenizer.get_added_vocab()[""]) + special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|startoftext|>"]) + special_vocab.add_to_gguf(self.gguf_writer) + + ###### CONVERSION LOGIC ###### @@ -10822,8 +10974,8 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--sentence-transformers-dense-modules", action="store_true", - help=("Whether to include sentence-transformers dense modules." - "It can be used for sentence-transformers models, like google/embeddinggemma-300m" + help=("Whether to include sentence-transformers dense modules. " + "It can be used for sentence-transformers models, like google/embeddinggemma-300m. " "Default these modules are not included.") ) diff --git a/convert_hf_to_gguf_update.py b/convert_hf_to_gguf_update.py index 4378378309..74c67e6a9c 100755 --- a/convert_hf_to_gguf_update.py +++ b/convert_hf_to_gguf_update.py @@ -145,6 +145,8 @@ models = [ {"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", }, {"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", }, {"name": "kormo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/KORMo-Team/KORMo-tokenizer", }, + {"name": "youtu", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Youtu-LLM-2B", }, + {"name": "solar-open", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/upstage/Solar-Open-100B", }, ] # some models are known to be broken upstream, so we will skip them as exceptions @@ -165,6 +167,8 @@ pre_computed_hashes = [ {"name": "kimi-k2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base", "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890"}, {"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3-Embedding-0.6B", "chkhsh": "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c"}, {"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"}, + # jina-v2-de variants + {"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/aari1995/German_Semantic_V3", "chkhsh": "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df"}, ] diff --git a/docs/backend/CANN.md b/docs/backend/CANN.md index 37dcfaef9a..b03c2a122c 100755 --- a/docs/backend/CANN.md +++ b/docs/backend/CANN.md @@ -327,3 +327,7 @@ Maximum number of compiled CANN graphs kept in the LRU cache, default is 12. Whe ### GGML_CANN_PREFILL_USE_GRAPH Enable ACL graph execution during the prefill stage, default is false. This option is only effective when FA is enabled. + +### GGML_CANN_OPERATOR_FUSION + +Enable operator fusion during computation, default is false. This option fuses compatible operators (e.g., ADD + RMS_NORM) to reduce overhead and improve performance. diff --git a/docs/backend/OPENCL.md b/docs/backend/OPENCL.md index ce6c7b5605..0561a74c47 100644 --- a/docs/backend/OPENCL.md +++ b/docs/backend/OPENCL.md @@ -218,6 +218,56 @@ cmake .. -G Ninja ` ninja ``` +## Linux + +The two steps just above also apply to Linux. When building for linux, the commands are mostly the same as those for PowerShell on Windows, but in the second step they do not have the `-DCMAKE_TOOLCHAIN_FILE` parameter, and then in both steps the backticks are replaced with back slashes. + +If not installed already, install Git, CMake, Clang, Ninja and Python, then run in the terminal the following: + +### I. Setup Environment + +1. **Install OpenCL Headers and Library** + +```bash +mkdir -p ~/dev/llm + +cd ~/dev/llm +git clone https://github.com/KhronosGroup/OpenCL-Headers && cd OpenCL-Headers +mkdir build && cd build +cmake .. -G Ninja \ + -DBUILD_TESTING=OFF \ + -DOPENCL_HEADERS_BUILD_TESTING=OFF \ + -DOPENCL_HEADERS_BUILD_CXX_TESTS=OFF \ + -DCMAKE_INSTALL_PREFIX="$HOME/dev/llm/opencl" +cmake --build . --target install + +cd ~/dev/llm +git clone https://github.com/KhronosGroup/OpenCL-ICD-Loader && cd OpenCL-ICD-Loader +mkdir build && cd build +cmake .. -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_PREFIX_PATH="$HOME/dev/llm/opencl" \ + -DCMAKE_INSTALL_PREFIX="$HOME/dev/llm/opencl" +cmake --build . --target install +``` + +### II. Build llama.cpp + +```bash +mkdir -p ~/dev/llm +cd ~/dev/llm + +git clone https://github.com/ggml-org/llama.cpp && cd llama.cpp +mkdir build && cd build + +cmake .. -G Ninja \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_PREFIX_PATH="$HOME/dev/llm/opencl" \ + -DBUILD_SHARED_LIBS=OFF \ + -DGGML_OPENCL=ON +ninja +``` + ## Known Issues - Flash attention does not always improve performance. diff --git a/docs/ops.md b/docs/ops.md index 2b2770cb76..142f401d03 100644 --- a/docs/ops.md +++ b/docs/ops.md @@ -22,7 +22,7 @@ Legend: | ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ | ❌ | -| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | ❌ | +| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ | | CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ | ❌ | | CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ | | CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ | diff --git a/docs/ops/WebGPU.csv b/docs/ops/WebGPU.csv index bfff75e66f..8cd7e12001 100644 --- a/docs/ops/WebGPU.csv +++ b/docs/ops/WebGPU.csv @@ -35,8 +35,8 @@ "WebGPU: WebGPU","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","WebGPU" "WebGPU: WebGPU","FLOOR","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","WebGPU" "WebGPU: WebGPU","FLOOR","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","WebGPU" -"WebGPU: WebGPU","CEIL","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","WebGPU" -"WebGPU: WebGPU","CEIL","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","WebGPU" +"WebGPU: WebGPU","CEIL","type=f16,ne_a=[128,2,2,2],v=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","CEIL","type=f16,ne_a=[5,7,11,13],v=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROUND","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROUND","type=f16,ne_a=[5,7,11,13],v=0","support","0","no","WebGPU" "WebGPU: WebGPU","TRUNC","type=f16,ne_a=[128,2,2,2],v=0","support","0","no","WebGPU" @@ -77,8 +77,8 @@ "WebGPU: WebGPU","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=1","support","1","yes","WebGPU" "WebGPU: WebGPU","FLOOR","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","WebGPU" "WebGPU: WebGPU","FLOOR","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","WebGPU" -"WebGPU: WebGPU","CEIL","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","WebGPU" -"WebGPU: WebGPU","CEIL","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","WebGPU" +"WebGPU: WebGPU","CEIL","type=f16,ne_a=[128,2,2,2],v=1","support","1","yes","WebGPU" +"WebGPU: WebGPU","CEIL","type=f16,ne_a=[5,7,11,13],v=1","support","1","yes","WebGPU" "WebGPU: WebGPU","ROUND","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","WebGPU" "WebGPU: WebGPU","ROUND","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","WebGPU" "WebGPU: WebGPU","TRUNC","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","WebGPU" @@ -119,8 +119,8 @@ "WebGPU: WebGPU","GELU_ERF","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","WebGPU" "WebGPU: WebGPU","FLOOR","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","WebGPU" "WebGPU: WebGPU","FLOOR","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","WebGPU" -"WebGPU: WebGPU","CEIL","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","WebGPU" -"WebGPU: WebGPU","CEIL","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","WebGPU" +"WebGPU: WebGPU","CEIL","type=f32,ne_a=[128,2,2,2],v=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","CEIL","type=f32,ne_a=[5,7,11,13],v=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROUND","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROUND","type=f32,ne_a=[5,7,11,13],v=0","support","0","no","WebGPU" "WebGPU: WebGPU","TRUNC","type=f32,ne_a=[128,2,2,2],v=0","support","0","no","WebGPU" @@ -161,8 +161,8 @@ "WebGPU: WebGPU","GELU_ERF","type=f32,ne_a=[5,7,11,13],v=1","support","1","yes","WebGPU" "WebGPU: WebGPU","FLOOR","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","WebGPU" "WebGPU: WebGPU","FLOOR","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","WebGPU" -"WebGPU: WebGPU","CEIL","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","WebGPU" -"WebGPU: WebGPU","CEIL","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","WebGPU" +"WebGPU: WebGPU","CEIL","type=f32,ne_a=[128,2,2,2],v=1","support","1","yes","WebGPU" +"WebGPU: WebGPU","CEIL","type=f32,ne_a=[5,7,11,13],v=1","support","1","yes","WebGPU" "WebGPU: WebGPU","ROUND","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","WebGPU" "WebGPU: WebGPU","ROUND","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","WebGPU" "WebGPU: WebGPU","TRUNC","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","WebGPU" @@ -965,6 +965,7 @@ "WebGPU: WebGPU","IM2COL","type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,2560],ne_kernel=[3,3,1,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1","support","0","no","WebGPU" "WebGPU: WebGPU","IM2COL","type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,2560],ne_kernel=[3,3,2,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1","support","0","no","WebGPU" "WebGPU: WebGPU","IM2COL","type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[5,5,1,32],ne_kernel=[3,4,1,32],s0=1,s1=1,p0=0,p1=0,d0=1,d1=1,is_2D=1","support","0","no","WebGPU" +"WebGPU: WebGPU","IM2COL","type_input=f32,type_kernel=f32,dst_type=f32,ne_input=[2,2,1536,729],ne_kernel=[2,2,1536,4096],s0=1,s1=1,p0=0,p1=0,d0=1,d1=1,is_2D=1","support","0","no","WebGPU" "WebGPU: WebGPU","IM2COL_3D","type_input=f32,type_kernel=f32,dst_type=f32,ne_input=[10,10,10,9],ne_kernel=[3,3,3,1],IC=3,s0=1,s1=1,s2=1,p0=1,p1=1,p2=1,d0=1,d1=1,d2=1,v=0","support","0","no","WebGPU" "WebGPU: WebGPU","IM2COL_3D","type_input=f32,type_kernel=f16,dst_type=f32,ne_input=[10,10,10,9],ne_kernel=[3,3,3,1],IC=3,s0=1,s1=1,s2=1,p0=1,p1=1,p2=1,d0=1,d1=1,d2=1,v=0","support","0","no","WebGPU" "WebGPU: WebGPU","IM2COL_3D","type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[10,10,10,9],ne_kernel=[3,3,3,1],IC=3,s0=1,s1=1,s2=1,p0=1,p1=1,p2=1,d0=1,d1=1,d2=1,v=0","support","0","no","WebGPU" @@ -4964,6 +4965,7 @@ "WebGPU: WebGPU","CONV_TRANSPOSE_1D","ne_input=[2,1,1,1],ne_kernel=[3,1,1,1],s0=1,p0=0,d0=1","support","0","no","WebGPU" "WebGPU: WebGPU","CONV_TRANSPOSE_2D","ne_input=[3,2,3,1],ne_kernel=[2,2,1,3],stride=1","support","0","no","WebGPU" "WebGPU: WebGPU","CONV_TRANSPOSE_2D","ne_input=[10,10,9,1],ne_kernel=[3,3,1,9],stride=2","support","0","no","WebGPU" +"WebGPU: WebGPU","CONV_TRANSPOSE_2D","ne_input=[129,63,35,1],ne_kernel=[3,3,48,35],stride=1","support","0","no","WebGPU" "WebGPU: WebGPU","COUNT_EQUAL","type=f32,ne=[4,500,1,1]","support","0","no","WebGPU" "WebGPU: WebGPU","COUNT_EQUAL","type=f32,ne=[4,5000,1,1]","support","0","no","WebGPU" "WebGPU: WebGPU","ARGMAX","type=f32,ne=[32,1,1,1]","support","0","no","WebGPU" @@ -5715,15 +5717,15 @@ "WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3]","support","0","no","WebGPU" "WebGPU: WebGPU","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001,inplace=1","support","1","yes","WebGPU" "WebGPU: WebGPU","L2_NORM","type=f32,ne=[64,5,4,3]","support","0","no","WebGPU" -"WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[3,1024,1,1]","support","0","no","WebGPU" -"WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[3,1024,1,1]","support","0","no","WebGPU" -"WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[3,1024,1,1]","support","0","no","WebGPU" -"WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[4,1536,1,1],ne_b=[3,1536,1,1]","support","0","no","WebGPU" -"WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[8,1536,1,1],ne_b=[3,1536,1,1]","support","0","no","WebGPU" -"WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[4,1536,4,1],ne_b=[3,1536,1,1]","support","0","no","WebGPU" -"WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[3,2048,1,1]","support","0","no","WebGPU" -"WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[3,2048,1,1]","support","0","no","WebGPU" -"WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[3,2048,1,1]","support","0","no","WebGPU" +"WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[3,1024,1,1],ne_b=[3,1024,1,1]","support","0","no","WebGPU" +"WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[6,1024,1,1],ne_b=[3,1024,1,1]","support","0","no","WebGPU" +"WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[3,1024,4,1],ne_b=[3,1024,1,1]","support","0","no","WebGPU" +"WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[3,1536,1,1],ne_b=[3,1536,1,1]","support","0","no","WebGPU" +"WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[6,1536,1,1],ne_b=[3,1536,1,1]","support","0","no","WebGPU" +"WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[3,1536,4,1],ne_b=[3,1536,1,1]","support","0","no","WebGPU" +"WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[3,2048,1,1],ne_b=[3,2048,1,1]","support","0","no","WebGPU" +"WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[6,2048,1,1],ne_b=[3,2048,1,1]","support","0","no","WebGPU" +"WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[3,2048,4,1],ne_b=[3,2048,1,1]","support","0","no","WebGPU" "WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[4,1024,1,1]","support","0","no","WebGPU" "WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[4,1024,1,1]","support","0","no","WebGPU" "WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[4,1024,1,1]","support","0","no","WebGPU" @@ -5733,6 +5735,15 @@ "WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[4,2048,1,1]","support","0","no","WebGPU" "WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[4,2048,1,1]","support","0","no","WebGPU" "WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[4,2048,1,1]","support","0","no","WebGPU" +"WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[9,1024,1,1],ne_b=[9,1024,1,1]","support","0","no","WebGPU" +"WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[18,1024,1,1],ne_b=[9,1024,1,1]","support","0","no","WebGPU" +"WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[9,1024,4,1],ne_b=[9,1024,1,1]","support","0","no","WebGPU" +"WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[9,1536,1,1],ne_b=[9,1536,1,1]","support","0","no","WebGPU" +"WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[18,1536,1,1],ne_b=[9,1536,1,1]","support","0","no","WebGPU" +"WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[9,1536,4,1],ne_b=[9,1536,1,1]","support","0","no","WebGPU" +"WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[9,2048,1,1],ne_b=[9,2048,1,1]","support","0","no","WebGPU" +"WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[18,2048,1,1],ne_b=[9,2048,1,1]","support","0","no","WebGPU" +"WebGPU: WebGPU","SSM_CONV","type=f32,ne_a=[9,2048,4,1],ne_b=[9,2048,1,1]","support","0","no","WebGPU" "WebGPU: WebGPU","SSM_SCAN","type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=4","support","0","no","WebGPU" "WebGPU: WebGPU","SSM_SCAN","type=f32,d_state=128,head_dim=64,n_head=16,n_group=2,n_seq_tokens=32,n_seqs=4","support","0","no","WebGPU" "WebGPU: WebGPU","SSM_SCAN","type=f32,d_state=256,head_dim=64,n_head=8,n_group=2,n_seq_tokens=32,n_seqs=4","support","0","no","WebGPU" @@ -8662,7 +8673,7 @@ "WebGPU: WebGPU","CLAMP","type=f16,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","0","no","WebGPU" "WebGPU: WebGPU","LEAKY_RELU","type=f16,ne_a=[10,5,4,3],negative_slope=0.100000","support","0","no","WebGPU" "WebGPU: WebGPU","FLOOR","type=f16,ne=[10,2,2,2]","support","0","no","WebGPU" -"WebGPU: WebGPU","CEIL","type=f16,ne=[10,2,2,2]","support","0","no","WebGPU" +"WebGPU: WebGPU","CEIL","type=f16,ne=[10,2,2,2]","support","1","yes","WebGPU" "WebGPU: WebGPU","ROUND","type=f16,ne=[10,2,2,2]","support","0","no","WebGPU" "WebGPU: WebGPU","TRUNC","type=f16,ne=[10,2,2,2]","support","0","no","WebGPU" "WebGPU: WebGPU","SQR","type=f16,ne=[7,1,5,3]","support","0","no","WebGPU" @@ -8674,8 +8685,8 @@ "WebGPU: WebGPU","LEAKY_RELU","type=f16,ne_a=[7,1,5,3],negative_slope=0.100000","support","0","no","WebGPU" "WebGPU: WebGPU","FLOOR","type=f16,ne=[7,1,5,3]","support","0","no","WebGPU" "WebGPU: WebGPU","FLOOR","type=f16,ne=[1024,1024,1,1]","support","0","no","WebGPU" -"WebGPU: WebGPU","CEIL","type=f16,ne=[7,1,5,3]","support","0","no","WebGPU" -"WebGPU: WebGPU","CEIL","type=f16,ne=[1024,1024,1,1]","support","0","no","WebGPU" +"WebGPU: WebGPU","CEIL","type=f16,ne=[7,1,5,3]","support","1","yes","WebGPU" +"WebGPU: WebGPU","CEIL","type=f16,ne=[1024,1024,1,1]","support","1","yes","WebGPU" "WebGPU: WebGPU","ROUND","type=f16,ne=[7,1,5,3]","support","0","no","WebGPU" "WebGPU: WebGPU","ROUND","type=f16,ne=[1024,1024,1,1]","support","0","no","WebGPU" "WebGPU: WebGPU","TRUNC","type=f16,ne=[7,1,5,3]","support","0","no","WebGPU" @@ -8688,7 +8699,7 @@ "WebGPU: WebGPU","CLAMP","type=f32,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","0","no","WebGPU" "WebGPU: WebGPU","LEAKY_RELU","type=f32,ne_a=[10,5,4,3],negative_slope=0.100000","support","0","no","WebGPU" "WebGPU: WebGPU","FLOOR","type=f32,ne=[10,2,2,2]","support","0","no","WebGPU" -"WebGPU: WebGPU","CEIL","type=f32,ne=[10,2,2,2]","support","0","no","WebGPU" +"WebGPU: WebGPU","CEIL","type=f32,ne=[10,2,2,2]","support","1","yes","WebGPU" "WebGPU: WebGPU","ROUND","type=f32,ne=[10,2,2,2]","support","0","no","WebGPU" "WebGPU: WebGPU","TRUNC","type=f32,ne=[10,2,2,2]","support","0","no","WebGPU" "WebGPU: WebGPU","SQR","type=f32,ne=[7,1,5,3]","support","0","no","WebGPU" @@ -8700,8 +8711,8 @@ "WebGPU: WebGPU","LEAKY_RELU","type=f32,ne_a=[7,1,5,3],negative_slope=0.100000","support","0","no","WebGPU" "WebGPU: WebGPU","FLOOR","type=f32,ne=[7,1,5,3]","support","0","no","WebGPU" "WebGPU: WebGPU","FLOOR","type=f32,ne=[1024,1024,1,1]","support","0","no","WebGPU" -"WebGPU: WebGPU","CEIL","type=f32,ne=[7,1,5,3]","support","0","no","WebGPU" -"WebGPU: WebGPU","CEIL","type=f32,ne=[1024,1024,1,1]","support","0","no","WebGPU" +"WebGPU: WebGPU","CEIL","type=f32,ne=[7,1,5,3]","support","1","yes","WebGPU" +"WebGPU: WebGPU","CEIL","type=f32,ne=[1024,1024,1,1]","support","1","yes","WebGPU" "WebGPU: WebGPU","ROUND","type=f32,ne=[7,1,5,3]","support","0","no","WebGPU" "WebGPU: WebGPU","ROUND","type=f32,ne=[1024,1024,1,1]","support","0","no","WebGPU" "WebGPU: WebGPU","TRUNC","type=f32,ne=[7,1,5,3]","support","0","no","WebGPU" @@ -8916,6 +8927,8 @@ "WebGPU: WebGPU","SOFT_MAX","type=f32,ne=[32,2,32,1],mask=1,sinks=0,m_prec=f16,nr23=[1,1],scale=0.100000,max_bias=0.000000,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","SOFT_MAX","type=f32,ne=[32,2,32,1],mask=1,sinks=1,m_prec=f32,nr23=[1,1],scale=0.100000,max_bias=8.000000,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","SOFT_MAX","type=f32,ne=[32,2,32,1],mask=1,sinks=1,m_prec=f16,nr23=[1,1],scale=0.100000,max_bias=8.000000,inplace=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","SOFT_MAX","type=f32,ne=[200001,2,3,1],mask=1,sinks=1,m_prec=f32,nr23=[1,1],scale=0.100000,max_bias=8.000000,inplace=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","SOFT_MAX","type=f32,ne=[200001,2,3,1],mask=1,sinks=1,m_prec=f16,nr23=[1,1],scale=0.100000,max_bias=8.000000,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","SOFT_MAX_BACK","type=f32,ne=[16,16,1,1],scale=1.000000,max_bias=0.000000","support","0","no","WebGPU" "WebGPU: WebGPU","SOFT_MAX_BACK","type=f32,ne=[15,15,1,1],scale=1.000000,max_bias=0.000000","support","0","no","WebGPU" "WebGPU: WebGPU","SOFT_MAX_BACK","type=f32,ne=[16,16,2,3],scale=1.000000,max_bias=0.000000","support","0","no","WebGPU" @@ -8968,6 +8981,7 @@ "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,40,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,52,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,64,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","ROPE","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[64,1,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[64,71,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[64,8,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","WebGPU" @@ -8977,6 +8991,7 @@ "WebGPU: WebGPU","ROPE","type=f32,ne_a=[80,32,2,1],n_dims=20,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[80,32,2,1],n_dims=32,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[80,32,4,1],n_dims=32,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","ROPE","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,12,2,1],n_dims=128,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,28,2,1],n_dims=128,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,12,2,1],n_dims=20,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","WebGPU" @@ -8987,11 +9002,13 @@ "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,28,2,1],n_dims=32,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[80,16,2,1],n_dims=80,mode=24,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,16,2,1],n_dims=128,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","ROPE","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[64,128,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,32,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,40,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,52,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,64,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","ROPE","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[64,1,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[64,71,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[64,8,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","WebGPU" @@ -9001,6 +9018,7 @@ "WebGPU: WebGPU","ROPE","type=f32,ne_a=[80,32,2,1],n_dims=20,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[80,32,2,1],n_dims=32,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[80,32,4,1],n_dims=32,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","ROPE","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,12,2,1],n_dims=128,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,28,2,1],n_dims=128,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,12,2,1],n_dims=20,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","WebGPU" @@ -9011,11 +9029,13 @@ "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,28,2,1],n_dims=32,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[80,16,2,1],n_dims=80,mode=24,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,16,2,1],n_dims=128,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","ROPE","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[64,128,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,32,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,40,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,52,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,64,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","ROPE","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[64,1,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[64,71,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[64,8,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","WebGPU" @@ -9025,6 +9045,7 @@ "WebGPU: WebGPU","ROPE","type=f32,ne_a=[80,32,2,1],n_dims=20,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[80,32,2,1],n_dims=32,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[80,32,4,1],n_dims=32,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","ROPE","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,12,2,1],n_dims=128,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,28,2,1],n_dims=128,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,12,2,1],n_dims=20,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","WebGPU" @@ -9035,11 +9056,13 @@ "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,28,2,1],n_dims=32,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[80,16,2,1],n_dims=80,mode=24,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,16,2,1],n_dims=128,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","ROPE","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[64,128,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,32,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,40,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,52,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,64,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","ROPE","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[64,1,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[64,71,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[64,8,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","WebGPU" @@ -9049,6 +9072,7 @@ "WebGPU: WebGPU","ROPE","type=f32,ne_a=[80,32,2,1],n_dims=20,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[80,32,2,1],n_dims=32,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[80,32,4,1],n_dims=32,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","ROPE","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,12,2,1],n_dims=128,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,28,2,1],n_dims=128,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,12,2,1],n_dims=20,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","WebGPU" @@ -9059,6 +9083,7 @@ "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,28,2,1],n_dims=32,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[80,16,2,1],n_dims=80,mode=24,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[128,16,2,1],n_dims=128,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","WebGPU" +"WebGPU: WebGPU","ROPE","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f32,ne_a=[64,128,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f16,ne_a=[128,32,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","WebGPU" "WebGPU: WebGPU","ROPE","type=f16,ne_a=[64,128,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","1","yes","WebGPU" @@ -9184,6 +9209,7 @@ "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,40,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,52,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,64,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","WebGPU" +"WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[64,1,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[64,71,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[64,8,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","WebGPU" @@ -9193,6 +9219,7 @@ "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[80,32,2,1],n_dims=20,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[80,32,2,1],n_dims=32,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[80,32,4,1],n_dims=32,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","WebGPU" +"WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,12,2,1],n_dims=128,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,28,2,1],n_dims=128,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,12,2,1],n_dims=20,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","WebGPU" @@ -9203,11 +9230,13 @@ "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,28,2,1],n_dims=32,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[80,16,2,1],n_dims=80,mode=24,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,16,2,1],n_dims=128,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","WebGPU" +"WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[64,128,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,32,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,40,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,52,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,64,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","WebGPU" +"WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[64,1,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[64,71,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[64,8,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","WebGPU" @@ -9217,6 +9246,7 @@ "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[80,32,2,1],n_dims=20,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[80,32,2,1],n_dims=32,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[80,32,4,1],n_dims=32,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","WebGPU" +"WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,12,2,1],n_dims=128,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,28,2,1],n_dims=128,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,12,2,1],n_dims=20,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","WebGPU" @@ -9227,11 +9257,13 @@ "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,28,2,1],n_dims=32,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[80,16,2,1],n_dims=80,mode=24,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,16,2,1],n_dims=128,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","WebGPU" +"WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[64,128,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=1,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,32,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,40,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,52,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,64,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","WebGPU" +"WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[64,1,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[64,71,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[64,8,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","WebGPU" @@ -9241,6 +9273,7 @@ "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[80,32,2,1],n_dims=20,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[80,32,2,1],n_dims=32,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[80,32,4,1],n_dims=32,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","WebGPU" +"WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,12,2,1],n_dims=128,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,28,2,1],n_dims=128,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,12,2,1],n_dims=20,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","WebGPU" @@ -9251,11 +9284,13 @@ "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,28,2,1],n_dims=32,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[80,16,2,1],n_dims=80,mode=24,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,16,2,1],n_dims=128,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","WebGPU" +"WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[64,128,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=0,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,32,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,40,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,52,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,64,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","WebGPU" +"WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[64,1,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[64,71,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[64,8,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","WebGPU" @@ -9265,6 +9300,7 @@ "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[80,32,2,1],n_dims=20,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[80,32,2,1],n_dims=32,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[80,32,4,1],n_dims=32,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","WebGPU" +"WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,12,2,1],n_dims=128,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,28,2,1],n_dims=128,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,12,2,1],n_dims=20,mode=8,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","WebGPU" @@ -9275,6 +9311,7 @@ "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,28,2,1],n_dims=32,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[80,16,2,1],n_dims=80,mode=24,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[128,16,2,1],n_dims=128,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","WebGPU" +"WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[16,16,8192,1],n_dims=16,mode=40,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f32,ne_a=[64,128,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=1,v=1,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f16,ne_a=[128,32,2,1],n_dims=128,mode=0,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","WebGPU" "WebGPU: WebGPU","ROPE_BACK","type=f16,ne_a=[64,128,2,1],n_dims=64,mode=2,n_ctx=512,fs=1.000000,ef=0.000000,af=1.000000,ff=0,v=0,inplace=0","support","0","no","WebGPU" @@ -9542,333 +9579,333 @@ "WebGPU: WebGPU","ARGSORT","type=f32,ne=[2048,2,1,3],order=1","support","0","no","WebGPU" "WebGPU: WebGPU","ARGSORT","type=f32,ne=[2049,2,1,3],order=1","support","0","no","WebGPU" "WebGPU: WebGPU","ARGSORT","type=f32,ne=[2,8,8192,1],order=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[1,1,1,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[12,1,2,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[2,1,1,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[13,1,2,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[2,1,1,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[13,1,2,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[4,1,1,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[15,1,2,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[4,1,1,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[15,1,2,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[4,1,1,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[15,1,2,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[8,1,1,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[19,1,2,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[8,1,1,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[19,1,2,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[8,1,1,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[19,1,2,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[8,1,1,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[19,1,2,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[16,1,1,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[27,1,2,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[16,1,1,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[27,1,2,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[16,1,1,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[27,1,2,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[16,1,1,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[27,1,2,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[16,1,1,1],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[27,1,2,1],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[32,1,1,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[43,1,2,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[32,1,1,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[43,1,2,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[32,1,1,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[43,1,2,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[32,1,1,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[43,1,2,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[32,1,1,1],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[43,1,2,1],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[64,1,1,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[75,1,2,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[64,1,1,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[75,1,2,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[64,1,1,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[75,1,2,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[64,1,1,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[75,1,2,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[64,1,1,1],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[75,1,2,1],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[128,1,1,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[139,1,2,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[128,1,1,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[139,1,2,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[128,1,1,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[139,1,2,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[128,1,1,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[139,1,2,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[128,1,1,1],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[139,1,2,1],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[128,1,1,1],k=100","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[139,1,2,1],k=100","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[256,1,1,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[267,1,2,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[256,1,1,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[267,1,2,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[256,1,1,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[267,1,2,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[256,1,1,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[267,1,2,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[256,1,1,1],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[267,1,2,1],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[256,1,1,1],k=100","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[267,1,2,1],k=100","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[512,1,1,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[523,1,2,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[512,1,1,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[523,1,2,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[512,1,1,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[523,1,2,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[512,1,1,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[523,1,2,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[512,1,1,1],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[523,1,2,1],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[512,1,1,1],k=100","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[523,1,2,1],k=100","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[512,1,1,1],k=500","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[523,1,2,1],k=500","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[1024,1,1,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[1035,1,2,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[1024,1,1,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[1035,1,2,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[1024,1,1,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[1035,1,2,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[1024,1,1,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[1035,1,2,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[1024,1,1,1],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[1035,1,2,1],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[1024,1,1,1],k=100","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[1035,1,2,1],k=100","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[1024,1,1,1],k=500","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[1035,1,2,1],k=500","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[1024,1,1,1],k=1023","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[1035,1,2,1],k=1023","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[2048,1,1,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[2059,1,2,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[2048,1,1,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[2059,1,2,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[2048,1,1,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[2059,1,2,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[2048,1,1,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[2059,1,2,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[2048,1,1,1],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[2059,1,2,1],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[2048,1,1,1],k=100","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[2059,1,2,1],k=100","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[2048,1,1,1],k=500","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[2059,1,2,1],k=500","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[2048,1,1,1],k=1023","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[2059,1,2,1],k=1023","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[4096,1,1,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[4107,1,2,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[4096,1,1,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[4107,1,2,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[4096,1,1,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[4107,1,2,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[4096,1,1,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[4107,1,2,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[4096,1,1,1],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[4107,1,2,1],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[4096,1,1,1],k=100","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[4107,1,2,1],k=100","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[4096,1,1,1],k=500","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[4107,1,2,1],k=500","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[4096,1,1,1],k=1023","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[4107,1,2,1],k=1023","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[8192,1,1,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[8203,1,2,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[8192,1,1,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[8203,1,2,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[8192,1,1,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[8203,1,2,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[8192,1,1,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[8203,1,2,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[8192,1,1,1],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[8203,1,2,1],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[8192,1,1,1],k=100","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[8203,1,2,1],k=100","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[8192,1,1,1],k=500","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[8203,1,2,1],k=500","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[8192,1,1,1],k=1023","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[8203,1,2,1],k=1023","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[16384,1,1,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[16395,1,2,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[16384,1,1,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[16395,1,2,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[16384,1,1,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[16395,1,2,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[16384,1,1,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[16395,1,2,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[16384,1,1,1],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[16395,1,2,1],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[16384,1,1,1],k=100","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[16395,1,2,1],k=100","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[16384,1,1,1],k=500","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[16395,1,2,1],k=500","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[16384,1,1,1],k=1023","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[16395,1,2,1],k=1023","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[16384,1,1,1],k=9999","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[16395,1,2,1],k=9999","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[32768,1,1,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[32779,1,2,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[32768,1,1,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[32779,1,2,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[32768,1,1,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[32779,1,2,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[32768,1,1,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[32779,1,2,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[32768,1,1,1],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[32779,1,2,1],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[32768,1,1,1],k=100","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[32779,1,2,1],k=100","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[32768,1,1,1],k=500","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[32779,1,2,1],k=500","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[32768,1,1,1],k=1023","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[32779,1,2,1],k=1023","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[32768,1,1,1],k=9999","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[32779,1,2,1],k=9999","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[65536,1,1,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[65547,1,2,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[65536,1,1,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[65547,1,2,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[65536,1,1,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[65547,1,2,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[65536,1,1,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[65547,1,2,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[65536,1,1,1],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[65547,1,2,1],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[65536,1,1,1],k=100","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[65547,1,2,1],k=100","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[65536,1,1,1],k=500","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[65547,1,2,1],k=500","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[65536,1,1,1],k=1023","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[65547,1,2,1],k=1023","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[65536,1,1,1],k=9999","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[65547,1,2,1],k=9999","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[131072,1,1,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[131083,1,2,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[131072,1,1,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[131083,1,2,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[131072,1,1,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[131083,1,2,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[131072,1,1,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[131083,1,2,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[131072,1,1,1],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[131083,1,2,1],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[131072,1,1,1],k=100","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[131083,1,2,1],k=100","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[131072,1,1,1],k=500","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[131083,1,2,1],k=500","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[131072,1,1,1],k=1023","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[131083,1,2,1],k=1023","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[131072,1,1,1],k=9999","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[131083,1,2,1],k=9999","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[262144,1,1,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[262155,1,2,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[262144,1,1,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[262155,1,2,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[262144,1,1,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[262155,1,2,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[262144,1,1,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[262155,1,2,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[262144,1,1,1],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[262155,1,2,1],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[262144,1,1,1],k=100","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[262155,1,2,1],k=100","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[262144,1,1,1],k=500","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[262155,1,2,1],k=500","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[262144,1,1,1],k=1023","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[262155,1,2,1],k=1023","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[262144,1,1,1],k=9999","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[262155,1,2,1],k=9999","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[524288,1,1,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[524299,1,2,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[524288,1,1,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[524299,1,2,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[524288,1,1,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[524299,1,2,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[524288,1,1,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[524299,1,2,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[524288,1,1,1],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[524299,1,2,1],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[524288,1,1,1],k=100","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[524299,1,2,1],k=100","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[524288,1,1,1],k=500","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[524299,1,2,1],k=500","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[524288,1,1,1],k=1023","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[524299,1,2,1],k=1023","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[524288,1,1,1],k=9999","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[524299,1,2,1],k=9999","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[16,10,10,10],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[60,10,10,10],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[1023,2,1,3],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[1024,2,1,3],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[1025,2,1,3],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[16384,1,1,1],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[2047,2,1,3],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[2048,2,1,3],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[2049,2,1,3],k=1","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[16,10,10,10],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[60,10,10,10],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[1023,2,1,3],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[1024,2,1,3],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[1025,2,1,3],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[16384,1,1,1],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[2047,2,1,3],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[2048,2,1,3],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[2049,2,1,3],k=2","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[16,10,10,10],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[60,10,10,10],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[1023,2,1,3],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[1024,2,1,3],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[1025,2,1,3],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[16384,1,1,1],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[2047,2,1,3],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[2048,2,1,3],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[2049,2,1,3],k=3","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[16,10,10,10],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[60,10,10,10],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[1023,2,1,3],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[1024,2,1,3],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[1025,2,1,3],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[16384,1,1,1],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[2047,2,1,3],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[2048,2,1,3],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[2049,2,1,3],k=7","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[16,10,10,10],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[60,10,10,10],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[1023,2,1,3],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[1024,2,1,3],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[1025,2,1,3],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[16384,1,1,1],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[2047,2,1,3],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[2048,2,1,3],k=15","support","0","no","WebGPU" -"WebGPU: WebGPU","TOP_K","type=f32,ne=[2049,2,1,3],k=15","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[1,1,1,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[12,1,2,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[2,1,1,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[13,1,2,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[2,1,1,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[13,1,2,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[4,1,1,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[15,1,2,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[4,1,1,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[15,1,2,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[4,1,1,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[15,1,2,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[8,1,1,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[19,1,2,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[8,1,1,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[19,1,2,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[8,1,1,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[19,1,2,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[8,1,1,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[19,1,2,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[16,1,1,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[27,1,2,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[16,1,1,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[27,1,2,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[16,1,1,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[27,1,2,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[16,1,1,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[27,1,2,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[16,1,1,1],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[27,1,2,1],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[32,1,1,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[43,1,2,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[32,1,1,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[43,1,2,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[32,1,1,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[43,1,2,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[32,1,1,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[43,1,2,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[32,1,1,1],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[43,1,2,1],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[64,1,1,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[75,1,2,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[64,1,1,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[75,1,2,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[64,1,1,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[75,1,2,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[64,1,1,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[75,1,2,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[64,1,1,1],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[75,1,2,1],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[128,1,1,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[139,1,2,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[128,1,1,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[139,1,2,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[128,1,1,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[139,1,2,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[128,1,1,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[139,1,2,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[128,1,1,1],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[139,1,2,1],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[128,1,1,1],k=100,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[139,1,2,1],k=100,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[256,1,1,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[267,1,2,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[256,1,1,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[267,1,2,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[256,1,1,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[267,1,2,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[256,1,1,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[267,1,2,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[256,1,1,1],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[267,1,2,1],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[256,1,1,1],k=100,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[267,1,2,1],k=100,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[512,1,1,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[523,1,2,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[512,1,1,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[523,1,2,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[512,1,1,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[523,1,2,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[512,1,1,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[523,1,2,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[512,1,1,1],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[523,1,2,1],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[512,1,1,1],k=100,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[523,1,2,1],k=100,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[512,1,1,1],k=500,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[523,1,2,1],k=500,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[1024,1,1,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[1035,1,2,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[1024,1,1,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[1035,1,2,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[1024,1,1,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[1035,1,2,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[1024,1,1,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[1035,1,2,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[1024,1,1,1],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[1035,1,2,1],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[1024,1,1,1],k=100,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[1035,1,2,1],k=100,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[1024,1,1,1],k=500,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[1035,1,2,1],k=500,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[1024,1,1,1],k=1023,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[1035,1,2,1],k=1023,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[2048,1,1,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[2059,1,2,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[2048,1,1,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[2059,1,2,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[2048,1,1,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[2059,1,2,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[2048,1,1,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[2059,1,2,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[2048,1,1,1],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[2059,1,2,1],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[2048,1,1,1],k=100,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[2059,1,2,1],k=100,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[2048,1,1,1],k=500,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[2059,1,2,1],k=500,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[2048,1,1,1],k=1023,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[2059,1,2,1],k=1023,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[4096,1,1,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[4107,1,2,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[4096,1,1,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[4107,1,2,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[4096,1,1,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[4107,1,2,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[4096,1,1,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[4107,1,2,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[4096,1,1,1],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[4107,1,2,1],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[4096,1,1,1],k=100,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[4107,1,2,1],k=100,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[4096,1,1,1],k=500,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[4107,1,2,1],k=500,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[4096,1,1,1],k=1023,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[4107,1,2,1],k=1023,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[8192,1,1,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[8203,1,2,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[8192,1,1,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[8203,1,2,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[8192,1,1,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[8203,1,2,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[8192,1,1,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[8203,1,2,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[8192,1,1,1],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[8203,1,2,1],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[8192,1,1,1],k=100,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[8203,1,2,1],k=100,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[8192,1,1,1],k=500,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[8203,1,2,1],k=500,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[8192,1,1,1],k=1023,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[8203,1,2,1],k=1023,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[16384,1,1,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[16395,1,2,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[16384,1,1,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[16395,1,2,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[16384,1,1,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[16395,1,2,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[16384,1,1,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[16395,1,2,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[16384,1,1,1],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[16395,1,2,1],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[16384,1,1,1],k=100,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[16395,1,2,1],k=100,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[16384,1,1,1],k=500,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[16395,1,2,1],k=500,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[16384,1,1,1],k=1023,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[16395,1,2,1],k=1023,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[16384,1,1,1],k=9999,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[16395,1,2,1],k=9999,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[32768,1,1,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[32779,1,2,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[32768,1,1,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[32779,1,2,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[32768,1,1,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[32779,1,2,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[32768,1,1,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[32779,1,2,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[32768,1,1,1],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[32779,1,2,1],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[32768,1,1,1],k=100,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[32779,1,2,1],k=100,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[32768,1,1,1],k=500,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[32779,1,2,1],k=500,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[32768,1,1,1],k=1023,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[32779,1,2,1],k=1023,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[32768,1,1,1],k=9999,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[32779,1,2,1],k=9999,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[65536,1,1,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[65547,1,2,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[65536,1,1,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[65547,1,2,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[65536,1,1,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[65547,1,2,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[65536,1,1,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[65547,1,2,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[65536,1,1,1],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[65547,1,2,1],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[65536,1,1,1],k=100,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[65547,1,2,1],k=100,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[65536,1,1,1],k=500,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[65547,1,2,1],k=500,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[65536,1,1,1],k=1023,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[65547,1,2,1],k=1023,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[65536,1,1,1],k=9999,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[65547,1,2,1],k=9999,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[131072,1,1,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[131083,1,2,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[131072,1,1,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[131083,1,2,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[131072,1,1,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[131083,1,2,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[131072,1,1,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[131083,1,2,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[131072,1,1,1],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[131083,1,2,1],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[131072,1,1,1],k=100,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[131083,1,2,1],k=100,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[131072,1,1,1],k=500,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[131083,1,2,1],k=500,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[131072,1,1,1],k=1023,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[131083,1,2,1],k=1023,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[131072,1,1,1],k=9999,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[131083,1,2,1],k=9999,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[262144,1,1,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[262155,1,2,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[262144,1,1,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[262155,1,2,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[262144,1,1,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[262155,1,2,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[262144,1,1,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[262155,1,2,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[262144,1,1,1],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[262155,1,2,1],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[262144,1,1,1],k=100,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[262155,1,2,1],k=100,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[262144,1,1,1],k=500,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[262155,1,2,1],k=500,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[262144,1,1,1],k=1023,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[262155,1,2,1],k=1023,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[262144,1,1,1],k=9999,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[262155,1,2,1],k=9999,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[524288,1,1,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[524299,1,2,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[524288,1,1,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[524299,1,2,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[524288,1,1,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[524299,1,2,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[524288,1,1,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[524299,1,2,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[524288,1,1,1],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[524299,1,2,1],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[524288,1,1,1],k=100,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[524299,1,2,1],k=100,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[524288,1,1,1],k=500,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[524299,1,2,1],k=500,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[524288,1,1,1],k=1023,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[524299,1,2,1],k=1023,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[524288,1,1,1],k=9999,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[524299,1,2,1],k=9999,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[16,10,10,10],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[60,10,10,10],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[1023,2,1,3],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[1024,2,1,3],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[1025,2,1,3],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[16384,1,1,1],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[2047,2,1,3],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[2048,2,1,3],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[2049,2,1,3],k=1,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[16,10,10,10],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[60,10,10,10],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[1023,2,1,3],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[1024,2,1,3],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[1025,2,1,3],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[16384,1,1,1],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[2047,2,1,3],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[2048,2,1,3],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[2049,2,1,3],k=2,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[16,10,10,10],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[60,10,10,10],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[1023,2,1,3],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[1024,2,1,3],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[1025,2,1,3],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[16384,1,1,1],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[2047,2,1,3],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[2048,2,1,3],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[2049,2,1,3],k=3,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[16,10,10,10],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[60,10,10,10],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[1023,2,1,3],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[1024,2,1,3],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[1025,2,1,3],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[16384,1,1,1],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[2047,2,1,3],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[2048,2,1,3],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[2049,2,1,3],k=7,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[16,10,10,10],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[60,10,10,10],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[1023,2,1,3],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[1024,2,1,3],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[1025,2,1,3],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[16384,1,1,1],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[2047,2,1,3],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[2048,2,1,3],k=15,ties=0","support","0","no","WebGPU" +"WebGPU: WebGPU","TOP_K","type=f32,ne=[2049,2,1,3],k=15,ties=0","support","0","no","WebGPU" "WebGPU: WebGPU","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=nearest,transpose=0","support","0","no","WebGPU" "WebGPU: WebGPU","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=nearest,transpose=1","support","0","no","WebGPU" -"WebGPU: WebGPU","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=nearest,flags=none","support","0","no","WebGPU" -"WebGPU: WebGPU","UPSCALE","type=f32,ne=[5,7,11,13],ne_tgt=[2,5,7,11],mode=nearest,flags=none","support","0","no","WebGPU" +"WebGPU: WebGPU","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=nearest","support","0","no","WebGPU" +"WebGPU: WebGPU","UPSCALE","type=f32,ne=[5,7,11,13],ne_tgt=[2,5,7,11],mode=nearest","support","0","no","WebGPU" "WebGPU: WebGPU","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=bilinear,transpose=0","support","0","no","WebGPU" "WebGPU: WebGPU","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=bilinear,transpose=1","support","0","no","WebGPU" -"WebGPU: WebGPU","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bilinear,flags=none","support","0","no","WebGPU" -"WebGPU: WebGPU","UPSCALE","type=f32,ne=[5,7,11,13],ne_tgt=[2,5,7,11],mode=bilinear,flags=none","support","0","no","WebGPU" +"WebGPU: WebGPU","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bilinear","support","0","no","WebGPU" +"WebGPU: WebGPU","UPSCALE","type=f32,ne=[5,7,11,13],ne_tgt=[2,5,7,11],mode=bilinear","support","0","no","WebGPU" "WebGPU: WebGPU","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=bicubic,transpose=0","support","0","no","WebGPU" "WebGPU: WebGPU","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=bicubic,transpose=1","support","0","no","WebGPU" -"WebGPU: WebGPU","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bicubic,flags=none","support","0","no","WebGPU" -"WebGPU: WebGPU","UPSCALE","type=f32,ne=[5,7,11,13],ne_tgt=[2,5,7,11],mode=bicubic,flags=none","support","0","no","WebGPU" -"WebGPU: WebGPU","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=513,transpose=0","support","0","no","WebGPU" -"WebGPU: WebGPU","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=513,transpose=1","support","0","no","WebGPU" -"WebGPU: WebGPU","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bilinear,flags=none","support","0","no","WebGPU" -"WebGPU: WebGPU","UPSCALE","type=f32,ne=[5,7,11,13],ne_tgt=[2,5,7,11],mode=bilinear,flags=none","support","0","no","WebGPU" -"WebGPU: WebGPU","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bilinear,flags=align_corners","support","0","no","WebGPU" -"WebGPU: WebGPU","UPSCALE","type=f32,ne=[1,4,3,2],ne_tgt=[2,8,3,2],mode=bilinear,flags=align_corners","support","0","no","WebGPU" -"WebGPU: WebGPU","UPSCALE","type=f32,ne=[4,1,3,2],ne_tgt=[1,1,3,2],mode=bilinear,flags=align_corners","support","0","no","WebGPU" -"WebGPU: WebGPU","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bicubic,flags=align_corners","support","0","no","WebGPU" -"WebGPU: WebGPU","UPSCALE","type=f32,ne=[1,4,3,2],ne_tgt=[2,8,3,2],mode=bicubic,flags=align_corners","support","0","no","WebGPU" -"WebGPU: WebGPU","UPSCALE","type=f32,ne=[4,1,3,2],ne_tgt=[1,1,3,2],mode=bicubic,flags=align_corners","support","0","no","WebGPU" +"WebGPU: WebGPU","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bicubic","support","0","no","WebGPU" +"WebGPU: WebGPU","UPSCALE","type=f32,ne=[5,7,11,13],ne_tgt=[2,5,7,11],mode=bicubic","support","0","no","WebGPU" +"WebGPU: WebGPU","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=bilinear|antialias,transpose=0","support","0","no","WebGPU" +"WebGPU: WebGPU","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=bilinear|antialias,transpose=1","support","0","no","WebGPU" +"WebGPU: WebGPU","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bilinear|antialias","support","0","no","WebGPU" +"WebGPU: WebGPU","UPSCALE","type=f32,ne=[5,7,11,13],ne_tgt=[2,5,7,11],mode=bilinear|antialias","support","0","no","WebGPU" +"WebGPU: WebGPU","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bilinear|align_corners","support","0","no","WebGPU" +"WebGPU: WebGPU","UPSCALE","type=f32,ne=[1,4,3,2],ne_tgt=[2,8,3,2],mode=bilinear|align_corners","support","0","no","WebGPU" +"WebGPU: WebGPU","UPSCALE","type=f32,ne=[4,1,3,2],ne_tgt=[1,1,3,2],mode=bilinear|align_corners","support","0","no","WebGPU" +"WebGPU: WebGPU","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=bicubic|align_corners","support","0","no","WebGPU" +"WebGPU: WebGPU","UPSCALE","type=f32,ne=[1,4,3,2],ne_tgt=[2,8,3,2],mode=bicubic|align_corners","support","0","no","WebGPU" +"WebGPU: WebGPU","UPSCALE","type=f32,ne=[4,1,3,2],ne_tgt=[1,1,3,2],mode=bicubic|align_corners","support","0","no","WebGPU" "WebGPU: WebGPU","SUM","type=f32,ne=[10,5,4,3]","support","0","no","WebGPU" "WebGPU: WebGPU","SUM_ROWS","type=f32,ne=[10,5,4,3],permute=0,slice=0","support","0","no","WebGPU" "WebGPU: WebGPU","SUM","type=f32,ne=[11,5,6,3],permute=[0,2,1,3]","support","0","no","WebGPU" @@ -9891,8 +9928,9 @@ "WebGPU: WebGPU","GROUP_NORM","type=f32,ne=[64,64,320,1],num_groups=32,eps=0.000001","support","0","no","WebGPU" "WebGPU: WebGPU","GROUP_NORM","type=f32,ne=[9,9,1280,1],num_groups=32,eps=0.000001","support","0","no","WebGPU" "WebGPU: WebGPU","ACC","type=f32,ne_a=[256,17,1,1],ne_b=[256,16,1,1]","support","0","no","WebGPU" -"WebGPU: WebGPU","PAD","type=f32,ne_a=[512,512,1,1],pad_0=1,pad_1=1","support","0","no","WebGPU" -"WebGPU: WebGPU","PAD","type=f32,ne_a=[512,512,3,1],lp0=1,rp0=1,lp1=1,rp1=1,lp2=1,rp2=1,lp3=1,rp3=1,v=0","support","0","no","WebGPU" +"WebGPU: WebGPU","PAD","type=f32,ne_a=[512,512,1,1],pad_0=1,pad_1=1,circular=0","support","0","no","WebGPU" +"WebGPU: WebGPU","PAD","type=f32,ne_a=[33,17,2,1],pad_0=4,pad_1=3,circular=1","support","0","no","WebGPU" +"WebGPU: WebGPU","PAD","type=f32,ne_a=[512,512,3,1],lp0=1,rp0=1,lp1=1,rp1=1,lp2=1,rp2=1,lp3=1,rp3=1,v=0,circular=0","support","0","no","WebGPU" "WebGPU: WebGPU","PAD_REFLECT_1D","type=f32,ne_a=[512,34,2,1],pad_0=10,pad_1=9","support","0","no","WebGPU" "WebGPU: WebGPU","PAD_REFLECT_1D","type=f32,ne_a=[3000,384,4,1],pad_0=10,pad_1=9","support","0","no","WebGPU" "WebGPU: WebGPU","ROLL","shift0=3,shift1=-2,shift3=1,shift4=-1","support","0","no","WebGPU" @@ -9903,6 +9941,7 @@ "WebGPU: WebGPU","CUMSUM","type=f32,ne=[10,5,4,3]","support","0","no","WebGPU" "WebGPU: WebGPU","CUMSUM","type=f32,ne=[127,5,4,3]","support","0","no","WebGPU" "WebGPU: WebGPU","CUMSUM","type=f32,ne=[128,5,4,3]","support","0","no","WebGPU" +"WebGPU: WebGPU","CUMSUM","type=f32,ne=[128,128,4,4]","support","0","no","WebGPU" "WebGPU: WebGPU","CUMSUM","type=f32,ne=[255,5,4,3]","support","0","no","WebGPU" "WebGPU: WebGPU","CUMSUM","type=f32,ne=[256,5,4,3]","support","0","no","WebGPU" "WebGPU: WebGPU","CUMSUM","type=f32,ne=[511,5,4,3]","support","0","no","WebGPU" @@ -9922,17 +9961,41 @@ "WebGPU: WebGPU","FILL","type=f32,ne=[303,207,11,3],c=2.000000","support","0","no","WebGPU" "WebGPU: WebGPU","FILL","type=f32,ne=[800,600,4,4],c=-152.000000","support","0","no","WebGPU" "WebGPU: WebGPU","FILL","type=f32,ne=[2048,512,2,2],c=3.500000","support","0","no","WebGPU" +"WebGPU: WebGPU","DIAG","type=f32,ne=[10,1,4,3]","support","0","no","WebGPU" +"WebGPU: WebGPU","DIAG","type=f32,ne=[79,1,19,13]","support","0","no","WebGPU" +"WebGPU: WebGPU","DIAG","type=f32,ne=[256,1,8,16]","support","0","no","WebGPU" "WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[10,10,4,3],ne_rhs=[3,10,4,3]","support","0","no","WebGPU" "WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[11,11,1,1],ne_rhs=[5,11,1,1]","support","0","no","WebGPU" "WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[17,17,2,4],ne_rhs=[9,17,2,4]","support","0","no","WebGPU" "WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[30,30,7,1],ne_rhs=[8,30,7,1]","support","0","no","WebGPU" "WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[42,42,5,2],ne_rhs=[10,42,5,2]","support","0","no","WebGPU" "WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[64,64,2,2],ne_rhs=[10,64,2,2]","support","0","no","WebGPU" +"WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[64,64,2,2],ne_rhs=[64,64,2,2]","support","0","no","WebGPU" +"WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[79,79,5,3],ne_rhs=[417,79,5,3]","support","0","no","WebGPU" +"WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[128,128,4,2],ne_rhs=[32,128,4,2]","support","0","no","WebGPU" +"WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[80,80,2,8],ne_rhs=[80,80,2,8]","support","0","no","WebGPU" +"WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[80,80,2,8],ne_rhs=[79,80,2,8]","support","0","no","WebGPU" +"WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[80,80,2,8],ne_rhs=[81,80,2,8]","support","0","no","WebGPU" +"WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[80,80,8,8],ne_rhs=[80,80,8,8]","support","0","no","WebGPU" +"WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[80,80,8,8],ne_rhs=[79,80,8,8]","support","0","no","WebGPU" +"WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[80,80,8,8],ne_rhs=[81,80,8,8]","support","0","no","WebGPU" +"WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[84,84,4,4],ne_rhs=[32,84,4,4]","support","0","no","WebGPU" +"WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[95,95,8,8],ne_rhs=[40,95,8,8]","support","0","no","WebGPU" "WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[100,100,4,4],ne_rhs=[41,100,4,4]","support","0","no","WebGPU" -"WebGPU: WebGPU","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=0","support","0","no","WebGPU" -"WebGPU: WebGPU","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=0","support","0","no","WebGPU" -"WebGPU: WebGPU","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=1","support","0","no","WebGPU" -"WebGPU: WebGPU","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=1","support","0","no","WebGPU" +"WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[128,128,4,4],ne_rhs=[31,128,4,4]","support","0","no","WebGPU" +"WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[128,128,4,4],ne_rhs=[32,128,4,4]","support","0","no","WebGPU" +"WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[128,128,3,4],ne_rhs=[32,128,3,4]","support","0","no","WebGPU" +"WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[128,128,4,1],ne_rhs=[32,128,4,1]","support","0","no","WebGPU" +"WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[64,64,4,4],ne_rhs=[200,64,4,4]","support","0","no","WebGPU" +"WebGPU: WebGPU","SOLVE_TRI","type=f32,ne_lhs=[64,64,4,4],ne_rhs=[384,64,4,4]","support","0","no","WebGPU" +"WebGPU: WebGPU","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=0,circular=0","support","0","no","WebGPU" +"WebGPU: WebGPU","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=0,circular=0","support","0","no","WebGPU" +"WebGPU: WebGPU","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=0,circular=1","support","0","no","WebGPU" +"WebGPU: WebGPU","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=0,circular=1","support","0","no","WebGPU" +"WebGPU: WebGPU","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=1,circular=0","support","0","no","WebGPU" +"WebGPU: WebGPU","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=1,circular=0","support","0","no","WebGPU" +"WebGPU: WebGPU","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=1,circular=1","support","0","no","WebGPU" +"WebGPU: WebGPU","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=1,circular=1","support","0","no","WebGPU" "WebGPU: WebGPU","FLASH_ATTN_EXT","hsk=40,hsv=40,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","0","no","WebGPU" "WebGPU: WebGPU","FLASH_ATTN_EXT","hsk=40,hsv=40,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","0","no","WebGPU" "WebGPU: WebGPU","FLASH_ATTN_EXT","hsk=40,hsv=40,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","0","no","WebGPU" diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index 91797cf78a..a29dc707c3 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -15,6 +15,7 @@ llama_add_compile_flags() if (EMSCRIPTEN) else() add_subdirectory(batched) + add_subdirectory(debug) add_subdirectory(embedding) add_subdirectory(eval-callback) @@ -34,7 +35,6 @@ else() add_subdirectory(gen-docs) add_subdirectory(training) add_subdirectory(diffusion) - add_subdirectory(model-conversion) if (NOT GGML_BACKEND_DL) add_subdirectory(convert-llama2c-to-ggml) # these examples use the backends directly and cannot be built with dynamic loading diff --git a/examples/batched/batched.cpp b/examples/batched/batched.cpp index 36a12d299f..6b134b4f6f 100644 --- a/examples/batched/batched.cpp +++ b/examples/batched/batched.cpp @@ -68,7 +68,7 @@ int main(int argc, char ** argv) { auto sparams = llama_sampler_chain_default_params(); sparams.no_perf = false; - std::vector samplers; + std::vector sampler_configs; for (int32_t i = 0; i < n_parallel; ++i) { llama_sampler * smpl = llama_sampler_chain_init(sparams); @@ -78,7 +78,13 @@ int main(int argc, char ** argv) { llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp)); llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed)); - samplers.push_back(smpl); + sampler_configs.push_back({ i, smpl }); + } + + // TODO: temporarily gated behind a flag + if (params.sampling.backend_sampling) { + ctx_params.samplers = sampler_configs.data(); + ctx_params.n_samplers = sampler_configs.size(); } llama_context * ctx = llama_init_from_model(model, ctx_params); @@ -180,7 +186,7 @@ int main(int argc, char ** argv) { continue; } - const llama_token new_token_id = llama_sampler_sample(samplers[i], ctx, i_batch[i]); + const llama_token new_token_id = llama_sampler_sample(sampler_configs[i].sampler, ctx, i_batch[i]); // is it an end of generation? -> mark the stream as finished if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_predict) { @@ -236,15 +242,15 @@ int main(int argc, char ** argv) { __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); LOG("\n"); - llama_perf_sampler_print(samplers[0]); + llama_perf_sampler_print(sampler_configs[0].sampler); llama_perf_context_print(ctx); fprintf(stderr, "\n"); llama_batch_free(batch); - for (auto & sampler_config : samplers) { - llama_sampler_free(sampler_config); + for (auto & sampler_config : sampler_configs) { + llama_sampler_free(sampler_config.sampler); } llama_free(ctx); diff --git a/examples/model-conversion/CMakeLists.txt b/examples/debug/CMakeLists.txt similarity index 73% rename from examples/model-conversion/CMakeLists.txt rename to examples/debug/CMakeLists.txt index fc1746ce45..34593072be 100644 --- a/examples/model-conversion/CMakeLists.txt +++ b/examples/debug/CMakeLists.txt @@ -1,5 +1,5 @@ -set(TARGET llama-logits) -add_executable(${TARGET} logits.cpp) +set(TARGET llama-debug) +add_executable(${TARGET} debug.cpp) install(TARGETS ${TARGET} RUNTIME) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/examples/debug/README.md b/examples/debug/README.md new file mode 100644 index 0000000000..28e00c9342 --- /dev/null +++ b/examples/debug/README.md @@ -0,0 +1,54 @@ +# llama.cpp/examples/debug + +This is a utility intended to help debug a model by registering a callback that +logs GGML operations and tensor data. It can also store the generated logits or +embeddings as well as the prompt and token ids for comparision with the original +model. + +### Usage + +```shell +llama-debug \ + --hf-repo ggml-org/models \ + --hf-file phi-2/ggml-model-q4_0.gguf \ + --model phi-2-q4_0.gguf \ + --prompt hello \ + --save-logits \ + --verbose +``` +The tensor data is logged as debug and required the --verbose flag. The reason +for this is that while useful for a model with many layers there can be a lot of +output. You can filter the tensor names using the `--tensor-filter` option. + +A recommended approach is to first run without `--verbose` and see if the +generated logits/embeddings are close to the original model. If they are not, +then it might be required to inspect tensor by tensor and in that case it is +useful to enable the `--verbose` flag along with `--tensor-filter` to focus on +specific tensors. + +### Options +This example supports all standard `llama.cpp` options and also accepts the +following options: +```console +$ llama-debug --help +... + +----- example-specific params ----- + +--save-logits save final logits to files for verification (default: false) +--logits-output-dir PATH directory for saving logits output files (default: data) +--tensor-filter REGEX filter tensor names for debug output (regex pattern, can be specified multiple times) +``` + +### Output Files + +When `--save-logits` is enabled, the following files are created in the output +directory: + +* `llamacpp-[-embeddings].bin` - Binary output (logits or embeddings) +* `llamacpp-[-embeddings].txt` - Text output (logits or embeddings, one per line) +* `llamacpp-[-embeddings]-prompt.txt` - Prompt text and token IDs +* `llamacpp-[-embeddings]-tokens.bin` - Binary token IDs for programmatic comparison + +These files can be compared against the original model's output to verify the +converted model. diff --git a/examples/debug/debug.cpp b/examples/debug/debug.cpp new file mode 100644 index 0000000000..9bc5d0abfd --- /dev/null +++ b/examples/debug/debug.cpp @@ -0,0 +1,421 @@ +#include "arg.h" +#include "common.h" +#include "log.h" +#include "llama.h" +#include "ggml.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +static void print_usage(int, char ** argv) { + const std::string usage_template = R"( + example usage: + + Print tensors: + + {prog} -m model.gguf -p "Hello my name is" --verbose + + The tensors to be printed can be filtered with --tensor-filter option. + + Save logits/embeddings: + + {prog} -m model.gguf -p "Hello my name is" --save-logits + + Add --embedding to save embeddings)" "\n"; + + // Fix the source code indentation above that is introduced by the raw string literal. + std::string usage = std::regex_replace(usage_template, std::regex("\\n {8}"), "\n"); + usage = std::regex_replace(usage, std::regex("\\{prog\\}"), argv[0]); + LOG("%s\n", usage.c_str()); +} + +static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data); + +struct callback_data { + std::vector data; + std::vector tensor_filters; + + callback_data() = default; + + callback_data(common_params & params, const std::vector & filter_patterns) { + for (const auto & pattern : filter_patterns) { + try { + std::string anchored_pattern = "^" + pattern; + tensor_filters.emplace_back(anchored_pattern, std::regex::optimize); + } catch (const std::regex_error & e) { + throw std::runtime_error("Invalid regex pattern '" + pattern + "': " + e.what()); + } + } + params.cb_eval = ggml_debug; + params.cb_eval_user_data = this; + } +}; + +struct output_data { + float * data_ptr = nullptr; + int data_size = 0; + std::string type_suffix; + std::vector storage; + std::string prompt; + std::vector tokens; + + output_data(llama_context * ctx, const llama_model * model, const common_params & params) { + const llama_vocab * vocab = llama_model_get_vocab(model); + const bool add_bos = llama_vocab_get_add_bos(vocab); + + tokens = common_tokenize(ctx, params.prompt, add_bos); + prompt = params.prompt; + + if (params.embedding) { + const int n_embd = llama_model_n_embd_out(model); + const bool pooling_enabled = llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE; + const int n_embd_count = pooling_enabled ? 1 : tokens.size(); + const int n_embeddings = n_embd * n_embd_count; + + float * embeddings; + if (pooling_enabled) { + embeddings = llama_get_embeddings_seq(ctx, 0); + storage.resize(n_embeddings); + common_embd_normalize(embeddings, storage.data(), n_embeddings, params.embd_normalize); + embeddings = storage.data(); + } else { + embeddings = llama_get_embeddings(ctx); + } + + data_ptr = embeddings; + data_size = n_embeddings; + type_suffix = "-embeddings"; + } else { + const float * logits = llama_get_logits_ith(ctx, tokens.size() - 1); + const int n_logits = llama_vocab_n_tokens(vocab); + + data_ptr = const_cast(logits); + data_size = n_logits; + type_suffix = ""; + } + } +}; + +static std::string ggml_ne_string(const ggml_tensor * t) { + std::string str; + for (int i = 0; i < GGML_MAX_DIMS; ++i) { + str += std::to_string(t->ne[i]); + if (i + 1 < GGML_MAX_DIMS) { + str += ", "; + } + } + return str; +} + +static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) { + union { + float f; + uint32_t i; + } u; + u.i = (uint32_t)h.bits << 16; + return u.f; +} + +static float ggml_get_float_value(const uint8_t * data, ggml_type type, + const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) { + size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0]; + switch (type) { + case GGML_TYPE_F16: + return ggml_fp16_to_fp32(*(const ggml_fp16_t *) &data[i]); + case GGML_TYPE_F32: + return *(const float *) &data[i]; + case GGML_TYPE_I64: + return (float) *(const int64_t *) &data[i]; + case GGML_TYPE_I32: + return (float) *(const int32_t *) &data[i]; + case GGML_TYPE_I16: + return (float) *(const int16_t *) &data[i]; + case GGML_TYPE_I8: + return (float) *(const int8_t *) &data[i]; + case GGML_TYPE_BF16: + return ggml_compute_bf16_to_fp32(*(const ggml_bf16_t *) &data[i]); + default: + GGML_ABORT("fatal error"); + } +} + +static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) { + GGML_ASSERT(n > 0); + float sum = 0; + float sum_sq = 0.0; + for (int64_t i3 = 0; i3 < ne[3]; i3++) { + for (int64_t i2 = 0; i2 < ne[2]; i2++) { + for (int64_t i1 = 0; i1 < ne[1]; i1++) { + for (int64_t i0 = 0; i0 < ne[0]; i0++) { + const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3); + sum += v; + sum_sq += v * v; + } + } + } + } + for (int64_t i3 = 0; i3 < ne[3]; i3++) { + LOG_DBG(" [\n"); + for (int64_t i2 = 0; i2 < ne[2]; i2++) { + if (i2 == n && ne[2] > 2*n) { + LOG_DBG(" ..., \n"); + i2 = ne[2] - n; + } + LOG_DBG(" [\n"); + for (int64_t i1 = 0; i1 < ne[1]; i1++) { + if (i1 == n && ne[1] > 2*n) { + LOG_DBG(" ..., \n"); + i1 = ne[1] - n; + } + LOG_DBG(" ["); + for (int64_t i0 = 0; i0 < ne[0]; i0++) { + if (i0 == n && ne[0] > 2*n) { + LOG_DBG("..., "); + i0 = ne[0] - n; + } + const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3); + LOG_DBG("%12.4f", v); + if (i0 < ne[0] - 1) { + LOG_DBG(", "); + } + } + LOG_DBG("],\n"); + } + LOG_DBG(" ],\n"); + } + LOG_DBG(" ]\n"); + LOG_DBG(" sum = %f\n", sum); + LOG_DBG(" sum_sq = %f\n", sum_sq); + } + + if (std::isnan(sum)) { + LOG_ERR("encountered NaN - aborting\n"); + exit(0); + } +} + +/** + * GGML operations callback during the graph execution. + * + * @param t current tensor + * @param ask when ask is true, the scheduler wants to know if we are interested in data from this tensor + * if we return true, a follow-up call will be made with ask=false in which we can do the actual collection. + * see ggml_backend_sched_eval_callback + * @param user_data user data to pass at each call back + * @return true to receive data or continue the graph, false otherwise + */ +static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) { + auto * cb_data = (callback_data *) user_data; + + const struct ggml_tensor * src0 = t->src[0]; + const struct ggml_tensor * src1 = t->src[1]; + + if (ask) { + return true; // Always retrieve data + } + + bool matches_filter = cb_data->tensor_filters.empty(); + + if (!matches_filter) { + for (const auto & filter : cb_data->tensor_filters) { + if (std::regex_search(t->name, filter)) { + matches_filter = true; + break; + } + } + } + + char src1_str[128] = {0}; + if (src1) { + snprintf(src1_str, sizeof(src1_str), "%s{%s}", src1->name, ggml_ne_string(src1).c_str()); + } + + if (matches_filter) { + LOG_DBG("%s: %24s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__, + t->name, + ggml_type_name(t->type), + ggml_op_desc(t), + src0->name, + ggml_ne_string(src0).c_str(), + src1 ? src1_str : "", + ggml_ne_string(t).c_str()); + } + + const bool is_host = ggml_backend_buffer_is_host(t->buffer); + + if (!is_host) { + auto n_bytes = ggml_nbytes(t); + cb_data->data.resize(n_bytes); + ggml_backend_tensor_get(t, cb_data->data.data(), 0, n_bytes); + } + + if (!ggml_is_quantized(t->type) && matches_filter) { + uint8_t * data = is_host ? (uint8_t *) t->data : cb_data->data.data(); + ggml_print_tensor(data, t->type, t->ne, t->nb, 3); + } + + return true; +} + + +static void save_output_data(const output_data & output, const std::string & model_name, const std::string & output_dir) { + std::filesystem::create_directory(output_dir); + auto base_path = std::filesystem::path{output_dir} / ("llamacpp-" + model_name + output.type_suffix); + + // Save logits/embeddings to binary file. + { + std::filesystem::path filepath{base_path.string() + ".bin"}; + std::ofstream file{filepath, std::ios::binary}; + if (!file) { + throw std::runtime_error("failed to open binary output file: " + filepath.string()); + } + file.write(reinterpret_cast(output.data_ptr), output.data_size * sizeof(float)); + LOG("Data saved to %s\n", filepath.c_str()); + } + + // Save logits/embeddings to text file. + { + std::filesystem::path filepath{base_path.string() + ".txt"}; + std::ofstream file{filepath}; + if (!file) { + throw std::runtime_error("failed to open text output file: " + filepath.string()); + } + for (int i = 0; i < output.data_size; i++) { + file << i << ": " << output.data_ptr[i] << '\n'; + } + LOG("Data saved to %s\n", filepath.c_str()); + } + + // Save prompt and tokens to text file. + { + std::filesystem::path filepath{base_path.string() + "-prompt.txt"}; + std::ofstream file{filepath}; + if (!file) { + throw std::runtime_error("failed to open prompt output file: " + filepath.string()); + } + + file << "prompt: " << output.prompt << '\n'; + file << "n_tokens: " << output.tokens.size() << '\n'; + + file << "token ids: "; + for (size_t i = 0; i < output.tokens.size(); i++) { + file << output.tokens[i]; + if (i + 1 < output.tokens.size()) { + file << ", "; + } + } + file << '\n'; + LOG("Prompt saved to %s\n", filepath.c_str()); + } + + // Save token ids to binary file. + { + std::filesystem::path filepath{base_path.string() + "-tokens.bin"}; + std::ofstream file{filepath, std::ios::binary}; + if (!file) { + throw std::runtime_error("failed to open tokens binary file: " + filepath.string()); + } + file.write(reinterpret_cast(output.tokens.data()), output.tokens.size() * sizeof(llama_token)); + LOG("Tokens saved to %s\n", filepath.c_str()); + } + +} + +static void print_tokenized_prompt(llama_context * ctx, const std::vector & tokens, const std::string & prompt) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + LOG("Model add_bos: %s\n", llama_vocab_get_add_bos(vocab) ? "true" : "false"); + LOG("Input prompt: \"%s\"\n", prompt.c_str()); + LOG("Token ids (%zu):\n", tokens.size()); + + for (auto id : tokens) { + std::string piece(128, '\0'); + int n = llama_token_to_piece(vocab, id, piece.data(), piece.size(), 0, true); + if (n < 0) { + LOG_ERR("failed to convert token %d to piece\n", id); + continue; + } + piece.resize(n); + LOG("%s(%d) ", piece.c_str(), id); + } + LOG("\n"); +} + +static bool run(llama_context * ctx, const common_params & params) { + const llama_model * model = llama_get_model(ctx); + const llama_vocab * vocab = llama_model_get_vocab(model); + + const bool add_bos = llama_vocab_get_add_bos(vocab); + + std::vector tokens = common_tokenize(ctx, params.prompt, add_bos); + + if (tokens.empty()) { + LOG_ERR("%s : there are not input tokens to process - (try to provide a prompt with '-p')\n", __func__); + return false; + } + + if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) { + LOG_ERR("%s : failed to eval\n", __func__); + return false; + } + + print_tokenized_prompt(ctx, tokens, params.prompt); + + if (params.save_logits) { + output_data output {ctx, model, params}; + std::filesystem::path model_path{params.model.path}; + std::string model_name{model_path.stem().string()}; + save_output_data(output, model_name, params.logits_output_dir); + } + + return true; +} + +int main(int argc, char ** argv) { + common_params params; + + if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DEBUG, print_usage)) { + return 1; + } + + common_init(); + + llama_backend_init(); + llama_numa_init(params.numa); + + callback_data cb_data(params, params.tensor_filter); + + auto llama_init = common_init_from_params(params); + + auto * model = llama_init->model(); + auto * ctx = llama_init->context(); + + if (model == nullptr || ctx == nullptr) { + LOG_ERR("%s : failed to init\n", __func__); + return 1; + } + + { + LOG_INF("\n"); + LOG_INF("%s\n", common_params_get_system_info(params).c_str()); + LOG_INF("\n"); + } + + if (!run(ctx, params)) { + return 1; + } + + LOG("\n"); + llama_perf_context_print(ctx); + + llama_backend_free(); + + return 0; +} diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 81111e81b2..d8eaaa2691 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -33,7 +33,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector & toke } } -static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) { +static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd_out, int embd_norm) { const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); // clear previous kv_cache values (irrelevant for embeddings) @@ -65,8 +65,8 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); } - float * out = output + embd_pos * n_embd; - common_embd_normalize(embd, out, n_embd, embd_norm); + float * out = output + embd_pos * n_embd_out; + common_embd_normalize(embd, out, n_embd_out, embd_norm); } } @@ -252,8 +252,8 @@ int main(int argc, char ** argv) { } // allocate output - const int n_embd = llama_model_n_embd(model); - std::vector embeddings(n_embd_count * n_embd, 0); + const int n_embd_out = llama_model_n_embd_out(model); + std::vector embeddings(n_embd_count * n_embd_out, 0); float * emb = embeddings.data(); // break into batches @@ -267,8 +267,8 @@ int main(int argc, char ** argv) { // encode if at capacity if (batch.n_tokens + n_toks > n_batch || s >= n_seq_max) { - float * out = emb + e * n_embd; - batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize); + float * out = emb + e * n_embd_out; + batch_decode(ctx, batch, out, s, n_embd_out, params.embd_normalize); e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s; s = 0; common_batch_clear(batch); @@ -280,8 +280,8 @@ int main(int argc, char ** argv) { } // final batch - float * out = emb + e * n_embd; - batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize); + float * out = emb + e * n_embd_out; + batch_decode(ctx, batch, out, s, n_embd_out, params.embd_normalize); if (params.embd_out.empty()) { LOG("\n"); @@ -289,19 +289,19 @@ int main(int argc, char ** argv) { if (pooling_type == LLAMA_POOLING_TYPE_NONE) { for (int j = 0; j < n_embd_count; j++) { LOG("embedding %d: ", j); - for (int i = 0; i < std::min(3, n_embd); i++) { + for (int i = 0; i < std::min(3, n_embd_out); i++) { if (params.embd_normalize == 0) { - LOG("%6.0f ", emb[j * n_embd + i]); + LOG("%6.0f ", emb[j * n_embd_out + i]); } else { - LOG("%9.6f ", emb[j * n_embd + i]); + LOG("%9.6f ", emb[j * n_embd_out + i]); } } LOG(" ... "); - for (int i = n_embd - 3; i < n_embd; i++) { + for (int i = n_embd_out - 3; i < n_embd_out; i++) { if (params.embd_normalize == 0) { - LOG("%6.0f ", emb[j * n_embd + i]); + LOG("%6.0f ", emb[j * n_embd_out + i]); } else { - LOG("%9.6f ", emb[j * n_embd + i]); + LOG("%9.6f ", emb[j * n_embd_out + i]); } } LOG("\n"); @@ -320,9 +320,9 @@ int main(int argc, char ** argv) { for (uint32_t i = 0; i < n_cls_out; i++) { // NOTE: if you change this log - update the tests in ci/run.sh if (n_cls_out == 1) { - LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]); + LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd_out]); } else { - LOG("rerank score %d: %8.3f [%s]\n", j, emb[j * n_embd + i], cls_out_labels[i].c_str()); + LOG("rerank score %d: %8.3f [%s]\n", j, emb[j * n_embd_out + i], cls_out_labels[i].c_str()); } } } @@ -330,11 +330,11 @@ int main(int argc, char ** argv) { // print the first part of the embeddings or for a single prompt, the full embedding for (int j = 0; j < n_prompts; j++) { LOG("embedding %d: ", j); - for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) { + for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd_out) : n_embd_out); i++) { if (params.embd_normalize == 0) { - LOG("%6.0f ", emb[j * n_embd + i]); + LOG("%6.0f ", emb[j * n_embd_out + i]); } else { - LOG("%9.6f ", emb[j * n_embd + i]); + LOG("%9.6f ", emb[j * n_embd_out + i]); } } LOG("\n"); @@ -350,7 +350,7 @@ int main(int argc, char ** argv) { LOG("\n"); for (int i = 0; i < n_prompts; i++) { for (int j = 0; j < n_prompts; j++) { - float sim = common_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd); + float sim = common_embd_similarity_cos(emb + i * n_embd_out, emb + j * n_embd_out, n_embd_out); LOG("%6.2f ", sim); } LOG("%1.10s", prompts[i].c_str()); @@ -368,9 +368,9 @@ int main(int argc, char ** argv) { if (notArray) LOG(" {\n \"object\": \"embedding\",\n \"index\": %d,\n \"embedding\": ",j); LOG("["); for (int i = 0;;) { // at least one iteration (n_embd > 0) - LOG(params.embd_normalize == 0 ? "%1.0f" : "%1.7f", emb[j * n_embd + i]); + LOG(params.embd_normalize == 0 ? "%1.0f" : "%1.7f", emb[j * n_embd_out + i]); i++; - if (i < n_embd) LOG(","); else break; + if (i < n_embd_out) LOG(","); else break; } LOG(notArray ? "]\n }" : "]"); j++; @@ -383,7 +383,7 @@ int main(int argc, char ** argv) { for (int i = 0;;) { // at least two iteration (n_embd_count > 1) LOG(" ["); for (int j = 0;;) { // at least two iteration (n_embd_count > 1) - float sim = common_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd); + float sim = common_embd_similarity_cos(emb + i * n_embd_out, emb + j * n_embd_out, n_embd_out); LOG("%6.2f", sim); j++; if (j < n_embd_count) LOG(", "); else break; @@ -397,7 +397,7 @@ int main(int argc, char ** argv) { if (notArray) LOG("\n}\n"); } else if (params.embd_out == "raw") { - print_raw_embeddings(emb, n_embd_count, n_embd, model, pooling_type, params.embd_normalize); + print_raw_embeddings(emb, n_embd_count, n_embd_out, model, pooling_type, params.embd_normalize); } LOG("\n"); diff --git a/examples/model-conversion/logits.cpp b/examples/model-conversion/logits.cpp deleted file mode 100644 index 5bcf063267..0000000000 --- a/examples/model-conversion/logits.cpp +++ /dev/null @@ -1,268 +0,0 @@ -#include "llama.h" -#include "common.h" - - -#include -#include -#include -#include -#include -#include - -static void print_usage(int, char ** argv) { - printf("\nexample usage:\n"); - printf("\n %s -m model.gguf [-ngl n_gpu_layers] -embd-mode [-pooling] [-embd-norm ] [prompt]\n", argv[0]); - printf("\n"); - printf(" -embd-norm: normalization type for pooled embeddings (default: 2)\n"); - printf(" -1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm\n"); - printf("\n"); -} - -int main(int argc, char ** argv) { - std::string model_path; - std::string prompt = "Hello, my name is"; - int ngl = 0; - bool embedding_mode = false; - bool pooling_enabled = false; - int32_t embd_norm = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm) - - { - int i = 1; - for (; i < argc; i++) { - if (strcmp(argv[i], "-m") == 0) { - if (i + 1 < argc) { - model_path = argv[++i]; - } else { - print_usage(argc, argv); - return 1; - } - } else if (strcmp(argv[i], "-ngl") == 0) { - if (i + 1 < argc) { - try { - ngl = std::stoi(argv[++i]); - } catch (...) { - print_usage(argc, argv); - return 1; - } - } else { - print_usage(argc, argv); - return 1; - } - } else if (strcmp(argv[i], "-embd-mode") == 0) { - embedding_mode = true; - } else if (strcmp(argv[i], "-pooling") == 0) { - pooling_enabled = true; - } else if (strcmp(argv[i], "-embd-norm") == 0) { - if (i + 1 < argc) { - try { - embd_norm = std::stoi(argv[++i]); - } catch (...) { - print_usage(argc, argv); - return 1; - } - } else { - print_usage(argc, argv); - return 1; - } - } else { - // prompt starts here - break; - } - } - - if (model_path.empty()) { - print_usage(argc, argv); - return 1; - } - - if (i < argc) { - prompt = argv[i++]; - for (; i < argc; i++) { - prompt += " "; - prompt += argv[i]; - } - } - } - - ggml_backend_load_all(); - llama_model_params model_params = llama_model_default_params(); - model_params.n_gpu_layers = ngl; - - llama_model * model = llama_model_load_from_file(model_path.c_str(), model_params); - - if (model == NULL) { - fprintf(stderr , "%s: error: unable to load model\n" , __func__); - return 1; - } - - // Extract basename from model_path - const char * basename = strrchr(model_path.c_str(), '/'); - basename = (basename == NULL) ? model_path.c_str() : basename + 1; - - char model_name[256]; - strncpy(model_name, basename, 255); - model_name[255] = '\0'; - - char * dot = strrchr(model_name, '.'); - if (dot != NULL && strcmp(dot, ".gguf") == 0) { - *dot = '\0'; - } - printf("Model name: %s\n", model_name); - - const llama_vocab * vocab = llama_model_get_vocab(model); - const int n_prompt = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true); - - std::vector prompt_tokens(n_prompt); - if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < 0) { - fprintf(stderr, "%s: error: failed to tokenize the prompt\n", __func__); - return 1; - } - - llama_context_params ctx_params = llama_context_default_params(); - ctx_params.n_ctx = n_prompt; - ctx_params.n_batch = n_prompt; - ctx_params.no_perf = false; - if (embedding_mode) { - ctx_params.embeddings = true; - ctx_params.pooling_type = pooling_enabled ? LLAMA_POOLING_TYPE_MEAN : LLAMA_POOLING_TYPE_NONE; - ctx_params.n_ubatch = ctx_params.n_batch; - } - - llama_context * ctx = llama_init_from_model(model, ctx_params); - if (ctx == NULL) { - fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__); - return 1; - } - - printf("Input prompt: \"%s\"\n", prompt.c_str()); - printf("Tokenized prompt (%d tokens): ", n_prompt); - for (auto id : prompt_tokens) { - char buf[128]; - int n = llama_token_to_piece(vocab, id, buf, sizeof(buf), 0, true); - if (n < 0) { - fprintf(stderr, "%s: error: failed to convert token to piece\n", __func__); - return 1; - } - std::string s(buf, n); - printf("%s (%d)", s.c_str(), id); - } - printf("\n"); - - llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size()); - - if (llama_decode(ctx, batch)) { - fprintf(stderr, "%s : failed to eval\n", __func__); - return 1; - } - - float * data_ptr; - int data_size; - const char * type; - std::vector embd_out; - - if (embedding_mode) { - const int n_embd = llama_model_n_embd(model); - const int n_embd_count = pooling_enabled ? 1 : batch.n_tokens; - const int n_embeddings = n_embd * n_embd_count; - float * embeddings; - type = "-embeddings"; - - if (llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE) { - embeddings = llama_get_embeddings_seq(ctx, 0); - embd_out.resize(n_embeddings); - printf("Normalizing embeddings using norm: %d\n", embd_norm); - common_embd_normalize(embeddings, embd_out.data(), n_embeddings, embd_norm); - embeddings = embd_out.data(); - } else { - embeddings = llama_get_embeddings(ctx); - } - - printf("Embedding dimension: %d\n", n_embd); - printf("\n"); - - // Print embeddings in the specified format - for (int j = 0; j < n_embd_count; j++) { - printf("embedding %d: ", j); - - // Print first 3 values - for (int i = 0; i < 3 && i < n_embd; i++) { - printf("%9.6f ", embeddings[j * n_embd + i]); - } - - printf(" ... "); - - // Print last 3 values - for (int i = n_embd - 3; i < n_embd; i++) { - if (i >= 0) { - printf("%9.6f ", embeddings[j * n_embd + i]); - } - } - - printf("\n"); - } - printf("\n"); - - printf("Embeddings size: %d\n", n_embeddings); - - data_ptr = embeddings; - data_size = n_embeddings; - } else { - float * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1); - const int n_logits = llama_vocab_n_tokens(vocab); - type = ""; - printf("Vocab size: %d\n", n_logits); - - data_ptr = logits; - data_size = n_logits; - } - - std::filesystem::create_directory("data"); - - // Save data to binary file - char bin_filename[512]; - snprintf(bin_filename, sizeof(bin_filename), "data/llamacpp-%s%s.bin", model_name, type); - printf("Saving data to %s\n", bin_filename); - - FILE * f = fopen(bin_filename, "wb"); - if (f == NULL) { - fprintf(stderr, "%s: error: failed to open binary output file\n", __func__); - return 1; - } - fwrite(data_ptr, sizeof(float), data_size, f); - fclose(f); - - // Also save as text for debugging - char txt_filename[512]; - snprintf(txt_filename, sizeof(txt_filename), "data/llamacpp-%s%s.txt", model_name, type); - f = fopen(txt_filename, "w"); - if (f == NULL) { - fprintf(stderr, "%s: error: failed to open text output file\n", __func__); - return 1; - } - for (int i = 0; i < data_size; i++) { - fprintf(f, "%d: %.6f\n", i, data_ptr[i]); - } - fclose(f); - - if (!embedding_mode) { - printf("First 10 logits: "); - for (int i = 0; i < 10 && i < data_size; i++) { - printf("%.6f ", data_ptr[i]); - } - printf("\n"); - - printf("Last 10 logits: "); - for (int i = data_size - 10; i < data_size; i++) { - if (i >= 0) printf("%.6f ", data_ptr[i]); - } - printf("\n\n"); - } - - printf("Data saved to %s\n", bin_filename); - printf("Data saved to %s\n", txt_filename); - - llama_free(ctx); - llama_model_free(model); - - return 0; -} diff --git a/examples/model-conversion/scripts/causal/compare-logits.py b/examples/model-conversion/scripts/causal/compare-logits.py index 894302c69e..1a933207d5 100755 --- a/examples/model-conversion/scripts/causal/compare-logits.py +++ b/examples/model-conversion/scripts/causal/compare-logits.py @@ -6,7 +6,7 @@ from pathlib import Path # Add utils directory to path for direct script execution sys.path.insert(0, str(Path(__file__).parent.parent / "utils")) -from common import get_model_name_from_env_path # type: ignore[import-not-found] +from common import get_model_name_from_env_path, compare_tokens # type: ignore[import-not-found] def quick_logits_check(pytorch_file, llamacpp_file): """Lightweight sanity check before NMSE""" @@ -58,6 +58,13 @@ def main(): print("Checked all required files were found. Proceeding...\n") + # Verify tokens as they are a prerequisite for logits comparison. + print("🔍 Token Comparison Check") + print("=" * 40) + if not compare_tokens(f"pytorch-{model_name}", f"llamacpp-{llamacpp_model_name}"): + print("\n❌ Token mismatch detected") + sys.exit(1) + print() print("🔍 GGML Model Validation for model ", model_name) print("=" * 40) diff --git a/examples/model-conversion/scripts/causal/run-casual-gen-embeddings-org.py b/examples/model-conversion/scripts/causal/run-casual-gen-embeddings-org.py index 55ad821385..4ab778fbc7 100755 --- a/examples/model-conversion/scripts/causal/run-casual-gen-embeddings-org.py +++ b/examples/model-conversion/scripts/causal/run-casual-gen-embeddings-org.py @@ -67,7 +67,7 @@ with torch.no_grad(): last_hidden_states = outputs.hidden_states[-1] # Get embeddings for all tokens - token_embeddings = last_hidden_states[0].cpu().numpy() # Remove batch dimension + token_embeddings = last_hidden_states[0].float().cpu().numpy() # Remove batch dimension print(f"Hidden states shape: {last_hidden_states.shape}") print(f"Token embeddings shape: {token_embeddings.shape}") diff --git a/examples/model-conversion/scripts/causal/run-converted-model-embeddings-logits.sh b/examples/model-conversion/scripts/causal/run-converted-model-embeddings-logits.sh index fa16a02c65..3cce3fc94d 100755 --- a/examples/model-conversion/scripts/causal/run-converted-model-embeddings-logits.sh +++ b/examples/model-conversion/scripts/causal/run-converted-model-embeddings-logits.sh @@ -13,6 +13,6 @@ if [ -z "$CONVERTED_MODEL" ]; then exit 1 fi -cmake --build ../../build --target llama-logits -j8 +cmake --build ../../build --target llama-debug -j8 -../../build/bin/llama-logits -m $CONVERTED_MODEL -embd-mode "Hello world today" +../../build/bin/llama-debug -m $CONVERTED_MODEL --embedding -p "Hello world today" --save-logits diff --git a/examples/model-conversion/scripts/causal/run-converted-model.sh b/examples/model-conversion/scripts/causal/run-converted-model.sh index 529e9987b0..b6c3d38662 100755 --- a/examples/model-conversion/scripts/causal/run-converted-model.sh +++ b/examples/model-conversion/scripts/causal/run-converted-model.sh @@ -21,6 +21,6 @@ fi echo $CONVERTED_MODEL echo $MODEL_TESTING_PROMPT -cmake --build ../../build --target llama-logits -j8 +cmake --build ../../build --target llama-debug -j8 -../../build/bin/llama-logits -m "$CONVERTED_MODEL" "$MODEL_TESTING_PROMPT" +../../build/bin/llama-debug -m "$CONVERTED_MODEL" -p "$MODEL_TESTING_PROMPT" --save-logits diff --git a/examples/model-conversion/scripts/causal/run-org-model.py b/examples/model-conversion/scripts/causal/run-org-model.py index b12173a1fb..215f1a9ee0 100755 --- a/examples/model-conversion/scripts/causal/run-org-model.py +++ b/examples/model-conversion/scripts/causal/run-org-model.py @@ -7,12 +7,11 @@ import importlib import torch import numpy as np -from pathlib import Path from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, AutoConfig # Add parent directory to path for imports sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) -from utils.common import debug_hook +from utils.common import debug_hook, save_output_data def parse_arguments(): parser = argparse.ArgumentParser(description="Process model with specified path") @@ -126,6 +125,7 @@ def main(): device = next(model.parameters()).device prompt = get_prompt(args) input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) + token_ids = input_ids[0].cpu().tolist() print(f"Input tokens: {input_ids}") print(f"Input text: {repr(prompt)}") @@ -151,19 +151,6 @@ def main(): print(f"Last token logits shape: {last_logits.shape}") print(f"Vocab size: {len(last_logits)}") - data_dir = Path("data") - data_dir.mkdir(exist_ok=True) - bin_filename = data_dir / f"pytorch-{model_name}.bin" - txt_filename = data_dir / f"pytorch-{model_name}.txt" - - # Save to file for comparison - last_logits.astype(np.float32).tofile(bin_filename) - - # Also save as text file for easy inspection - with open(txt_filename, "w") as f: - for i, logit in enumerate(last_logits): - f.write(f"{i}: {logit:.6f}\n") - # Print some sample logits for quick verification print(f"First 10 logits: {last_logits[:10]}") print(f"Last 10 logits: {last_logits[-10:]}") @@ -175,8 +162,7 @@ def main(): token = tokenizer.decode([idx]) print(f" Token {idx} ({repr(token)}): {last_logits[idx]:.6f}") - print(f"Saved bin logits to: {bin_filename}") - print(f"Saved txt logist to: {txt_filename}") + save_output_data(last_logits, token_ids, prompt, model_name) if __name__ == "__main__": main() diff --git a/examples/model-conversion/scripts/embedding/run-converted-model.sh b/examples/model-conversion/scripts/embedding/run-converted-model.sh index 0f490e6c3b..5d264b0663 100755 --- a/examples/model-conversion/scripts/embedding/run-converted-model.sh +++ b/examples/model-conversion/scripts/embedding/run-converted-model.sh @@ -50,10 +50,9 @@ fi echo $CONVERTED_MODEL -cmake --build ../../build --target llama-logits -j8 -# TODO: update logits.cpp to accept a --file/-f option for the prompt +cmake --build ../../build --target llama-debug -j8 if [ -n "$USE_POOLING" ]; then - ../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode -pooling "$PROMPT" + ../../build/bin/llama-debug -m "$CONVERTED_MODEL" --embedding --pooling mean -p "$PROMPT" --save-logits else - ../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode "$PROMPT" + ../../build/bin/llama-debug -m "$CONVERTED_MODEL" --embedding --pooling none -p "$PROMPT" --save-logits fi diff --git a/examples/model-conversion/scripts/embedding/run-original-model.py b/examples/model-conversion/scripts/embedding/run-original-model.py index 774e5638f7..0802cbcf4a 100755 --- a/examples/model-conversion/scripts/embedding/run-original-model.py +++ b/examples/model-conversion/scripts/embedding/run-original-model.py @@ -3,13 +3,15 @@ import argparse import os import sys -import numpy as np import importlib -from pathlib import Path from transformers import AutoTokenizer, AutoConfig, AutoModel import torch +# Add parent directory to path for imports +sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) +from utils.common import save_output_data + def parse_arguments(): parser = argparse.ArgumentParser(description='Run original embedding model') @@ -169,6 +171,7 @@ def main(): return_tensors="pt" ) tokens = encoded['input_ids'][0] + token_ids = tokens.cpu().tolist() token_strings = tokenizer.convert_ids_to_tokens(tokens) for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)): print(f"{token_id:6d} -> '{token_str}'") @@ -185,6 +188,7 @@ def main(): ) tokens = encoded['input_ids'][0] + token_ids = tokens.cpu().tolist() token_strings = tokenizer.convert_ids_to_tokens(tokens) for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)): print(f"{token_id:6d} -> '{token_str}'") @@ -228,24 +232,11 @@ def main(): print() - data_dir = Path("data") - data_dir.mkdir(exist_ok=True) - bin_filename = data_dir / f"pytorch-{model_name}-embeddings.bin" - txt_filename = data_dir / f"pytorch-{model_name}-embeddings.txt" - flattened_embeddings = all_embeddings.flatten() - flattened_embeddings.astype(np.float32).tofile(bin_filename) - - with open(txt_filename, "w") as f: - idx = 0 - for j in range(n_embd_count): - for value in all_embeddings[j]: - f.write(f"{idx}: {value:.6f}\n") - idx += 1 print(f"Total values: {len(flattened_embeddings)} ({n_embd_count} embeddings × {n_embd} dimensions)") print("") - print(f"Saved bin embeddings to: {bin_filename}") - print(f"Saved txt embeddings to: {txt_filename}") + + save_output_data(flattened_embeddings, token_ids, prompt_text, model_name, type_suffix="-embeddings") if __name__ == "__main__": diff --git a/examples/model-conversion/scripts/utils/common.py b/examples/model-conversion/scripts/utils/common.py index 7595d0410e..71761127bb 100644 --- a/examples/model-conversion/scripts/utils/common.py +++ b/examples/model-conversion/scripts/utils/common.py @@ -3,6 +3,8 @@ import os import sys import torch +import numpy as np +from pathlib import Path def get_model_name_from_env_path(env_path_name): @@ -148,3 +150,96 @@ def setup_rope_debug(model_module_path: str, function_name: str = "apply_rotary_ # Patch it setattr(module, function_name, debug_rope) print(f"RoPE debug patching applied to {model_module_path}.{function_name}") + + +def save_output_data(data, tokens, prompt, model_name, type_suffix="", output_dir="data"): + """ + Save output data (logits/embeddings), tokens, and prompt to files. + + Args: + data: numpy array of floats (logits or embeddings) + tokens: list or array of token IDs + prompt: string containing the input prompt + model_name: name of the model + type_suffix: optional suffix like "-embeddings" (default: "") + output_dir: directory to save files (default: "data") + + Creates the following files in output_dir: + - pytorch-{model_name}{type_suffix}.bin + - pytorch-{model_name}{type_suffix}.txt + - pytorch-{model_name}{type_suffix}-prompt.txt + - pytorch-{model_name}{type_suffix}-tokens.bin + """ + data_dir = Path(output_dir) + data_dir.mkdir(exist_ok=True) + base_path = data_dir / f"pytorch-{model_name}{type_suffix}" + + # Convert and flatten logits/embeddings + data = data.cpu().numpy() if isinstance(data, torch.Tensor) else np.asarray(data) + data = data.flatten() if data.ndim > 1 else data + + # Save logits/embedding files + data.astype(np.float32).tofile(f"{base_path}.bin") + print(f"Data saved to {base_path}.bin") + + with open(f"{base_path}.txt", "w") as f: + f.writelines(f"{i}: {value:.6f}\n" for i, value in enumerate(data)) + print(f"Data saved to {base_path}.txt") + + # Convert and flatten tokens + tokens = tokens.cpu().numpy() if isinstance(tokens, torch.Tensor) else np.asarray(tokens) + tokens = tokens.flatten() if tokens.ndim > 1 else tokens + + # Save token binary file + tokens.astype(np.int32).tofile(f"{base_path}-tokens.bin") + print(f"Tokens saved to {base_path}-tokens.bin") + + # Save prompt file + with open(f"{base_path}-prompt.txt", "w") as f: + f.write(f"prompt: {prompt}\n") + f.write(f"n_tokens: {len(tokens)}\n") + f.write(f"token ids: {', '.join(str(int(tid)) for tid in tokens)}\n") + print(f"Prompt saved to {base_path}-prompt.txt") + + +def compare_tokens(original, converted, type_suffix="", output_dir="data"): + data_dir = Path(output_dir) + + # Read tokens from both models + tokens1_file = data_dir / f"{original}{type_suffix}-tokens.bin" + tokens2_file = data_dir / f"{converted}{type_suffix}-tokens.bin" + + if not tokens1_file.exists(): + print(f"Error: Token file not found: {tokens1_file}") + return False + + if not tokens2_file.exists(): + print(f"Error: Token file not found: {tokens2_file}") + return False + + tokens1 = np.fromfile(tokens1_file, dtype=np.int32) + tokens2 = np.fromfile(tokens2_file, dtype=np.int32) + + print(f"\nComparing tokens between:") + print(f" Original : {original} ({len(tokens1)} tokens)") + print(f" Converted: {converted} ({len(tokens2)} tokens)") + + if len(tokens1) != len(tokens2): + print(f"\n❌ Token count mismatch: {len(tokens1)} vs {len(tokens2)}") + return False + + if np.array_equal(tokens1, tokens2): + print(f"\n✅ All {len(tokens1)} tokens match!") + return True + + mismatches = np.where(tokens1 != tokens2)[0] + print(f"\n❌ Found {len(mismatches)} mismatched tokens:") + + num_to_show = min(len(mismatches), 10) + for idx in mismatches[:num_to_show]: + print(f" Position {idx}: {tokens1[idx]} vs {tokens2[idx]}") + + if len(mismatches) > num_to_show: + print(f" ... and {len(mismatches) - num_to_show} more mismatches") + + return False diff --git a/examples/model-conversion/scripts/utils/compare_tokens.py b/examples/model-conversion/scripts/utils/compare_tokens.py new file mode 100755 index 0000000000..a286cb5683 --- /dev/null +++ b/examples/model-conversion/scripts/utils/compare_tokens.py @@ -0,0 +1,76 @@ +#!/usr/bin/env python3 + +import argparse +import sys +from common import compare_tokens # type: ignore + + +def parse_arguments(): + parser = argparse.ArgumentParser( + description='Compare tokens between two models', + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + %(prog)s pytorch-gemma-3-270m-it llamacpp-gemma-3-270m-it-bf16 + """ + ) + parser.add_argument( + 'original', + help='Original model name' + ) + parser.add_argument( + 'converted', + help='Converted model name' + ) + parser.add_argument( + '-s', '--suffix', + default='', + help='Type suffix (e.g., "-embeddings")' + ) + parser.add_argument( + '-d', '--data-dir', + default='data', + help='Directory containing token files (default: data)' + ) + parser.add_argument( + '-v', '--verbose', + action='store_true', + help='Print prompts from both models' + ) + return parser.parse_args() + + +def main(): + args = parse_arguments() + + if args.verbose: + from pathlib import Path + data_dir = Path(args.data_dir) + + prompt1_file = data_dir / f"{args.original}{args.suffix}-prompt.txt" + prompt2_file = data_dir / f"{args.converted}{args.suffix}-prompt.txt" + + if prompt1_file.exists(): + print(f"\nOriginal model prompt ({args.original}):") + print(f" {prompt1_file.read_text().strip()}") + + if prompt2_file.exists(): + print(f"\nConverted model prompt ({args.converted}):") + print(f" {prompt2_file.read_text().strip()}") + + print() + + result = compare_tokens( + args.original, + args.converted, + type_suffix=args.suffix, + output_dir=args.data_dir + ) + + # Enable the script to be used in shell scripts so that they can check + # the exit code for success/failure. + sys.exit(0 if result else 1) + + +if __name__ == "__main__": + main() diff --git a/examples/model-conversion/scripts/utils/semantic_check.py b/examples/model-conversion/scripts/utils/semantic_check.py index e64c000497..38b03ce4d2 100644 --- a/examples/model-conversion/scripts/utils/semantic_check.py +++ b/examples/model-conversion/scripts/utils/semantic_check.py @@ -4,8 +4,10 @@ import numpy as np import argparse import os import importlib +from pathlib import Path from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, AutoModel +from common import compare_tokens # type: ignore[import-not-found] unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME') @@ -157,9 +159,25 @@ def main(): else: prompt = args.prompt + python_emb_path = Path(args.python_embeddings) + cpp_emb_path = Path(args.cpp_embeddings) + + # Extract base names (e.g., "pytorch-model-name-embeddings.bin" -> "pytorch-model-name") + python_model_name = python_emb_path.stem.replace("-embeddings", "") + cpp_model_name = cpp_emb_path.stem.replace("-embeddings", "") + print("Semantic Similarity Test Between Python and llama.cpp Embedding Models") print("=" * 70) + # First verify tokens match before comparing embeddings + print("\n🔍 Token Comparison Check") + print("=" * 70) + data_dir = python_emb_path.parent + if not compare_tokens(python_model_name, cpp_model_name, type_suffix="-embeddings", output_dir=str(data_dir)): + print("\n❌ Token mismatch detected") + exit(1) + print() + # Single prompt detailed comparison print(f"\nTesting with prompt: '{prompt}'") diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 8f92ff9057..3f2afd4346 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -217,8 +217,8 @@ int main(int argc, char ** argv) { struct llama_batch batch = llama_batch_init(n_batch, 0, 1); // allocate output - const int n_embd = llama_model_n_embd(model); - std::vector embeddings(n_chunks * n_embd, 0); + const int n_embd_out = llama_model_n_embd_out(model); + std::vector embeddings(n_chunks * n_embd_out, 0); float * emb = embeddings.data(); // break into batches @@ -232,8 +232,8 @@ int main(int argc, char ** argv) { // encode if at capacity if (batch.n_tokens + n_toks > n_batch || s >= llama_n_seq_max(ctx)) { - float * out = emb + p * n_embd; - batch_process(ctx, batch, out, s, n_embd); + float * out = emb + p * n_embd_out; + batch_process(ctx, batch, out, s, n_embd_out); common_batch_clear(batch); p += s; s = 0; @@ -245,12 +245,12 @@ int main(int argc, char ** argv) { } // final batch - float * out = emb + p * n_embd; - batch_process(ctx, batch, out, s, n_embd); + float * out = emb + p * n_embd_out; + batch_process(ctx, batch, out, s, n_embd_out); // save embeddings to chunks for (int i = 0; i < n_chunks; i++) { - chunks[i].embedding = std::vector(emb + i * n_embd, emb + (i + 1) * n_embd); + chunks[i].embedding = std::vector(emb + i * n_embd_out, emb + (i + 1) * n_embd_out); // clear tokens as they are no longer needed chunks[i].tokens.clear(); } @@ -266,8 +266,8 @@ int main(int argc, char ** argv) { batch_add_seq(query_batch, query_tokens, 0); - std::vector query_emb(n_embd, 0); - batch_process(ctx, query_batch, query_emb.data(), 1, n_embd); + std::vector query_emb(n_embd_out, 0); + batch_process(ctx, query_batch, query_emb.data(), 1, n_embd_out); common_batch_clear(query_batch); @@ -275,7 +275,7 @@ int main(int argc, char ** argv) { { std::vector> similarities; for (int i = 0; i < n_chunks; i++) { - float sim = common_embd_similarity_cos(chunks[i].embedding.data(), query_emb.data(), n_embd); + float sim = common_embd_similarity_cos(chunks[i].embedding.data(), query_emb.data(), n_embd_out); similarities.push_back(std::make_pair(i, sim)); } diff --git a/ggml/src/ggml-cann/aclnn_ops.cpp b/ggml/src/ggml-cann/aclnn_ops.cpp index 2180a06fd0..6b718e01c3 100644 --- a/ggml/src/ggml-cann/aclnn_ops.cpp +++ b/ggml/src/ggml-cann/aclnn_ops.cpp @@ -26,6 +26,7 @@ #include "ggml.h" #include +#include #include #include #include @@ -1962,7 +1963,7 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context & ctx, ggml_tensor * acl_tensor_ptr acl_weight_tensor; // Only check env once. - static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on")); + static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on")); if (weight_to_nz && is_matmul_weight(weight)) { acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ); } else { @@ -3805,3 +3806,57 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) { cubeMathType); } + +void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx, + ggml_tensor * add_node, + ggml_tensor * rms_norm_node) { + // Get the two input tensors for ADD operation + ggml_tensor * x1 = add_node->src[0]; + ggml_tensor * x2 = add_node->src[1]; + + // Create ACL tensors for the two ADD inputs + acl_tensor_ptr acl_x1 = ggml_cann_create_tensor(x1); + acl_tensor_ptr acl_x2 = ggml_cann_create_tensor(x2); + + // Get epsilon parameter from rms_norm_tensor + float eps; + memcpy(&eps, rms_norm_node->op_params, sizeof(float)); + + // Build gamma tensor (RMS normalization scaling factor) + // Gamma should match the normalized dimensions (last dimension of x1) + size_t acl_gamma_nb[GGML_MAX_DIMS]; + acl_gamma_nb[0] = ggml_type_size(rms_norm_node->type); + for (int i = 1; i < GGML_MAX_DIMS; i++) { + acl_gamma_nb[i] = acl_gamma_nb[i - 1] * x1->ne[i - 1]; + } + acl_tensor_ptr acl_gamma = + get_cache_acl_tensor(ctx, &ctx.rms_norm_one_tensor_cache.cache, ctx.rms_norm_one_tensor_cache.size, x1->ne, + acl_gamma_nb, rms_norm_node->type, + 1, // dims - only the last dimension + 1.0f // value + ); + + // Build rstdOut tensor (output for normalized standard deviation) + // Shape should be the dimensions that are NOT normalized + int64_t acl_rstd_ne[] = { 1, x1->ne[1], x1->ne[2], x1->ne[3] }; + size_t acl_rstd_nb[GGML_MAX_DIMS - 1]; + acl_rstd_nb[0] = sizeof(float); + for (int i = 1; i < GGML_MAX_DIMS - 1; i++) { + acl_rstd_nb[i] = acl_rstd_nb[i - 1] * acl_rstd_ne[i - 1]; + } + acl_tensor_ptr acl_rstd = + get_cache_acl_tensor(ctx, &ctx.rms_norm_zero_tensor_cache.cache, ctx.rms_norm_zero_tensor_cache.size, + acl_rstd_ne, acl_rstd_nb, GGML_TYPE_F32, GGML_MAX_DIMS, + 0.0f // value + ); + + acl_tensor_ptr acl_xout = ggml_cann_create_tensor(add_node); + + // Create yOut tensor (final output after RMS normalization) + acl_tensor_ptr acl_yout = ggml_cann_create_tensor(rms_norm_node); + + // Call fused ADD + RMS_NORM operator + GGML_CANN_CALL_ACLNN_OP(ctx, AddRmsNorm, acl_x1.get(), acl_x2.get(), acl_gamma.get(), + eps, // double type + acl_yout.get(), acl_rstd.get(), acl_xout.get()); +} diff --git a/ggml/src/ggml-cann/aclnn_ops.h b/ggml/src/ggml-cann/aclnn_ops.h index a6ea016c54..08ee7b1fbd 100644 --- a/ggml/src/ggml-cann/aclnn_ops.h +++ b/ggml/src/ggml-cann/aclnn_ops.h @@ -935,6 +935,20 @@ template void register_acl_resources(std::vector get_env(const std::string & name); +std::optional get_env_as_lowercase(const std::string & name); bool parse_bool(const std::string & value); int parse_integer(const std::string & value); diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp index ef23ec78da..162d238ae4 100644 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -105,10 +105,10 @@ int32_t ggml_cann_get_device() { } /** - * @brief Get the value of the specified environment variable (name). + * @brief Get the value of the specified environment variable (name) as lowercase. * if not empty, return a std::string object */ -std::optional get_env(const std::string & name) { +std::optional get_env_as_lowercase(const std::string & name) { const char * val = std::getenv(name.c_str()); if (!val) { return std::nullopt; @@ -122,7 +122,7 @@ std::optional get_env(const std::string & name) { * @brief Verify whether the environment variable is a valid value. */ bool parse_bool(const std::string & value) { - std::unordered_set valid_values = { "on", "1", "yes", "y", "enable", "true" }; + static const std::unordered_set valid_values = { "on", "1", "yes", "y", "enable", "true" }; return valid_values.find(value) != valid_values.end(); } @@ -259,7 +259,7 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool { * @param device The device ID to associate with this buffer pool. */ explicit ggml_cann_pool_buf_prio(int device) : device(device) { - disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or("")); + disable_clean = parse_bool(get_env_as_lowercase("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or("")); } /** @@ -452,7 +452,7 @@ struct ggml_cann_pool_buf : public ggml_cann_pool { * @param device The device ID to associate with this buffer pool. */ explicit ggml_cann_pool_buf(int device) : device(device) { - disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or("")); + disable_clean = parse_bool(get_env_as_lowercase("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or("")); } /** @@ -764,7 +764,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool { * @return A unique pointer to the created CANN pool. */ std::unique_ptr ggml_backend_cann_context::new_pool_for_device(int device) { - std::string mem_pool_type = get_env("GGML_CANN_MEM_POOL").value_or(""); + std::string mem_pool_type = get_env_as_lowercase("GGML_CANN_MEM_POOL").value_or(""); if (mem_pool_type == "prio") { GGML_LOG_INFO("%s: device %d use buffer pool with priority queue\n", __func__, device); @@ -1217,7 +1217,7 @@ static void ggml_backend_cann_buffer_set_tensor(ggml_backend_buffer_t buffer, // Why aclrtSynchronizeDevice? // Only check env once. - static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on")); + static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on")); if (!need_transform(tensor->type)) { ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE)); if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) { @@ -1442,7 +1442,7 @@ static size_t ggml_backend_cann_buffer_type_get_alloc_size(ggml_backend_buffer_t int64_t ne0 = tensor->ne[0]; // Only check env once. - static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on")); + static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on")); // last line must bigger than 32, because every single op deal at // least 32 bytes. @@ -1888,6 +1888,7 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg break; case GGML_OP_OUT_PROD: ggml_cann_out_prod(ctx, dst); + break; case GGML_OP_SSM_CONV: ggml_cann_ssm_conv(ctx, dst); break; @@ -2077,6 +2078,40 @@ static void ggml_backend_cann_synchronize(ggml_backend_t backend) { ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream())); } +/** + * @brief Check if CANN backend can fuse the specified operation sequence + * + * This function determines whether an operation sequence starting from the specified node + * can be fused into an optimized operation in the CANN backend. Operation fusion can reduce + * memory access overhead and improve computational efficiency. + * + * @param cgraph Pointer to the computation graph + * @param node_idx Index of the starting node in the computation graph + * @param ops Sequence of operation types to check for fusion + * @return true if the operations can be fused + * @return false if the operations cannot be fused + */ +static bool ggml_cann_can_fuse(const struct ggml_cgraph * cgraph, + int node_idx, + std::initializer_list ops) { + if (!ggml_can_fuse(cgraph, node_idx, ops)) { + return false; + } + + // CANN backend supports fusing ADD + RMS_NORM operations + if ((ops.size() == 2) && ops.begin()[0] == GGML_OP_ADD && ops.begin()[1] == GGML_OP_RMS_NORM) { + ggml_tensor * add_node = cgraph->nodes[node_idx]; + // TODO: support broadcast for ADD + RMS_NORM + if (add_node->src[0]->ne[0] != add_node->src[1]->ne[0] || add_node->src[0]->ne[1] != add_node->src[1]->ne[1] || + add_node->src[0]->ne[2] != add_node->src[1]->ne[2] || add_node->src[0]->ne[3] != add_node->src[1]->ne[3]) { + return false; + } + return true; + } + + return false; +} + /** * @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API. * @@ -2101,9 +2136,18 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx #endif // USE_ACL_GRAPH // Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph. // With the use of CANN graphs, the execution will be performed by the graph launch. + static bool opt_fusion = parse_bool(get_env_as_lowercase("GGML_CANN_OPERATOR_FUSION").value_or("")); + if (!use_cann_graph || cann_graph_capture_required) { for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; + if (opt_fusion) { + if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM })) { + ggml_cann_op_add_rms_norm_fused(*cann_ctx, node, cgraph->nodes[i + 1]); + i++; + continue; + } + } if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { @@ -2157,7 +2201,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend, #ifdef USE_ACL_GRAPH bool use_cann_graph = true; - static bool prefill_use_graph = parse_bool(get_env("GGML_CANN_PREFILL_USE_GRAPH").value_or("")); + static bool prefill_use_graph = parse_bool(get_env_as_lowercase("GGML_CANN_PREFILL_USE_GRAPH").value_or("")); if (!prefill_use_graph) { // Do not use acl_graph for prefill. for (int i = 0; i < cgraph->n_nodes; i++) { diff --git a/ggml/src/ggml-cuda/CMakeLists.txt b/ggml/src/ggml-cuda/CMakeLists.txt index ae8f963f69..dcc004134d 100644 --- a/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ggml/src/ggml-cuda/CMakeLists.txt @@ -54,6 +54,20 @@ if (CUDAToolkit_FOUND) enable_language(CUDA) + # TODO: Remove once CCCL 3.2 has been released and bundled with CUDA Toolkit + if (GGML_CUDA_CUB_3DOT2) + include(FetchContent) + + FetchContent_Declare( + CCCL + GIT_REPOSITORY https://github.com/nvidia/cccl.git + GIT_TAG v3.2.0-rc2 + GIT_SHALLOW TRUE + ) + + FetchContent_MakeAvailable(CCCL) + endif() + # Replace any plain 12X CUDA architectures with their "architecture-specific" equivalents 12Xa. # 12X is forwards-compatible, 12Xa is not. # Notably the Blackwell FP4 tensor core instructions are not forwards compatible and therefore need 12Xa. @@ -143,6 +157,9 @@ if (CUDAToolkit_FOUND) # As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas) else () + if (GGML_CUDA_CUB_3DOT2) + target_link_libraries(ggml-cuda PRIVATE CCCL::CCCL) + endif() if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "10.1") target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) else() @@ -150,6 +167,9 @@ if (CUDAToolkit_FOUND) endif() endif() else() + if (GGML_CUDA_CUB_3DOT2) + target_link_libraries(ggml-cuda PRIVATE CCCL::CCCL) + endif() target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas) endif() @@ -218,6 +238,10 @@ if (CUDAToolkit_FOUND) if (NOT MSVC) list(APPEND CUDA_CXX_FLAGS -Wno-pedantic) + else() + # CCCL 3.2 onwards will require a cpp-standard-compliant preprocessor for MSVC + # https://github.com/NVIDIA/cccl/pull/6827 + list(APPEND CUDA_CXX_FLAGS /Zc:preprocessor) endif() list(JOIN CUDA_CXX_FLAGS " " CUDA_CXX_FLAGS_JOINED) # pass host compiler flags as a single argument diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index da9652c3be..57c8a99a28 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -22,13 +22,13 @@ static __global__ void init_offsets(int * offsets, const int ncols, const int nr } #ifdef GGML_CUDA_USE_CUB -static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, - const float * x, - int * dst, - const int ncols, - const int nrows, - ggml_sort_order order, - cudaStream_t stream) { +void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, + const float * x, + int * dst, + const int ncols, + const int nrows, + ggml_sort_order order, + cudaStream_t stream) { ggml_cuda_pool_alloc temp_indices_alloc(pool, ncols * nrows); ggml_cuda_pool_alloc temp_keys_alloc(pool, ncols * nrows); ggml_cuda_pool_alloc offsets_alloc(pool, nrows + 1); @@ -49,28 +49,49 @@ static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, size_t temp_storage_bytes = 0; if (order == GGML_SORT_ORDER_ASC) { - DeviceSegmentedRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) - temp_indices, dst, // values (indices) - ncols * nrows, nrows, // num items, num segments - d_offsets, d_offsets + 1, 0, sizeof(float) * 8, // all bits - stream); + if (nrows == 1) { + DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream); + } else { + DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols * nrows, nrows, // num items, num segments + d_offsets, d_offsets + 1, stream); + } } else { - DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, - dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, - sizeof(float) * 8, stream); + if (nrows == 1) { + DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream); + } else { + DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices, + dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, stream); + } } ggml_cuda_pool_alloc temp_storage_alloc(pool, temp_storage_bytes); void * d_temp_storage = temp_storage_alloc.get(); if (order == GGML_SORT_ORDER_ASC) { - DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, - ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, sizeof(float) * 8, - stream); + if (nrows == 1) { + DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream); + } else { + DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst, + ncols * nrows, nrows, d_offsets, d_offsets + 1, stream); + } } else { - DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, - temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, - 0, sizeof(float) * 8, stream); + if (nrows == 1) { + DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) + temp_indices, dst, // values (indices) + ncols, 0, sizeof(float) * 8, stream); + } else { + DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, + temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, + stream); + } } } #endif // GGML_CUDA_USE_CUB @@ -141,12 +162,12 @@ static int next_power_of_2(int x) { return n; } -static void argsort_f32_i32_cuda_bitonic(const float * x, - int * dst, - const int ncols, - const int nrows, - ggml_sort_order order, - cudaStream_t stream) { +void argsort_f32_i32_cuda_bitonic(const float * x, + int * dst, + const int ncols, + const int nrows, + ggml_sort_order order, + cudaStream_t stream) { // bitonic sort requires ncols to be power of 2 const int ncols_pad = next_power_of_2(ncols); diff --git a/ggml/src/ggml-cuda/argsort.cuh b/ggml/src/ggml-cuda/argsort.cuh index 68a001547f..22b7306f20 100644 --- a/ggml/src/ggml-cuda/argsort.cuh +++ b/ggml/src/ggml-cuda/argsort.cuh @@ -1,3 +1,19 @@ #include "common.cuh" void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +#ifdef GGML_CUDA_USE_CUB +void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, + const float * x, + int * dst, + const int ncols, + const int nrows, + ggml_sort_order order, + cudaStream_t stream); +#endif // GGML_CUDA_USE_CUB +void argsort_f32_i32_cuda_bitonic(const float * x, + int * dst, + const int ncols, + const int nrows, + ggml_sort_order order, + cudaStream_t stream); diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh index 62e618850b..9516d8ec8f 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -950,15 +950,16 @@ struct ggml_cuda_device_info { int device_count; struct cuda_device_info { - int cc; // compute capability - int nsm; // number of streaming multiprocessors - size_t smpb; // max. shared memory per block - size_t smpbo; // max. shared memory per block (with opt-in) - bool integrated; // Device is integrated as opposed to discrete - bool vmm; // virtual memory support - size_t vmm_granularity; // granularity of virtual memory + int cc; // compute capability + int nsm; // number of streaming multiprocessors + size_t smpb; // max. shared memory per block + size_t smpbo; // max. shared memory per block (with opt-in) + bool integrated; // Device is integrated as opposed to discrete + bool vmm; // virtual memory support + size_t vmm_granularity; // granularity of virtual memory size_t total_vram; - int warp_size; // Number of threads in a dispatch + int warp_size; // Number of threads in a dispatch + bool supports_cooperative_launch; // whether cooperative launch is supported }; cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {}; @@ -1035,7 +1036,7 @@ struct ggml_tensor_extra_gpu { #define USE_CUDA_GRAPH #endif -struct ggml_graph_node_properties { +struct ggml_cuda_graph_node_properties { void * node_address; ggml_op node_op; int64_t ne[GGML_MAX_DIMS]; @@ -1058,12 +1059,27 @@ struct ggml_cuda_graph { cudaGraphExec_t instance = nullptr; size_t num_nodes = 0; std::vector nodes; - std::vector params; bool disable_due_to_gpu_arch = false; bool disable_due_to_too_many_updates = false; - bool disable_due_to_failed_graph_capture = false; int number_consecutive_updates = 0; - std::vector ggml_graph_properties; + std::vector props; + + void record_update(bool use_graph, bool update_required) { + if (use_graph && update_required) { + number_consecutive_updates++; + } else { + number_consecutive_updates = 0; + } + if (number_consecutive_updates >= 4) { + GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__); + disable_due_to_too_many_updates = true; + } + } + + bool is_enabled() const { + static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); + return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env || disable_due_to_too_many_updates); + } #endif }; diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index c4ceb4fc57..ee84303ef0 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -12,11 +12,11 @@ const int CUDA_CPY_BLOCK_NM = 8; // block size of 3rd dimension if available const int CUDA_CPY_BLOCK_ROWS = 8; // block dimension for marching through rows template -static __global__ void cpy_scalar(const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13) { - const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; +static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, + const int64_t nb12, const int64_t nb13) { + const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; if (i >= ne) { return; @@ -40,10 +40,10 @@ static __global__ void cpy_scalar(const char * cx, char * cdst, const int ne, } template -static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13) { +static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, + const int64_t nb12, const int64_t nb13) { const T* src = reinterpret_cast(cx); T* dst = reinterpret_cast(cdst); @@ -117,60 +117,60 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) { } template -static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13) { - const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk; +static __global__ void cpy_f32_q(const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, + const int64_t nb12, const int64_t nb13) { + const int64_t i = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*qk; if (i >= ne) { return; } - const int i03 = i/(ne00 * ne01 * ne02); - const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); - const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; - const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; - const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; + const int64_t i03 = i/(ne00 * ne01 * ne02); + const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); + const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; + const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; + const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; - const int i13 = i/(ne10 * ne11 * ne12); - const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); - const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; - const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; - const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13; + const int64_t i13 = i/(ne10 * ne11 * ne12); + const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); + const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; + const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; + const int64_t dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13; cpy_blck(cx + x_offset, cdst + dst_offset); } template -static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, - const int nb12, const int nb13) { - const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk; +static __global__ void cpy_q_f32(const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, + const int64_t nb12, const int64_t nb13) { + const int64_t i = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*qk; if (i >= ne) { return; } - const int i03 = i/(ne00 * ne01 * ne02); - const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); - const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; - const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; - const int x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; + const int64_t i03 = i/(ne00 * ne01 * ne02); + const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); + const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; + const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; + const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; - const int i13 = i/(ne10 * ne11 * ne12); - const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); - const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; - const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; - const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13; + const int64_t i13 = i/(ne10 * ne11 * ne12); + const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); + const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; + const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; + const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13; cpy_blck(cx + x_offset, cdst + dst_offset); } template static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const int64_t ne) { - const int64_t i = blockDim.x*blockIdx.x + threadIdx.x; + const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; if (i >= ne) { return; @@ -188,19 +188,20 @@ static void ggml_cpy_scalar_contiguous_cuda( cudaStream_t stream) { const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_scalar_contiguous<<>> (cx, cdst, ne); } template static void ggml_cpy_scalar_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { if (transposed) { GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed - int ne00n, ne01n, ne02n; + int64_t ne00n, ne01n, ne02n; if (nb00 <= nb02) { // most likely safe to handle nb00 = nb02 case here ne00n = ne00; ne01n = ne01; @@ -211,143 +212,159 @@ static void ggml_cpy_scalar_cuda( ne02n = 1; } - dim3 dimGrid( (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D, - (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D, - (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM); + int64_t grid_x = (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D; + int64_t grid_y = (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D; + int64_t grid_z = (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM; + GGML_ASSERT(grid_x < UINT_MAX); + GGML_ASSERT(grid_y < USHRT_MAX); + GGML_ASSERT(grid_z < USHRT_MAX); + dim3 dimGrid(grid_x, grid_y, grid_z); dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1); cpy_scalar_transpose<<>> (cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } else { - const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_scalar><<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } } static void ggml_cpy_f32_q8_0_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK8_0 == 0); - const int num_blocks = ne / QK8_0; + const int64_t num_blocks = ne / QK8_0; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_f32_q<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_q8_0_f32_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { - const int num_blocks = ne; + const int64_t num_blocks = ne; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_q_f32<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_f32_q4_0_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK4_0 == 0); - const int num_blocks = ne / QK4_0; + const int64_t num_blocks = ne / QK4_0; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_f32_q<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_q4_0_f32_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, - const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, - const int nb10, const int nb11, const int nb12, const int nb13, + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, + const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { - const int num_blocks = ne; + const int64_t num_blocks = ne; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_q_f32, QK4_0><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_f32_q4_1_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK4_1 == 0); - const int num_blocks = ne / QK4_1; + const int64_t num_blocks = ne / QK4_1; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_f32_q<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_q4_1_f32_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, - const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, - const int nb10, const int nb11, const int nb12, const int nb13, + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, + const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { - const int num_blocks = ne; + const int64_t num_blocks = ne; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_q_f32, QK4_1><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_f32_q5_0_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK5_0 == 0); - const int num_blocks = ne / QK5_0; + const int64_t num_blocks = ne / QK5_0; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_f32_q<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_q5_0_f32_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, - const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, - const int nb10, const int nb11, const int nb12, const int nb13, + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, + const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { - const int num_blocks = ne; + const int64_t num_blocks = ne; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_q_f32, QK5_0><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_f32_q5_1_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK5_1 == 0); - const int num_blocks = ne / QK5_1; + const int64_t num_blocks = ne / QK5_1; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_f32_q<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_q5_1_f32_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, - const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, - const int nb10, const int nb11, const int nb12, const int nb13, + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, + const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, + const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { - const int num_blocks = ne; + const int64_t num_blocks = ne; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_q_f32, QK5_1><<>>( cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } static void ggml_cpy_f32_iq4_nl_cuda( - const char * cx, char * cdst, const int ne, - const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, - const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) { + const char * cx, char * cdst, const int64_t ne, + const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, + const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { GGML_ASSERT(ne % QK4_NL == 0); - const int num_blocks = ne / QK4_NL; + const int64_t num_blocks = ne / QK4_NL; + GGML_ASSERT(num_blocks < UINT_MAX); cpy_f32_q<<>> (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); } @@ -356,9 +373,6 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); - GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX); - GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX); - const int64_t ne00 = src0->ne[0]; const int64_t ne01 = src0->ne[1]; const int64_t ne02 = src0->ne[2]; diff --git a/ggml/src/ggml-cuda/cumsum.cu b/ggml/src/ggml-cuda/cumsum.cu index 3bd1394c51..def9c32955 100644 --- a/ggml/src/ggml-cuda/cumsum.cu +++ b/ggml/src/ggml-cuda/cumsum.cu @@ -5,7 +5,7 @@ #include "ggml.h" #ifdef GGML_CUDA_USE_CUB -# include +# include #endif // GGML_CUDA_USE_CUB template @@ -185,9 +185,34 @@ static __global__ void cumsum_kernel( } } +#ifdef GGML_CUDA_USE_CUB +template +static void cumsum_cub(ggml_cuda_pool & pool, + const T * src, + T * dst, + int64_t ne, + cudaStream_t stream) { + size_t tmp_size = 0; + + // Query how much temp storage CUDA UnBound (CUB) needs + cub::DeviceScan::InclusiveSum(nullptr, // d_temp_storage (null = just query size) + tmp_size, // reference to size (will be set by CUB) + src, // input pointer + dst, // output pointer + ne, // number of elements + stream // CUDA stream to use + ); + + ggml_cuda_pool_alloc tmp_alloc(pool, tmp_size); + + // Perform the inclusive scan + cub::DeviceScan::InclusiveSum((void *) tmp_alloc.get(), tmp_size, src, dst, ne, stream); +} +#endif // GGML_CUDA_USE_CUB + template static void cumsum_cuda( - const T * src, T * dst, + [[maybe_unused]] ggml_backend_cuda_context & ctx, const T * src, T * dst, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, @@ -201,6 +226,15 @@ static void cumsum_cuda( if (is_contiguous) { use_cub = true; + const int64_t nrows = ne01 * ne02 * ne03; + // TODO: Compare with DeviceSegmentedScan::InclusiveSegmentedSum for nrows > 1 once InclusiveSegmentedSum is released + // Heuristics were determined as part of https://github.com/ggml-org/llama.cpp/pull/17004 + if (((nrows == 1) && (ne00 > 1024)) || (ne00 / nrows > 4096)) { + for (int i=0; idata, (float *)dst->data, + ctx, (const float *)src0->data, (float *)dst->data, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], diff --git a/ggml/src/ggml-cuda/fattn-common.cuh b/ggml/src/ggml-cuda/fattn-common.cuh index 8dc82a9d3b..3144678728 100644 --- a/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ggml/src/ggml-cuda/fattn-common.cuh @@ -11,10 +11,12 @@ #define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs. // log(2) = 0.6931, by adding this to the KQ maximum used for the softmax the numerical range representable -// by the VKQ accumulators is effectively being shifted up by a factor of 8. +// by the VKQ accumulators is effectively being shifted up by a factor of 2. // This reduces issues with numerical overflow but also causes larger values to be flushed to zero. // However, as the output from FlashAttention will usually be used as an input for a matrix multiplication this should be negligible. -#define FATTN_KQ_MAX_OFFSET 0.6931f +// Still, the value range should be shifted as much as necessary but as little as possible. +// The macro on the following line shifts it by a factor of 2**3=8, as was needed to fix https://github.com/ggml-org/llama.cpp/issues/18606 . +#define FATTN_KQ_MAX_OFFSET (3.0f*0.6931f) typedef void (* fattn_kernel_t)( const char * __restrict__ Q, @@ -918,7 +920,9 @@ void launch_fattn( blocks_num.y = 1; blocks_num.z = 1; - dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float)); + if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. + dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2))); + } } else { const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size. diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu index 84eccea3f7..bac69cdd1c 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -19,6 +19,7 @@ #include "ggml-cuda/count-equal.cuh" #include "ggml-cuda/cpy.cuh" #include "ggml-cuda/cross-entropy-loss.cuh" +#include "ggml-cuda/cumsum.cuh" #include "ggml-cuda/diagmask.cuh" #include "ggml-cuda/diag.cuh" #include "ggml-cuda/fattn.cuh" @@ -44,6 +45,7 @@ #include "ggml-cuda/ssm-scan.cuh" #include "ggml-cuda/sum.cuh" #include "ggml-cuda/sumrows.cuh" +#include "ggml-cuda/top-k.cuh" #include "ggml-cuda/mean.cuh" #include "ggml-cuda/tsembd.cuh" #include "ggml-cuda/topk-moe.cuh" @@ -231,6 +233,14 @@ static ggml_cuda_device_info ggml_cuda_init() { info.devices[id].nsm = prop.multiProcessorCount; info.devices[id].smpb = prop.sharedMemPerBlock; info.devices[id].warp_size = prop.warpSize; + +#ifndef GGML_USE_MUSA + int supports_coop_launch = 0; + CUDA_CHECK(cudaDeviceGetAttribute(&supports_coop_launch, cudaDevAttrCooperativeLaunch, id)); + info.devices[id].supports_cooperative_launch = !!supports_coop_launch; +#else + info.devices[id].supports_cooperative_launch = false; +#endif // !(GGML_USE_MUSA) #if defined(GGML_USE_HIP) info.devices[id].smpbo = prop.sharedMemPerBlock; @@ -2677,6 +2687,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SUM: ggml_cuda_op_sum(ctx, dst); break; + case GGML_OP_CUMSUM: + ggml_cuda_op_cumsum(ctx, dst); + break; case GGML_OP_SUM_ROWS: ggml_cuda_op_sum_rows(ctx, dst); break; @@ -2689,6 +2702,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_SSM_SCAN: ggml_cuda_op_ssm_scan(ctx, dst); break; + case GGML_OP_TOP_K: + ggml_cuda_op_top_k(ctx, dst); + break; case GGML_OP_ARGSORT: ggml_cuda_op_argsort(ctx, dst); break; @@ -2698,9 +2714,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_OP_CROSS_ENTROPY_LOSS: ggml_cuda_cross_entropy_loss(ctx, dst); break; - case GGML_OP_CUMSUM: - ggml_cuda_op_cumsum(ctx, dst); - break; case GGML_OP_TRI: ggml_cuda_op_tri(ctx, dst); break; @@ -2840,9 +2853,9 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) { } #ifdef USE_CUDA_GRAPH -static bool check_node_graph_compatibility(ggml_cgraph * cgraph, - bool use_cuda_graph) { +static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { + bool use_cuda_graph = true; // Loop over nodes in GGML graph to obtain info needed for CUDA graph const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected"; @@ -2902,41 +2915,41 @@ static bool check_node_graph_compatibility(ggml_cgraph * cgraph, return use_cuda_graph; } -static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { - graph_node_properties->node_address = node->data; - graph_node_properties->node_op = node->op; +static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) { + props->node_address = node->data; + props->node_op = node->op; for (int i = 0; i < GGML_MAX_DIMS; i++) { - graph_node_properties->ne[i] = node->ne[i]; - graph_node_properties->nb[i] = node->nb[i]; + props->ne[i] = node->ne[i]; + props->nb[i] = node->nb[i]; } for (int i = 0; i < GGML_MAX_SRC; i++) { - graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr; + props->src_address[i] = node->src[i] ? node->src[i]->data : nullptr; } - memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS); + memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS); } -static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) { - if (node->data != graph_node_properties->node_address && +static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) { + if (node->data != props->node_address && node->op != GGML_OP_VIEW) { return false; } - if (node->op != graph_node_properties->node_op) { + if (node->op != props->node_op) { return false; } for (int i = 0; i < GGML_MAX_DIMS; i++) { - if (node->ne[i] != graph_node_properties->ne[i]) { + if (node->ne[i] != props->ne[i]) { return false; } - if (node->nb[i] != graph_node_properties->nb[i]) { + if (node->nb[i] != props->nb[i]) { return false; } } for (int i = 0; i < GGML_MAX_SRC; i++) { if (node->src[i] && - node->src[i]->data != graph_node_properties->src_address[i] && + node->src[i]->data != props->src_address[i] && node->op != GGML_OP_VIEW ) { return false; @@ -2944,44 +2957,55 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra } if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) && - memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) { + memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) { return false; } return true; } -static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) { +static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) { - bool cuda_graph_update_required = false; + bool res = false; if (cuda_ctx->cuda_graph->instance == nullptr) { - cuda_graph_update_required = true; + res = true; } // Check if the graph size has changed - if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) { - cuda_graph_update_required = true; - cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes); + if (cuda_ctx->cuda_graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) { + res = true; + cuda_ctx->cuda_graph->props.resize(cgraph->n_nodes + cgraph->n_leafs); } // Loop over nodes in GGML graph to determine if CUDA graph update is required // and store properties to allow this comparison for the next token for (int i = 0; i < cgraph->n_nodes; i++) { - bool has_matching_properties = true; - if (!cuda_graph_update_required) { - has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); + bool props_match = true; + if (!res) { + props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &cuda_ctx->cuda_graph->props[i]); } - if (!has_matching_properties) { - cuda_graph_update_required = true; + if (!props_match) { + res = true; } - set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]); + ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[i], cgraph->nodes[i]); } - return cuda_graph_update_required; + for (int i = 0; i < cgraph->n_leafs; i++) { + bool props_match= true; + if (!res) { + props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &cuda_ctx->cuda_graph->props[cgraph->n_nodes + i]); + } + if (!props_match) { + res = true; + } + ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[cgraph->n_nodes + i], cgraph->leafs[i]); + } + + return res; } -static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) { +static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx) { #if CUDART_VERSION >= 12000 cudaGraphExecUpdateResultInfo result_info; @@ -3212,10 +3236,11 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, return false; } -static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, - bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) { +static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required) { + bool graph_evaluated_or_captured = false; + // flag used to determine whether it is an integrated_gpu - const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated; + const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated; ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context(); bool is_concurrent_event_active = false; @@ -3253,6 +3278,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx should_launch_concurrent_events = should_launch_concurrent_events && event.is_valid(); } } + if (should_launch_concurrent_events) { // Restore original node order within each concurrent region to enable fusion within streams @@ -3304,6 +3330,8 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx cgraph->nodes[start_pos + i] = const_cast(event.original_order[i]); } } + } else { + stream_ctx.concurrent_events.clear(); } for (int i = 0; i < cgraph->n_nodes; i++) { @@ -3682,7 +3710,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); } if (cuda_graph_update_required) { // Update graph executable - update_cuda_graph_executable(cuda_ctx); + ggml_cuda_graph_update_executable(cuda_ctx); } // Launch graph CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream())); @@ -3692,60 +3720,45 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx } } -static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { - ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; - - ggml_cuda_set_device(cuda_ctx->device); +static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx) { #ifdef USE_CUDA_GRAPH - static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr); - // Objects required for CUDA Graph if (cuda_ctx->cuda_graph == nullptr) { cuda_ctx->cuda_graph.reset(new ggml_cuda_graph()); } - bool use_cuda_graph = true; - bool cuda_graph_update_required = false; - if (cuda_ctx->cuda_graph->graph == nullptr) { if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) { cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true; -#ifndef NDEBUG GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__); -#endif } } - // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly, - // or previous graph capture failure. - // Also disable for multi-gpu for now. TO DO investigate - if (disable_cuda_graphs_due_to_env - || cuda_ctx->cuda_graph->disable_due_to_gpu_arch - || cuda_ctx->cuda_graph->disable_due_to_too_many_updates - || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) { - use_cuda_graph = false; - } - - if (use_cuda_graph) { - cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph); - - use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph); - - // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates. - if (use_cuda_graph && cuda_graph_update_required) { - cuda_ctx->cuda_graph->number_consecutive_updates++; - } else { - cuda_ctx->cuda_graph->number_consecutive_updates = 0; - } - - if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) { - cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true; -#ifndef NDEBUG - GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__); -#endif - } + return cuda_ctx->cuda_graph->is_enabled(); +#else + return false; +#endif // USE_CUDA_GRAPH +} + +static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; + + ggml_cuda_set_device(cuda_ctx->device); + + bool use_cuda_graph = false; + bool cuda_graph_update_required = false; + +#ifdef USE_CUDA_GRAPH + use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx); + + if (cuda_ctx->cuda_graph->is_enabled()) { + cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph); + use_cuda_graph = ggml_cuda_graph_check_compability(cgraph); + + cuda_ctx->cuda_graph->record_update(use_cuda_graph, cuda_graph_update_required); } +#endif // USE_CUDA_GRAPH if (use_cuda_graph && cuda_graph_update_required) { // Start CUDA graph capture @@ -3757,14 +3770,7 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed)); } -#else - bool use_cuda_graph = false; - bool cuda_graph_update_required = false; -#endif // USE_CUDA_GRAPH - - bool graph_evaluated_or_captured = false; - - evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required); + ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required); return GGML_STATUS_SUCCESS; } @@ -3797,8 +3803,10 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; + const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx); + static bool enable_graph_optimization = [] { - const char * env = getenv("GGML_CUDA_GRAPH_OPT"); + const char * env = getenv("GGML_CUDA_GRAPH_OPT"); return env != nullptr && atoi(env) == 1; }(); @@ -3806,12 +3814,13 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph return; } - GGML_ASSERT(ggml_backend_cuda_get_device_count() == 1 && "compute graph optimization is only supported on single GPU in the CUDA backend"); - GGML_LOG_DEBUG("Optimizing CUDA graph %p with %d nodes\n", cgraph->nodes, cgraph->n_nodes); - ggml_cuda_stream_context & stream_context = cuda_ctx->stream_context(); stream_context.reset(); + if (!use_cuda_graph || ggml_backend_cuda_get_device_count() != 1) { + return; + } + // number of out-degrees for a particular node std::unordered_map fan_out; // reverse mapping of node to index in the cgraph @@ -3872,6 +3881,12 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph if (count >= min_fan_out && count <= max_fan_out) { const int root_node_idx = node_indices[root_node]; + // only optimize for attn_norm + // TODO: make this more generic + if (!strstr(root_node->name, "attn_norm")) { + continue; + } + bool is_part_of_event = false; for (const auto & [start, end] : concurrent_node_ranges) { if (root_node_idx >= start && root_node_idx <= end) { @@ -4600,6 +4615,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g return true; case GGML_OP_SUM: return ggml_is_contiguous_rows(op->src[0]); + case GGML_OP_TOP_K: case GGML_OP_ARGSORT: #ifndef GGML_CUDA_USE_CUB return op->src[0]->ne[0] <= 1024; diff --git a/ggml/src/ggml-cuda/mean.cu b/ggml/src/ggml-cuda/mean.cu index 691d8dcb14..60542fc19d 100644 --- a/ggml/src/ggml-cuda/mean.cu +++ b/ggml/src/ggml-cuda/mean.cu @@ -34,13 +34,11 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { // CUDA_GRAPHS_DISABLED ((ncols > 65536) && ((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) || - ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates || - ctx.cuda_graph->disable_due_to_failed_graph_capture)) || + ctx.cuda_graph->is_enabled())) || // CUDA_GRAPHS ENABLED ((ncols > 32768) && !((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) || - ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates || - ctx.cuda_graph->disable_due_to_failed_graph_capture))) { + ctx.cuda_graph->is_enabled()))) { #else (ncols > 65536)) { #endif // USE_CUDA_GRAPH diff --git a/ggml/src/ggml-cuda/mmq.cu b/ggml/src/ggml-cuda/mmq.cu index 85692d4543..ceb95758d2 100644 --- a/ggml/src/ggml-cuda/mmq.cu +++ b/ggml/src/ggml-cuda/mmq.cu @@ -333,6 +333,28 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t } if (amd_wmma_available(cc)) { + // RDNA 4 is consistently worse on rocblas + // https://github.com/ggml-org/llama.cpp/pull/18537#issuecomment-3706422301 + if (GGML_CUDA_CC_IS_RDNA3(cc)) { + // High expert counts almost always better on MMQ + // due to a large amount of graph splits + // https://github.com/ggml-org/llama.cpp/pull/18202 + if (n_experts >= 64) { + return true; + } + + switch (type) { + // These quants are really bad on MMQ + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q6_K: + // These quants are usually worse but not always + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + return ne11 <= 128; + default: + return true; + } + } return true; } diff --git a/ggml/src/ggml-cuda/softmax.cu b/ggml/src/ggml-cuda/softmax.cu index eeacde0bdb..1ae84ebf63 100644 --- a/ggml/src/ggml-cuda/softmax.cu +++ b/ggml/src/ggml-cuda/softmax.cu @@ -1,6 +1,14 @@ #include "common.cuh" #include "ggml.h" #include "softmax.cuh" + +#ifdef GGML_USE_HIP +#include +#else +#include +#include +#endif // GGML_USE_HIP + #include #include @@ -160,6 +168,156 @@ static __global__ void soft_max_f32( dst[col] = vals[col] * inv_sum; } } + + +// TODO: This is a common pattern used across kernels that could be moved to common.cuh + templated +static __device__ float two_stage_warp_reduce_max(float val) { + val = warp_reduce_max(val); + if (blockDim.x > WARP_SIZE) { + assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0); + __shared__ float local_vals[32]; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + local_vals[warp_id] = val; + } + __syncthreads(); + val = -INFINITY; + if (lane_id < (static_cast(blockDim.x) / WARP_SIZE)) { + val = local_vals[lane_id]; + } + return warp_reduce_max(val); + } else { + return val; + } +} + +static __device__ float two_stage_warp_reduce_sum(float val) { + val = warp_reduce_sum(val); + if (blockDim.x > WARP_SIZE) { + assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0); + __shared__ float local_vals[32]; + const int warp_id = threadIdx.x / WARP_SIZE; + const int lane_id = threadIdx.x % WARP_SIZE; + if (lane_id == 0) { + local_vals[warp_id] = val; + } + __syncthreads(); + val = 0.0f; + if (lane_id < (static_cast(blockDim.x) / WARP_SIZE)) { + val = local_vals[lane_id]; + } + return warp_reduce_sum(val); + } else { + return val; + } +} + +// TODO: Template to allow keeping ncols in registers if they fit +static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __restrict__ x, + float * __restrict__ dst, + float * __restrict__ tmp_maxs, + float * __restrict__ tmp_sums, + const soft_max_params p) { + namespace cg = cooperative_groups; + + const cg::grid_group g = cg::this_grid(); + + const int tid = threadIdx.x; + const int col_start = blockIdx.x * blockDim.x + tid; + const int n_elem_per_thread = 4; + + float local_vals[n_elem_per_thread] = { -INFINITY, -INFINITY, -INFINITY, -INFINITY }; + float local_max = -INFINITY; + const int step_size = gridDim.x * blockDim.x; + + // Compute thread-local max + for (int col = col_start; col < p.ncols;) { +#pragma unroll + for (int i = 0; i < n_elem_per_thread; i++) { + const int idx = col + i * step_size; + local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY; + } +#pragma unroll + for (int i = 0; i < n_elem_per_thread; i++) { + local_max = fmaxf(local_max, local_vals[i]); + } + col += step_size * n_elem_per_thread; + } + + // Compute CTA-level max + local_max = two_stage_warp_reduce_max(local_max); + + // Store CTA-level max to GMEM + if (tid == 0) { + tmp_maxs[blockIdx.x] = local_max; + } + g.sync(); + + // Compute compute global max from CTA-level maxs + assert(gridDim.x < blockDim.x); // currently we only support this case + if (tid < gridDim.x) { + local_max = tmp_maxs[tid]; + } else { + local_max = -INFINITY; + } + local_max = two_stage_warp_reduce_max(local_max); + + // Compute softmax dividends, accumulate divisor + float tmp_expf = 0.0f; + for (int col = col_start; col < p.ncols;) { +#pragma unroll + for (int i = 0; i < n_elem_per_thread; i++) { + const int idx = col + i * step_size; + local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY; + } +#pragma unroll + for (int i = 0; i < n_elem_per_thread; i++) { + const int idx = col + i * step_size; + if (idx < p.ncols) { + const float tmp = expf(local_vals[i] - local_max); + tmp_expf += tmp; + dst[idx] = tmp; + } + } + col += step_size * n_elem_per_thread; + } + + // Reduce divisor within CTA + tmp_expf = two_stage_warp_reduce_sum(tmp_expf); + + // Store CTA-level sum to GMEM + if (tid == 0) { + tmp_sums[blockIdx.x] = tmp_expf; + } + g.sync(); + + // Compute global sum from CTA-level sums + if (tid < gridDim.x) { + tmp_expf = tmp_sums[tid]; + } else { + tmp_expf = 0.0f; + } + tmp_expf = two_stage_warp_reduce_sum(tmp_expf); + + // Divide dividend by global sum + store data + for (int col = col_start; col < p.ncols;) { +#pragma unroll + for (int i = 0; i < n_elem_per_thread; i++) { + const int idx = col + i * step_size; + local_vals[i] = idx < p.ncols ? dst[idx] : -INFINITY; + } +#pragma unroll + for (int i = 0; i < n_elem_per_thread; i++) { + const int idx = col + i * step_size; + if (idx < p.ncols) { + dst[idx] = local_vals[i] / tmp_expf; + } + } + col += step_size * n_elem_per_thread; + } +} + #ifdef __clang__ #pragma clang diagnostic pop #endif // __clang__ @@ -216,9 +374,31 @@ static void launch_soft_max_kernels(const float * x, const T * mask, const float soft_max_f32<<>>(x, mask, sinks, dst, p); } +__launch_bounds__(8*WARP_SIZE, 1) static __global__ void soft_max_f32_parallelize_cols(const float * __restrict__ x, + float * __restrict__ dst, + float * __restrict__ tmp_maxs, + float * __restrict__ tmp_sums, + const soft_max_params p) +// We loop over all instead of parallelizing across gridDim.y as cooperative groups +// currently only support synchronizing the complete grid if not launched as a cluster group +// (which requires CC > 9.0) +// https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#grid-synchronization +// https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#class-cluster-group +{ + for (int rowx = 0; rowx < p.ne01 * p.ne02 * p.ne03; rowx++) { + soft_max_f32_parallelize_cols_single_row(x + int64_t(rowx) * p.ncols, dst + int64_t(rowx) * p.ncols, tmp_maxs, + tmp_sums, p); + } +} -template -static void soft_max_f32_cuda(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & params, cudaStream_t stream) { +template +static void soft_max_f32_cuda(const float * x, + const T * mask, + const float * sinks, + float * dst, + const soft_max_params & params, + cudaStream_t stream, + [[maybe_unused]] ggml_backend_cuda_context & ctx) { int nth = WARP_SIZE; const int64_t ncols_x = params.ncols; @@ -236,8 +416,25 @@ static void soft_max_f32_cuda(const float * x, const T * mask, const float * sin if (nbytes_shared <= smpbo) { launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared); } else { - const size_t nbytes_shared_low = WARP_SIZE*sizeof(float); - soft_max_f32<<>>(x, mask, sinks, dst, params); + // Parallelize across SMs for top-p/dist-sampling + // The heuristic for parallelizing rows across SMs vs parallelizing single row & looping over all rows was done on the basis of a B6000 GPU and + // Can be adapted further for lower-SM-count GPUs, though keeping data in registers should be implemented first as that is the optimal solution. + if (ggml_cuda_info().devices[id].supports_cooperative_launch && + ncols_x / (params.ne01 * params.ne02 * params.ne03) > 8192 && mask == nullptr && sinks == nullptr && + params.scale == 1.0f && params.max_bias == 0.0f) { + ggml_cuda_pool_alloc tmp_maxs_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float)); + ggml_cuda_pool_alloc tmp_sums_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float)); + + void * kernel_args[] = { (void *) &x, (void *) &dst, (void *) &tmp_maxs_alloc.ptr, + (void *) &tmp_sums_alloc.ptr, (void *) const_cast(¶ms) }; + CUDA_CHECK(cudaLaunchCooperativeKernel((void *) soft_max_f32_parallelize_cols, + dim3(ggml_cuda_info().devices[id].nsm, 1, 1), + dim3(WARP_SIZE * 8, 1, 1), kernel_args, 0, stream)); + } else { + const size_t nbytes_shared_low = WARP_SIZE * sizeof(float); + soft_max_f32 + <<>>(x, mask, sinks, dst, params); + } } } @@ -315,9 +512,9 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { params.m1 = m1; if (use_f16) { - soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream); + soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream, ctx); } else { - soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream); + soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream, ctx); } } diff --git a/ggml/src/ggml-cuda/ssm-scan.cu b/ggml/src/ggml-cuda/ssm-scan.cu index 6b424381df..c1d4e2bc8d 100644 --- a/ggml/src/ggml-cuda/ssm-scan.cu +++ b/ggml/src/ggml-cuda/ssm-scan.cu @@ -114,7 +114,7 @@ __global__ void __launch_bounds__(splitD, 1) #endif // __clang__ // assumes as many threads as d_state -template +template __global__ void __launch_bounds__(d_state, 1) ssm_scan_f32_group( const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, @@ -125,20 +125,25 @@ __global__ void __launch_bounds__(d_state, 1) const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) { - const int head_idx = (blockIdx.x * splitH) / d_head; - const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float); - const int seq_idx = blockIdx.y; + const int warp = threadIdx.x / WARP_SIZE; + const int lane = threadIdx.x % WARP_SIZE; + const int warp_idx = blockIdx.x * c_factor + warp; + + const int head_idx = warp_idx / d_head; + const int head_off = (warp_idx % d_head) * sizeof(float); + const int seq_idx = blockIdx.y; const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float); - const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); - const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float)); - const float * dt_block = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float)); - const float * A_block = (const float *) ((const char *) src3 + head_idx * src3_nb1); - const float * B_block = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off)); - const float * C_block = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off)); - float * y_block = dst + (seq_idx * n_tok * n_head * d_head) + blockIdx.x * splitH; - float * s_block = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); + // TODO: refactor strides to be in elements/floats instead of bytes to be cleaner and consistent with the rest of the codebase + const float * s0_warp = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); + const float * x_warp = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + (warp_idx * sizeof(float))); + const float * dt_warp = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float)); + const float * A_warp = (const float *) ((const char *) src3 + head_idx * src3_nb1); + const float * B_warp = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off)); + const float * C_warp = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off)); + float * y_warp = dst + (seq_idx * n_tok * n_head * d_head) + warp_idx; + float * s_warp = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); // strides across n_seq_tokens const int stride_x = src1_nb2 / sizeof(float); @@ -147,80 +152,42 @@ __global__ void __launch_bounds__(d_state, 1) const int stride_C = src5_nb2 / sizeof(float); const int stride_y = n_head * d_head; - float state[splitH]; - // for the parallel accumulation - __shared__ float stateC[splitH * d_state]; + float state[c_factor]; + float state_sum = 0.0f; #pragma unroll - for (int j = 0; j < splitH; j++) { - state[j] = s0_block[j * d_state + threadIdx.x]; + for (int j = 0; j < c_factor; j++) { + state[j] = s0_warp[WARP_SIZE * j + lane]; } for (int64_t i = 0; i < n_tok; i++) { - // TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements - // TODO: only calculate B and C once per head group - // NOTE: dt_soft_plus, dA and x_dt have the same value across threads here. - float dt_soft_plus = dt_block[i * stride_dt]; - if (dt_soft_plus <= 20.0f) { - dt_soft_plus = log1pf(expf(dt_soft_plus)); - } - const float dA = expf(dt_soft_plus * A_block[0]); - const float B = B_block[i * stride_B + threadIdx.x]; - const float C = C_block[i * stride_C + threadIdx.x]; + // NOTE: dt_soft_plus, dA and x_dt have the same value for a warp here. + // Recalculation is intentional; sharing via shuffles/smem proved slower due to sync overhead. + const float dt_soft_plus = (dt_warp[i * stride_dt] <= 20.0f ? log1pf(expf(dt_warp[i * stride_dt])) : dt_warp[i * stride_dt]); - // across d_head + state_sum = 0.0f; + const float dA = expf(dt_soft_plus * A_warp[0]); + const float x_dt = x_warp[i * stride_x] * dt_soft_plus; #pragma unroll - for (int j = 0; j < splitH; j++) { - const float x_dt = x_block[i * stride_x + j] * dt_soft_plus; - - state[j] = (state[j] * dA) + (B * x_dt); - - stateC[j * d_state + threadIdx.x] = state[j] * C; + for (int j = 0; j < c_factor; j++) { + const float B_val = B_warp[i * stride_B + WARP_SIZE * j + lane]; + const float C_val = C_warp[i * stride_C + WARP_SIZE * j + lane]; + state[j] = (state[j] * dA) + (B_val * x_dt); + state_sum += state[j] * C_val; } - __syncthreads(); + // parallel accumulation for output + state_sum = warp_reduce_sum(state_sum); - // parallel accumulation for stateC - // TODO: simplify - { - static_assert((d_state & -d_state) == d_state, "the state size has to be a power of 2"); - static_assert((splitH & -splitH) == splitH, "splitH has to be a power of 2"); - - // reduce until w matches the warp size - // TODO: does this work even when the physical warp size is 64? -#pragma unroll - for (int w = d_state; w > WARP_SIZE; w >>= 1) { - // (assuming there are d_state threads) -#pragma unroll - for (int j = 0; j < ((w >> 1) * splitH + d_state - 1) / d_state; j++) { - // TODO: check for bank conflicts - const int k = (threadIdx.x % (w >> 1)) + (d_state * (threadIdx.x / (w >> 1))) + j * d_state * (d_state / (w >> 1)); - stateC[k] += stateC[k + (w >> 1)]; - - } - __syncthreads(); - } - - static_assert(splitH >= d_state / WARP_SIZE); - -#pragma unroll - for (int j = 0; j < splitH / (d_state / WARP_SIZE); j++) { - float y = stateC[(threadIdx.x % WARP_SIZE) + d_state * (threadIdx.x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)]; - y = warp_reduce_sum(y); - - // store the above accumulations - if (threadIdx.x % WARP_SIZE == 0) { - const int k = threadIdx.x / WARP_SIZE + j * (d_state / WARP_SIZE); - y_block[i * stride_y + k] = y; - } - } + if (lane == 0) { + y_warp[i * stride_y] = state_sum; } } // write back the state #pragma unroll - for (int j = 0; j < splitH; j++) { - s_block[j * d_state + threadIdx.x] = state[j]; + for (int j = 0; j < c_factor; j++) { + s_warp[WARP_SIZE * j + lane] = state[j]; } } @@ -231,27 +198,24 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim, const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq, cudaStream_t stream) { - const int threads = 128; // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition! if (src3_nb1 == sizeof(float)) { // Mamba-2 if (d_state == 128) { - GGML_ASSERT(d_state % threads == 0); - // NOTE: can be any power of two between 4 and 64 - const int splitH = 16; - GGML_ASSERT(head_dim % splitH == 0); - const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1); - ssm_scan_f32_group<16, 128><<>>( + constexpr int threads = 128; + constexpr int num_warps = threads/WARP_SIZE; + + const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1); + ssm_scan_f32_group<128/WARP_SIZE, 128><<>>( src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok); } else if (d_state == 256) { // Falcon-H1 - const int threads = 256; - // NOTE: can be any power of two between 8 and 64 - const int splitH = 16; - GGML_ASSERT(head_dim % splitH == 0); - const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1); - ssm_scan_f32_group<16, 256><<>>( + constexpr int threads = 256; + constexpr int num_warps = threads/WARP_SIZE; + + const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1); + ssm_scan_f32_group<256/WARP_SIZE, 256><<>>( src0, src1, src2, src3, src4, src5, src6, dst, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok); @@ -260,6 +224,7 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa } } else { // Mamba-1 + constexpr int threads = 128; GGML_ASSERT(n_head % threads == 0); GGML_ASSERT(head_dim == 1); GGML_ASSERT(n_group == 1); diff --git a/ggml/src/ggml-cuda/top-k.cu b/ggml/src/ggml-cuda/top-k.cu new file mode 100644 index 0000000000..318ac38691 --- /dev/null +++ b/ggml/src/ggml-cuda/top-k.cu @@ -0,0 +1,96 @@ +#include "argsort.cuh" +#include "top-k.cuh" + +#ifdef GGML_CUDA_USE_CUB +# include +# if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2) +# include +# define CUB_TOP_K_AVAILABLE +using namespace cub; +# endif // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2 +#endif // GGML_CUDA_USE_CUB + +#ifdef CUB_TOP_K_AVAILABLE + +static void top_k_cub(ggml_cuda_pool & pool, + const float * src, + int * dst, + const int ncols, + const int k, + cudaStream_t stream) { + auto requirements = cuda::execution::require(cuda::execution::determinism::not_guaranteed, + cuda::execution::output_ordering::unsorted); + auto stream_env = cuda::stream_ref{ stream }; + auto env = cuda::std::execution::env{ stream_env, requirements }; + + auto indexes_in = cuda::make_counting_iterator(0); + + size_t temp_storage_bytes = 0; + DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, ncols, k, + env); + + ggml_cuda_pool_alloc temp_storage_alloc(pool, temp_storage_bytes); + void * d_temp_storage = temp_storage_alloc.get(); + + DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, + ncols, k, env); +} + +#elif defined(GGML_CUDA_USE_CUB) // CUB_TOP_K_AVAILABLE + +static int next_power_of_2(int x) { + int n = 1; + while (n < x) { + n *= 2; + } + return n; +} + +#endif // CUB_TOP_K_AVAILABLE + +void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const float * src0_d = (const float *) src0->data; + int * dst_d = (int *) dst->data; + cudaStream_t stream = ctx.stream(); + + // are these asserts truly necessary? + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_I32); + GGML_ASSERT(ggml_is_contiguous(src0)); + + const int64_t ncols = src0->ne[0]; + const int64_t nrows = ggml_nrows(src0); + const int64_t k = dst->ne[0]; + ggml_cuda_pool & pool = ctx.pool(); +#ifdef CUB_TOP_K_AVAILABLE + // TODO: Switch to `DeviceSegmentedTopK` for multi-row TopK once implemented + // https://github.com/NVIDIA/cccl/issues/6391 + // TODO: investigate if there exists a point where parallelized argsort is faster than sequential top-k + for (int i = 0; i < nrows; i++) { + top_k_cub(pool, src0_d + i * ncols, dst_d + i * k, ncols, k, stream); + } +#elif defined(GGML_CUDA_USE_CUB) // CUB_TOP_K_AVAILABLE + // Fall back to argsort + copy + const int ncols_pad = next_power_of_2(ncols); + const size_t shared_mem = ncols_pad * sizeof(int); + const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb; + + ggml_cuda_pool_alloc temp_dst_alloc(pool, ncols * nrows); + int * tmp_dst = temp_dst_alloc.get(); + + if (shared_mem > max_shared_mem || ncols > 1024) { + argsort_f32_i32_cuda_cub(pool, src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream); + } else { + argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream); + } + CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows, + cudaMemcpyDeviceToDevice, stream)); +#else // GGML_CUDA_USE_CUB + ggml_cuda_pool_alloc temp_dst_alloc(pool, ncols * nrows); + int * tmp_dst = temp_dst_alloc.get(); + argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream); + CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows, + cudaMemcpyDeviceToDevice, stream)); +#endif +} diff --git a/ggml/src/ggml-cuda/top-k.cuh b/ggml/src/ggml-cuda/top-k.cuh new file mode 100644 index 0000000000..f4d8f61e5b --- /dev/null +++ b/ggml/src/ggml-cuda/top-k.cuh @@ -0,0 +1,3 @@ +#include "common.cuh" + +void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h index 951a88d567..016b04e5a0 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h @@ -45,9 +45,11 @@ #define cublasSgemm hipblasSgemm #define cublasStatus_t hipblasStatus_t #define cublasOperation_t hipblasOperation_t +#define cudaDevAttrCooperativeLaunch hipDeviceAttributeCooperativeLaunch #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess +#define cudaDeviceGetAttribute hipDeviceGetAttribute #define cudaDeviceProp hipDeviceProp_t #define cudaDeviceSynchronize hipDeviceSynchronize #define cudaError_t hipError_t @@ -70,6 +72,7 @@ #define cudaHostRegisterPortable hipHostRegisterPortable #define cudaHostRegisterReadOnly hipHostRegisterReadOnly #define cudaHostUnregister hipHostUnregister +#define cudaLaunchCooperativeKernel hipLaunchCooperativeKernel #define cudaLaunchHostFunc hipLaunchHostFunc #define cudaMalloc hipMalloc #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) diff --git a/ggml/src/ggml-cuda/vendors/musa.h b/ggml/src/ggml-cuda/vendors/musa.h index 221e67f96a..1abb8acfd4 100644 --- a/ggml/src/ggml-cuda/vendors/musa.h +++ b/ggml/src/ggml-cuda/vendors/musa.h @@ -61,6 +61,7 @@ #define cudaHostRegisterPortable musaHostRegisterPortable #define cudaHostRegisterReadOnly musaHostRegisterReadOnly #define cudaHostUnregister musaHostUnregister +#define cudaLaunchCooperativeKernel musaLaunchCooperativeKernel #define cudaLaunchHostFunc musaLaunchHostFunc #define cudaMalloc musaMalloc #define cudaMallocHost musaMallocHost diff --git a/ggml/src/ggml-hexagon/ggml-hexagon.cpp b/ggml/src/ggml-hexagon/ggml-hexagon.cpp index 13b96d61f8..365a24b496 100644 --- a/ggml/src/ggml-hexagon/ggml-hexagon.cpp +++ b/ggml/src/ggml-hexagon/ggml-hexagon.cpp @@ -1773,6 +1773,37 @@ static bool hex_supported_dims2(const struct ggml_tensor * x, const struct ggml_ return true; } +static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; + const struct ggml_tensor * src1 = op->src[1]; + const struct ggml_tensor * src2 = op->src[2]; + const struct ggml_tensor * src3 = op->src[3]; + const struct ggml_tensor * src4 = op->src[4]; + const struct ggml_tensor * dst = op; + + // Check for F16 support only as requested + if ((src0->type != GGML_TYPE_F16 && src0->type != GGML_TYPE_F32) || src1->type != GGML_TYPE_F16 || src2->type != GGML_TYPE_F16) { + return false; + } + + if (src3 && src3->type != GGML_TYPE_F16) { // mask + return false; + } + + if (src4 && src4->type != GGML_TYPE_F32) { // sinks + return false; + } + + // For now we support F32 or F16 output as htp backend often converts output on the fly if needed, + // but the op implementation writes to F16 or F32. + // Let's assume dst can be F32 or F16. + if (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) { + return false; + } + + return opt_experimental; +} + static bool hex_supported_src0_type(ggml_type t) { return t == GGML_TYPE_F32; } @@ -1815,12 +1846,11 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src1 = dst->src[1]; - if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) { + if (dst->type != GGML_TYPE_F32) { return false; } - // TODO: add support for non-cont tensors - if (!ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { + if (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) { return false; } @@ -1836,7 +1866,6 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s return false; // typically the lm-head which would be too large for VTCM } - // if ((src0->ne[2] != src1->ne[2] || src0->ne[3] != src1->ne[3])) return false; if ((src1->ne[2] != 1 || src1->ne[3] != 1)) { return false; } @@ -1885,21 +1914,10 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session } break; - case GGML_TYPE_F16: - if (!opt_experimental) { - return false; - } - break; - default: return false; } - // TODO: add support for non-cont tensors - if (!ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) { - return false; - } - return true; } @@ -2060,6 +2078,46 @@ static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * s return true; } +static bool ggml_hexagon_supported_set_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; // values + const struct ggml_tensor * src1 = op->src[1]; // indices + const struct ggml_tensor * dst = op; + + if (src0->type != GGML_TYPE_F32) { + return false; + } + + if (src1->type != GGML_TYPE_I32 && src1->type != GGML_TYPE_I64) { + return false; + } + + if (dst->type != GGML_TYPE_F16) { + return false; + } + + return true; +} + +static bool ggml_hexagon_supported_get_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { + const struct ggml_tensor * src0 = op->src[0]; // values + const struct ggml_tensor * src1 = op->src[1]; // indices + const struct ggml_tensor * dst = op; + + if (src0->type != GGML_TYPE_F32) { + return false; + } + + if (src1->type != GGML_TYPE_I32 && src1->type != GGML_TYPE_I64) { + return false; + } + + if (dst->type != GGML_TYPE_F32) { + return false; + } + + return true; +} + static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { const int32_t * op_params = &op->op_params[0]; @@ -2154,6 +2212,11 @@ static size_t htp_req_buff_init(htp_tensor *h, dspqueue_buffer * d, const ggml_t d->offset = (uint8_t *) t->data - buf->base; d->size = ggml_nbytes(t); + if (!d->size) { + // Some requests contain srcs where ggml_nbytes() returns 0 but the rest of the op is non-empty + d->size = 64; + } + switch (type) { case DSPQBUF_TYPE_DSP_WRITE_CPU_READ: // Flush CPU @@ -2239,6 +2302,17 @@ static inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bu return n_bufs; } +static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + req->op = HTP_OP_GET_ROWS; + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + template static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { switch (t->op) { @@ -2266,6 +2340,17 @@ static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * return n_bufs; } +static inline size_t init_set_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + req->op = HTP_OP_SET_ROWS; + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); @@ -2277,6 +2362,11 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf supported = true; break; + case GGML_OP_SCALE: + req->op = HTP_OP_SCALE; + supported = true; + break; + case GGML_OP_UNARY: if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) { req->op = HTP_OP_UNARY_SILU; @@ -2331,6 +2421,21 @@ static inline size_t init_rope_req(htp_general_req * req, dspqueue_buffer * bufs return n_bufs; } +static inline size_t init_flash_attn_ext_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { + memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); + req->op = HTP_OP_FLASH_ATTN_EXT; + + size_t n_bufs = 0; + n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src3, &bufs[n_bufs], t->src[3], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->src4, &bufs[n_bufs], t->src[4], DSPQBUF_TYPE_CPU_WRITE_DSP_READ); + n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ); + + return n_bufs; +} + static const char * ggml_backend_hexagon_name(ggml_backend_t backend) { auto sess = static_cast(backend->context); return sess->name.c_str(); @@ -2417,6 +2522,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg ggml_hexagon_dispatch_op>(sess, node, flags); break; case GGML_OP_RMS_NORM: + case GGML_OP_SCALE: ggml_hexagon_dispatch_op(sess, node, flags); break; case GGML_OP_UNARY: @@ -2439,6 +2545,18 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg ggml_hexagon_dispatch_op(sess, node, flags); break; + case GGML_OP_FLASH_ATTN_EXT: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + + case GGML_OP_SET_ROWS: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + + case GGML_OP_GET_ROWS: + ggml_hexagon_dispatch_op(sess, node, flags); + break; + default: GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node)); } @@ -2778,6 +2896,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons break; case GGML_OP_RMS_NORM: + case GGML_OP_SCALE: supp = ggml_hexagon_supported_unary(sess, op); break; @@ -2805,6 +2924,18 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons supp = ggml_hexagon_supported_rope(sess, op); break; + case GGML_OP_FLASH_ATTN_EXT: + supp = ggml_hexagon_supported_flash_attn_ext(sess, op); + break; + + case GGML_OP_SET_ROWS: + supp = ggml_hexagon_supported_set_rows(sess, op); + break; + + case GGML_OP_GET_ROWS: + supp = ggml_hexagon_supported_get_rows(sess, op); + break; + default: break; } diff --git a/ggml/src/ggml-hexagon/htp/CMakeLists.txt b/ggml/src/ggml-hexagon/htp/CMakeLists.txt index 2cf8aaa42a..6a34a215fa 100644 --- a/ggml/src/ggml-hexagon/htp/CMakeLists.txt +++ b/ggml/src/ggml-hexagon/htp/CMakeLists.txt @@ -28,6 +28,9 @@ add_library(${HTP_LIB} SHARED softmax-ops.c act-ops.c rope-ops.c + flash-attn-ops.c + set-rows-ops.c + get-rows-ops.c ) target_compile_definitions(${HTP_LIB} PRIVATE diff --git a/ggml/src/ggml-hexagon/htp/act-ops.c b/ggml/src/ggml-hexagon/htp/act-ops.c index 7e488456ee..88bd2ddc43 100644 --- a/ggml/src/ggml-hexagon/htp/act-ops.c +++ b/ggml/src/ggml-hexagon/htp/act-ops.c @@ -85,13 +85,16 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0, struct htp_spad * dst_spad, uint32_t nth, uint32_t ith, - uint32_t src0_nrows_per_thread) { + uint32_t src0_nrows_per_thread, + dma_queue * dma_queue) { htp_act_preamble3; size_t src0_row_size = nb01; size_t src1_row_size = nb11; size_t dst_row_size = nb1; + + const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows const uint32_t src0_start_row = src0_nrows_per_thread * ith; @@ -105,10 +108,129 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0, uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); - int is_aligned = 1; - if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) { - is_aligned = 0; - FARF(HIGH, "swiglu-f32: unaligned addresses in elementwise op, possibly slower execution\n"); + const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; + const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; + uint8_t * restrict data_dst = (uint8_t *) dst->data; + + const bool src1_valid = src1->ne[0]; + const int nc = (src1_valid) ? ne00 : ne00 / 2; + if (!src1_valid) { + const int32_t swapped = op_params[1]; + data_src1 = data_src0; + src1_row_size = src0_row_size; + + const size_t nc_in_bytes = nc * SIZEOF_FP32; + data_src0 += swapped ? nc_in_bytes : 0; + data_src1 += swapped ? 0 : nc_in_bytes; + } + + const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN); + const size_t src1_row_size_aligned = htp_round_up(src1_row_size, VLEN); + const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN); + + uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); + uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread); + uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); + + // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 + size_t src0_spad_half_size = src0_spad->size_per_thread / 2; + size_t src1_spad_half_size = src1_spad->size_per_thread / 2; + size_t dst_spad_half_size = dst_spad->size_per_thread / 2; + + const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block + if (BLOCK == 0) { + FARF(ERROR, + "swiglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", + src0_spad->size_per_thread, src0_row_size_aligned); + return; + } + + // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 + for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { + const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); + + // Dummy DMA transation for sequencing (interleaving dst,src,dst,...) + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)), + dst_row_size, dst_row_size_aligned, 0); + + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)), + src0_row_size_aligned, src0_row_size, block_size); + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)), + src1_row_size_aligned, src1_row_size, block_size); + } + + for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) { + const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); + + float * dst_spad = (float *) dma_queue_pop(dma_queue).src; + float * src0_spad = (float *) dma_queue_pop(dma_queue).dst; + float * src1_spad = (float *) dma_queue_pop(dma_queue).dst; + + for (uint32_t ib = 0; ib < block_size; ib++) { + const float * src0_spad_ptr = src0_spad + ib * (src0_row_size_aligned / sizeof(float)); + const float * src1_spad_ptr = src1_spad + ib * (src1_row_size_aligned / sizeof(float)); + float * dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float)); + + //swiglu(x) = x1 * sigmoid(x0) + hvx_fast_sigmoid_f32((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, nc); + hvx_mul_mul_f32_opt((const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, + (const uint8_t *) src1_spad_ptr, (uint8_t *) dst_spad_ptr, nc); + } + + dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size, + dst_row_size_aligned, block_size); + + // prefetch N+2 loop iteration if any + const uint32_t pref_block = (ir + BLOCK * 2); + if (pref_block < src0_end_row) { + const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)), + src0_row_size_aligned, src0_row_size, pref_block_size); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)), + src1_row_size_aligned, src1_row_size, pref_block_size); + } + } + + dma_queue_flush(dma_queue); + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "swiglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0, + const struct htp_tensor * src1, + struct htp_tensor * dst, + const int32_t * op_params, + struct htp_spad * src0_spad, + struct htp_spad * src1_spad, + struct htp_spad * dst_spad, + uint32_t nth, + uint32_t ith, + uint32_t src0_nrows_per_thread, + dma_queue * dma_queue) { + htp_act_preamble3; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + size_t src0_row_size = nb01; + size_t src1_row_size = nb11; + size_t dst_row_size = nb1; + + const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows + + const uint32_t src0_start_row = src0_nrows_per_thread * ith; + const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); + + // no work for this thread + if (src0_start_row >= src0_end_row) { + return; } const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; @@ -127,130 +249,94 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0, data_src1 += swapped ? 0 : nc_in_bytes; } - uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size); - uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_row_size); - uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size); + const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN); + const size_t src1_row_size_aligned = htp_round_up(src1_row_size, VLEN); + const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN); - const bool opt_path = ((1 == is_aligned) && !(nb01 & (VLEN - 1))); - for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) { - const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size)); - const float * restrict src1 = (float *) (data_src1 + (ir * src1_row_size)); - float * restrict dst = (float *) (data_dst + (ir * dst_row_size)); + uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); + uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread); + uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); - if (ir + 1 < src0_end_row) { - htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size); - } + // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 + size_t src0_spad_half_size = src0_spad->size_per_thread / 2; + size_t src1_spad_half_size = src1_spad->size_per_thread / 2; + size_t dst_spad_half_size = dst_spad->size_per_thread / 2; - if (opt_path) { - hvx_fast_sigmoid_f32((const uint8_t *) src0, (uint8_t *) src0_spad_data, nc); - hvx_mul_mul_f32_opt((const uint8_t *) src0, (const uint8_t *) src0_spad_data, (const uint8_t *) src1, - (uint8_t *) dst, nc); - } else { - hvx_exp_f32((const uint8_t *) src0, src0_spad_data, nc, true); - hvx_add_scalar_f32(src0_spad_data, 1.0, src1_spad_data, nc); - hvx_inverse_f32(src1_spad_data, src0_spad_data, nc); - - hvx_mul_f32((const uint8_t *) src0, src0_spad_data, dst_spad_data, nc); - hvx_mul_f32(dst_spad_data, (const uint8_t *) src1, (uint8_t *) dst, nc); - } - } - - t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "swiglu-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, - ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, - (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); -} - -static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0, - const struct htp_tensor * src1, - struct htp_tensor * dst, - const int32_t * op_params, - struct htp_spad * src0_spad, - struct htp_spad * src1_spad, - struct htp_spad * dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread) { - htp_act_preamble3; - - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); - - const size_t src0_row_size = nb01; - const size_t src1_row_size = nb11; - const size_t dst_row_size = nb1; - - const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows - - const uint32_t src0_start_row = src0_nrows_per_thread * ith; - const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows); - - // no work for this thread - if (src0_start_row >= src0_end_row) { + const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block + if (BLOCK == 0) { + FARF(ERROR, + "swiglu-oai-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least " + "%zu\n", + src0_spad->size_per_thread, src0_row_size_aligned); return; } + const float alpha = ((const float *) (op_params))[2]; + const float limit = ((const float *) (op_params))[3]; - if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) { - FARF(HIGH, "act-f32: unaligned addresses in activations op, possibly slower execution\n"); + // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 + for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { + const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); + + // Dummy DMA transation for sequencing (interleaving dst,src,dst,...) + dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)), + dst_row_size, dst_row_size_aligned, 0); + + dma_queue_push_ddr_to_vtcm( + dma_queue, + dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)), + src0_row_size_aligned, src0_row_size, block_size); + dma_queue_push_ddr_to_vtcm( + dma_queue, + dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)), + src1_row_size_aligned, src1_row_size, block_size); } - const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; - const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; - uint8_t * restrict data_dst = (uint8_t *) dst->data; + for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) { + const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); - bool src1_valid = src1->ne[0]; - if (!src1_valid) { - data_src1 = data_src0; - } + float * dst_spad = (float *) dma_queue_pop(dma_queue).src; + float * src0_spad = (float *) dma_queue_pop(dma_queue).dst; + float * src1_spad = (float *) dma_queue_pop(dma_queue).dst; - uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size); - uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_row_size); - uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size); + for (uint32_t ib = 0; ib < block_size; ib++) { + const float * src0_spad_ptr = src0_spad + ib * (src0_row_size_aligned / sizeof(float)); + const float * src1_spad_ptr = src1_spad + ib * (src1_row_size_aligned / sizeof(float)); + float * dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float)); - const int32_t swapped = op_params[1]; - const float alpha = ((const float *) (op_params))[2]; - const float limit = ((const float *) (op_params))[3]; - - const int nc = (src1_valid) ? ne00 : ne00 / 2; - - for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) { - const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size)); - const float * restrict src1 = (float *) (data_src1 + (ir * src1_row_size)); - float * restrict dst = (float *) (data_dst + (ir * dst_row_size)); - - if (ir + 1 < src0_end_row) { - htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size); + // x (src0_spad_data) = std::min(src0_p[k], limit); + hvx_min_scalar_f32((const uint8_t *) src0_spad_ptr, limit, (uint8_t *) src0_spad_ptr, nc); + // y1 (src1_spad_data) = std::clamp(src1_p[k], -limit, limit); + hvx_clamp_scalar_f32((const uint8_t *) src1_spad_ptr, -limit, limit, (uint8_t *) src1_spad_ptr, nc); + // y (src1_spad_data) = y1 + 1.f + hvx_add_scalar_f32((const uint8_t *) src1_spad_ptr, 1.0, (uint8_t *) src1_spad_ptr, nc); + // x1 (dst_spad_data) = alpha * (x) + hvx_mul_scalar_f32((const uint8_t *) src0_spad_ptr, alpha, (uint8_t *) dst_spad_ptr, nc); + // x2 (dst_spad_data) = sigmoid(x1) = 1/(1+exp(-x1)) + hvx_fast_sigmoid_f32((const uint8_t *) dst_spad_ptr, (uint8_t *) dst_spad_ptr, nc); + // out = x * sigmoid(alpha * x) * (y + 1.f) + hvx_mul_mul_f32_opt((const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, + (const uint8_t *) src1_spad_ptr, (uint8_t *) dst_spad_ptr, nc); } - if (!src1) { - src0 += swapped ? nc : 0; - src1 += swapped ? 0 : nc; - } + dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size, + dst_row_size_aligned, block_size); - // x (src0_spad_data) = std::min(src0_p[k], limit); - hvx_min_scalar_f32((const uint8_t *) src0, limit, src0_spad_data, nc); - // y1 (src1_spad_data) = std::clamp(src1_p[k], -limit, limit); - hvx_clamp_scalar_f32((const uint8_t *) src1, -limit, limit, src1_spad_data, nc); - // y (src1_spad_data) = y1 + 1.f - hvx_add_scalar_f32(src1_spad_data, 1.0, src1_spad_data, nc); - // x1 (dst_spad_data) = alpha * (x) - hvx_mul_scalar_f32(src0_spad_data, alpha, dst_spad_data, nc); - // x2 (dst_spad_data) = expf(-x1) - hvx_exp_f32(dst_spad_data, dst_spad_data, nc, true); - // x3 (dst_spad_data) = x2 + 1.f - hvx_add_scalar_f32(dst_spad_data, 1.0, dst_spad_data, nc); - // x4 (dst_spad_data) = 1 / x3 - hvx_inverse_f32(dst_spad_data, dst_spad_data, nc); - // out_glu(dst_spad_data) = x * x4 - hvx_mul_f32(src0_spad_data, dst_spad_data, dst_spad_data, nc); - // out = out_glu * (y + 1.f); - hvx_mul_f32(dst_spad_data, src1_spad_data, (uint8_t *) dst, nc); + // prefetch N+2 loop iteration if any + const uint32_t pref_block = (ir + BLOCK * 2); + if (pref_block < src0_end_row) { + const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)), + src0_row_size_aligned, src0_row_size, pref_block_size); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)), + src1_row_size_aligned, src1_row_size, pref_block_size); + } } + dma_queue_flush(dma_queue); + t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "swiglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, src0->ne[0], + FARF(HIGH, "swiglu-oai-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } @@ -371,7 +457,8 @@ static void unary_silu_fp32_per_thread(const struct htp_tensor * src0, struct htp_spad * dst_spad, uint32_t nth, uint32_t ith, - uint32_t src0_nrows_per_thread) { + uint32_t src0_nrows_per_thread, + dma_queue * dma_queue) { htp_act_preamble2; uint64_t t1, t2; @@ -379,6 +466,8 @@ static void unary_silu_fp32_per_thread(const struct htp_tensor * src0, const size_t src0_row_size = nb01; const size_t dst_row_size = nb1; + const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN); + const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN); const uint32_t src0_nrows = ne01 * ne02 * ne03; @@ -390,64 +479,91 @@ static void unary_silu_fp32_per_thread(const struct htp_tensor * src0, return; } - int is_aligned = 1; - int opt_path = 0; - if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) { - is_aligned = 0; - FARF(HIGH, "silu-f32: unaligned addresses in elementwise op, possibly slower execution\n"); - } - if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) { - opt_path = 1; + const uint8_t * data_src0 = (const uint8_t *) src0->data; + uint8_t * data_dst = (uint8_t *) dst->data; + + uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); + uint8_t * dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); + + // While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 + size_t src0_spad_half_size = src0_spad->size_per_thread / 2; + size_t dst_spad_half_size = dst_spad->size_per_thread / 2; + + const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block + + if (BLOCK == 0) { + FARF(ERROR, "silu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n", + src0_spad->size_per_thread, src0_row_size_aligned); + return; } - const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; - uint8_t * restrict data_dst = (uint8_t *) dst->data; + // See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 + for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { + const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); - uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size); - uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size); + // Dummy DMA transation for sequencing (interleaving dst,src,dst,...) + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)), + dst_row_size, dst_row_size_aligned, 0); - for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) { - const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size)); - float * restrict dst = (float *) (data_dst + (ir * dst_row_size)); + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)), + src0_row_size_aligned, src0_row_size, block_size); + } - if (ir + 1 < src0_end_row) { - htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size); + for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) { + const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); + + float* dst_spad = (float *) dma_queue_pop(dma_queue).src; + float* src0_spad = (float *) dma_queue_pop(dma_queue).dst; + + for (uint32_t ib = 0; ib < block_size; ib++) { + const float* src0_spad_ptr = src0_spad + ib * (src0_row_size_aligned / sizeof(float)); + float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float)); + + // silu = x * sigmoid(x) + hvx_fast_sigmoid_f32((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, ne0); + hvx_mul_f32_opt((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, (uint8_t *) dst_spad_ptr, ne0); } - if (1 == opt_path) { - hvx_fast_sigmoid_f32((const uint8_t *) src0, (uint8_t *) src0_spad_data, ne0); - hvx_mul_f32_opt((const uint8_t *) src0, src0_spad_data, (uint8_t *) dst, ne0); - } else { - hvx_exp_f32((const uint8_t *) src0, src0_spad_data, ne0, true); - hvx_add_scalar_f32(src0_spad_data, 1.0, dst_spad_data, ne0); - hvx_inverse_f32(dst_spad_data, src0_spad_data, ne0); + dma_queue_push_vtcm_to_ddr(dma_queue, + dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), + dst_row_size, dst_row_size_aligned, block_size); - hvx_mul_f32((const uint8_t *) src0, src0_spad_data, (uint8_t *) dst, ne0); + // prefetch N+2 loop iteration if any + const uint32_t pref_block = (ir + BLOCK * 2); + if (pref_block < src0_end_row) { + const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block); + dma_queue_push_ddr_to_vtcm(dma_queue, + dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)), + src0_row_size_aligned, src0_row_size, pref_block_size); } } + dma_queue_flush(dma_queue); + t2 = HAP_perf_get_qtimer_count(); - FARF(HIGH, "silu-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, ne00, ne01, ne02, + FARF(HIGH, "silu-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } static void unary_silu_fp32(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = (struct htp_ops_context *) data; unary_silu_fp32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread); + octx->src0_nrows_per_thread, octx->ctx->dma[i]); } static void glu_swiglu_fp32(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = (struct htp_ops_context *) data; glu_swiglu_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad, - &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread); + &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); } static void glu_swiglu_oai_fp32(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = (struct htp_ops_context *) data; glu_swiglu_oai_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad, - &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread); + &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); } static int execute_op_activations_fp32(struct htp_ops_context * octx) { diff --git a/ggml/src/ggml-hexagon/htp/flash-attn-ops.c b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c new file mode 100644 index 0000000000..04a7b843ce --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/flash-attn-ops.c @@ -0,0 +1,566 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#ifdef HTP_DEBUG +# define FARF_HIGH 1 +#endif +#include +#include +#include +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-dma.h" +#include "htp-msg.h" +#include "htp-ops.h" +#include "hvx-utils.h" +#include "ops-utils.h" + +// Dot product of FP32 and FP16 vectors, accumulating to float +static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) { + const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32 + const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16 + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + const HVX_Vector zero = Q6_V_vsplat_R(0); + HVX_Vector rsum = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; i++) { + // Load y (fp32) and convert into fp16 + HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements + HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements + HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); + + // Load x (fp16) + HVX_Vector x_hf = vx[i]; + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + if (nloe) { + // Load y (fp32) and convert into fp16 + HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements + HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements + HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); + + // Load x (fp16) + HVX_Vector x_hf = vx[i]; + + // Zero-out unused elements + // Note that we need to clear both x and y because they may contain NANs + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + x_hf = Q6_V_vand_QV(bmask, x_hf); + y_hf = Q6_V_vand_QV(bmask, y_hf); + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s)); + rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum)); + + hvx_vec_store_u(r, 4, rsum); +} + +// Dot product of two F16 vectors, accumulating to float +static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) { + const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16 + const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16 + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + const HVX_Vector zero = Q6_V_vsplat_R(0); + HVX_Vector rsum = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; i++) { + HVX_Vector y_hf = vy[i]; + HVX_Vector x_hf = vx[i]; + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + if (nloe) { + HVX_Vector y_hf = vy[i]; + + // Load x (fp16) and zero-out unused elements + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]); + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s)); + rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum)); + hvx_vec_store_u(r, 4, rsum); +} + +// MAD: y (F32) += x (F16) * v (float) +static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) { + const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x; + HVX_Vector * restrict ptr_y = (HVX_Vector *) y; + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + HVX_Vector S = hvx_vec_splat_fp16(s); + + uint32_t i = 0; + #pragma unroll(4) + for (i = 0; i < nvec; ++i) { + // Multiply x * s -> pair of F32 vectors + HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S); + ptr_y[i*2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(xs_p), ptr_y[i*2])); + ptr_y[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(xs_p), ptr_y[i*2+1])); + } + + if (nloe) { + HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S); + + HVX_Vector xs = Q6_V_lo_W(xs_p); + i = 2 * i; // index for ptr_y + + if (nloe >= 32) { + ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i])); + nloe -= 32; ++i; xs = Q6_V_hi_W(xs_p); + } + + if (nloe) { + HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i])); + hvx_vec_store_u(&ptr_y[i], nloe * 4, xy); + } + } +} + +#define FLASH_ATTN_BLOCK_SIZE 128 + +static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, int nth) { + const struct htp_tensor * q = &octx->src0; + const struct htp_tensor * k = &octx->src1; + const struct htp_tensor * v = &octx->src2; + const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL; + const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL; + struct htp_tensor * dst = &octx->dst; + + const uint32_t neq0 = q->ne[0]; + const uint32_t neq1 = q->ne[1]; + const uint32_t neq2 = q->ne[2]; + const uint32_t neq3 = q->ne[3]; + + const uint32_t nek0 = k->ne[0]; + const uint32_t nek1 = k->ne[1]; + const uint32_t nek2 = k->ne[2]; + const uint32_t nek3 = k->ne[3]; + + const uint32_t nev0 = v->ne[0]; + const uint32_t nev1 = v->ne[1]; + const uint32_t nev2 = v->ne[2]; + const uint32_t nev3 = v->ne[3]; + + const uint32_t nbq1 = q->nb[1]; + const uint32_t nbq2 = q->nb[2]; + const uint32_t nbq3 = q->nb[3]; + + const uint32_t nbk1 = k->nb[1]; + const uint32_t nbk2 = k->nb[2]; + const uint32_t nbk3 = k->nb[3]; + + const uint32_t nbv1 = v->nb[1]; + const uint32_t nbv2 = v->nb[2]; + const uint32_t nbv3 = v->nb[3]; + + const uint32_t ne1 = dst->ne[1]; + const uint32_t ne2 = dst->ne[2]; + const uint32_t ne3 = dst->ne[3]; + + const uint32_t nb1 = dst->nb[1]; + const uint32_t nb2 = dst->nb[2]; + const uint32_t nb3 = dst->nb[3]; + + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (float *) octx->op_params + 0, sizeof(float)); + memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float)); + + if (logit_softcap != 0) { + scale /= logit_softcap; + } + + // total rows in q + const uint32_t nr = neq1*neq2*neq3; + + const uint32_t dr = (nr + nth - 1) / nth; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = MIN(ir0 + dr, nr); + + if (ir0 >= ir1) return; + + dma_queue * dma = octx->ctx->dma[ith]; + + const uint32_t DK = nek0; + const uint32_t DV = nev0; + + const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2); + const size_t size_q_row_padded = htp_round_up(size_q_row, 128); + + const size_t size_k_row = DK * sizeof(__fp16); + const size_t size_v_row = DV * sizeof(__fp16); + const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask + + const size_t size_k_row_padded = htp_round_up(size_k_row, 128); + const size_t size_v_row_padded = htp_round_up(size_v_row, 128); + + const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE; + const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE; + const size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128); + + // Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator + uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith; + uint8_t * spad_k = octx->src1_spad.data + octx->src1_spad.size_per_thread * ith; + uint8_t * spad_v = octx->src2_spad.data + octx->src2_spad.size_per_thread * ith; + uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith; + uint8_t * spad_a = octx->dst_spad.data + octx->dst_spad.size_per_thread * ith; + + const uint32_t n_head = neq2; + const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + for (uint32_t ir = ir0; ir < ir1; ++ir) { + const uint32_t iq3 = fastdiv(ir, &octx->src0_div21); + const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1); + const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1); + + const uint32_t ik3 = fastdiv(iq3, &octx->broadcast_rk3); + const uint32_t ik2 = fastdiv(iq2, &octx->broadcast_rk2); + + const uint32_t iv3 = fastdiv(iq3, &octx->broadcast_rv3); + const uint32_t iv2 = fastdiv(iq2, &octx->broadcast_rv2); + + // Fetch Q row + const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3); + dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), size_q_row_padded, nbq1, size_q_row, 1); + + const uint32_t h = iq2; // head index + const float slope = (max_bias > 0.0f) ? (h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)) : 1.0f; + + float S = 0.0f; // sum + float M = -INFINITY; // maximum KQ value + + // Clear accumulator + float * VKQ32 = (float *) spad_a; + memset(VKQ32, 0, DV * sizeof(float)); + + const __fp16 * mp_base = NULL; + if (mask) { + const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &octx->src3_div2); + const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &octx->src3_div3); + mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]); + } + + const uint32_t n_blocks = (nek1 + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE; + + // Prefetch first two blocks + for (uint32_t ib = 0; ib < MIN(n_blocks, 2); ++ib) { + const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE; + const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start); + + // K + const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3); + uint8_t * k_dst = spad_k + (ib % 2) * size_k_block; + dma_queue_push(dma, dma_make_ptr(k_dst, k_src), size_k_row_padded, nbk1, size_k_row, current_block_size); + + // V + const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3); + uint8_t * v_dst = spad_v + (ib % 2) * size_v_block; + dma_queue_push(dma, dma_make_ptr(v_dst, v_src), size_v_row_padded, nbv1, size_v_row, current_block_size); + + // Mask + if (mask) { + const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start); + uint8_t * m_dst = spad_m + (ib % 2) * size_m_block; + // Mask is 1D contiguous for this row + dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1); + } + } + + const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst; + + for (uint32_t ib = 0; ib < n_blocks; ++ib) { + const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE; + const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start); + + // Wait for DMA + uint8_t * k_base = dma_queue_pop(dma).dst; // K + uint8_t * v_base = dma_queue_pop(dma).dst; // V + __fp16 * m_base = mask ? dma_queue_pop(dma).dst : NULL; // M + + // Inner loop processing the block from VTCM + uint32_t ic = 0; + + // Process in blocks of 32 (VLEN_FP32) + for (; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32) { + // 1. Compute scores + float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32]; + for (int j = 0; j < VLEN_FP32; ++j) { + const uint32_t cur_ic = ic + j; + const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded; + if (q->type == HTP_TYPE_F32) { + hvx_dot_f32_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale); + } else { + hvx_dot_f16_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale); + } + } + + HVX_Vector scores = *(HVX_Vector *) scores_arr; + + // 2. Softcap + if (logit_softcap != 0.0f) { + scores = hvx_vec_tanh_fp32(scores); + scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_fp32(logit_softcap)); + scores = Q6_Vsf_equals_Vqf32(scores); + } + + // 3. Mask + if (mask) { + const __fp16 * mp = m_base + ic; + HVX_Vector m_vals_fp16 = *(const HVX_UVector *) mp; + + HVX_Vector one_fp16 = Q6_Vh_vsplat_R(0x3c00); + HVX_VectorPair m_vals_fp32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_fp16), one_fp16); + + HVX_Vector m_vals_fp32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_fp32_pair)); + + HVX_Vector slope_vec = hvx_vec_splat_fp32(slope); + HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_fp32, slope_vec); + scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val)); + scores = Q6_Vsf_equals_Vqf32(scores); + } + + // 4. Online Softmax Update + HVX_Vector v_max = hvx_vec_reduce_max_fp32(scores); + float m_block = hvx_vec_get_fp32(v_max); + + float M_old = M; + float M_new = (m_block > M) ? m_block : M; + M = M_new; + + float ms = expf(M_old - M_new); + + hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms); + S = S * ms; + + HVX_Vector M_new_vec = hvx_vec_splat_fp32(M_new); + HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec); + HVX_Vector P = hvx_vec_exp_fp32(Q6_Vsf_equals_Vqf32(scores_shifted)); + + HVX_Vector p_sum_vec = hvx_vec_fp32_reduce_sum(P); + float p_sum = hvx_vec_get_fp32(p_sum_vec); + S += p_sum; + + // 5. Accumulate V + float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32]; + *(HVX_Vector*)p_arr = P; + + for (int j = 0; j < VLEN_FP32; ++j) { + const uint32_t cur_ic = ic + j; + const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded; + hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]); + } + } + + // Leftover + for (; ic < current_block_size; ++ic) { + float s_val; + const uint8_t * k_ptr = k_base + ic * size_k_row_padded; + + if (q->type == HTP_TYPE_F32) { + hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale); + } else { + hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale); + } + + if (logit_softcap != 0.0f) { + s_val = logit_softcap * tanhf(s_val); + } + + if (mask) { + const float m_val = m_base[ic]; + s_val += slope * m_val; + } + + const float Mold = M; + float ms = 1.0f; + float vs = 1.0f; + + if (s_val > M) { + M = s_val; + ms = expf(Mold - M); + hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms); + } else { + vs = expf(s_val - M); + } + + const uint8_t * v_ptr = v_base + ic * size_v_row_padded; + + hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs); + + S = S * ms + vs; + } + + // Issue DMA for next+1 block (if exists) + if (ib + 2 < n_blocks) { + const uint32_t next_ib = ib + 2; + const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE; + const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start); + + // K + const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3); + dma_queue_push(dma, dma_make_ptr(k_base, k_src), size_k_row_padded, nbk1, size_k_row, next_block_size); + + // V + const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3); + dma_queue_push(dma, dma_make_ptr(v_base, v_src), size_v_row_padded, nbv1, size_v_row, next_block_size); + + // Mask + if (mask) { + const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start); + dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1); + } + } + } + + // sinks + if (sinks) { + const float s = ((float *)((char *) sinks->data))[h]; + + float ms = 1.0f; + float vs = 1.0f; + + if (s > M) { + ms = expf(M - s); + hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms); + } else { + vs = expf(s - M); + } + + S = S * ms + vs; + } + + const float S_inv = S == 0.0f ? 0.0f : 1.0f/S; + hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, S_inv); + + // Store result + // dst indices + const int i1 = iq1; + const int i2 = iq2; + const int i3 = iq3; + + // dst is permuted + uint8_t * dst_ptr = (uint8_t *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1) * nb1; + + if (dst->type == HTP_TYPE_F32) { + hvx_copy_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV); + } else if (dst->type == HTP_TYPE_F16) { + hvx_copy_fp16_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV); + } + } +} + +static void htp_flash_attn_ext_job(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + flash_attn_ext_f16_thread(octx, i, n); +} + +int op_flash_attn_ext(struct htp_ops_context * octx) { + const struct htp_tensor * q = &octx->src0; + const struct htp_tensor * k = &octx->src1; + const struct htp_tensor * v = &octx->src2; + const struct htp_tensor * mask = (octx->src3.type != HTP_TYPE_COUNT) ? &octx->src3 : NULL; + struct htp_tensor * dst = &octx->dst; + + // Check support + if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) || + k->type != HTP_TYPE_F16 || + v->type != HTP_TYPE_F16) { + return HTP_STATUS_NO_SUPPORT; + } + + octx->src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]); + octx->src0_div1 = init_fastdiv_values(q->ne[1]); + + octx->broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]); + octx->broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]); + octx->broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]); + octx->broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]); + + if (mask) { + octx->src3_div2 = init_fastdiv_values(mask->ne[2]); + octx->src3_div3 = init_fastdiv_values(mask->ne[3]); + } + + size_t size_q_row_padded = htp_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128); + size_t size_k_row_padded = htp_round_up(k->ne[0] * sizeof(__fp16), 128); + size_t size_v_row_padded = htp_round_up(v->ne[0] * sizeof(__fp16), 128); + + size_t size_q_block = size_q_row_padded * 1; // single row for now + size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE; + size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE; + size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128); + + size_t size_vkq_acc = htp_round_up(v->ne[0] * sizeof(float), 128); // VKQ32 + + octx->src0_spad.size_per_thread = size_q_block * 1; + octx->src1_spad.size_per_thread = size_k_block * 2; + octx->src2_spad.size_per_thread = size_v_block * 2; + octx->src3_spad.size_per_thread = mask ? size_m_block * 2 : 0; + octx->dst_spad.size_per_thread = size_vkq_acc; + + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; + octx->src2_spad.size = octx->src2_spad.size_per_thread * octx->n_threads; + octx->src3_spad.size = octx->src3_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + + size_t total_spad = octx->src0_spad.size + octx->src1_spad.size + octx->src2_spad.size + octx->src3_spad.size + octx->dst_spad.size; + + if (octx->ctx->vtcm_size < total_spad) { + return HTP_STATUS_VTCM_TOO_SMALL; + } + + octx->src0_spad.data = octx->ctx->vtcm_base; + octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size; + octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size; + octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size; + octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size; + + if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) { + worker_pool_run_func(octx->ctx->worker_pool, htp_flash_attn_ext_job, octx, octx->n_threads); + } + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/get-rows-ops.c b/ggml/src/ggml-hexagon/htp/get-rows-ops.c new file mode 100644 index 0000000000..54321421eb --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/get-rows-ops.c @@ -0,0 +1,112 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#ifdef HTP_DEBUG +# define FARF_HIGH 1 +#endif +#include +#include +#include +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-msg.h" +#include "htp-ops.h" +#include "hvx-utils.h" +#include "ops-utils.h" + +#define get_rows_preamble \ + const uint32_t ne00 = octx->src0.ne[0]; \ + const uint32_t ne01 = octx->src0.ne[1]; \ + const uint32_t ne02 = octx->src0.ne[2]; \ + const uint32_t ne03 = octx->src0.ne[3]; \ + \ + const uint32_t ne10 = octx->src1.ne[0]; \ + const uint32_t ne11 = octx->src1.ne[1]; \ + const uint32_t ne12 = octx->src1.ne[2]; \ + \ + const uint32_t nb01 = octx->src0.nb[1]; \ + const uint32_t nb02 = octx->src0.nb[2]; \ + const uint32_t nb03 = octx->src0.nb[3]; \ + \ + const uint32_t nb10 = octx->src1.nb[0]; \ + const uint32_t nb11 = octx->src1.nb[1]; \ + const uint32_t nb12 = octx->src1.nb[2]; \ + \ + const uint32_t nb1 = octx->dst.nb[1]; \ + const uint32_t nb2 = octx->dst.nb[2]; \ + const uint32_t nb3 = octx->dst.nb[3]; \ + \ + const uint32_t nr = ne10 * ne11 * ne12; + +static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) { + get_rows_preamble; + + // parallelize by src1 elements (which correspond to dst rows) + const uint32_t dr = octx->src1_nrows_per_thread; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; + + const bool is_i32 = (octx->src1.type == HTP_TYPE_I32); + + for (uint32_t i = ir0; i < ir1; ++i) { + const uint32_t i12 = fastdiv(i, &octx->get_rows_div_ne10_ne11); + const uint32_t rem = i - i12 * ne11 * ne10; + const uint32_t i11 = fastdiv(rem, &octx->get_rows_div_ne10); + const uint32_t i10 = rem - i11 * ne10; + + const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; + + uint32_t i01 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; + + if (i01 >= ne01) { + // invalid index, skip for now to avoid crash + continue; + } + + const uintptr_t src0_ptr = octx->src0.data + i01*nb01 + i11*nb02 + i12*nb03; + const uintptr_t dst_ptr = octx->dst.data + i10*nb1 + i11*nb2 + i12*nb3; + hvx_copy_fp32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00); + } + + return HTP_STATUS_OK; +} + +static void get_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) { + get_rows_thread_f32_f32((struct htp_ops_context *) data, n, i); +} + +int op_get_rows(struct htp_ops_context * octx) { + get_rows_preamble; + + if (octx->src0.type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->dst.type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + octx->get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]); + octx->get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]); + + const uint32_t n_jobs = MIN(nr, octx->n_threads); + octx->src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; + + worker_pool_run_func(octx->ctx->worker_pool, get_rows_work_f32_f32, octx, n_jobs); + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/htp-ctx.h b/ggml/src/ggml-hexagon/htp/htp-ctx.h index 5c3d217f1c..4bd0ea7a36 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ctx.h +++ b/ggml/src/ggml-hexagon/htp/htp-ctx.h @@ -11,11 +11,6 @@ #define HTP_MAX_NTHREADS 10 -// FIXME: move these into matmul-ops -#define HTP_SPAD_SRC0_NROWS 16 -#define HTP_SPAD_SRC1_NROWS 16 -#define HTP_SPAD_DST_NROWS 2 - // Main context for htp DSP backend struct htp_context { dspqueue_t queue; diff --git a/ggml/src/ggml-hexagon/htp/htp-msg.h b/ggml/src/ggml-hexagon/htp/htp-msg.h index a61652304a..846d061784 100644 --- a/ggml/src/ggml-hexagon/htp/htp-msg.h +++ b/ggml/src/ggml-hexagon/htp/htp-msg.h @@ -36,6 +36,8 @@ enum htp_data_type { HTP_TYPE_F16 = 1, HTP_TYPE_Q4_0 = 2, HTP_TYPE_Q8_0 = 8, + HTP_TYPE_I32 = 26, + HTP_TYPE_I64 = 27, HTP_TYPE_MXFP4 = 39, HTP_TYPE_COUNT }; @@ -57,6 +59,10 @@ enum htp_op { HTP_OP_SOFTMAX = 11, HTP_OP_ADD_ID = 12, HTP_OP_ROPE = 13, + HTP_OP_FLASH_ATTN_EXT = 14, + HTP_OP_SET_ROWS = 15, + HTP_OP_SCALE = 16, + HTP_OP_GET_ROWS = 17, INVALID }; @@ -137,6 +143,8 @@ struct htp_general_req { struct htp_tensor src0; // Input0 tensor struct htp_tensor src1; // Input1 tensor struct htp_tensor src2; // Input2 tensor + struct htp_tensor src3; // Input3 tensor + struct htp_tensor src4; // Input4 tensor struct htp_tensor dst; // Output tensor // should be multiple of 64 bytes (cacheline) @@ -152,6 +160,6 @@ struct htp_general_rsp { }; #define HTP_MAX_MESSAGE_SIZE sizeof(struct htp_general_req) -#define HTP_MAX_PACKET_BUFFERS 4 +#define HTP_MAX_PACKET_BUFFERS 8 #endif /* HTP_MSG_H */ diff --git a/ggml/src/ggml-hexagon/htp/htp-ops.h b/ggml/src/ggml-hexagon/htp/htp-ops.h index e87657436f..7c828ae636 100644 --- a/ggml/src/ggml-hexagon/htp/htp-ops.h +++ b/ggml/src/ggml-hexagon/htp/htp-ops.h @@ -13,6 +13,7 @@ struct htp_spad { uint8_t * data; + size_t stride; size_t size; size_t size_per_thread; }; @@ -26,11 +27,14 @@ struct htp_ops_context { struct htp_tensor src0; struct htp_tensor src1; struct htp_tensor src2; + struct htp_tensor src3; + struct htp_tensor src4; struct htp_tensor dst; struct htp_spad src0_spad; struct htp_spad src1_spad; struct htp_spad src2_spad; + struct htp_spad src3_spad; struct htp_spad dst_spad; worker_pool_context_t * wpool; // worker pool @@ -49,6 +53,27 @@ struct htp_ops_context { struct fastdiv_values src1_div3; // fastdiv values for ne3 struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1 + struct fastdiv_values src3_div1; // fastdiv values for ne1 + struct fastdiv_values src3_div2; // fastdiv values for ne2 + struct fastdiv_values src3_div3; // fastdiv values for ne3 + struct fastdiv_values src3_div21; // fastdiv values for ne2 * ne1 + + struct fastdiv_values broadcast_rk2; + struct fastdiv_values broadcast_rk3; + struct fastdiv_values broadcast_rv2; + struct fastdiv_values broadcast_rv3; + + struct fastdiv_values mm_div_ne12_ne1; // fastdiv values for ne12 * ne1 + struct fastdiv_values mm_div_ne1; // fastdiv values for ne1 + struct fastdiv_values mm_div_r2; // fastdiv values for ne12 / ne02 + struct fastdiv_values mm_div_r3; // fastdiv values for ne13 / ne03 + + struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12 + struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11 + + struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10 + struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11 + uint32_t flags; }; @@ -60,5 +85,8 @@ int op_activations(struct htp_ops_context * octx); int op_softmax(struct htp_ops_context * octx); int op_add_id(struct htp_ops_context * octx); int op_rope(struct htp_ops_context * octx); +int op_flash_attn_ext(struct htp_ops_context * octx); +int op_set_rows(struct htp_ops_context * octx); +int op_get_rows(struct htp_ops_context * octx); #endif /* HTP_OPS_H */ diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.c b/ggml/src/ggml-hexagon/htp/hvx-utils.c index f9e02ab67e..29d73b8622 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.c +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.c @@ -848,55 +848,6 @@ float hvx_self_sum_f32(const uint8_t * restrict src, const int num_elems) { return hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(v)); } -void hvx_scale_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, const float scale) { - int left_over = num_elems & (VLEN_FP32 - 1); - int num_elems_whole = num_elems - left_over; - - int unaligned_addr = 0; - int unaligned_loop = 0; - if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) { - FARF(HIGH, "hvx_scale_f32: unaligned address in hvx op, possibly slower execution\n"); - unaligned_addr = 1; - } - - if ((1 == unaligned_addr) && (num_elems_whole != 0)) { - unaligned_loop = 1; - FARF(HIGH, "hvx_scale_f32: unaligned loop in hvx op, possibly slower execution\n"); - } - - HVX_Vector scale_vec = hvx_vec_splat_fp32(scale); - - if (0 == unaligned_loop) { - HVX_Vector * vec_in1 = (HVX_Vector *) src; - HVX_Vector * vec_out = (HVX_Vector *) dst; - - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1++, scale_vec); - *vec_out++ = Q6_Vsf_equals_Vqf32(v); - } - } else { - #pragma unroll(4) - for (int i = 0; i < num_elems_whole; i += VLEN_FP32) { - HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32); - - HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, scale_vec); - - *(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out); - } - } - - if (left_over > 0) { - const float * srcf = (const float *) src + num_elems_whole; - float * dstf = (float *) dst + num_elems_whole; - - HVX_Vector in = *(HVX_UVector *) srcf; - - HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, scale_vec); - hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out)); - } -} - float hvx_self_max_f32(const uint8_t * restrict src, const int num_elems) { int left_over = num_elems & (VLEN_FP32 - 1); int num_elems_whole = num_elems - left_over; @@ -1065,3 +1016,5 @@ void hvx_clamp_scalar_f32(const uint8_t * restrict src, hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, in_vec); } } + + diff --git a/ggml/src/ggml-hexagon/htp/hvx-utils.h b/ggml/src/ggml-hexagon/htp/hvx-utils.h index d2d5d23636..22876e6dba 100644 --- a/ggml/src/ggml-hexagon/htp/hvx-utils.h +++ b/ggml/src/ggml-hexagon/htp/hvx-utils.h @@ -41,15 +41,24 @@ static inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in) } #endif -static inline HVX_Vector hvx_vec_splat_fp32(float i) { +static inline HVX_Vector hvx_vec_splat_fp32(float v) { union { - float f; - int32_t i; - } fp32 = { .f = i }; + float f; + uint32_t i; + } fp32 = { .f = v }; return Q6_V_vsplat_R(fp32.i); } +static inline HVX_Vector hvx_vec_splat_fp16(float v) { + union { + __fp16 f; + uint16_t i; + } fp16 = { .f = v }; + + return Q6_Vh_vsplat_R(fp16.i); +} + static inline void hvx_vec_store_u(void * addr, uint32_t n, HVX_Vector v) { // Rotate as needed. v = Q6_V_vlalign_VVR(v, v, (size_t) addr); @@ -242,6 +251,120 @@ static inline void hvx_copy_fp32_au(uint8_t * restrict dst, const uint8_t * rest } } +// copy n fp32 elements : source is unaligned, destination unaligned +static inline void hvx_copy_fp32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + HVX_UVector * restrict vdst = (HVX_UVector *) dst; + HVX_UVector * restrict vsrc = (HVX_UVector *) src; + + assert((unsigned long) dst % 128 == 0); + + uint32_t nvec = n / 32; + uint32_t nloe = n % 32; + + uint32_t i = 0; + + #pragma unroll(4) + for (; i < nvec; i++) { + HVX_Vector v = vsrc[i]; + vdst[i] = v; + } + + if (nloe) { + HVX_Vector v = vsrc[i]; + hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v); + } +} + +// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is unaligned +static inline void hvx_copy_fp16_fp32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + HVX_UVector * restrict vdst = (HVX_UVector *) dst; // fp16 + HVX_UVector * restrict vsrc = (HVX_UVector *) src; // fp32 + + const HVX_Vector zero = Q6_V_vsplat_R(0); + + uint32_t nvec = n / 64; + uint32_t nloe = n % 64; + + uint32_t i = 0; + + #pragma unroll(4) + for (; i < nvec; i++) { + // Load y (fp32) and convert into fp16 + HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements + HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements + HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); + vdst[i] = Q6_Vh_vdeal_Vh(s_hf); + } + + if (nloe) { + // Load y (fp32) and convert into fp16 + HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements + HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements + HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); + hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf)); + } +} + +// copy/convert n fp32 elements into n fp16 elements : source is aligned, destination is unaligned +static inline void hvx_copy_fp16_fp32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + HVX_UVector * restrict vdst = (HVX_UVector *) dst; // fp16 + HVX_Vector * restrict vsrc = (HVX_Vector *) src; // fp32 + + const HVX_Vector zero = Q6_V_vsplat_R(0); + + uint32_t nvec = n / 64; + uint32_t nloe = n % 64; + + uint32_t i = 0; + + #pragma unroll(4) + for (; i < nvec; i++) { + // Load y (fp32) and convert into fp16 + HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements + HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements + HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); + vdst[i] = Q6_Vh_vdeal_Vh(s_hf); + } + + if (nloe) { + // Load y (fp32) and convert into fp16 + HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements + HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements + HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); + hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf)); + } +} + +// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is aligned +static inline void hvx_copy_fp16_fp32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) { + HVX_Vector * restrict vdst = (HVX_Vector *) dst; // fp16 + HVX_UVector * restrict vsrc = (HVX_UVector *) src; // fp32 + + const HVX_Vector zero = Q6_V_vsplat_R(0); + + uint32_t nvec = n / 64; + uint32_t nloe = n % 64; + + uint32_t i = 0; + + #pragma unroll(4) + for (; i < nvec; i++) { + // Load y (fp32) and convert into fp16 + HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements + HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements + HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); + vdst[i] = Q6_Vh_vdeal_Vh(s_hf); + } + + if (nloe) { + // Load y (fp32) and convert into fp16 + HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements + HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements + HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf)); + hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf)); + } +} + // bcast 1 fp32 element from source to n fp32 elements in destination : destination is aligned static inline void hvx_bcast_fp32_a(uint8_t * restrict dst, float elem, uint32_t n) { HVX_Vector * restrict vdst = (HVX_Vector *) dst; @@ -273,8 +396,6 @@ static __attribute__((always_inline)) int32_t is_in_one_chunk(void * addr, uint3 return right_off <= chunk_size; } - - static void hvx_vec_dump_fp16_n(char * pref, HVX_Vector v, uint32_t n) { HVX_VectorAlias u = { .v = v }; @@ -531,13 +652,13 @@ static inline HVX_Vector hvx_vec_abs_fp32(HVX_Vector v) { } static inline HVX_Vector hvx_vec_neg_fp32(HVX_Vector v) { -#if __HTP_ARCH__ > 75 +#if __HVX_ARCH__ > 75 return Q6_Vsf_vfneg_Vsf(v); #else // neg by setting the fp32 sign bit HVX_Vector mask = Q6_V_vsplat_R(0x80000000); return Q6_V_vxor_VV(v, mask); -#endif // __HTP_ARCH__ > 75 +#endif // __HVX_ARCH__ > 75 } // ==================================================== @@ -976,6 +1097,24 @@ static inline HVX_Vector hvx_vec_fast_sigmoid_fp32_guard(HVX_Vector v, return Q6_V_vmux_QVV(pred_min, out, Q6_V_vzero()); } +static inline HVX_Vector hvx_vec_tanh_fp32(HVX_Vector x) { + // tanh(x) = 2 * sigmoid(2x) - 1 + HVX_Vector two = hvx_vec_splat_fp32(2.0f); + HVX_Vector one = hvx_vec_splat_fp32(1.0f); + HVX_Vector x2 = Q6_Vqf32_vmpy_VsfVsf(x, two); + + static const float kMinExp = -87.f; // 0 + static const float kMaxExp = 87.f; // 1 + HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp); + HVX_Vector min_exp = hvx_vec_splat_fp32(kMinExp); + + HVX_Vector sig2x = hvx_vec_fast_sigmoid_fp32_guard(Q6_Vsf_equals_Vqf32(x2), one, max_exp, min_exp); + + HVX_Vector res = Q6_Vqf32_vmpy_VsfVsf(sig2x, two); + res = Q6_Vqf32_vsub_Vqf32Vsf(res, one); + return Q6_Vsf_equals_Vqf32(res); +} + static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) { int step_of_1 = num_elems >> 5; int remaining = num_elems - step_of_1 * VLEN_FP32; @@ -1056,6 +1195,115 @@ static inline void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restr } } +static inline void hvx_scale_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) { + int nvec = n / VLEN_FP32; + int nloe = n % VLEN_FP32; + + HVX_Vector vs = hvx_vec_splat_fp32(scale); + + HVX_Vector * vsrc = (HVX_Vector *) src; + HVX_Vector * vdst = (HVX_Vector *) dst; + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; ++i) { + HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); + vdst[i] = Q6_Vsf_equals_Vqf32(v); + } + + if (nloe) { + HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); + hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v)); + } +} + +static inline void hvx_scale_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) { + int nvec = n / VLEN_FP32; + int nloe = n % VLEN_FP32; + + HVX_Vector vs = hvx_vec_splat_fp32(scale); + + HVX_UVector * vsrc = (HVX_UVector *) src; + HVX_UVector * vdst = (HVX_UVector *) dst; + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; ++i) { + HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); + vdst[i] = Q6_Vsf_equals_Vqf32(v); + } + + if (nloe) { + HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs); + hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v)); + } +} + +static inline void hvx_scale_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) { + if (htp_is_aligned((void *) src, VLEN) && htp_is_aligned((void *) dst, VLEN)) { + hvx_scale_f32_aa(dst, src, n, scale); + } else { + hvx_scale_f32_uu(dst, src, n, scale); + } +} + +static inline void hvx_scale_offset_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) { + int nvec = n / VLEN_FP32; + int nloe = n % VLEN_FP32; + + HVX_Vector vs = hvx_vec_splat_fp32(scale); + HVX_Vector vo = hvx_vec_splat_fp32(offset); + + HVX_Vector * vsrc = (HVX_Vector *) src; + HVX_Vector * vdst = (HVX_Vector *) dst; + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; ++i) { + HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); + vdst[i] = Q6_Vsf_equals_Vqf32(v); + } + + if (nloe) { + HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); + hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v)); + } +} + +static inline void hvx_scale_offset_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) { + int nvec = n / VLEN_FP32; + int nloe = n % VLEN_FP32; + + HVX_Vector vs = hvx_vec_splat_fp32(scale); + HVX_Vector vo = hvx_vec_splat_fp32(offset); + + HVX_UVector * vsrc = (HVX_UVector *) src; + HVX_UVector * vdst = (HVX_UVector *) dst; + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; ++i) { + HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); + vdst[i] = Q6_Vsf_equals_Vqf32(v); + } + + if (nloe) { + HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo); + hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v)); + } +} + +static inline void hvx_scale_offset_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) { + if (htp_is_aligned((void *) src, VLEN) && htp_is_aligned((void *) dst, VLEN)) { + hvx_scale_offset_f32_aa(dst, src, n, scale, offset); + } else { + hvx_scale_offset_f32_uu(dst, src, n, scale, offset); + } +} float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems); void hvx_mul_f32(const uint8_t * restrict src0, @@ -1090,7 +1338,6 @@ void hvx_sub_f32_opt(const uint8_t * restrict src0, uint8_t * restrict dst, const int num_elems); void hvx_sub_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems); -void hvx_scale_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, const float scale); void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems); void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems); void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate); diff --git a/ggml/src/ggml-hexagon/htp/main.c b/ggml/src/ggml-hexagon/htp/main.c index fb5508a560..24b3e90e4b 100644 --- a/ggml/src/ggml-hexagon/htp/main.c +++ b/ggml/src/ggml-hexagon/htp/main.c @@ -443,6 +443,45 @@ static void proc_matmul_req(struct htp_context * ctx, send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } +static void proc_get_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[1]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[2].fd; + rsp_bufs[0].ptr = bufs[2].ptr; + rsp_bufs[0].offset = bufs[2].offset; + rsp_bufs[0].size = bufs[2].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.src1 = req->src1; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.src1.data = (uint32_t) bufs[1].ptr; + octx.dst.data = (uint32_t) bufs[2].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_get_rows(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + static void proc_matmul_id_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs, @@ -668,7 +707,7 @@ static void proc_rope_req(struct htp_context * ctx, uint32_t n_bufs) { struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; - int write_idx = (n_bufs == 4) ? 3 : 2; + int write_idx = n_bufs - 1; // We had written to the output buffer, we'd also need to flush it rsp_bufs[0].fd = bufs[write_idx].fd; @@ -716,6 +755,102 @@ static void proc_rope_req(struct htp_context * ctx, send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); } +static void proc_set_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) { + struct dspqueue_buffer rsp_bufs[1]; + + // We had written to the output buffer, we'd also need to flush it + rsp_bufs[0].fd = bufs[2].fd; + rsp_bufs[0].ptr = bufs[2].ptr; + rsp_bufs[0].offset = bufs[2].offset; + rsp_bufs[0].size = bufs[2].size; + rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + // Setup Op context + struct htp_ops_context octx = { 0 }; + octx.ctx = ctx; + octx.src0 = req->src0; + octx.src1 = req->src1; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.src1.data = (uint32_t) bufs[1].ptr; + octx.dst.data = (uint32_t) bufs[2].ptr; + octx.n_threads = ctx->n_threads; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_set_rows(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); +} + +static void proc_flash_attn_ext_req(struct htp_context * ctx, + struct htp_general_req * req, + struct dspqueue_buffer * bufs, + uint32_t n_bufs) { + // Setup Op context + struct htp_ops_context octx; + memset(&octx, 0, sizeof(octx)); + + octx.ctx = ctx; + octx.n_threads = ctx->n_threads; + + octx.src0 = req->src0; + octx.src1 = req->src1; + octx.src2 = req->src2; + octx.src3 = req->src3; + octx.src4 = req->src4; + octx.dst = req->dst; + octx.flags = req->flags; + octx.op = req->op; + + memcpy(octx.op_params, req->op_params, sizeof(octx.op_params)); + + // Update data pointers + octx.src0.data = (uint32_t) bufs[0].ptr; + octx.src1.data = (uint32_t) bufs[1].ptr; + octx.src2.data = (uint32_t) bufs[2].ptr; + + int last_buf = 3; + + if (octx.src3.ne[0]) { + octx.src3.data = (uint32_t) bufs[last_buf++].ptr; // mask is valid + } + + if (octx.src4.ne[0]) { + octx.src4.data = (uint32_t) bufs[last_buf++].ptr; // sinks is valid + } + + octx.dst.data = (uint32_t) bufs[last_buf].ptr; + + struct profile_data prof; + profile_start(&prof); + + uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR; + if (vtcm_acquire(ctx) == AEE_SUCCESS) { + rsp_status = op_flash_attn_ext(&octx); + vtcm_release(ctx); + } + + profile_stop(&prof); + + struct dspqueue_buffer rsp_buf = bufs[last_buf]; + rsp_buf.flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP + DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU + + send_htp_rsp(ctx, req->op, rsp_status, &bufs[last_buf], 1, &prof); +} + static void htp_packet_callback(dspqueue_t queue, int error, void * context) { struct htp_context * ctx = (struct htp_context *) context; @@ -790,6 +925,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { break; case HTP_OP_RMS_NORM: + case HTP_OP_SCALE: if (n_bufs != 2) { FARF(ERROR, "Bad unary-req buffer list"); continue; @@ -833,6 +969,30 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) { proc_rope_req(ctx, &req, bufs, n_bufs); break; + case HTP_OP_FLASH_ATTN_EXT: + if (!(n_bufs >= 4 && n_bufs <= 6)) { + FARF(ERROR, "Bad flash-attn-ext-req buffer list"); + continue; + } + proc_flash_attn_ext_req(ctx, &req, bufs, n_bufs); + break; + + case HTP_OP_SET_ROWS: + if (n_bufs != 3) { + FARF(ERROR, "Bad set-rows-req buffer list"); + continue; + } + proc_set_rows_req(ctx, &req, bufs); + break; + + case HTP_OP_GET_ROWS: + if (n_bufs != 3) { + FARF(ERROR, "Bad get-rows-req buffer list"); + continue; + } + proc_get_rows_req(ctx, &req, bufs); + break; + default: FARF(ERROR, "Unknown Op %u", req.op); break; diff --git a/ggml/src/ggml-hexagon/htp/matmul-ops.c b/ggml/src/ggml-hexagon/htp/matmul-ops.c index f14523d485..9bb39db9fc 100644 --- a/ggml/src/ggml-hexagon/htp/matmul-ops.c +++ b/ggml/src/ggml-hexagon/htp/matmul-ops.c @@ -26,14 +26,14 @@ #include "hvx-utils.h" #include "ops-utils.h" +#define MM_SPAD_SRC0_NROWS 16 +#define MM_SPAD_SRC1_NROWS 16 +#define MM_SPAD_DST_NROWS 2 + struct htp_matmul_type { const char * type; void (*vec_dot)(const int n, float * restrict s, const void * restrict vx, const void * restrict vy); - void (*vec_dot_rx2)(const int n, - float * restrict s, - const void * restrict vx, - uint32_t vx_row_size, - const void * restrict vy); + void (*vec_dot_rx2)(const int n, float * restrict s, const void * restrict vx, uint32_t vx_row_size, const void * restrict vy); }; typedef struct { @@ -907,145 +907,174 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n, hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); } -#if 1 -static void vec_dot_f16_f32(const int n, float * restrict s, const void * restrict x, const void * restrict y) { - if (0) { - float rsum = 0; - const __fp16 * restrict vx = (const __fp16 * restrict) x; - const float * restrict vy = (const float * restrict) y; +static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const HVX_Vector * restrict x = (const HVX_Vector *) vx; + const HVX_Vector * restrict y = (const HVX_Vector *) vy; - for (uint32_t i = 0; i < n; i++) { - rsum += (float)vx[i] * vy[i]; - } - *s = rsum; - return; - } + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements - const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x; - const HVX_UVectorPair * restrict vy = (const HVX_UVectorPair * restrict) y; + HVX_Vector rsum = Q6_V_vsplat_R(0); - uint32_t nv0 = n / 64; // num full fp16 hvx vectors - uint32_t nv1 = n % 64; // leftover elements - - // for some reason we need volatile here so that the compiler doesn't try anything funky - volatile HVX_Vector rsum = Q6_V_vsplat_R(0); - float r_sum_scalar = 0.0f; uint32_t i = 0; - for (i = 0; i < nv0; i++) { - HVX_VectorPair yp = vy[i]; - - HVX_Vector x = vx[i]; - HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0 - - //NOTE: need volatile here to prevent compiler optimization - // Seem compiler cannot guarantee read-after-write?? - volatile HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp)); - volatile HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp)); - - HVX_Vector sum = Q6_Vqf32_vadd_Vqf32Vqf32(hi, lo); - rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum); + #pragma unroll(4) + for (i = 0; i < nvec; i++) { + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]); + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); } - if (nv1) { - // HVX_VectorPair yp = vy[i]; + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]); + HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]); - // HVX_Vector x = vx[i]; - // HVX_VectorPair xp = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(x), Q6_Vh_vsplat_R(0x3C00)); // mul by 1.0 - - // if (nv1 >= 32) { - // volatile HVX_Vector hi = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_hi_W(xp)), Q6_V_hi_W(yp)); - // rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, hi); - // nv1 -= 32; - // } - - // rsum = hvx_vec_qf32_reduce_sum(rsum); - - // if (nv1) { - // volatile HVX_Vector lo = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(Q6_V_lo_W(xp)), Q6_V_lo_W(yp)); - // HVX_Vector sum = hvx_vec_qf32_reduce_sum_n(lo, nv1); - // rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, sum); - // } - - //process the remainder using scalar loop - rsum = hvx_vec_qf32_reduce_sum(rsum); - const __fp16 * restrict sx = (const __fp16 * restrict) x; - const float * restrict sy = (const float * restrict) y; - - for (uint32_t i = nv0 * 64; i < n; i++) { - r_sum_scalar += (float) sx[i] * sy[i]; - } - - // hvx_vec_dump_fp16("X", x); - // hvx_vec_dump_fp16("Y", y); - // hvx_vec_dump_fp32("SUM", Q6_Vsf_equals_Vqf32(sum)); - // hvx_vec_dump_fp32("RSUM", Q6_Vsf_equals_Vqf32(rsum)); - } else { - rsum = hvx_vec_qf32_reduce_sum(rsum); + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); } - *s = hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(rsum)) + r_sum_scalar; - -# ifdef HTP_DEBUG - { - float rsum = 0; - const __fp16 * restrict vx = (const __fp16 * restrict) x; - const float * restrict vy = (const float * restrict) y; - - for (uint32_t i = 0; i < n; i++) { - rsum += vx[i] * vy[i]; - } - - float diff = fabs(*s - rsum); - if (diff > 0.001) { - FARF(HIGH, "vec-dot-f16-missmatch: %u (%u:%u) expected %.6f got %.6f\n", n, nv0, nv1, rsum, *s); - // htp_dump_f16("x", vx, n); - // htp_dump_f32("y", vy, n); - } - } -# endif + rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum)); + hvx_vec_store_u(&s[0], 4, rsum); } -#else -static void vec_dot_f16_f32(const int n, float * restrict s, const void * restrict x, const void * restrict y) { - const uint32_t fk = 64; - const uint32_t nb = n / fk; - assert(n % fk == 0); - assert(nb % 4 == 0); +static void vec_dot_f16_f16_aa_rx2(const int n, + float * restrict s, + const void * restrict vx, + uint32_t vx_row_size, + const void * restrict vy) { + const HVX_Vector * restrict x0 = (const HVX_Vector *) vx; + const HVX_Vector * restrict x1 = (const HVX_Vector *) ((const uint8_t *) vx + vx_row_size); + const HVX_Vector * restrict y = (const HVX_Vector *) vy; - const uint32_t x_blk_size = 2 * fk; // fp16 - const uint32_t y_blk_size = 4 * fk; // fp32 + uint32_t nvec = n / VLEN_FP16; + uint32_t nloe = n % VLEN_FP16; - // Row sum (qf32) HVX_Vector rsum0 = Q6_V_vsplat_R(0); HVX_Vector rsum1 = Q6_V_vsplat_R(0); - HVX_Vector rsum2 = Q6_V_vsplat_R(0); - HVX_Vector rsum3 = Q6_V_vsplat_R(0); - for (uint32_t i = 0; i < nb; i += 4) { - HVX_Vector_x4 vx = hvx_vec_load_x4_f16(x + (i * x_blk_size)); - HVX_Vector_x4 vy = hvx_vec_load_x4_f32_as_f16(y + (i * y_blk_size)); + uint32_t i = 0; - HVX_VectorPair fa0 = Q6_Wqf32_vmpy_VhfVhf(vx.v[0], vy.v[0]); - HVX_VectorPair fa1 = Q6_Wqf32_vmpy_VhfVhf(vx.v[1], vy.v[1]); - HVX_VectorPair fa2 = Q6_Wqf32_vmpy_VhfVhf(vx.v[2], vy.v[2]); - HVX_VectorPair fa3 = Q6_Wqf32_vmpy_VhfVhf(vx.v[3], vy.v[3]); + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + HVX_Vector y_hf = y[i]; + HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0[i], y_hf); + HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1[i], y_hf); - rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa0), Q6_V_hi_W(fa0))); - rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa1), Q6_V_hi_W(fa1))); - rsum2 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum2, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa2), Q6_V_hi_W(fa2))); - rsum3 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum3, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(fa3), Q6_V_hi_W(fa3))); + rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf))); + rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf))); } - // Reduce and convert into fp32 - rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, rsum1); - rsum2 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum2, rsum3); - HVX_Vector rsum = hvx_vec_qf32_reduce_sum(Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, rsum2)); - hvx_vec_store_u(s, 4, Q6_Vsf_equals_Vqf32(rsum)); -} -#endif + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + HVX_Vector x0_hf = Q6_V_vand_QV(bmask, x0[i]); + HVX_Vector x1_hf = Q6_V_vand_QV(bmask, x1[i]); + HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]); -#define htp_matmul_preamble \ + HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf); + HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf); + + rsum0 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum0, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf))); + rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf))); + } + + rsum0 = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum0)); + rsum1 = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum1)); + HVX_VectorPair p0 = Q6_W_vshuff_VVR(rsum1, rsum0, 4); + + hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0)); +} + +static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { + const HVX_UVector * restrict x = (const HVX_UVector *) vx; + const HVX_UVector * restrict y = (const HVX_UVector *) vy; + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + HVX_Vector rsum = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(4) + for (i = 0; i < nvec; i++) { + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x[i], y[i]); + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + if (nloe) { + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + HVX_Vector x_hf = Q6_V_vand_QV(bmask, x[i]); + HVX_Vector y_hf = Q6_V_vand_QV(bmask, y[i]); + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum)); + hvx_vec_store_u(&s[0], 4, rsum); +} + +static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * restrict x, const void * restrict y) { + const HVX_UVector * restrict vx = (const HVX_UVector * restrict) x; + const HVX_UVector * restrict vy = (const HVX_UVector * restrict) y; + + uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors + uint32_t nloe = n % VLEN_FP16; // leftover elements + + const HVX_Vector zero = Q6_V_vsplat_R(0); + + HVX_Vector rsum = Q6_V_vsplat_R(0); + + uint32_t i = 0; + + #pragma unroll(2) + for (i = 0; i < nvec; i++) { + // Load y (fp32) and convert into fp16 + HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements + HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements + HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); + + // Load x (fp16) + HVX_Vector x_hf = vx[i]; + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + if (nloe) { + // Load y (fp32) and convert into fp16 + HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements + HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements + HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf))); + + // Load x (fp16) + HVX_Vector x_hf = vx[i]; + + // Zero-out unused elements + // Note that we need to clear both x and y because they may contain NANs + HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2); + x_hf = Q6_V_vand_QV(bmask, x_hf); + y_hf = Q6_V_vand_QV(bmask, y_hf); + + HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); + + rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); + } + + rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum)); + hvx_vec_store_u(&s[0], 4, rsum); +} + +#define htp_matmul_tensors_preamble \ + struct htp_tensor * restrict src0 = &octx->src0; \ + struct htp_tensor * restrict src1 = &octx->src1; \ + struct htp_tensor * restrict src2 = &octx->src2; \ + struct htp_tensor * restrict dst = &octx->dst; \ + struct htp_spad * restrict src0_spad = &octx->src0_spad; \ + struct htp_spad * restrict src1_spad = &octx->src1_spad; \ + struct htp_spad * restrict dst_spad = &octx->dst_spad; \ + \ const uint32_t ne00 = src0->ne[0]; \ const uint32_t ne01 = src0->ne[1]; \ const uint32_t ne02 = src0->ne[2]; \ @@ -1056,6 +1085,11 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri const uint32_t ne12 = src1->ne[2]; \ const uint32_t ne13 = src1->ne[3]; \ \ + const uint32_t ne20 = src2->ne[0]; \ + const uint32_t ne21 = src2->ne[1]; \ + const uint32_t ne22 = src2->ne[2]; \ + const uint32_t ne23 = src2->ne[3]; \ + \ const uint32_t ne0 = dst->ne[0]; \ const uint32_t ne1 = dst->ne[1]; \ const uint32_t ne2 = dst->ne[2]; \ @@ -1076,18 +1110,94 @@ static void vec_dot_f16_f32(const int n, float * restrict s, const void * restri const uint32_t nb2 = dst->nb[2]; \ const uint32_t nb3 = dst->nb[3]; -// q8x4 src1 tensor is already in VTCM spad -static void matmul(struct htp_matmul_type * mt, - struct htp_tensor * restrict src0, - struct htp_tensor * restrict src1, - struct htp_tensor * restrict dst, - struct htp_spad * restrict src0_spad, - struct htp_spad * restrict src1_spad, - struct htp_spad * restrict dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { +#define htp_matmul_preamble \ + htp_matmul_tensors_preamble; \ + dma_queue *dma_queue = octx->ctx->dma[ith]; \ + uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread; + +// *** matmul with support for 4d tensors and full broadcasting + +static void matmul_4d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { + htp_matmul_preamble; + + uint64_t t1, t2; + t1 = HAP_perf_get_qtimer_count(); + + assert(ne12 % ne02 == 0); + assert(ne13 % ne03 == 0); + + // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers) + const uint32_t nr0 = ne0; + + // This is the size of the rest of the dimensions of the result + const uint32_t nr1 = ne1 * ne2 * ne3; + + // distribute the thread work across the inner or outer loop based on which one is larger + uint32_t nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows + uint32_t nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows + + // The number of elements in each chunk + const uint32_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; + const uint32_t dr1 = (nr1 + nchunk1 - 1) / nchunk1; + + uint32_t current_chunk = ith; + + const uint32_t ith0 = current_chunk % nchunk0; + const uint32_t ith1 = current_chunk / nchunk0; + + const uint32_t ir0_start = dr0 * ith0; + const uint32_t ir0_end = MIN(ir0_start + dr0, nr0); + + const uint32_t ir1_start = dr1 * ith1; + const uint32_t ir1_end = MIN(ir1_start + dr1, nr1); + + // no work for this thread + if (ir0_start >= ir0_end || ir1_start >= ir1_end) { + return; + } + + // block-tiling attempt + const uint32_t blck_0 = 64; + const uint32_t blck_1 = 64; + + for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { + for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { + for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) { + const uint32_t i13 = fastdiv(ir1, &octx->mm_div_ne12_ne1); + const uint32_t i12 = fastdiv(ir1 - i13 * ne12 * ne1, &octx->mm_div_ne1); + const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1); + + // broadcast src0 into src1 + const uint32_t i03 = fastdiv(i13, &octx->mm_div_r3); + const uint32_t i02 = fastdiv(i12, &octx->mm_div_r2); + + const uint32_t i1 = i11; + const uint32_t i2 = i12; + const uint32_t i3 = i13; + + const uint8_t * restrict src0_base = (const uint8_t *) src0->data + (0 + i02 * nb02 + i03 * nb03); + const uint8_t * restrict src1_col = (const uint8_t *) src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13); + float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); + + const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end); + for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) { + const uint8_t * restrict src0_row = src0_base + ir0 * nb01; + mt->vec_dot(ne00, &dst_col[ir0], src0_row, src1_col); + } + } + } + } + + t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "matmul-4d %d/%d: %ux%ux%ux%u (%u:%u %u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, + src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0_start, ir0_end, ir1_start, ir1_end, src1->ne[0], + src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], + (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// src1 tensor is already in VTCM spad +static void matmul_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { htp_matmul_preamble; const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows @@ -1104,9 +1214,10 @@ static void matmul(struct htp_matmul_type * mt, const size_t dst_row_size = nb1; const size_t src0_row_size = nb01; - const size_t src1_row_size = q8x4x2_row_size(ne10); + const size_t src1_row_size = nb11; - const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128); + const size_t src0_stride = src0_spad->stride; + const size_t src1_stride = src1_spad->stride; // Per-thread VTCM scratchpads for all tensors // Note that the entire src1 tensor is already in VTCM @@ -1124,11 +1235,11 @@ static void matmul(struct htp_matmul_type * mt, #pragma unroll(4) for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const int is0 = (ir0 - src0_start_row); - if (is0 >= HTP_SPAD_SRC0_NROWS) { + if (is0 >= MM_SPAD_SRC0_NROWS) { break; } - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), - src0_row_size_padded, src0_row_size, 2); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 2); } // Process src0 rows @@ -1137,17 +1248,17 @@ static void matmul(struct htp_matmul_type * mt, #pragma unroll(2) for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) { - const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_row_size); + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size)); - mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col); + mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_stride, src1_col); } // Prefetch next (n + spad_nrows) row - const int pr0 = (ir0 + HTP_SPAD_SRC0_NROWS); - const int is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS; + const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS); + const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; if (pr0 < src0_end_row_x2) { - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size), - src0_row_size_padded, src0_row_size, 2); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), + src0_stride, src0_row_size, 2); } } @@ -1155,13 +1266,13 @@ static void matmul(struct htp_matmul_type * mt, if (src0_end_row != src0_end_row_x2) { uint32_t ir0 = src0_end_row_x2; const int is0 = (ir0 - src0_start_row); - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), - src0_row_size_padded, src0_row_size, 1); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; #pragma unroll(2) for (uint32_t ir1 = 0; ir1 < src1_nrows; ++ir1) { - const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_row_size); + const uint8_t * restrict src1_col = (const uint8_t *) (src1_data + ir1 * src1_stride); float * restrict dst_row = (float *) (dst->data + (ir1 * dst_row_size)); mt->vec_dot(ne00, &dst_row[ir0], ss0, src1_col); } @@ -1176,17 +1287,7 @@ static void matmul(struct htp_matmul_type * mt, } // q8x4x2 src1 tensor is already in VTCM spad -static void matvec(struct htp_matmul_type * mt, - struct htp_tensor * restrict src0, - struct htp_tensor * restrict src1, - struct htp_tensor * restrict dst, - struct htp_spad * restrict src0_spad, - struct htp_spad * restrict src1_spad, - struct htp_spad * restrict dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { +static void matvec_2d(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { htp_matmul_preamble; const uint32_t src0_nrows = ne01; @@ -1202,9 +1303,10 @@ static void matvec(struct htp_matmul_type * mt, const size_t dst_row_size = nb1; const size_t src0_row_size = nb01; - const size_t src1_row_size = q8x4x2_row_size(ne10); + const size_t src1_row_size = nb11; - const size_t src0_row_size_padded = htp_round_up(src0_row_size, 128); + const size_t src0_stride = src0_spad->stride; + const size_t src1_stride = src1_spad->stride; // Per-thread VTCM scratchpads for all tensors // Note that the entire src1 tensor is already in VTCM @@ -1226,24 +1328,24 @@ static void matvec(struct htp_matmul_type * mt, #pragma unroll(2) for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const uint32_t is0 = (ir0 - src0_start_row); - if (is0 >= HTP_SPAD_SRC0_NROWS) { + if (is0 >= MM_SPAD_SRC0_NROWS) { break; } - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), - src0_row_size_padded, src0_row_size, 2); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 2); } // Process src0 rows for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; - mt->vec_dot_rx2(ne00, &tmp[ir0 - src0_start_row], ss0, src0_row_size_padded, src1_col); + mt->vec_dot_rx2(ne00, &tmp[ir0 - src0_start_row], ss0, src0_stride, src1_col); // Prefetch next (n + spad_nrows) row - const uint32_t pr0 = (ir0 + HTP_SPAD_SRC0_NROWS); - const uint32_t is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS; + const uint32_t pr0 = (ir0 + MM_SPAD_SRC0_NROWS); + const uint32_t is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; if (pr0 < src0_end_row_x2) { - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size), - src0_row_size_padded, src0_row_size, 2); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + pr0 * src0_row_size), + src0_stride, src0_row_size, 2); } } @@ -1251,8 +1353,8 @@ static void matvec(struct htp_matmul_type * mt, if (src0_end_row != src0_end_row_x2) { const uint32_t ir0 = src0_end_row_x2; const uint32_t is0 = (ir0 - src0_start_row); - dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), - src0_row_size_padded, src0_row_size, 1); + dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_stride, src0_row + ir0 * src0_row_size), + src0_stride, src0_row_size, 1); const uint8_t * ss0 = dma_queue_pop(dma_queue).dst; mt->vec_dot(ne00, &tmp[ir0 - src0_start_row], ss0, src1_col); } @@ -1274,22 +1376,13 @@ struct mmid_row_mapping { uint32_t i2; }; -// q8x4 src1 tensor is already in VTCM spad -static void matmul_id(struct htp_matmul_type * mt, - struct htp_tensor * restrict src0, - struct htp_tensor * restrict src1, - struct htp_tensor * restrict ids, - struct htp_tensor * restrict dst, - struct htp_spad * restrict src0_spad, - struct htp_spad * restrict src1_spad, - struct htp_spad * restrict src2_spad, - struct htp_spad * restrict dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { +// src1 tensor is already in VTCM spad +static void matmul_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { htp_matmul_preamble; + struct htp_tensor * restrict ids = &octx->src2; + struct htp_spad * restrict src2_spad = &octx->src2_spad; + uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); @@ -1340,7 +1433,7 @@ static void matmul_id(struct htp_matmul_type * mt, #pragma unroll(4) for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const int is0 = (ir0 - src0_start_row); - if (is0 >= HTP_SPAD_SRC0_NROWS) { + if (is0 >= MM_SPAD_SRC0_NROWS) { break; } dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), @@ -1365,8 +1458,8 @@ static void matmul_id(struct htp_matmul_type * mt, } // Prefetch next (n + spad_nrows) row - const int pr0 = (ir0 + HTP_SPAD_SRC0_NROWS); - const int is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS; + const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS); + const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; if (pr0 < src0_end_row_x2) { dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size), src0_row_size_padded, src0_row_size, 2); @@ -1404,22 +1497,13 @@ static void matmul_id(struct htp_matmul_type * mt, dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -// q8x4 src1 tensor is already in VTCM spad -static void matvec_id(struct htp_matmul_type * mt, - struct htp_tensor * restrict src0, - struct htp_tensor * restrict src1, - struct htp_tensor * restrict src2, - struct htp_tensor * restrict dst, - struct htp_spad * restrict src0_spad, - struct htp_spad * restrict src1_spad, - struct htp_spad * restrict src2_spad, - struct htp_spad * restrict dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { +// src1 tensor is already in VTCM spad +static void matvec_id(struct htp_matmul_type * mt, struct htp_ops_context * octx, uint32_t nth, uint32_t ith) { htp_matmul_preamble; + struct htp_tensor * restrict ids = &octx->src2; + struct htp_spad * restrict src2_spad = &octx->src2_spad; + uint64_t t1, t2; t1 = HAP_perf_get_qtimer_count(); @@ -1464,7 +1548,7 @@ static void matvec_id(struct htp_matmul_type * mt, #pragma unroll(4) for (uint32_t ir0 = src0_start_row; ir0 < src0_end_row_x2; ir0 += 2) { const int is0 = (ir0 - src0_start_row); - if (is0 >= HTP_SPAD_SRC0_NROWS) { + if (is0 >= MM_SPAD_SRC0_NROWS) { break; } dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + ir0 * src0_row_size), @@ -1477,8 +1561,8 @@ static void matvec_id(struct htp_matmul_type * mt, mt->vec_dot_rx2(ne00, &dst_row[ir0], ss0, src0_row_size_padded, src1_col); // Prefetch next (n + spad_nrows) row - const int pr0 = (ir0 + HTP_SPAD_SRC0_NROWS); - const int is0 = (pr0 - src0_start_row) % HTP_SPAD_SRC0_NROWS; + const int pr0 = (ir0 + MM_SPAD_SRC0_NROWS); + const int is0 = (pr0 - src0_start_row) % MM_SPAD_SRC0_NROWS; if (pr0 < src0_end_row_x2) { dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(spad_src0 + is0 * src0_row_size_padded, src0_row + pr0 * src0_row_size), src0_row_size_padded, src0_row_size, 2); @@ -1504,106 +1588,6 @@ static void matvec_id(struct htp_matmul_type * mt, dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } -// *** matmul in fp16 - -static void matmul_f16_f32(struct htp_tensor * restrict src0, - struct htp_tensor * restrict src1, - struct htp_tensor * restrict dst, - struct htp_spad * restrict src0_spad, - struct htp_spad * restrict src1_spad, - struct htp_spad * restrict dst_spad, - uint32_t nth, - uint32_t ith, - uint32_t src0_nrows_per_thread, - dma_queue * dma_queue) { - htp_matmul_preamble; - - uint64_t t1, t2; - t1 = HAP_perf_get_qtimer_count(); - - assert(ne12 % ne02 == 0); - assert(ne13 % ne03 == 0); - - // This is the size of the first dimension of the result, so we can iterate that way. (see the ASSERT above, these are the same numbers) - const uint32_t nr0 = ne0; - - // This is the size of the rest of the dimensions of the result - const uint32_t nr1 = ne1 * ne2 * ne3; - - // distribute the thread work across the inner or outer loop based on which one is larger - uint32_t nchunk0 = nr0 > nr1 ? nth : 1; // parallelize by src0 rows - uint32_t nchunk1 = nr0 > nr1 ? 1 : nth; // parallelize by src1 rows - - // The number of elements in each chunk - const uint32_t dr0 = (nr0 + nchunk0 - 1) / nchunk0; - const uint32_t dr1 = (nr1 + nchunk1 - 1) / nchunk1; - - uint32_t current_chunk = ith; - - const uint32_t ith0 = current_chunk % nchunk0; - const uint32_t ith1 = current_chunk / nchunk0; - - const uint32_t ir0_start = dr0 * ith0; - const uint32_t ir0_end = MIN(ir0_start + dr0, nr0); - - const uint32_t ir1_start = dr1 * ith1; - const uint32_t ir1_end = MIN(ir1_start + dr1, nr1); - - // broadcast factors - const uint32_t r2 = ne12 / ne02; - const uint32_t r3 = ne13 / ne03; - - // no work for this thread - if (ir0_start >= ir0_end || ir1_start >= ir1_end) { - return; - } - - // block-tiling attempt - const uint32_t blck_0 = 64; - const uint32_t blck_1 = 64; - - __attribute__((aligned(128))) float tmp[64]; - - for (uint32_t iir1 = ir1_start; iir1 < ir1_end; iir1 += blck_1) { - for (uint32_t iir0 = ir0_start; iir0 < ir0_end; iir0 += blck_0) { - for (uint32_t ir1 = iir1; ir1 < MIN(iir1 + blck_1, ir1_end); ir1++) { - const uint32_t i13 = (ir1 / (ne12 * ne1)); - const uint32_t i12 = (ir1 - i13 * ne12 * ne1) / ne1; - const uint32_t i11 = (ir1 - i13 * ne12 * ne1 - i12 * ne1); - - // broadcast src0 into src1 - const uint32_t i03 = i13 / r3; - const uint32_t i02 = i12 / r2; - - const uint32_t i1 = i11; - const uint32_t i2 = i12; - const uint32_t i3 = i13; - - const uint8_t * restrict src0_base = (const uint8_t *) src0->data + (0 + i02 * nb02 + i03 * nb03); - const uint8_t * restrict src1_col = - (const uint8_t *) src1->data + (i11 * nb11 + i12 * nb12 + i13 * nb13); - float * dst_col = (float *) ((uint8_t * restrict) dst->data + (i1 * nb1 + i2 * nb2 + i3 * nb3)); - - const uint32_t ir0_block_end = MIN(iir0 + blck_0, ir0_end); - for (uint32_t ir0 = iir0; ir0 < ir0_block_end; ir0++) { - // Use nb01 stride for non-contiguous src0 support - const uint8_t * restrict src0_row = src0_base + ir0 * nb01; - vec_dot_f16_f32(ne00, &tmp[ir0 - iir0], src0_row, src1_col); - } - - hvx_copy_fp32_ua((uint8_t *) &dst_col[iir0], (uint8_t *) tmp, MIN(iir0 + blck_0, ir0_end) - iir0); - } - } - } - - t2 = HAP_perf_get_qtimer_count(); - - FARF(HIGH, "matmul-f16-f32 %d/%d: %ux%ux%ux%u (%u:%u %u:%u) * %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, - src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], ir0_start, ir0_end, ir1_start, ir1_end, src1->ne[0], - src1->ne[1], src1->ne[2], src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], - (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); -} - // *** dynamic quant static inline void quantize_block_fp32_q8x1(float * restrict x, uint8_t * restrict y_q, uint8_t * restrict y_d) { @@ -1780,20 +1764,14 @@ static void quantize_row_fp32_q8x4x2(float * restrict x, uint8_t * restrict y, u for (uint32_t i = 0; i < nb; i++) { #if FP32_QUANTIZE_GROUP_SIZE == 32 - quantize_block_fp32_q8x1(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2, - t_d + (i * 2 + 0) * dblk_size / 2); - quantize_block_fp32_q8x1(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2, - t_d + (i * 2 + 1) * dblk_size / 2); + quantize_block_fp32_q8x1(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); + quantize_block_fp32_q8x1(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); #elif FP32_QUANTIZE_GROUP_SIZE == 64 - quantize_block_fp32_q8x2(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2, - t_d + (i * 2 + 0) * dblk_size / 2); - quantize_block_fp32_q8x2(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2, - t_d + (i * 2 + 1) * dblk_size / 2); + quantize_block_fp32_q8x2(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); + quantize_block_fp32_q8x2(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); #elif FP32_QUANTIZE_GROUP_SIZE == 128 - quantize_block_fp32_q8x4(x + (i * 2 + 0) * qk / 2, y_q + (i * 2 + 0) * qblk_size / 2, - t_d + (i * 2 + 0) * dblk_size / 2); - quantize_block_fp32_q8x4(x + (i * 2 + 1) * qk / 2, y_q + (i * 2 + 1) * qblk_size / 2, - t_d + (i * 2 + 1) * dblk_size / 2); + quantize_block_fp32_q8x4(x + (i*2 + 0) * qk/2, y_q + (i*2 + 0) * qblk_size/2, t_d + (i*2 + 0) * dblk_size/2); + quantize_block_fp32_q8x4(x + (i*2 + 1) * qk/2, y_q + (i*2 + 1) * qblk_size/2, t_d + (i*2 + 1) * dblk_size/2); #else #error "FP32_QUANTIZE_GROUP_SIZE must be 32, 64, or 128" #endif @@ -1848,14 +1826,95 @@ static void quantize_fp32_q8x4x2(const struct htp_tensor * src, ir_last, src_row_size, dst_row_size, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); } +static void quantize_fp32_fp16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith, + uint32_t nrows_per_thread, uint32_t dst_stride) { + + uint64_t t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ne0 = src->ne[0]; + const uint32_t ne1 = src->ne[1]; + const uint32_t ne2 = src->ne[2]; + const uint32_t ne3 = src->ne[3]; + + const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows + + const uint32_t ir_first = nrows_per_thread * ith; // first row + const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row + + const size_t src_row_size = ne0 * sizeof(float); + const size_t src_stride = src->nb[1]; + + uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first); + uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first); + + for (uint32_t i = ir_first; i < ir_last; ++i) { + htp_l2fetch(src_data, 2, src_row_size, src_stride); + hvx_copy_fp16_fp32_au(dst_data, src_data, ne0); + + dst_data += dst_stride; + src_data += src_stride; + } + + uint64_t t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "quantize-fp32-fp16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first, + ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + +// TODO just a plain copy that should be done via the DMA during the Op setup +static void quantize_fp16_fp16(const struct htp_tensor * src, uint8_t * restrict dst, uint32_t nth, uint32_t ith, + uint32_t nrows_per_thread, uint32_t dst_stride) { + + uint64_t t1 = HAP_perf_get_qtimer_count(); + + const uint32_t ne0 = src->ne[0]; + const uint32_t ne1 = src->ne[1]; + const uint32_t ne2 = src->ne[2]; + const uint32_t ne3 = src->ne[3]; + + const uint32_t nrows = ne1 * ne2 * ne3; // total n_rows + + const uint32_t ir_first = nrows_per_thread * ith; // first row + const uint32_t ir_last = MIN(ir_first + nrows_per_thread, nrows); // last row + + const size_t src_row_size = ne0 * sizeof(float); + const size_t src_stride = src->nb[1]; + + uint8_t * restrict src_data = (uint8_t *) src->data + (src_stride * ir_first); + uint8_t * restrict dst_data = (uint8_t *) dst + (dst_stride * ir_first); + + for (uint32_t i = ir_first; i < ir_last; ++i) { + htp_l2fetch(src_data, 2, src_row_size, src_stride); + hvx_copy_fp16_au(dst_data, src_data, ne0); + + dst_data += dst_stride; + src_data += src_stride; + } + + uint64_t t2 = HAP_perf_get_qtimer_count(); + + FARF(HIGH, "quantize-fp16-fp16: %u/%u : n-rows %u (%u:%u) row-size %u (%u) -> %u usec %u\n", ith, nth, nrows, ir_first, + ir_last, src_row_size, src_stride, dst_stride, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); +} + static void htp_quantize_fp32_q8x4x2(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; quantize_fp32_q8x4x2(&octx->src1, octx->src1_spad.data, &octx->src0_spad, n, i, octx->src1_nrows_per_thread); } -// ** matmul callbacks for worker_pool +static void htp_quantize_fp32_fp16(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + quantize_fp32_fp16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride); +} -static void htp_matvec_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { +static void htp_quantize_fp16_fp16(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + quantize_fp16_fp16(&octx->src1, octx->src1_spad.data, n, i, octx->src1_nrows_per_thread, octx->src1_spad.stride); +} + +// ** matmul/matvec callbacks for worker_pool + +static void htp_matvec_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; struct htp_matmul_type mt; @@ -1863,11 +1922,10 @@ static void htp_matvec_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data mt.vec_dot = vec_dot_q4x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - matvec(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matvec_2d(&mt, octx, n, i); } -static void htp_matmul_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { +static void htp_matmul_2d_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; struct htp_matmul_type mt; @@ -1875,11 +1933,10 @@ static void htp_matmul_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data mt.vec_dot = vec_dot_q4x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - matmul(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matmul_2d(&mt, octx, n, i); } -static void htp_matvec_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { +static void htp_matvec_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; struct htp_matmul_type mt; @@ -1887,11 +1944,10 @@ static void htp_matvec_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data mt.vec_dot = vec_dot_q8x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - matvec(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matvec_2d(&mt, octx, n, i); } -static void htp_matmul_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { +static void htp_matmul_2d_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; struct htp_matmul_type mt; @@ -1899,11 +1955,10 @@ static void htp_matmul_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data mt.vec_dot = vec_dot_q8x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - matmul(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matmul_2d(&mt, octx, n, i); } -static void htp_matvec_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { +static void htp_matvec_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; struct htp_matmul_type mt; @@ -1911,11 +1966,10 @@ static void htp_matvec_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * d mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - matvec(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matvec_2d(&mt, octx, n, i); } -static void htp_matmul_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { +static void htp_matmul_2d_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; struct htp_matmul_type mt; @@ -1923,14 +1977,49 @@ static void htp_matmul_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * d mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - matmul(&mt, &octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matmul_2d(&mt, octx, n, i); } -static void htp_matmul_f16_f32(unsigned int n, unsigned int i, void * data) { +static void htp_matvec_2d_f16_f16(unsigned int n, unsigned int i, void * data) { struct htp_ops_context * octx = data; - matmul_f16_f32(&octx->src0, &octx->src1, &octx->dst, &octx->src0_spad, &octx->src1_spad, &octx->dst_spad, n, i, - octx->src0_nrows_per_thread, octx->ctx->dma[i]); + + struct htp_matmul_type mt; + mt.type = "f16-f16"; + mt.vec_dot = vec_dot_f16_f16_aa; + mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2; + + matvec_2d(&mt, octx, n, i); +} + +static void htp_matmul_2d_f16_f16(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + + struct htp_matmul_type mt; + mt.type = "f16-f16"; + mt.vec_dot = vec_dot_f16_f16_aa; + mt.vec_dot_rx2 = vec_dot_f16_f16_aa_rx2; + + matmul_2d(&mt, octx, n, i); +} + +static void htp_matmul_4d_f16_f32(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + + struct htp_matmul_type mt; + mt.type = "f16-f32"; + mt.vec_dot = vec_dot_f16_f32_uu; + + matmul_4d(&mt, octx, n, i); +} + +static void htp_matmul_4d_f16_f16(unsigned int n, unsigned int i, void * data) { + struct htp_ops_context * octx = data; + + struct htp_matmul_type mt; + mt.type = "f16-f16"; + mt.vec_dot = vec_dot_f16_f16_uu; + + matmul_4d(&mt, octx, n, i); } // ** matmul-id callbacks for worker_pool @@ -1943,8 +2032,7 @@ static void htp_matvec_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * d mt.vec_dot = vec_dot_q4x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - matvec_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad, - &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matvec_id(&mt, octx, n, i); } static void htp_matmul_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { @@ -1955,8 +2043,7 @@ static void htp_matmul_id_q4x4x2_q8x4x2(unsigned int n, unsigned int i, void * d mt.vec_dot = vec_dot_q4x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_q4x4x2_q8x4x2_rx2; - matmul_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad, - &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matmul_id(&mt, octx, n, i); } static void htp_matvec_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { @@ -1967,8 +2054,7 @@ static void htp_matvec_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * d mt.vec_dot = vec_dot_q8x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - matvec_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad, - &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matvec_id(&mt, octx, n, i); } static void htp_matmul_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { @@ -1979,8 +2065,7 @@ static void htp_matmul_id_q8x4x2_q8x4x2(unsigned int n, unsigned int i, void * d mt.vec_dot = vec_dot_q8x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_q8x4x2_q8x4x2_rx2; - matmul_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad, - &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matmul_id(&mt, octx, n, i); } static void htp_matvec_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { @@ -1991,8 +2076,7 @@ static void htp_matvec_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - matvec_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad, - &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matvec_id(&mt, octx, n, i); } static void htp_matmul_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void * data) { @@ -2003,18 +2087,17 @@ static void htp_matmul_id_mxfp4x4x2_q8x4x2(unsigned int n, unsigned int i, void mt.vec_dot = vec_dot_mxfp4x4x2_q8x4x2; mt.vec_dot_rx2 = vec_dot_mxfp4x4x2_q8x4x2_rx2; - matmul_id(&mt, &octx->src0, &octx->src1, &octx->src2, &octx->dst, &octx->src0_spad, &octx->src1_spad, - &octx->src2_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); + matmul_id(&mt, octx, n, i); } // ** main matmul entry point -int op_matmul(struct htp_ops_context * octx) { - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - struct htp_tensor * dst = &octx->dst; +static inline bool htp_is_permuted(const struct htp_tensor * t) { + return t->nb[0] > t->nb[1] || t->nb[1] > t->nb[2] || t->nb[2] > t->nb[3]; +} - htp_matmul_preamble; +int op_matmul(struct htp_ops_context * octx) { + htp_matmul_tensors_preamble; const char * op_type; @@ -2038,9 +2121,9 @@ int op_matmul(struct htp_ops_context * octx) { op_type = "q4x4x2-fp32"; quant_job_func = htp_quantize_fp32_q8x4x2; if (src1_nrows > 1) { - matmul_job_func = htp_matmul_q4x4x2_q8x4x2; + matmul_job_func = htp_matmul_2d_q4x4x2_q8x4x2; } else { - matmul_job_func = htp_matvec_q4x4x2_q8x4x2; + matmul_job_func = htp_matvec_2d_q4x4x2_q8x4x2; } src1_row_size = q8x4x2_row_size(ne10); // row size post quantization @@ -2048,8 +2131,8 @@ int op_matmul(struct htp_ops_context * octx) { // Entire src1 tensor is placed into the VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); // src0 spad is also used in dynamic quantizer to store padded src1 rows @@ -2067,9 +2150,9 @@ int op_matmul(struct htp_ops_context * octx) { op_type = "q8x4x2-fp32"; quant_job_func = htp_quantize_fp32_q8x4x2; if (src1_nrows > 1) { - matmul_job_func = htp_matmul_q8x4x2_q8x4x2; + matmul_job_func = htp_matmul_2d_q8x4x2_q8x4x2; } else { - matmul_job_func = htp_matvec_q8x4x2_q8x4x2; + matmul_job_func = htp_matvec_2d_q8x4x2_q8x4x2; } src1_row_size = q8x4x2_row_size(ne10); // row size post quantization @@ -2077,8 +2160,8 @@ int op_matmul(struct htp_ops_context * octx) { // Entire src1 tensor is placed into the VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); // src0 spad is also used in dynamic quantizer to store padded src1 rows @@ -2096,9 +2179,9 @@ int op_matmul(struct htp_ops_context * octx) { op_type = "mxfp4x4x2-f32"; quant_job_func = htp_quantize_fp32_q8x4x2; if (src1_nrows > 1) { - matmul_job_func = htp_matmul_mxfp4x4x2_q8x4x2; + matmul_job_func = htp_matmul_2d_mxfp4x4x2_q8x4x2; } else { - matmul_job_func = htp_matvec_mxfp4x4x2_q8x4x2; + matmul_job_func = htp_matvec_2d_mxfp4x4x2_q8x4x2; } src1_row_size = q8x4x2_row_size(ne10); // row size post quantization @@ -2106,8 +2189,8 @@ int op_matmul(struct htp_ops_context * octx) { // Entire src1 tensor is placed into the VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); // src0 spad is also used in dynamic quantizer to store padded src1 rows @@ -2122,20 +2205,69 @@ int op_matmul(struct htp_ops_context * octx) { break; case HTP_TYPE_F16: - op_type = "f16-f32"; - quant_job_func = NULL; // htp_quantize_f32_f16; - matmul_job_func = htp_matmul_f16_f32; + { + // Try optimized f16-f16 path first (src1 in VTCM) + const size_t f16_src1_row_size = htp_round_up(ne10 * 2, 128); + const size_t f16_src1_spad_size = htp_round_up(f16_src1_row_size * src1_nrows, 256); + const size_t f16_src0_spad_size = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256) * octx->n_threads; + const size_t f16_dst_spad_size = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256) * octx->n_threads; - // For all tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size, 256); - octx->src1_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC1_NROWS * src1_row_size, 256); + const size_t f16_total_size = f16_src1_spad_size + f16_src0_spad_size + f16_dst_spad_size; - octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; - octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; - octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + // Default matmul implementation does not support multi-batch src0 (N-vs-N broadcasting). + // It only supports 1-vs-N broadcasting (src0 is 2D) or standard 2D matmul. + const bool is_batched = (ne02 > 1) || (ne03 > 1); + const bool is_permuted = htp_is_permuted(&octx->src0) || htp_is_permuted(&octx->src1); - need_quant = false; + if (!is_batched && !is_permuted && f16_total_size <= octx->ctx->vtcm_size) { + // Optimized path + op_type = "f16-f16"; + quant_job_func = (src1->type == HTP_TYPE_F32) ? htp_quantize_fp32_fp16 : htp_quantize_fp16_fp16; + if (src1_nrows > 1) { + matmul_job_func = htp_matmul_2d_f16_f16; + } else { + matmul_job_func = htp_matvec_2d_f16_f16; + } + + src1_row_size = f16_src1_row_size; // row size post quantization + + octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); + + octx->src1_spad.size = octx->src1_spad.size_per_thread; + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + } else { + // Fallback to f16/f32 (DDR) if src1 doesn't fit in VTCM or broadcasting is required + quant_job_func = NULL; + if (src1->type == HTP_TYPE_F32) { + op_type = "f16-f32"; + matmul_job_func = htp_matmul_4d_f16_f32; + } else { + op_type = "f16-f16"; + matmul_job_func = htp_matmul_4d_f16_f16; + } + + src1_row_size = nb11; // original row size in DDR + + octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size, 256); + octx->src1_spad.size_per_thread = htp_round_up(MM_SPAD_SRC1_NROWS * src1_row_size, 256); + + octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads; + octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads; + octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads; + + // Init fastdiv for matmul_4d (supports broadcasting) + octx->mm_div_ne12_ne1 = init_fastdiv_values(src1->ne[2] * dst->ne[1]); + octx->mm_div_ne1 = init_fastdiv_values(dst->ne[1]); + octx->mm_div_r2 = init_fastdiv_values(src1->ne[2] / src0->ne[2]); + octx->mm_div_r3 = init_fastdiv_values(src1->ne[3] / src0->ne[3]); + + need_quant = false; + } + } break; default: @@ -2166,6 +2298,9 @@ int op_matmul(struct htp_ops_context * octx) { octx->src0_nrows_per_thread = (src0_nrows + octx->n_threads - 1) / octx->n_threads; octx->src0_nrows_per_thread += (octx->src0_nrows_per_thread & 1); // round up to even + octx->src0_spad.stride = src0_row_size_padded; + octx->src1_spad.stride = src1_row_size; + if (need_quant) { // Run quant jobs const uint32_t n_quant_jobs = MIN(src1_nrows, octx->n_threads); @@ -2185,12 +2320,9 @@ int op_matmul(struct htp_ops_context * octx) { // ** main matmul-id entry point int op_matmul_id(struct htp_ops_context * octx) { - const struct htp_tensor * src0 = &octx->src0; - const struct htp_tensor * src1 = &octx->src1; - const struct htp_tensor * ids = &octx->src2; - struct htp_tensor * dst = &octx->dst; + htp_matmul_tensors_preamble; - htp_matmul_preamble; + struct htp_tensor * restrict ids = &octx->src2; const char * op_type; @@ -2228,8 +2360,8 @@ int op_matmul_id(struct htp_ops_context * octx) { // Entire src1 tensor is placed into the VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256); @@ -2257,8 +2389,8 @@ int op_matmul_id(struct htp_ops_context * octx) { // Entire src1 tensor is placed into the VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256); @@ -2286,8 +2418,8 @@ int op_matmul_id(struct htp_ops_context * octx) { // Entire src1 tensor is placed into the VTCM // For other tensors we allocate N rows per thread, padded to HVX vector size - octx->dst_spad.size_per_thread = htp_round_up(HTP_SPAD_DST_NROWS * dst_row_size, 256); - octx->src0_spad.size_per_thread = htp_round_up(HTP_SPAD_SRC0_NROWS * src0_row_size_padded, 256); + octx->dst_spad.size_per_thread = htp_round_up(MM_SPAD_DST_NROWS * dst_row_size, 256); + octx->src0_spad.size_per_thread = htp_round_up(MM_SPAD_SRC0_NROWS * src0_row_size_padded, 256); octx->src1_spad.size_per_thread = htp_round_up(src1_row_size * src1_nrows, 256); octx->src2_spad.size_per_thread = htp_round_up(matrix_row_counts_size + matrix_row_map_size, 256); diff --git a/ggml/src/ggml-hexagon/htp/set-rows-ops.c b/ggml/src/ggml-hexagon/htp/set-rows-ops.c new file mode 100644 index 0000000000..bdd64fcc8f --- /dev/null +++ b/ggml/src/ggml-hexagon/htp/set-rows-ops.c @@ -0,0 +1,168 @@ +#pragma clang diagnostic ignored "-Wunused-variable" +#pragma clang diagnostic ignored "-Wunused-function" +#pragma clang diagnostic ignored "-Wunused-but-set-variable" + +#ifdef HTP_DEBUG +# define FARF_HIGH 1 +#endif +#include +#include +#include +#include +#include +#include +#include + +#define GGML_COMMON_DECL_C +#include "ggml-common.h" +#include "htp-ctx.h" +#include "htp-msg.h" +#include "htp-ops.h" +#include "hvx-utils.h" +#include "ops-utils.h" + +#define set_rows_preamble \ + const uint32_t ne00 = octx->src0.ne[0]; \ + const uint32_t ne01 = octx->src0.ne[1]; \ + const uint32_t ne02 = octx->src0.ne[2]; \ + const uint32_t ne03 = octx->src0.ne[3]; \ + \ + const uint32_t ne10 = octx->src1.ne[0]; \ + const uint32_t ne11 = octx->src1.ne[1]; \ + const uint32_t ne12 = octx->src1.ne[2]; \ + \ + const uint32_t nb01 = octx->src0.nb[1]; \ + const uint32_t nb02 = octx->src0.nb[2]; \ + const uint32_t nb03 = octx->src0.nb[3]; \ + \ + const uint32_t nb10 = octx->src1.nb[0]; \ + const uint32_t nb11 = octx->src1.nb[1]; \ + const uint32_t nb12 = octx->src1.nb[2]; \ + \ + const uint32_t nb1 = octx->dst.nb[1]; \ + const uint32_t nb2 = octx->dst.nb[2]; \ + const uint32_t nb3 = octx->dst.nb[3]; \ + \ + const uint32_t ne1 = octx->dst.ne[1]; \ + \ + const uint32_t nr = ne01; + +static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) { + set_rows_preamble; + + // parallelize by rows of src0 + const uint32_t dr = octx->src0_nrows_per_thread; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; + + const bool is_i32 = (octx->src1.type == HTP_TYPE_I32); + + for (uint32_t i03 = 0; i03 < ne03; ++i03) { + for (uint32_t i02 = 0; i02 < ne02; ++i02) { + for (uint32_t i = ir0; i < ir1; ++i) { + const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12); + const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11); + const uint32_t i10 = i; + + const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; + + uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; + if (i1 >= ne1) { + // ignore invalid indices + continue; + } + + const uintptr_t src0_ptr = octx->src0.data + i*nb01 + i02*nb02 + i03*nb03; + const uintptr_t dst_ptr = octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3; + + // copy row + hvx_copy_fp32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00); + } + } + } + + return HTP_STATUS_OK; +} + +static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, const int ith) { + set_rows_preamble; + + // parallelize by rows of src0 + const uint32_t dr = octx->src0_nrows_per_thread; + const uint32_t ir0 = dr * ith; + const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr; + + const bool is_i32 = (octx->src1.type == HTP_TYPE_I32); + + for (uint32_t i03 = 0; i03 < ne03; ++i03) { + for (uint32_t i02 = 0; i02 < ne02; ++i02) { + for (uint32_t i = ir0; i < ir1; ++i) { + const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12); + const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11); + const uint32_t i10 = i; + + const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12; + + uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr; + if (i1 >= ne1) { + // ignore invalid indices + continue; + } + + const uint8_t* src0_ptr = (const uint8_t *) octx->src0.data + i*nb01 + i02*nb02 + i03*nb03; + uint8_t* dst_ptr = (uint8_t *) octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3; + + hvx_copy_fp16_fp32_uu(dst_ptr, src0_ptr, ne00); + } + } + } + + return HTP_STATUS_OK; +} + +static void set_rows_work_f16_f32(unsigned int n, unsigned int i, void *data) { + set_rows_thread_f16_f32((struct htp_ops_context *) data, n, i); +} + +static void set_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) { + set_rows_thread_f32_f32((struct htp_ops_context *) data, n, i); +} + +int op_set_rows(struct htp_ops_context * octx) { + set_rows_preamble; + + if (octx->src0.type != HTP_TYPE_F32) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->dst.type != HTP_TYPE_F32 && octx->dst.type != HTP_TYPE_F16) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) { + return HTP_STATUS_NO_SUPPORT; + } + + if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) { + return HTP_STATUS_OK; + } + + octx->set_rows_div_ne12 = init_fastdiv_values(ne12); + octx->set_rows_div_ne11 = init_fastdiv_values(ne11); + + const uint32_t n_jobs = MIN(nr, octx->n_threads); + octx->src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs; + + switch(octx->dst.type) { + case HTP_TYPE_F32: + worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f32_f32, octx, n_jobs); + break; + case HTP_TYPE_F16: + worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f16_f32, octx, n_jobs); + break; + default: + return HTP_STATUS_NO_SUPPORT; + } + + return HTP_STATUS_OK; +} diff --git a/ggml/src/ggml-hexagon/htp/softmax-ops.c b/ggml/src/ggml-hexagon/htp/softmax-ops.c index 5bf0cbf792..80d249a22c 100644 --- a/ggml/src/ggml-hexagon/htp/softmax-ops.c +++ b/ggml/src/ggml-hexagon/htp/softmax-ops.c @@ -238,7 +238,7 @@ static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ct hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale, (const uint8_t *) mp_f32, slope); } else { - hvx_scale_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale); + hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, softmax_ctx->scale); if (mp_f32) { if (softmax_ctx->use_f16) { for (int i = 0; i < ne00; ++i) { @@ -258,7 +258,7 @@ static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ct float max = hvx_self_max_f32((const uint8_t *) wp0, ne00); float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max); sum = sum > 0.0 ? (1.0 / sum) : 1; - hvx_scale_f32((const uint8_t *) wp2, (uint8_t *) dp, ne00, sum); + hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum); } } } diff --git a/ggml/src/ggml-hexagon/htp/unary-ops.c b/ggml/src/ggml-hexagon/htp/unary-ops.c index bb7557b025..8ed1e5b661 100644 --- a/ggml/src/ggml-hexagon/htp/unary-ops.c +++ b/ggml/src/ggml-hexagon/htp/unary-ops.c @@ -83,6 +83,31 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src, } } +static void scale_htp_f32(const float * restrict src, + float * restrict dst, + uint8_t * restrict spad, + const uint32_t num_rows, + const uint32_t row_elems, + const size_t row_size, + int32_t * op_params, + int opt_path) { + float scale = 0.f; + float bias = 0.f; + memcpy(&scale, &op_params[0], sizeof(float)); + memcpy(&bias, &op_params[1], sizeof(float)); + + for (uint32_t ir = 0; ir < num_rows; ir++) { + const float * restrict src_local = src + (ir * row_elems); + float * restrict dst_local = dst + (ir * row_elems); + + if (ir + 1 < num_rows) { + htp_l2fetch(src_local + row_elems, 1, row_size, row_size); + } + + hvx_scale_offset_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias); + } +} + static void rms_norm_htp_f32(const float * restrict src, float * restrict dst, uint8_t * restrict spad, @@ -110,7 +135,7 @@ static void rms_norm_htp_f32(const float * restrict src, const float mean = sum / row_elems; const float scale = 1.0f / sqrtf(mean + epsilon); - hvx_scale_f32((const uint8_t *) src_local, (uint8_t *) dst_local, row_elems, scale); + hvx_scale_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale); } } } @@ -162,6 +187,9 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src, case HTP_OP_RMS_NORM: rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); break; + case HTP_OP_SCALE: + scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); + break; default: break; @@ -195,6 +223,10 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) { unary_op_func = unary_job_dispatcher_f32; op_type = "rmsnorm-f32"; break; + case HTP_OP_SCALE: + unary_op_func = unary_job_dispatcher_f32; + op_type = "scale-f32"; + break; default: FARF(ERROR, "Unsupported unary Op %u\n", octx->op); diff --git a/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ggml/src/ggml-metal/ggml-metal-ops.cpp index acf2aa9184..a50b12b6f3 100644 --- a/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -2181,7 +2181,11 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) { const bool has_mask = op->src[3] != nullptr; - if (ggml_metal_op_flash_attn_ext_use_vec(op)) { + // note: the non-vec kernel requires more extra memory, so always reserve for it + GGML_ASSERT(OP_FLASH_ATTN_EXT_NCPSG >= OP_FLASH_ATTN_EXT_VEC_NCPSG); + + //if (ggml_metal_op_flash_attn_ext_use_vec(op)) { + if (false) { // note: always reserve the padding space to avoid graph reallocations //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0; const bool has_kvpad = true; diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp index 164b39d01e..d7c8ad8c16 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp @@ -1517,10 +1517,12 @@ bool rpc_server::graph_compute(const std::vector & input) { struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false); graph->n_nodes = n_nodes; std::unordered_map tensor_ptrs; + tensor_ptrs.reserve(n_tensors); for (uint32_t i = 0; i < n_tensors; i++) { - tensor_ptrs[tensors[i].id] = &tensors[i]; + tensor_ptrs.emplace(tensors[i].id, &tensors[i]); } std::unordered_map tensor_map; + tensor_map.reserve(n_nodes); for (uint32_t i = 0; i < n_nodes; i++) { int64_t id; memcpy(&id, &nodes[i], sizeof(id)); diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp index 541e4a50b7..d68735a040 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -550,6 +550,8 @@ struct vk_device_struct { uint64_t max_memory_allocation_size; uint64_t max_buffer_size; uint64_t suballocation_block_size; + uint64_t min_imported_host_pointer_alignment; + bool external_memory_host {}; bool fp16; bool bf16; bool pipeline_robustness; @@ -765,6 +767,9 @@ struct vk_device_struct { vk_pipeline pipeline_topk_f32[num_topk_pipelines]; vk_pipeline pipeline_sum_rows_f32; vk_pipeline pipeline_cumsum_f32; + vk_pipeline pipeline_cumsum_small_f32; + vk_pipeline pipeline_cumsum_multipass1_f32; + vk_pipeline pipeline_cumsum_multipass2_f32; vk_pipeline pipeline_argmax_f32; vk_pipeline pipeline_count_equal_i32; std::map pipeline_solve_tri_f32; @@ -2407,7 +2412,8 @@ static std::vector ggml_vk_find_memory_properties(const vk::PhysicalDe return indices; } -static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list & req_flags_list) { +static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list & req_flags_list, + void *import_ptr = nullptr) { VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags_list.begin()[0]) << ", " << to_string(req_flags_list.begin()[req_flags_list.size()-1]) << ")"); if (size > device->max_buffer_size) { throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device buffer size limit"); @@ -2436,6 +2442,12 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std nullptr, }; + vk::ExternalMemoryBufferCreateInfo external_memory_bci; + if (import_ptr) { + external_memory_bci.handleTypes = vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT; + buffer_create_info.setPNext(&external_memory_bci); + } + buf->buffer = device->device.createBuffer(buffer_create_info); vk::MemoryRequirements mem_req = device->device.getBufferMemoryRequirements(buf->buffer); @@ -2450,35 +2462,80 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std mem_flags_info.setPNext(&mem_priority_info); } - for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) { - const auto & req_flags = *it; - - const std::vector memory_type_indices = ggml_vk_find_memory_properties(&mem_props, &mem_req, req_flags); - - if (memory_type_indices.empty()) { - continue; + if (import_ptr) { + vk::MemoryHostPointerPropertiesEXT host_pointer_props; + try { + host_pointer_props = device->device.getMemoryHostPointerPropertiesEXT(vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT, import_ptr); + } catch (vk::SystemError& e) { + GGML_LOG_WARN("ggml_vulkan: Failed getMemoryHostPointerPropertiesEXT (%s)\n", e.what()); + device->device.destroyBuffer(buf->buffer); + return {}; } - buf->memory_property_flags = req_flags; + vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties(); - bool done = false; + uint32_t memory_type_idx; + vk::MemoryPropertyFlags property_flags = *req_flags_list.begin(); + for (memory_type_idx = 0; memory_type_idx < 32; ++memory_type_idx) { + if (!(host_pointer_props.memoryTypeBits & (1u << memory_type_idx))) { + continue; + } + if (!(mem_req.memoryTypeBits & (1u << memory_type_idx))) { + continue; + } - for (auto mtype_it = memory_type_indices.begin(); mtype_it != memory_type_indices.end(); mtype_it++) { - try { - buf->device_memory = device->device.allocateMemory({ mem_req.size, *mtype_it, &mem_flags_info }); - done = true; + vk::MemoryType memory_type = mem_props.memoryTypes[memory_type_idx]; + // check for visible+coherent+cached. Other flags (e.g. devicelocal) are allowed + if ((memory_type.propertyFlags & property_flags) == property_flags) { + property_flags = memory_type.propertyFlags; break; - } catch (const vk::SystemError& e) { - // loop and retry - // during last attempt throw the exception - if (it + 1 == req_flags_list.end() && mtype_it + 1 == memory_type_indices.end()) { - device->device.destroyBuffer(buf->buffer); - throw e; - } } } + if (memory_type_idx == 32) { + GGML_LOG_WARN("ggml_vulkan: Memory type for host allocation not found\n"); + device->device.destroyBuffer(buf->buffer); + return {}; + } - if (done) { - break; + buf->memory_property_flags = mem_props.memoryTypes[memory_type_idx].propertyFlags; + try { + vk::ImportMemoryHostPointerInfoEXT import_info; + import_info.handleType = vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT; + import_info.pHostPointer = import_ptr; + import_info.setPNext(&mem_flags_info); + buf->device_memory = device->device.allocateMemory({ size, memory_type_idx, &import_info }); + } catch (const vk::SystemError& e) { + } + } else { + for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) { + const auto & req_flags = *it; + + const std::vector memory_type_indices = ggml_vk_find_memory_properties(&mem_props, &mem_req, req_flags); + + if (memory_type_indices.empty()) { + continue; + } + buf->memory_property_flags = req_flags; + + bool done = false; + + for (auto mtype_it = memory_type_indices.begin(); mtype_it != memory_type_indices.end(); mtype_it++) { + try { + buf->device_memory = device->device.allocateMemory({ mem_req.size, *mtype_it, &mem_flags_info }); + done = true; + break; + } catch (const vk::SystemError& e) { + // loop and retry + // during last attempt throw the exception + if (it + 1 == req_flags_list.end() && mtype_it + 1 == memory_type_indices.end()) { + device->device.destroyBuffer(buf->buffer); + throw e; + } + } + } + + if (done) { + break; + } } } @@ -2489,8 +2546,12 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std buf->ptr = nullptr; - if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { - buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE); + if (import_ptr) { + buf->ptr = import_ptr; + } else { + if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { + buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE); + } } device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0); @@ -2702,7 +2763,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec switch (src0_type) { case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_M: - lut_size = 2*2048; + lut_size = 2*2048 + 4*2048; break; case GGML_TYPE_IQ2_XXS: lut_size = 8*256; @@ -2895,44 +2956,50 @@ static void ggml_vk_load_shaders(vk_device& device) { const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1; const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1; - l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 }; - m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; - s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; + const uint32_t s_warptile_wm = device->subgroup_size == 8 ? 8 : 32; - l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 }; - m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; - s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; + l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 }; + m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; + s_warptile = { subgroup_size_32, 32, 32, 16, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; + + l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 }; + m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; + s_warptile_mmq = { subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; // Integer MMQ has a smaller shared memory profile, but heavier register use - l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 }; - m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 }; - s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 }; + l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 }; + m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 }; + s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, 2, 1, 1, subgroup_size_8 }; // K-quants use even more registers, mitigate by setting WMITER to 1 - l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 }; - m_warptile_mmq_int_k = { 128, 64, 64, 32, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 }; - s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, subgroup_size_8 }; + l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 }; + m_warptile_mmq_int_k = { 128, 64, 64, 32, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 }; + s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 1, 2, 1, 1, subgroup_size_8 }; - l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 }; - m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 }; - s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 }; + l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 }; + m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 }; + s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 }; - l_warptile_mmqid = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_8 }; - m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 }; - s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 }; + l_warptile_mmqid = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_8 }; + m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 }; + s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 }; - l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 }; - m_warptile_mmqid_int = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 }; - s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, mul_mat_subgroup_size_8 }; + l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 }; + m_warptile_mmqid_int = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 }; + s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, 2, 1, 1, mul_mat_subgroup_size_8 }; - l_warptile_mmqid_int_k = { 128, 128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 }; - m_warptile_mmqid_int_k = { 128, 64, 64, 32, mul_mat_subgroup_size_16, 32, 1, 2, 2, 1, mul_mat_subgroup_size_16 }; - s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, mul_mat_subgroup_size_16 }; + l_warptile_mmqid_int_k = { 128, 128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 }; + m_warptile_mmqid_int_k = { 128, 64, 64, 32, mul_mat_subgroup_size_16, 32, 1, 2, 2, 1, mul_mat_subgroup_size_16 }; + s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 1, 2, 1, 1, mul_mat_subgroup_size_16 }; // chip specific tuning if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) { m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; m_warptile_mmqid = m_warptile_mmqid_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; + } else if (device->vendor_id == VK_VENDOR_ID_INTEL && device->coopmat_support && device->architecture == INTEL_XE2) { + // Xe2/Xe3 with coopmat enabled - warptile performance tuning + l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; + l_warptile_mmq = { 512, 128, 128, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; } l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; @@ -3615,6 +3682,11 @@ static void ggml_vk_load_shaders(vk_device& device) { m_wg_denoms = { 64, 64, 1 }; s_wg_denoms = { 32, 32, 1 }; + if (device->vendor_id == VK_VENDOR_ID_INTEL && device->architecture == INTEL_XE2) { + // Xe2/Xe3 - bf16 warptile performance tuning + l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, 4, 4, 1, subgroup_size_8 }; + } + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); } @@ -3627,6 +3699,7 @@ static void ggml_vk_load_shaders(vk_device& device) { uint32_t rm_kq = 2; uint32_t rm_stdq_int = 1; uint32_t rm_kq_int = 1; + auto const &rm_iq_int = [](uint32_t i) { return i == 0 ? 8u : 4u; }; if (device->vendor_id == VK_VENDOR_ID_AMD) { if (device->architecture == AMD_GCN) { rm_stdq = 2; @@ -3730,6 +3803,10 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_q8_1_f32", arr_dmmv_q4_k_q8_1_f32_len[reduc], arr_dmmv_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_q8_1_f32", arr_dmmv_q5_k_q8_1_f32_len[reduc], arr_dmmv_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_q8_1_f32", arr_dmmv_q6_k_q8_1_f32_len[reduc], arr_dmmv_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_q8_1_f32", arr_dmmv_iq1_s_q8_1_f32_len[reduc], arr_dmmv_iq1_s_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(i), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(i), i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_q8_1_f32", arr_dmmv_iq1_m_q8_1_f32_len[reduc], arr_dmmv_iq1_m_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(i), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(i), i+1}, 1, true, use_subgroups, subgroup_size_int); + } #endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT } @@ -3776,6 +3853,9 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", arr_dmmv_id_q4_k_q8_1_f32_len[reduc], arr_dmmv_id_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", arr_dmmv_id_q5_k_q8_1_f32_len[reduc], arr_dmmv_id_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", arr_dmmv_id_q6_k_q8_1_f32_len[reduc], arr_dmmv_id_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_q8_1_f32", arr_dmmv_id_iq1_s_q8_1_f32_len[reduc], arr_dmmv_id_iq1_s_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_q8_1_f32", arr_dmmv_id_iq1_m_q8_1_f32_len[reduc], arr_dmmv_id_iq1_m_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int); } #endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT } @@ -3783,6 +3863,7 @@ static void ggml_vk_load_shaders(vk_device& device) { #if !defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) GGML_UNUSED(rm_stdq_int); GGML_UNUSED(rm_kq_int); + GGML_UNUSED(rm_iq_int); #endif // dequant shaders @@ -4169,7 +4250,11 @@ static void ggml_vk_load_shaders(vk_device& device) { ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); - ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size }, 1, true, true, device->subgroup_size); + const uint32_t cumsum_elem_per_thread = (device->vendor_id == VK_VENDOR_ID_AMD || device->vendor_id == VK_VENDOR_ID_INTEL) ? 2 : 4; + ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 256, device->subgroup_size, cumsum_elem_per_thread }, 1, true, true, device->subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_cumsum_small_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size, 1 }, 1, true, true, device->subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_cumsum_multipass1_f32, "cumsum_multipass1_f32", cumsum_multipass1_f32_len, cumsum_multipass1_f32_data, "main", 3, sizeof(vk_op_sum_rows_push_constants), {256, 1, 1}, { 256, device->subgroup_size }, 1, true, true, device->subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_cumsum_multipass2_f32, "cumsum_multipass2_f32", cumsum_multipass2_f32_len, cumsum_multipass2_f32_data, "main", 3, sizeof(vk_op_sum_rows_push_constants), {256, 1, 1}, { 256, device->subgroup_size }, 1, true, true, device->subgroup_size); ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1); @@ -4429,6 +4514,8 @@ static vk_device ggml_vk_get_device(size_t idx) { } else if (strcmp("VK_EXT_memory_priority", properties.extensionName) == 0 && getenv("GGML_VK_ENABLE_MEMORY_PRIORITY")) { device->memory_priority = true; + } else if (strcmp("VK_EXT_external_memory_host", properties.extensionName) == 0) { + device->external_memory_host = true; } } @@ -4443,6 +4530,7 @@ static vk_device ggml_vk_get_device(size_t idx) { vk::PhysicalDeviceVulkan12Properties vk12_props; vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props; + vk::PhysicalDeviceExternalMemoryHostPropertiesEXT external_memory_host_props; props2.pNext = &props3; props3.pNext = &subgroup_props; @@ -4482,11 +4570,22 @@ static vk_device ggml_vk_get_device(size_t idx) { last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props; } + if (device->external_memory_host) { + last_struct->pNext = (VkBaseOutStructure *)&external_memory_host_props; + last_struct = (VkBaseOutStructure *)&external_memory_host_props; + } + device->physical_device.getProperties2(&props2); device->properties = props2.properties; device->vendor_id = device->properties.vendorID; device->driver_id = driver_props.driverID; + if (device->driver_id == vk::DriverId::eMoltenvk) { + // Disable external_memory_host until https://github.com/KhronosGroup/MoltenVK/pull/2622 + // is available in the Vulkan SDK. + device->external_memory_host = false; + } + // Implementing the async backend interfaces seems broken on older Intel HW, // see https://github.com/ggml-org/llama.cpp/issues/17302. device->support_async = (device->vendor_id != VK_VENDOR_ID_INTEL || @@ -4568,6 +4667,8 @@ static vk_device ggml_vk_get_device(size_t idx) { device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated; + device->min_imported_host_pointer_alignment = external_memory_host_props.minImportedHostPointerAlignment; + device->max_workgroup_size_log2 = uint32_t(log2f(float(device->properties.limits.maxComputeWorkGroupInvocations))); std::vector queue_family_props = device->physical_device.getQueueFamilyProperties(); @@ -4699,6 +4800,10 @@ static vk_device ggml_vk_get_device(size_t idx) { device_extensions.push_back("VK_KHR_pipeline_executable_properties"); } + if (device->external_memory_host) { + device_extensions.push_back("VK_EXT_external_memory_host"); + } + vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2); device->pipeline_executable_properties_support = pipeline_executable_properties_support; @@ -4965,11 +5070,23 @@ static vk_device ggml_vk_get_device(size_t idx) { switch (device->vendor_id) { #ifndef GGML_VULKAN_RUN_TESTS case VK_VENDOR_ID_AMD: + device->mul_mat_l[i] = false; + device->mul_mat_m[i] = true; + device->mul_mat_s[i] = true; + device->mul_mat_id_l[i] = false; + device->mul_mat_id_m[i] = true; + device->mul_mat_id_s[i] = true; + break; case VK_VENDOR_ID_INTEL: - device->mul_mat_l[i] = false; + if (!device->coopmat_support || device->architecture != INTEL_XE2) { + device->mul_mat_l[i] = false; + device->mul_mat_id_l[i] = false; + } else { + device->mul_mat_l[i] = true; // if coopmat & XE2+, allow large matmul warptile config for Intel + device->mul_mat_id_l[i] = true; + } device->mul_mat_m[i] = true; device->mul_mat_s[i] = true; - device->mul_mat_id_l[i] = false; device->mul_mat_id_m[i] = true; device->mul_mat_id_s[i] = true; break; @@ -5616,6 +5733,8 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: break; default: return nullptr; @@ -5772,6 +5891,8 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context case GGML_TYPE_Q4_K: case GGML_TYPE_Q5_K: case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: break; default: return nullptr; @@ -6753,7 +6874,12 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1); - ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array{ne}, { ne, 1, 1 }); + const uint32_t num_blocks = CEIL_DIV(ne, pipeline->wg_denoms[0]); + // clamp the number of elements to the max workgroup count. The shader will iterate over the total number of blocks. + const uint64_t max_elements = std::min(uint64_t{ctx->device->properties.limits.maxComputeWorkGroupCount[0]} * pipeline->wg_denoms[0], std::numeric_limits::max()); + const uint32_t elements = std::min(ne, static_cast(max_elements)); + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array{ ne, num_blocks }, { elements, 1, 1 }); ggml_vk_sync_buffers(ctx, subctx); } @@ -7037,7 +7163,7 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_ // Quantization overhead is not worth it for small k switch (device->vendor_id) { case VK_VENDOR_ID_NVIDIA: - if (src0_type == GGML_TYPE_Q2_K) { + if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_IQ1_S || src0_type == GGML_TYPE_IQ1_M) { return true; } @@ -8791,7 +8917,11 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const return nullptr; case GGML_OP_CUMSUM: if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { - return ctx->device->pipeline_cumsum_f32; + if (src0->ne[0] <= 512) { + return ctx->device->pipeline_cumsum_small_f32; + } else { + return ctx->device->pipeline_cumsum_f32; + } } return nullptr; case GGML_OP_SOLVE_TRI: @@ -10695,8 +10825,50 @@ static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, cons } static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { - vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]); - ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, p); + vk_op_sum_rows_push_constants pc = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]); + // Use the single pass shader when the rows are small or there are enough rows to fill the GPU. + // For fewer, larger rows, use the multipass shader to spread each row across SMs. + if (dst->ne[0] <= 4096 || ggml_nrows(dst) >= ctx->device->shader_core_count) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, pc); + return; + } + + // First pass computes partial sums within a block, and stores the last partial + // to the temp buffer. Second pass sums the block partials from the temp buffer + // and adds that to the result of the first pass. + vk_pipeline pipeline1 = ctx->device->pipeline_cumsum_multipass1_f32; + vk_pipeline pipeline2 = ctx->device->pipeline_cumsum_multipass2_f32; + GGML_ASSERT(pipeline1 != nullptr && pipeline2 != nullptr); + + ggml_pipeline_request_descriptor_sets(ctx, pipeline1, 1); + ggml_pipeline_request_descriptor_sets(ctx, pipeline2, 1); + + std::array elements; + + elements[0] = dst->ne[0]; + elements[1] = (uint32_t)ggml_nrows(dst); + elements[2] = 1; + + size_t temp_size = sizeof(float) * elements[0] * ggml_nrows(dst); + + if (ctx->prealloc_size_split_k < temp_size) { + ctx->prealloc_size_split_k = temp_size; + ggml_vk_preallocate_buffers(ctx, subctx); + } + + vk_subbuffer src_buf = ggml_vk_tensor_subbuffer(ctx, src0); + vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst); + vk_subbuffer temp_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0); + + if (ctx->prealloc_split_k_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline1, {src_buf, dst_buf, temp_buf}, pc, elements); + ggml_vk_sync_buffers(ctx, subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline2, {src_buf, dst_buf, temp_buf}, pc, elements); + + ctx->prealloc_split_k_need_sync = true; } static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { @@ -14133,6 +14305,19 @@ static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const } static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + const vk_device& device = ggml_vk_get_device(ctx->device); + + // reject any tensors larger than the max buffer size + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (op->src[i] && ggml_nbytes(op->src[i]) > device->max_buffer_size) { + return false; + } + } + if (ggml_nbytes(op) > device->max_buffer_size) { + return false; + } + switch (op->op) { case GGML_OP_UNARY: switch (ggml_get_unary_op(op)) { @@ -14181,8 +14366,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm case GGML_OP_MUL_MAT_ID: { ggml_type src0_type = op->src[0]->type; - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - const vk_device& device = ggml_vk_get_device(ctx->device); if (op->op == GGML_OP_MUL_MAT_ID) { if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) { // If there's not enough shared memory for row_ids and the result tile, fallback to CPU @@ -14243,8 +14426,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } case GGML_OP_FLASH_ATTN_EXT: { - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - auto device = ggml_vk_get_device(ctx->device); bool coopmat2 = device->coopmat2; uint32_t HSK = op->src[1]->ne[0]; uint32_t HSV = op->src[2]->ne[0]; @@ -14466,8 +14647,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) { return false; } - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - auto device = ggml_vk_get_device(ctx->device); // pipeline_argsort_large_f32 requires vulkan memory model. if (device->vulkan_memory_model) { return true; @@ -14480,8 +14659,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) { return false; } - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - auto device = ggml_vk_get_device(ctx->device); // We could potentially support larger, using argsort to sort the // whole thing. Not clear if this is needed. uint32_t min_pipeline = (uint32_t)log2f(float(op->ne[0])) + 1; @@ -14528,8 +14705,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]); case GGML_OP_CUMSUM: { - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - auto device = ggml_vk_get_device(ctx->device); if (device->subgroup_arithmetic && device->subgroup_require_full_support) { return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]); } @@ -14537,9 +14712,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm } case GGML_OP_SOLVE_TRI: { - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - const vk_device& device = ggml_vk_get_device(ctx->device); - if (op->type != GGML_TYPE_F32 || op->src[0]->type != GGML_TYPE_F32) { return false; } @@ -14604,9 +14776,6 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm return false; } - ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; - const vk_device& device = ggml_vk_get_device(ctx->device); - const uint32_t SPLIT_H = 16; size_t stateC_size = SPLIT_H * d_state * sizeof(float); @@ -14700,6 +14869,51 @@ static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggm VK_CHECK(device->device.waitForFences({ vkev->fence }, true, UINT64_MAX), "event_synchronize"); } +static vk_buffer ggml_vk_buffer_from_host_ptr(vk_device & device, void * ptr, size_t size) { + if (!device->external_memory_host) { + return {}; + } + + uintptr_t uptr = reinterpret_cast(ptr); + if (uptr & (device->min_imported_host_pointer_alignment - 1)) { + return {}; + } + if (size & (device->min_imported_host_pointer_alignment - 1)) { + return {}; + } + + const vk::MemoryPropertyFlags property_flags = vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached; + + vk_buffer buf {}; + try { + buf = ggml_vk_create_buffer(device, size, { property_flags }, ptr); + } catch (vk::SystemError& e) { + GGML_LOG_WARN("ggml_vulkan: Failed ggml_vk_create_buffer (%s)\n", e.what()); + } + + return buf; +} + +static ggml_backend_buffer_t ggml_backend_vk_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) { + VK_LOG_DEBUG("ggml_backend_vk_device_buffer_from_host_ptr(backend=" << dev << ", ptr=" << ptr << ", size=" << size << ")"); + GGML_UNUSED(max_tensor_size); + + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + auto device = ggml_vk_get_device(ctx->device); + + vk_buffer buf = ggml_vk_buffer_from_host_ptr(device, ptr, size); + + if (!buf) { + return {}; + } + + ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(device, std::move(buf), device->name); + + ggml_backend_buffer_t ret = ggml_backend_buffer_init(ggml_backend_vk_device_get_buffer_type(dev), ggml_backend_vk_buffer_interface, bufctx, size); + + return ret; +} + static const struct ggml_backend_device_i ggml_backend_vk_device_i = { /* .get_name = */ ggml_backend_vk_device_get_name, /* .get_description = */ ggml_backend_vk_device_get_description, @@ -14709,7 +14923,7 @@ static const struct ggml_backend_device_i ggml_backend_vk_device_i = { /* .init_backend = */ ggml_backend_vk_device_init, /* .get_buffer_type = */ ggml_backend_vk_device_get_buffer_type, /* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type, - /* .buffer_from_host_ptr = */ NULL, + /* .buffer_from_host_ptr = */ ggml_backend_vk_device_buffer_from_host_ptr, /* .supports_op = */ ggml_backend_vk_device_supports_op, /* .supports_buft = */ ggml_backend_vk_device_supports_buft, /* .offload_op = */ ggml_backend_vk_device_offload_op, diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp b/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp index a4c8fc354e..75e3c3b0eb 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/cumsum.comp @@ -14,6 +14,7 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; layout (constant_id = 0) const uint BLOCK_SIZE = 128; layout (constant_id = 1) const uint SUBGROUP_SIZE = 32; +layout (constant_id = 2) const uint ELEM_PER_THREAD = 4; #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) @@ -38,32 +39,45 @@ void main() { last_sum = 0; } - uint col = tid; - uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE); + uint col = tid * ELEM_PER_THREAD; + uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE * ELEM_PER_THREAD); for (int i = 0; i < num_iter; ++i) { - FLOAT_TYPE v = 0; - if (col < p.n_cols) { - v = FLOAT_TYPE(data_a[src_idx + col]); + FLOAT_TYPE v[ELEM_PER_THREAD]; + FLOAT_TYPE thread_sum = 0; + [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) { + if (col + j < p.n_cols) { + thread_sum += FLOAT_TYPE(data_a[src_idx + col + j]); + } + v[j] = thread_sum; } - v = subgroupInclusiveAdd(v); + thread_sum = subgroupExclusiveAdd(thread_sum); + [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) { + v[j] += thread_sum; + } // Store the largest partial sum for each subgroup, then add the partials for all // lower subgroups and the final partial sum from the previous iteration. if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) { - partial[subgroup_id] = v; + partial[subgroup_id] = v[ELEM_PER_THREAD - 1]; } barrier(); - for (int j = 0; j < subgroup_id; ++j) { - v += partial[j]; + for (int s = 0; s < subgroup_id; ++s) { + [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) { + v[j] += partial[s]; + } + } + [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) { + v[j] += last_sum; } - v += last_sum; barrier(); if (tid == BLOCK_SIZE - 1) { - last_sum = v; + last_sum = v[ELEM_PER_THREAD - 1]; } - if (col < p.n_cols) { - data_d[dst_idx + col] = D_TYPE(v); + [[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) { + if (col + j < p.n_cols) { + data_d[dst_idx + col + j] = D_TYPE(v[j]); + } } - col += BLOCK_SIZE; + col += BLOCK_SIZE * ELEM_PER_THREAD; } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp new file mode 100644 index 0000000000..6d39f927fc --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass1.comp @@ -0,0 +1,60 @@ +#version 450 + +#include "types.glsl" +#include "sum_rows.glsl" + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_basic : enable + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; +layout (binding = 2) writeonly buffer T {D_TYPE data_t[];}; + +layout (constant_id = 0) const uint BLOCK_SIZE = 128; +layout (constant_id = 1) const uint SUBGROUP_SIZE = 32; + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + +shared FLOAT_TYPE partial[BLOCK_SIZE / SUBGROUP_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.y; + const uint tid = gl_LocalInvocationID.x; + const uint col = gl_GlobalInvocationID.x; + + const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L); + const uint i03_offset = i03 * p.ne01*p.ne02; + const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L); + const uint i01 = row - i03_offset - i02*p.ne01; + + const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03; + const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13; + + uint subgroup_id = tid / SUBGROUP_SIZE; + + FLOAT_TYPE v = 0; + if (col < p.n_cols) { + v = FLOAT_TYPE(data_a[src_idx + col]); + } + v = subgroupInclusiveAdd(v); + + // Store the largest partial sum for each subgroup, then add the partials for all + // lower subgroups and the final partial sum from the previous iteration. + if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) { + partial[subgroup_id] = v; + } + barrier(); + for (int j = 0; j < subgroup_id; ++j) { + v += partial[j]; + } + barrier(); + if (tid == BLOCK_SIZE - 1) { + data_t[gl_WorkGroupID.x + gl_NumWorkGroups.x * row] = v; + } + if (col < p.n_cols) { + data_d[dst_idx + col] = D_TYPE(v); + } +} diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp b/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp new file mode 100644 index 0000000000..e401893466 --- /dev/null +++ b/ggml/src/ggml-vulkan/vulkan-shaders/cumsum_multipass2.comp @@ -0,0 +1,66 @@ +#version 450 + +#include "types.glsl" +#include "sum_rows.glsl" + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_basic : enable + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) buffer D {D_TYPE data_d[];}; +layout (binding = 2) readonly buffer T {D_TYPE data_t[];}; + +layout (constant_id = 0) const uint BLOCK_SIZE = 128; +layout (constant_id = 1) const uint SUBGROUP_SIZE = 32; + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + +shared FLOAT_TYPE temp[BLOCK_SIZE / SUBGROUP_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.y; + const uint tid = gl_LocalInvocationID.x; + + const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L); + const uint i03_offset = i03 * p.ne01*p.ne02; + const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L); + const uint i01 = row - i03_offset - i02*p.ne01; + + const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03; + const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13; + + const uint col = gl_GlobalInvocationID.x; + + float v = 0; + // prefetch value we're adding to + if (col < p.n_cols) { + v = data_d[dst_idx + col]; + } + + // compute the sum of all previous blocks + uint c = tid; + float sum = 0; + while (c < gl_WorkGroupID.x) { + sum += data_t[c + gl_NumWorkGroups.x * row]; + c += BLOCK_SIZE; + } + + sum = subgroupAdd(sum); + if (gl_SubgroupInvocationID == 0) { + temp[gl_SubgroupID] = sum; + } + barrier(); + sum = 0; + [[unroll]] for (uint s = 0; s < BLOCK_SIZE / SUBGROUP_SIZE; ++s) { + sum += temp[s]; + } + + // Add the sum to what the first pass computed + if (col < p.n_cols) { + data_d[dst_idx + col] = v + sum; + } +} + diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl index 376944f1e2..7865a6bda7 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl @@ -462,7 +462,8 @@ vec2 get_dm(uint ib, uint a_offset) { #if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) vec2 get_dm(uint ib, uint a_offset) { - return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m)); + const vec2 dm = vec2(data_a_packed32[a_offset + ib].dm); + return dm; } #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp index 15f005be3e..ff5f43979d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp @@ -14,6 +14,8 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; #define K_PER_ITER 8 #elif defined(DATA_A_QUANT_K) #define K_PER_ITER 16 +#elif defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M) +#define K_PER_ITER 32 #else #error unimplemented #endif @@ -49,6 +51,15 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 1]; cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 2]; cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 3]; +#elif K_PER_ITER == 32 + cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 ]; + cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 1]; + cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 2]; + cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 3]; + cache_b_qs[4] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 4]; + cache_b_qs[5] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 5]; + cache_b_qs[6] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 6]; + cache_b_qs[7] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 7]; #else #error unimplemented #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl index 2389ea0b1e..6ddbed309d 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq_funcs.glsl @@ -377,3 +377,118 @@ FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { return FLOAT_TYPE(float(cache_b_ds.x) * float(d_scale) * float(q_sum)); } #endif + +#if defined(DATA_A_IQ1_S) +void repack8(uint ib, uint iqs, out i32vec4 out0, out i32vec4 out1) { + const uint ib32 = iqs / 32; + + const uint qh = data_a[ib].qh[ib32]; + + const uint qs16_0 = data_a_packed16[ib].qs[(4 * ib32 + 0) / 2]; + const uint qs16_1 = data_a_packed16[ib].qs[(4 * ib32 + 2) / 2]; + + const uint qs0 = qs16_0 & 0xFF; + const uint qs1 = qs16_0 >> 8; + const uint qs2 = qs16_1 & 0xFF; + const uint qs3 = qs16_1 >> 8; + + const uint hi0 = bitfieldExtract(qh, 3 * int(0), 3); + const uint hi1 = bitfieldExtract(qh, 3 * int(1), 3); + const uint hi2 = bitfieldExtract(qh, 3 * int(2), 3); + const uint hi3 = bitfieldExtract(qh, 3 * int(3), 3); + + const int32_t grid0 = int32_t(iq1s_grid_gpu[qs0 | (hi0 << 8)]); + const int32_t grid1 = int32_t(iq1s_grid_gpu[qs1 | (hi1 << 8)]); + const int32_t grid2 = int32_t(iq1s_grid_gpu[qs2 | (hi2 << 8)]); + const int32_t grid3 = int32_t(iq1s_grid_gpu[qs3 | (hi3 << 8)]); + + out0 = i32vec4((grid0 >> 0) & 0x0F0F0F0F, + (grid0 >> 4) & 0x0F0F0F0F, + (grid1 >> 0) & 0x0F0F0F0F, + (grid1 >> 4) & 0x0F0F0F0F); + out1 = i32vec4((grid2 >> 0) & 0x0F0F0F0F, + (grid2 >> 4) & 0x0F0F0F0F, + (grid3 >> 0) & 0x0F0F0F0F, + (grid3 >> 4) & 0x0F0F0F0F); +} + +vec2 get_dm(uint ib, uint iqs) { + const uint ib32 = iqs / 32; + + const uint qh = data_a[ib].qh[ib32]; + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + + const float d = float(data_a[ib].d); + const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1); + + // the -1 cancels out the bias in iq1s_grid_gpu + return FLOAT_TYPE_VEC2(dl, dl * (delta - 1)); +} + +FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { + int32_t q_sum = 0; + + const uint ib_k = ib_a / 8; + const uint iqs_k = (ib_a % 8) * 32 + iqs * 32; + + i32vec4 qs_a0; + i32vec4 qs_a1; + repack8(ib_k, iqs_k, qs_a0, qs_a1); + + const vec2 dm = get_dm(ib_k, iqs_k); + + q_sum += dotPacked4x8EXT(qs_a0.x, cache_b_qs[0]); + q_sum += dotPacked4x8EXT(qs_a0.y, cache_b_qs[1]); + q_sum += dotPacked4x8EXT(qs_a0.z, cache_b_qs[2]); + q_sum += dotPacked4x8EXT(qs_a0.w, cache_b_qs[3]); + q_sum += dotPacked4x8EXT(qs_a1.x, cache_b_qs[4]); + q_sum += dotPacked4x8EXT(qs_a1.y, cache_b_qs[5]); + q_sum += dotPacked4x8EXT(qs_a1.z, cache_b_qs[6]); + q_sum += dotPacked4x8EXT(qs_a1.w, cache_b_qs[7]); + + return FLOAT_TYPE(float(cache_b_ds.x) * float(dm.x) * float(q_sum) + float(dm.y) * float(cache_b_ds.y)); +} +#endif + +#if defined(DATA_A_IQ1_M) +FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) { + const uint ib_k = ib_a / 8; + const uint iqs_k = (ib_a % 8) * 32 + iqs * 32; + + const uint ib32 = iqs_k / 32; + const uint ib64 = ib32 / 2; + + const uint16_t[4] scales = data_a[ib_k].scales; + const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; + const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); + + const uint qs32 = data_a_packed32[ib_k].qs[ib32]; + const uint qh16 = data_a_packed16[ib_k].qh[ib32]; + + float sum = 0; + const uint sc = data_a[ib_k].scales[ib64]; + [[unroll]] for (int l = 0; l < 4; ++l) { + const uint ib16 = 2 * ib32 + l / 2; + const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1); + const uint qh = qh16 >> (4 * l); + const uint qs = (qs32 >> (8 * l)) & 0xFF; + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + + const int32_t grid = int32_t(iq1s_grid_gpu[qs | ((qh & 7) << 8)]); + + int32_t q_sum = 0; + q_sum += dotPacked4x8EXT((grid >> 0) & 0x0F0F0F0F, cache_b_qs[2 * l + 0]); + q_sum += dotPacked4x8EXT((grid >> 4) & 0x0F0F0F0F, cache_b_qs[2 * l + 1]); + + int32_t y_sum = 0; + y_sum += dotPacked4x8EXT(int(0x01010101), cache_b_qs[2 * l + 0]); + y_sum += dotPacked4x8EXT(int(0x01010101), cache_b_qs[2 * l + 1]); + + // the -1 cancels out the bias in iq1s_grid_gpu + sum += dl * (q_sum + y_sum * (delta - 1)); + } + sum *= float(cache_b_ds.x); + + return sum; +} +#endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl index 1a3531761a..ce7f2d699a 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl @@ -47,7 +47,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin #endif #elif defined(DATA_A_Q4_0) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; - const uint buf_idx = col * SHMEM_STRIDE + 2 * row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; const uint ib = idx / 4; const uint iqs = idx & 0x03; @@ -63,16 +63,15 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw); #elif defined(DATA_A_Q4_1) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; - const uint buf_idx = col * SHMEM_STRIDE + 2 * row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; const uint ib = idx / 4; const uint iqs = idx & 0x03; - const float d = float(data_a_packed16[ib].d); - const float m = float(data_a_packed16[ib].m); - const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16); - const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m; - const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m; + const vec2 dm = vec2(data_a_packed32[ib].dm); + const uint vui = data_a_packed32[ib].qs[iqs]; + const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * dm.x + dm.y; + const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * dm.x + dm.y; buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy); buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(v0.zw); @@ -80,7 +79,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.zw); #elif defined(DATA_A_Q5_0) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; - const uint buf_idx = col * SHMEM_STRIDE + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; const uint ib = idx / 8; const uint iqs = idx & 0x07; @@ -97,22 +96,26 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw); #elif defined(DATA_A_Q5_1) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; - const uint buf_idx = col * SHMEM_STRIDE + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; - const uint ib = idx / 8; - const uint iqs = idx & 0x07; + const uint ib = idx / 4; + const uint iqs = idx & 0x03; - const float d = float(data_a_packed16[ib].d); - const float m = float(data_a_packed16[ib].m); - const uint uint_qh = data_a_packed16[ib].qh; - const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10); - const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10); + const vec2 dm = vec2(data_a_packed32[ib].dm); + const uint uint_qh = data_a_packed32[ib].qh; + const uvec2 qh0 = uvec2(((uint_qh >> 4*iqs) << 4) & 0x10, (uint_qh >> (4*iqs + 12)) & 0x10); + const uvec2 qh1 = uvec2(((uint_qh >> (4*iqs + 1)) << 4) & 0x10, (uint_qh >> (4*iqs + 13)) & 0x10); + const uvec2 qh2 = uvec2(((uint_qh >> (4*iqs + 2)) << 4) & 0x10, (uint_qh >> (4*iqs + 14)) & 0x10); + const uvec2 qh3 = uvec2(((uint_qh >> (4*iqs + 3)) << 4) & 0x10, (uint_qh >> (4*iqs + 15)) & 0x10); - const uint vui = uint(data_a_packed16[ib].qs[iqs]); - const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m; + const uint vui = data_a_packed32[ib].qs[iqs]; + const vec4 v0 = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) * dm.x + dm.y; + const vec4 v1 = vec4(((vui >> 16) & 0xF) | qh2.x, ((vui >> 20) & 0xF) | qh2.y, ((vui >> 24) & 0xF) | qh3.x, ((vui >> 28) & 0xF) | qh3.y) * dm.x + dm.y; - buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz); - buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw); + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xz); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v1.xz); + buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v0.yw); + buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.yw); #elif defined(DATA_A_Q8_0) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -131,20 +134,21 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - const uint ib = idx / 128; // 2 values per idx - const uint iqs = idx % 128; // 0..127 + const uint ib = idx / 64; // 4 values per idx + const uint iqs = (idx % 64) * 2; // 0,2,4..126 const uint qsi = (iqs / 64) * 16 + (iqs % 16); // 0..15 const uint scalesi = iqs / 8; // 0..15 const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 - const uvec2 qs = uvec2(unpack8(data_a_packed16[ib].qs[qsi])); + const vec4 qs = vec4(unpack8((data_a_packed32[ib].qs[qsi / 2] >> qsshift) & 0x03030303)); const uint scales = data_a[ib].scales[scalesi]; const vec2 dm = vec2(data_a[ib].dm); - const vec2 v = dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4); + const vec4 v = dm.x * float(scales & 0xF) * qs - dm.y * float(scales >> 4); - buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy); + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw); #elif defined(DATA_A_Q3_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -173,8 +177,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - const uint ib = idx / 128; // 2 values per idx - const uint iqs = idx % 128; // 0..127 + const uint ib = idx / 64; // 4 values per idx + const uint iqs = (idx % 64) * 2; // 0,2,4..126 const uint n = iqs / 32; // 0,1,2,3 const uint b = (iqs % 32) / 16; // 0,1 @@ -200,16 +204,16 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const float d = loadd.x * sc; const float m = -loadd.y * mbyte; - const vec2 q = vec2(unpack8((uint(data_a_packed16[ib].qs[qsi / 2]) >> (b * 4)) & 0x0F0F).xy); + const vec4 q = vec4(unpack8((data_a_packed32[ib].qs[qsi / 4] >> (b * 4)) & 0x0F0F0F0F)); - buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, q.x, m), - fma(d, q.y, m)); + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m)); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m)); #elif defined(DATA_A_Q5_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; - const uint ib = idx / 128; // 2 values per idx - const uint iqs = idx % 128; // 0..127 + const uint ib = idx / 64; // 4 values per idx + const uint iqs = (idx % 64) * 2; // 0,2,4..126 const uint n = iqs / 32; // 0,1,2,3 const uint b = (iqs % 32) / 16; // 0,1 @@ -236,12 +240,12 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin const float d = loadd.x * sc; const float m = -loadd.y * mbyte; - const uint qs = (uint(data_a_packed16[ib].qs[qsi / 2]) >> (b * 4)) & 0x0F0F; - const uint qh = ((uint(data_a_packed16[ib].qh[qhi / 2]) >> (iqs / 16)) & 0x0101) << 4; - const vec2 q = vec2(unpack8(qs | qh).xy); + const uint qs = (data_a_packed32[ib].qs[qsi / 4] >> (b * 4)) & 0x0F0F0F0F; + const uint qh = ((data_a_packed32[ib].qh[qhi / 4] >> (iqs / 16)) & 0x01010101) << 4; + const vec4 q = vec4(unpack8(qs | qh)); - buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, q.x, m), - fma(d, q.y, m)); + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m)); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m)); #elif defined(DATA_A_Q6_K) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; @@ -455,7 +459,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); #elif defined(DATA_A_IQ4_NL) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; - const uint buf_idx = col * SHMEM_STRIDE + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; const uint ib = idx / 8; const uint iqs = idx & 0x07; @@ -469,7 +473,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin kvalues_iq4nl[vui >> 12]); #elif defined(DATA_A_MXFP4) const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; - const uint buf_idx = col * SHMEM_STRIDE + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; const uint ib = idx / 8; const uint iqs = (idx & 0x07) * 2; diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp b/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp index 20e45d0253..7ea29a07e3 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp @@ -15,6 +15,7 @@ layout (push_constant) uniform parameter { uint ne; + uint num_blocks; } p; #include "types.glsl" @@ -33,8 +34,7 @@ layout (binding = 1) writeonly buffer D {block_q8_1_x4 data_b[];}; shared float shmem[GROUP_SIZE]; #endif -void quantize() { - const uint wgid = gl_WorkGroupID.x; +void quantize(const uint wgid) { const uint tid = INVOCATION_ID; // Each thread handles a vec4, so 8 threads handle a block @@ -45,11 +45,7 @@ void quantize() { const uint ib = wgid * blocks_per_group + block_in_wg; const uint iqs = tid % 8; -#ifndef QBLOCK_X4 - if (ib >= gl_NumWorkGroups.x * blocks_per_group) { - return; - } -#else +#ifdef QBLOCK_X4 const uint ibx4_outer = ib / 4; const uint ibx4_inner = ib % 4; @@ -123,5 +119,9 @@ void quantize() { } void main() { - quantize(); + uint wgid = gl_WorkGroupID.x; + while (wgid < p.num_blocks) { + quantize(wgid); + wgid += gl_NumWorkGroups.x; + } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp index 4bf6d2bcb0..ef2f202ec9 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/topk_moe.comp @@ -101,6 +101,10 @@ void main() { const uint lane = gl_SubgroupInvocationID; float probs[experts_per_thread]; + [[unroll]] + for (int i = 0; i < experts_per_thread; i++) { + probs[i] = -INFINITY; + } [[unroll]] for (uint i = 0; i < n_experts; i += WARP_SIZE) { @@ -112,8 +116,9 @@ void main() { softmax_warp_inplace(probs, n_experts, lane, nexperts_use_push); } else if (gating_func == GATING_FUNC_SIGMOID) { [[unroll]] - for (int i = 0; i < experts_per_thread; i++) { - probs[i] = 1.f / (1.f + exp(-probs[i])); + for (uint i = 0; i < n_experts; i += WARP_SIZE) { + const uint expert = i + lane; + probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? 1.f / (1.f + exp(-probs[i / WARP_SIZE])) : -INFINITY; } } @@ -150,11 +155,11 @@ void main() { uint max_expert = lane; [[unroll]] - for (int i = 1; i < experts_per_thread; i++) { - const uint expert = lane + i * WARP_SIZE; - if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i] > max_val_s) { - max_val = probs[i]; - max_val_s = selection_probs[i]; + for (uint i = WARP_SIZE; i < n_experts; i += WARP_SIZE) { + const uint expert = i + lane; + if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i / WARP_SIZE] > max_val_s) { + max_val = probs[i / WARP_SIZE]; + max_val_s = selection_probs[i / WARP_SIZE]; max_expert = expert; } } diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl index 402a2a8397..bdb2c09259 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl +++ b/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl @@ -396,6 +396,12 @@ struct block_iq1_s { uint16_t qh[QUANT_K_IQ1_S/32]; }; +struct block_iq1_s_packed16 { + float16_t d; + uint16_t qs[QUANT_K_IQ1_S/8/2]; + uint16_t qh[QUANT_K_IQ1_S/32]; +}; + #define QUANT_K_IQ1_M 256 #define QUANT_R_IQ1_M 1 @@ -405,6 +411,18 @@ struct block_iq1_m { uint16_t scales[QUANT_K_IQ1_M/64]; }; +struct block_iq1_m_packed16 { + uint16_t qs[QUANT_K_IQ1_M/8/2]; + uint16_t qh[QUANT_K_IQ1_M/16/2]; + uint16_t scales[QUANT_K_IQ1_M/64]; +}; + +struct block_iq1_m_packed32 { + uint32_t qs[QUANT_K_IQ1_M/8/4]; + uint32_t qh[QUANT_K_IQ1_M/16/4]; + uint32_t scales[QUANT_K_IQ1_M/64/2]; +}; + struct block_iq1_m_packed64 { uint64_t qs[QUANT_K_IQ1_M/8/8]; uint64_t qh[QUANT_K_IQ1_M/16/8]; @@ -415,12 +433,15 @@ struct block_iq1_m_packed64 { #define QUANT_K QUANT_K_IQ1_S #define QUANT_R QUANT_R_IQ1_S #define A_TYPE block_iq1_s +#define A_TYPE_PACKED16 block_iq1_s_packed16 #endif #if defined(DATA_A_IQ1_M) #define QUANT_K QUANT_K_IQ1_M #define QUANT_R QUANT_R_IQ1_M #define A_TYPE block_iq1_m +#define A_TYPE_PACKED16 block_iq1_m_packed16 +#define A_TYPE_PACKED32 block_iq1_m_packed32 #endif #if defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M) @@ -559,7 +580,270 @@ const uint[1024] iq1s_grid_const = { 0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557 }; +// Same content as iq1s_grid_const except each 2-bit value is expanded to 4-bit +// and has 1 added to it (allows packed values to be extracted with & 0x0F0F0F0F +// and 0xF0F0F0F0). +const uint32_t[2048] iq1s_grid_gpu_const = { + 0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000, + 0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101, + 0x02000000, 0x02000002, 0x02000200, 0x02000202, 0x02010101, 0x02020000, 0x02020002, 0x02020200, + 0x02020202, 0x00000110, 0x00000111, 0x00010011, 0x00010110, 0x00010112, 0x00010211, 0x00010212, + 0x00020111, 0x01000011, 0x01000112, 0x01000211, 0x01010012, 0x01010111, 0x01010212, 0x01020011, + 0x01020110, 0x01020112, 0x01020210, 0x02000111, 0x02010011, 0x02010110, 0x02010112, 0x02020111, + 0x00000020, 0x00000022, 0x00000220, 0x00000222, 0x00010121, 0x00020020, 0x00020022, 0x00020220, + 0x00020222, 0x01000121, 0x01010021, 0x01010221, 0x01020120, 0x01020221, 0x02000020, 0x02000022, + 0x02000220, 0x02000222, 0x02010021, 0x02010121, 0x02010221, 0x02020020, 0x02020022, 0x02020220, + 0x02020222, 0x00011001, 0x00011100, 0x00011102, 0x00021101, 0x01001001, 0x01001201, 0x01011101, + 0x01011202, 0x01021100, 0x01021101, 0x02011001, 0x02011201, 0x02021101, 0x00001011, 0x00001110, + 0x00001111, 0x00001112, 0x00011111, 0x00011210, 0x00011212, 0x00021211, 0x01001010, 0x01001111, + 0x01001212, 0x01011010, 0x01011011, 0x01011110, 0x01011111, 0x01011112, 0x01011211, 0x01021010, + 0x01021012, 0x01021111, 0x01021210, 0x01021212, 0x02001011, 0x02011011, 0x02011111, 0x02011210, + 0x02011212, 0x02021011, 0x02021110, 0x02021111, 0x02021112, 0x02021211, 0x00011120, 0x00011221, + 0x01001021, 0x01001120, 0x01011020, 0x01011022, 0x01011121, 0x01011220, 0x01021020, 0x01021021, + 0x01021122, 0x01021221, 0x02001121, 0x02011021, 0x02011120, 0x02011221, 0x00002000, 0x00002002, + 0x00002200, 0x00002202, 0x00012101, 0x00022000, 0x00022002, 0x00022200, 0x00022202, 0x01002101, + 0x01012001, 0x01012102, 0x01022101, 0x02002000, 0x02002002, 0x02002200, 0x02002202, 0x02012101, + 0x02022000, 0x02022002, 0x02022200, 0x02022202, 0x00002111, 0x00012011, 0x00012110, 0x00012211, + 0x00022110, 0x00022111, 0x01002011, 0x01012010, 0x01012011, 0x01012111, 0x01022011, 0x01022110, + 0x01022211, 0x02012011, 0x02012110, 0x02012112, 0x02012211, 0x02022111, 0x00002020, 0x00002022, + 0x00002220, 0x00002222, 0x00012121, 0x00022020, 0x00022022, 0x00022220, 0x00022222, 0x01002121, + 0x01012021, 0x01012221, 0x01022021, 0x01022121, 0x02002020, 0x02002022, 0x02002121, 0x02002220, + 0x02002222, 0x02012121, 0x02022020, 0x02022022, 0x02022220, 0x02022222, 0x00110000, 0x00110001, + 0x00110100, 0x00110201, 0x00120100, 0x00120101, 0x01100001, 0x01100100, 0x01110000, 0x01110101, + 0x01110200, 0x01120001, 0x01120100, 0x01120101, 0x01120201, 0x02110001, 0x02110100, 0x02110102, + 0x02120001, 0x02120101, 0x00100011, 0x00100110, 0x00100112, 0x00100211, 0x00110010, 0x00110012, + 0x00110111, 0x00110210, 0x00120011, 0x00120110, 0x00120211, 0x01100111, 0x01100212, 0x01110010, + 0x01110011, 0x01110012, 0x01110110, 0x01110111, 0x01110112, 0x01110211, 0x01120010, 0x01120111, + 0x02100110, 0x02110012, 0x02110111, 0x02120011, 0x02120110, 0x00110021, 0x00110120, 0x00110122, + 0x00120121, 0x01100020, 0x01100122, 0x01100221, 0x01110022, 0x01110121, 0x01110220, 0x01110222, + 0x01120120, 0x01120122, 0x02100121, 0x02110021, 0x02110120, 0x02110122, 0x02120121, 0x00101001, + 0x00101102, 0x00101201, 0x00111100, 0x00111101, 0x00111200, 0x00111201, 0x00121001, 0x00121102, + 0x01101001, 0x01101101, 0x01101102, 0x01101200, 0x01101202, 0x01111001, 0x01111100, 0x01111101, + 0x01111102, 0x01111201, 0x01121002, 0x01121101, 0x01121200, 0x02101100, 0x02101201, 0x02111000, + 0x02111100, 0x02111101, 0x02111200, 0x02111201, 0x02111202, 0x02121001, 0x02121100, 0x02121101, + 0x02121201, 0x00101012, 0x00101111, 0x00101212, 0x00111011, 0x00111110, 0x00111111, 0x00111112, + 0x00111211, 0x00121010, 0x00121012, 0x00121111, 0x00121210, 0x00121212, 0x01101011, 0x01101110, + 0x01101111, 0x01101112, 0x01111011, 0x01111012, 0x01111110, 0x01111111, 0x01111112, 0x01111211, + 0x01111212, 0x01121011, 0x01121110, 0x01121111, 0x01121112, 0x01121211, 0x02101010, 0x02101012, + 0x02101110, 0x02101111, 0x02101210, 0x02101212, 0x02111010, 0x02111011, 0x02111110, 0x02111111, + 0x02111112, 0x02111211, 0x02111212, 0x02121010, 0x02121012, 0x02121111, 0x00101021, 0x00101120, + 0x00101121, 0x00101122, 0x00111121, 0x00111122, 0x00111220, 0x00111222, 0x00121021, 0x00121122, + 0x01101020, 0x01101022, 0x01101120, 0x01101121, 0x01101220, 0x01101222, 0x01111021, 0x01111121, + 0x01111122, 0x01111220, 0x01111221, 0x01121021, 0x01121120, 0x01121121, 0x01121220, 0x01121221, + 0x01121222, 0x02101122, 0x02101222, 0x02111022, 0x02111121, 0x02121120, 0x02121221, 0x00112001, + 0x00112102, 0x00122101, 0x01102001, 0x01102100, 0x01102102, 0x01102201, 0x01112000, 0x01112101, + 0x01112200, 0x01112202, 0x01122000, 0x01122001, 0x01122100, 0x01122102, 0x01122201, 0x02102101, + 0x02112001, 0x02112100, 0x02122101, 0x00112010, 0x00112012, 0x00112111, 0x00112212, 0x00122011, + 0x00122111, 0x01102012, 0x01102110, 0x01102111, 0x01102210, 0x01112011, 0x01112110, 0x01112111, + 0x01112112, 0x01112211, 0x01112212, 0x01122010, 0x01122111, 0x01122212, 0x02102211, 0x02112011, + 0x02112012, 0x02112111, 0x02112210, 0x02122011, 0x02122112, 0x02122211, 0x00102221, 0x00112122, + 0x00122120, 0x00122122, 0x01102120, 0x01102122, 0x01102221, 0x01112020, 0x01112022, 0x01112121, + 0x01112220, 0x01122021, 0x01122122, 0x01122221, 0x02102121, 0x02112021, 0x02112122, 0x02112222, + 0x00200000, 0x00200002, 0x00200200, 0x00200202, 0x00210101, 0x00220000, 0x00220002, 0x00220101, + 0x00220200, 0x00220202, 0x01200101, 0x01210001, 0x01210201, 0x01220001, 0x01220101, 0x02200000, + 0x02200002, 0x02200200, 0x02200202, 0x02210101, 0x02220000, 0x02220002, 0x02220101, 0x02220200, + 0x02220202, 0x00200111, 0x00210011, 0x00210110, 0x00210211, 0x00220111, 0x01200012, 0x01200110, + 0x01200211, 0x01210111, 0x01210210, 0x01210212, 0x01220011, 0x01220110, 0x01220111, 0x01220112, + 0x02200111, 0x02210010, 0x02210112, 0x02210211, 0x02220111, 0x00200021, 0x00200220, 0x00200222, + 0x00210021, 0x00210121, 0x00220020, 0x00220022, 0x00220220, 0x00220222, 0x01200121, 0x01210021, + 0x01210122, 0x01210221, 0x01220121, 0x02200021, 0x02200220, 0x02200222, 0x02210021, 0x02210121, + 0x02220020, 0x02220022, 0x02220220, 0x02220222, 0x00201101, 0x00211100, 0x00211102, 0x00211201, + 0x00221101, 0x01201100, 0x01201101, 0x01201102, 0x01201201, 0x01211002, 0x01211101, 0x01211200, + 0x01211202, 0x01221102, 0x02201101, 0x02211001, 0x02211100, 0x02211201, 0x02221001, 0x02221101, + 0x00201211, 0x00211111, 0x00221011, 0x00221211, 0x01201010, 0x01201111, 0x01201210, 0x01211011, + 0x01211110, 0x01211111, 0x01211211, 0x01221012, 0x01221111, 0x01221210, 0x02201211, 0x02211010, + 0x02211110, 0x02211111, 0x02211210, 0x02211212, 0x02221011, 0x02221110, 0x02221112, 0x02221211, + 0x00201121, 0x00211020, 0x00211022, 0x00211221, 0x00221121, 0x01201021, 0x01201221, 0x01211121, + 0x01221020, 0x01221021, 0x01221221, 0x02201120, 0x02201122, 0x02211020, 0x02211222, 0x00202000, + 0x00202002, 0x00202200, 0x00202202, 0x00212101, 0x00222000, 0x00222002, 0x00222200, 0x00222202, + 0x01202101, 0x01212001, 0x01212100, 0x01222101, 0x02202000, 0x02202002, 0x02202200, 0x02202202, + 0x02222000, 0x02222002, 0x02222200, 0x02222202, 0x00202211, 0x00212011, 0x00212110, 0x00212211, + 0x00222111, 0x01202112, 0x01202211, 0x01212012, 0x01212111, 0x01222011, 0x01222110, 0x01222112, + 0x01222211, 0x02202111, 0x02212010, 0x02212112, 0x02212211, 0x02222110, 0x02222111, 0x00202020, + 0x00202022, 0x00202220, 0x00202222, 0x00222020, 0x00222022, 0x00222220, 0x00222222, 0x01202121, + 0x01212021, 0x01212122, 0x01212221, 0x01222121, 0x02202020, 0x02202022, 0x02202220, 0x02202222, + 0x02212121, 0x02222020, 0x02222022, 0x02222220, 0x02222222, 0x10000101, 0x10010001, 0x10010102, + 0x10020101, 0x11000201, 0x11010002, 0x11010101, 0x11010200, 0x11010202, 0x11020001, 0x11020100, + 0x11020102, 0x12010100, 0x12010201, 0x12020001, 0x12020102, 0x10000010, 0x10000011, 0x10000110, + 0x10000112, 0x10000211, 0x10010012, 0x10010111, 0x10010112, 0x10010210, 0x10010212, 0x10020011, + 0x10020112, 0x10020211, 0x11000111, 0x11000210, 0x11000212, 0x11010011, 0x11010110, 0x11010111, + 0x11010112, 0x11010211, 0x11010212, 0x11020111, 0x11020210, 0x11020212, 0x12000011, 0x12000110, + 0x12000112, 0x12010010, 0x12010012, 0x12010111, 0x12020010, 0x12020011, 0x12020012, 0x10000121, + 0x10010021, 0x10010120, 0x10010122, 0x10020121, 0x11000021, 0x11010022, 0x11010121, 0x11010222, + 0x11020120, 0x11020221, 0x12000221, 0x12010120, 0x12020121, 0x10001001, 0x10011101, 0x10011201, + 0x10021201, 0x11001101, 0x11001200, 0x11001202, 0x11011001, 0x11011100, 0x11011101, 0x11011102, + 0x11021001, 0x11021002, 0x11021101, 0x11021200, 0x11021202, 0x12001001, 0x12001102, 0x12001201, + 0x12011000, 0x12011002, 0x12011101, 0x12021000, 0x12021001, 0x12021201, 0x10001011, 0x10001012, + 0x10001111, 0x10001212, 0x10011011, 0x10011110, 0x10011111, 0x10011112, 0x10011211, 0x10021010, + 0x10021111, 0x10021212, 0x11001011, 0x11001110, 0x11001111, 0x11001112, 0x11001211, 0x11011010, + 0x11011011, 0x11011110, 0x11011111, 0x11011112, 0x11011210, 0x11011211, 0x11021011, 0x11021110, + 0x11021111, 0x11021112, 0x11021211, 0x12001012, 0x12001110, 0x12001111, 0x12001210, 0x12011011, + 0x12011110, 0x12011111, 0x12011112, 0x12011211, 0x12011212, 0x12021111, 0x12021210, 0x12021212, + 0x10001021, 0x10001121, 0x10001221, 0x10011120, 0x10011121, 0x10011220, 0x10011222, 0x10021021, + 0x10021120, 0x10021221, 0x11001020, 0x11001022, 0x11001121, 0x11001220, 0x11011020, 0x11011021, + 0x11011022, 0x11011121, 0x11011122, 0x11011221, 0x11021022, 0x11021121, 0x11021220, 0x12001021, + 0x12001121, 0x12001222, 0x12011120, 0x12011121, 0x12021021, 0x12021120, 0x12021122, 0x10002101, + 0x10012001, 0x10012101, 0x10012202, 0x10022101, 0x11002002, 0x11002201, 0x11012000, 0x11012101, + 0x11012200, 0x11022001, 0x11022100, 0x11022102, 0x11022201, 0x12002101, 0x12012001, 0x12012100, + 0x12012102, 0x12012201, 0x12022101, 0x10002011, 0x10002111, 0x10002112, 0x10002212, 0x10012010, + 0x10012110, 0x10012111, 0x10012210, 0x10022011, 0x10022110, 0x10022112, 0x11002010, 0x11002111, + 0x11002212, 0x11012011, 0x11012012, 0x11012110, 0x11012111, 0x11012112, 0x11012211, 0x11022010, + 0x11022012, 0x11022111, 0x11022112, 0x11022212, 0x12002112, 0x12002211, 0x12012012, 0x12012111, + 0x12012112, 0x12012210, 0x12022011, 0x12022110, 0x12022112, 0x12022211, 0x10012122, 0x11002120, + 0x11002122, 0x11002221, 0x11012121, 0x11012220, 0x11012222, 0x11022120, 0x11022221, 0x12012120, + 0x12022121, 0x10100001, 0x10100100, 0x10100101, 0x10100102, 0x10100201, 0x10110002, 0x10110101, + 0x10110202, 0x10120001, 0x10120100, 0x10120201, 0x11100000, 0x11100101, 0x11100200, 0x11110001, + 0x11110100, 0x11110101, 0x11110102, 0x11110201, 0x11120101, 0x11120200, 0x12100102, 0x12100201, + 0x12110101, 0x12110200, 0x12120000, 0x12120001, 0x12120102, 0x12120201, 0x10100111, 0x10100210, + 0x10100211, 0x10100212, 0x10110011, 0x10110110, 0x10110111, 0x10110112, 0x10110210, 0x10110211, + 0x10120010, 0x10120111, 0x10120112, 0x10120210, 0x10120212, 0x11100011, 0x11100110, 0x11100111, + 0x11100112, 0x11100211, 0x11110010, 0x11110011, 0x11110012, 0x11110110, 0x11110111, 0x11110112, + 0x11110210, 0x11110211, 0x11110212, 0x11120011, 0x11120110, 0x11120111, 0x11120112, 0x11120211, + 0x12100012, 0x12100111, 0x12110011, 0x12110110, 0x12110111, 0x12110112, 0x12110211, 0x12120010, + 0x12120111, 0x12120212, 0x10100021, 0x10100122, 0x10110022, 0x10110121, 0x10110222, 0x10120021, + 0x10120120, 0x11100022, 0x11100121, 0x11100222, 0x11110021, 0x11110120, 0x11110121, 0x11110122, + 0x11110221, 0x11120022, 0x11120121, 0x12100121, 0x12110020, 0x12110022, 0x12110121, 0x12110221, + 0x12110222, 0x12120120, 0x10101100, 0x10101101, 0x10111001, 0x10111100, 0x10111101, 0x10111102, + 0x10111200, 0x10111201, 0x10121001, 0x10121101, 0x10121200, 0x10121202, 0x11101001, 0x11101100, + 0x11101101, 0x11101102, 0x11101201, 0x11101202, 0x11111000, 0x11111001, 0x11111100, 0x11111101, + 0x11111102, 0x11111200, 0x11111201, 0x11111202, 0x11121001, 0x11121002, 0x11121100, 0x11121101, + 0x11121102, 0x11121201, 0x12101000, 0x12101200, 0x12101202, 0x12111001, 0x12111100, 0x12111101, + 0x12111102, 0x12111201, 0x12121001, 0x12121100, 0x12121101, 0x12121202, 0x10101011, 0x10101012, + 0x10101110, 0x10101111, 0x10101112, 0x10101211, 0x10111010, 0x10111011, 0x10111012, 0x10111110, + 0x10111111, 0x10111112, 0x10111211, 0x10111212, 0x10121011, 0x10121110, 0x10121111, 0x10121112, + 0x10121211, 0x11101010, 0x11101011, 0x11101012, 0x11101110, 0x11101111, 0x11101112, 0x11101210, + 0x11101211, 0x11111010, 0x11111011, 0x11111012, 0x11111110, 0x11111111, 0x11111112, 0x11111210, + 0x11111211, 0x11111212, 0x11121010, 0x11121011, 0x11121110, 0x11121111, 0x11121112, 0x11121210, + 0x11121211, 0x11121212, 0x12101011, 0x12101110, 0x12101111, 0x12101211, 0x12101212, 0x12111010, + 0x12111011, 0x12111110, 0x12111111, 0x12111112, 0x12111210, 0x12111211, 0x12121011, 0x12121110, + 0x12121111, 0x12121112, 0x12121211, 0x10101020, 0x10101021, 0x10101022, 0x10101120, 0x10101122, + 0x10101220, 0x10101221, 0x10111021, 0x10111120, 0x10111121, 0x10111220, 0x10111221, 0x10121020, + 0x10121021, 0x10121022, 0x10121120, 0x10121121, 0x10121122, 0x10121220, 0x10121221, 0x11101021, + 0x11101121, 0x11101122, 0x11101220, 0x11101221, 0x11101222, 0x11111020, 0x11111021, 0x11111022, + 0x11111120, 0x11111121, 0x11111122, 0x11111220, 0x11111221, 0x11111222, 0x11121021, 0x11121120, + 0x11121121, 0x11121221, 0x12101022, 0x12101121, 0x12101122, 0x12101220, 0x12101221, 0x12101222, + 0x12111021, 0x12111121, 0x12111222, 0x12121022, 0x12121121, 0x12121122, 0x12121220, 0x12121221, + 0x10102100, 0x10102101, 0x10102102, 0x10102201, 0x10112000, 0x10112101, 0x10112200, 0x10122001, + 0x10122202, 0x11102101, 0x11102200, 0x11102202, 0x11112001, 0x11112100, 0x11112101, 0x11112102, + 0x11112200, 0x11112201, 0x11122000, 0x11122002, 0x11122100, 0x11122101, 0x12102002, 0x12102201, + 0x12112000, 0x12112002, 0x12112101, 0x12112200, 0x12122001, 0x12122201, 0x10102011, 0x10102012, + 0x10102111, 0x10102212, 0x10112011, 0x10112110, 0x10112111, 0x10112112, 0x10112211, 0x10122111, + 0x11102011, 0x11102110, 0x11102111, 0x11102112, 0x11102211, 0x11112010, 0x11112011, 0x11112012, + 0x11112110, 0x11112111, 0x11112112, 0x11112210, 0x11112211, 0x11112212, 0x11122011, 0x11122110, + 0x11122111, 0x11122112, 0x11122211, 0x12102011, 0x12102111, 0x12102211, 0x12112011, 0x12112110, + 0x12112111, 0x12112112, 0x12112210, 0x12112211, 0x12122111, 0x10102120, 0x10102220, 0x10112121, + 0x10112222, 0x10122020, 0x10122121, 0x10122122, 0x10122221, 0x11102121, 0x11102220, 0x11102221, + 0x11112021, 0x11112121, 0x11112122, 0x11112220, 0x11112221, 0x11122022, 0x11122121, 0x11122220, + 0x11122222, 0x12102021, 0x12102222, 0x12112022, 0x12112121, 0x12112122, 0x12112220, 0x12112222, + 0x12122021, 0x10200101, 0x10210100, 0x10210102, 0x10210201, 0x10220101, 0x11200100, 0x11210000, + 0x11210101, 0x11210102, 0x11210200, 0x11210202, 0x11220001, 0x11220100, 0x11220102, 0x11220201, + 0x12200001, 0x12210102, 0x12220101, 0x10200011, 0x10200110, 0x10200112, 0x10200211, 0x10210012, + 0x10210111, 0x10220011, 0x10220012, 0x10220112, 0x10220211, 0x11200111, 0x11200211, 0x11210011, + 0x11210111, 0x11210112, 0x11210211, 0x11220111, 0x11220112, 0x11220212, 0x12200110, 0x12200212, + 0x12210012, 0x12210111, 0x12220011, 0x12220112, 0x12220211, 0x10210021, 0x10210122, 0x10210221, + 0x11200020, 0x11200021, 0x11200122, 0x11210121, 0x11210122, 0x11210220, 0x11220020, 0x12200121, + 0x12210021, 0x12210122, 0x12220121, 0x10211001, 0x10211002, 0x10211101, 0x10211102, 0x10211202, + 0x10221001, 0x10221102, 0x10221201, 0x11201000, 0x11201002, 0x11201101, 0x11201200, 0x11201202, + 0x11211001, 0x11211100, 0x11211101, 0x11211102, 0x11211201, 0x11211202, 0x11221000, 0x11221002, + 0x11221101, 0x12201100, 0x12201101, 0x12201201, 0x12211000, 0x12211002, 0x12211100, 0x12211101, + 0x12211102, 0x12211200, 0x12211202, 0x12221001, 0x12221100, 0x12221201, 0x10201111, 0x10201210, + 0x10201212, 0x10211011, 0x10211111, 0x10211112, 0x10211211, 0x11201110, 0x11201111, 0x11201112, + 0x11201211, 0x11211010, 0x11211011, 0x11211110, 0x11211111, 0x11211112, 0x11211211, 0x11221011, + 0x11221110, 0x11221111, 0x11221112, 0x11221211, 0x12201112, 0x12201211, 0x12201212, 0x12211011, + 0x12211111, 0x12211112, 0x12211211, 0x12211212, 0x12221012, 0x12221111, 0x12221112, 0x12221210, + 0x10201022, 0x10201221, 0x10211121, 0x10221020, 0x10221122, 0x10221220, 0x10221221, 0x11201020, + 0x11201121, 0x11201220, 0x11201222, 0x11211021, 0x11211120, 0x11211121, 0x11211122, 0x11211220, + 0x11211222, 0x11221020, 0x11221121, 0x11221220, 0x12201020, 0x12201022, 0x12201121, 0x12201222, + 0x12211120, 0x12211122, 0x12211220, 0x12211221, 0x12221020, 0x12221120, 0x12221122, 0x12221222, + 0x10212102, 0x10212201, 0x10222101, 0x11202001, 0x11212002, 0x11212101, 0x11212202, 0x11222001, + 0x11222201, 0x12202101, 0x12212001, 0x12212200, 0x12222102, 0x10202011, 0x10202110, 0x10212010, + 0x10212111, 0x10222011, 0x10222110, 0x10222112, 0x10222211, 0x11202010, 0x11202011, 0x11202111, + 0x11202112, 0x11202210, 0x11212011, 0x11212110, 0x11212111, 0x11212112, 0x11212211, 0x11222010, + 0x11222111, 0x11222212, 0x12202012, 0x12202110, 0x12202212, 0x12212111, 0x12222011, 0x12222110, + 0x12222111, 0x12222211, 0x10212021, 0x10212122, 0x10212220, 0x11202021, 0x11202120, 0x11202221, + 0x11212020, 0x11212121, 0x11212220, 0x11212222, 0x11222120, 0x11222121, 0x11222221, 0x12202122, + 0x12212120, 0x12212220, 0x12212222, 0x12222122, 0x20000000, 0x20000002, 0x20000200, 0x20000202, + 0x20020000, 0x20020002, 0x20020200, 0x20020202, 0x21000101, 0x21010000, 0x21010001, 0x21010100, + 0x21010102, 0x21010201, 0x21020101, 0x22000000, 0x22000002, 0x22000200, 0x22000202, 0x22010101, + 0x22020000, 0x22020002, 0x22020200, 0x22020202, 0x20000111, 0x20010011, 0x20010110, 0x20010112, + 0x20010211, 0x20020111, 0x21000011, 0x21000110, 0x21000211, 0x21010010, 0x21010012, 0x21010111, + 0x21010112, 0x21010210, 0x21010211, 0x21020110, 0x21020112, 0x21020211, 0x22000111, 0x22000211, + 0x22010110, 0x22010112, 0x22010211, 0x22020111, 0x20000020, 0x20000022, 0x20000220, 0x20000222, + 0x20010121, 0x20020020, 0x20020022, 0x20020220, 0x20020222, 0x21010021, 0x21010120, 0x21010221, + 0x21020121, 0x22000020, 0x22000022, 0x22000220, 0x22000222, 0x22010121, 0x22020020, 0x22020022, + 0x22020220, 0x22020222, 0x20011100, 0x20011201, 0x21001001, 0x21001100, 0x21011001, 0x21011101, + 0x21011202, 0x21021001, 0x21021100, 0x21021201, 0x22011100, 0x22011201, 0x20001011, 0x20001211, + 0x20011012, 0x20011111, 0x20011212, 0x20021112, 0x20021211, 0x21001010, 0x21001011, 0x21001111, + 0x21001210, 0x21011011, 0x21011110, 0x21011111, 0x21011112, 0x21011211, 0x21011212, 0x21021111, + 0x21021112, 0x21021210, 0x21021212, 0x22001011, 0x22001110, 0x22001112, 0x22001211, 0x22011010, + 0x22011012, 0x22011111, 0x22011210, 0x22021112, 0x20011021, 0x20011122, 0x20011221, 0x20021121, + 0x21001021, 0x21001120, 0x21001221, 0x21001222, 0x21011020, 0x21011121, 0x21011221, 0x21011222, + 0x21021021, 0x21021122, 0x21021222, 0x22001121, 0x22011021, 0x22011222, 0x22021120, 0x20002000, + 0x20002002, 0x20002200, 0x20002202, 0x20012101, 0x20022000, 0x20022002, 0x20022200, 0x20022202, + 0x21002001, 0x21002101, 0x21012001, 0x21012100, 0x21012201, 0x21022101, 0x21022201, 0x22002000, + 0x22002002, 0x22002200, 0x22002202, 0x22012101, 0x22022000, 0x22022002, 0x22022200, 0x22022202, + 0x20002111, 0x20002112, 0x20012011, 0x20012110, 0x20012112, 0x20022111, 0x21002011, 0x21002110, + 0x21002112, 0x21002211, 0x21012010, 0x21012012, 0x21012111, 0x21012212, 0x21022011, 0x21022110, + 0x22002111, 0x22012112, 0x22012211, 0x22022111, 0x20002020, 0x20002022, 0x20002220, 0x20002222, + 0x20012121, 0x20022020, 0x20022022, 0x20022220, 0x20022222, 0x21002121, 0x21012021, 0x21012120, + 0x21012122, 0x22002020, 0x22002022, 0x22002220, 0x22002222, 0x22012121, 0x22022020, 0x22022022, + 0x22022220, 0x22022222, 0x20100101, 0x20110001, 0x20110102, 0x20110200, 0x20110201, 0x20120101, + 0x21100001, 0x21100102, 0x21100201, 0x21110101, 0x21110200, 0x21110202, 0x21120201, 0x21120202, + 0x22100101, 0x22110001, 0x22110100, 0x22110102, 0x22110201, 0x22120101, 0x20100011, 0x20100110, + 0x20100112, 0x20100211, 0x20110010, 0x20110111, 0x20110210, 0x20110212, 0x20120011, 0x20120110, + 0x20120112, 0x20120211, 0x21100010, 0x21100111, 0x21110010, 0x21110011, 0x21110110, 0x21110111, + 0x21110112, 0x21110211, 0x21120012, 0x21120111, 0x22100110, 0x22100112, 0x22110012, 0x22110111, + 0x22110210, 0x22120011, 0x22120110, 0x22120112, 0x22120211, 0x20100121, 0x20110021, 0x20110120, + 0x20110221, 0x20120121, 0x21100120, 0x21100122, 0x21100221, 0x21110020, 0x21110022, 0x21110121, + 0x21110220, 0x21120122, 0x21120221, 0x22100121, 0x22110120, 0x22110122, 0x22120221, 0x20101001, + 0x20101100, 0x20101102, 0x20111000, 0x20111101, 0x20111200, 0x20121102, 0x21101000, 0x21101202, + 0x21111001, 0x21111100, 0x21111101, 0x21111102, 0x21111200, 0x21111201, 0x21121000, 0x21121001, + 0x21121002, 0x21121101, 0x22101100, 0x22101102, 0x22111002, 0x22111100, 0x22111101, 0x22111200, + 0x22121001, 0x22121201, 0x20101010, 0x20101111, 0x20101210, 0x20101212, 0x20111010, 0x20111011, + 0x20111110, 0x20111111, 0x20111112, 0x20111211, 0x20121011, 0x20121111, 0x20121211, 0x20121212, + 0x21101011, 0x21101110, 0x21101111, 0x21101112, 0x21101211, 0x21111010, 0x21111011, 0x21111012, + 0x21111110, 0x21111111, 0x21111112, 0x21111210, 0x21111211, 0x21111212, 0x21121011, 0x21121110, + 0x21121111, 0x21121112, 0x21121211, 0x22101011, 0x22101111, 0x22101210, 0x22111011, 0x22111012, + 0x22111110, 0x22111111, 0x22111112, 0x22111211, 0x22111212, 0x22121010, 0x22121012, 0x22121111, + 0x22121210, 0x22121212, 0x20101021, 0x20101120, 0x20111020, 0x20111121, 0x20111221, 0x20121020, + 0x20121122, 0x20121221, 0x21101121, 0x21101220, 0x21101221, 0x21111021, 0x21111022, 0x21111121, + 0x21111122, 0x21111221, 0x21121121, 0x21121220, 0x22101022, 0x22101120, 0x22101221, 0x22101222, + 0x22111022, 0x22111120, 0x22111121, 0x22121120, 0x22121122, 0x22121221, 0x20102101, 0x20112102, + 0x20112201, 0x20122101, 0x21102001, 0x21102102, 0x21112000, 0x21112002, 0x21112101, 0x21112102, + 0x21112202, 0x21122100, 0x21122101, 0x22102101, 0x22112001, 0x22112102, 0x22112201, 0x22122101, + 0x20102110, 0x20102112, 0x20102211, 0x20112010, 0x20112012, 0x20112111, 0x20112210, 0x20112212, + 0x20122010, 0x20122011, 0x20122110, 0x20122112, 0x21102010, 0x21102012, 0x21102111, 0x21102210, + 0x21102212, 0x21112011, 0x21112110, 0x21112111, 0x21112112, 0x21112211, 0x21122012, 0x21122111, + 0x21122112, 0x21122212, 0x22102011, 0x22102110, 0x22112010, 0x22112012, 0x22112111, 0x22112212, + 0x22122011, 0x22122112, 0x20102121, 0x20112121, 0x20122121, 0x21102120, 0x21102122, 0x21102221, + 0x21112020, 0x21112121, 0x21112220, 0x21122021, 0x22102121, 0x22112021, 0x22112120, 0x22112121, + 0x22112122, 0x20200000, 0x20200002, 0x20200200, 0x20200202, 0x20210101, 0x20220000, 0x20220002, + 0x20220200, 0x20220202, 0x21200101, 0x21210001, 0x21210100, 0x21210102, 0x21210201, 0x22200000, + 0x22200002, 0x22200200, 0x22200202, 0x22210101, 0x22220000, 0x22220002, 0x22220200, 0x22220202, + 0x20200111, 0x20200211, 0x20210011, 0x20210110, 0x20210112, 0x20210211, 0x20210212, 0x21200112, + 0x21200211, 0x21210011, 0x21210111, 0x21210210, 0x21210212, 0x21220011, 0x21220110, 0x22200111, + 0x22210010, 0x22210012, 0x22210112, 0x22210211, 0x20200022, 0x20200220, 0x20200222, 0x20210020, + 0x20210221, 0x20220022, 0x20220220, 0x20220222, 0x21200121, 0x21210021, 0x21210122, 0x21210221, + 0x21220121, 0x22200020, 0x22200022, 0x22200220, 0x22200222, 0x22210121, 0x22220020, 0x22220022, + 0x22220220, 0x22220222, 0x20211201, 0x20221101, 0x21201001, 0x21201100, 0x21211000, 0x21211100, + 0x21211101, 0x21211200, 0x21211202, 0x21221001, 0x21221101, 0x21221102, 0x21221200, 0x21221201, + 0x22201101, 0x20201112, 0x20201211, 0x20211010, 0x20211012, 0x20211111, 0x20211210, 0x20221112, + 0x20221211, 0x21201012, 0x21201111, 0x21211011, 0x21211110, 0x21211111, 0x21211112, 0x21211211, + 0x21221111, 0x21221212, 0x22201011, 0x22201110, 0x22201111, 0x22201112, 0x22201211, 0x22211012, + 0x22211111, 0x22211210, 0x20201121, 0x20211021, 0x20211122, 0x20211222, 0x20221021, 0x20221121, + 0x21201120, 0x21201122, 0x21201222, 0x21211022, 0x21211121, 0x21211122, 0x21211220, 0x21221020, + 0x21221022, 0x22201122, 0x22211020, 0x22211121, 0x22211122, 0x22211221, 0x22221021, 0x22221120, + 0x22221122, 0x20202000, 0x20202002, 0x20202200, 0x20202202, 0x20222000, 0x20222002, 0x20222200, + 0x20222202, 0x21212001, 0x21212100, 0x21212102, 0x21212201, 0x22202000, 0x22202002, 0x22202200, + 0x22202202, 0x22212101, 0x22222000, 0x22222002, 0x22222200, 0x22222202, 0x20202111, 0x20212110, + 0x20212211, 0x20222011, 0x20222111, 0x21202011, 0x21212010, 0x21212111, 0x21212212, 0x21222011, + 0x21222112, 0x21222211, 0x22212010, 0x22212112, 0x20202020, 0x20202022, 0x20202220, 0x20202222, + 0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020, + 0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222, +}; + shared uint16_t iq1s_grid[2048]; +shared uint32_t iq1s_grid_gpu[2048]; #define NEEDS_INIT_IQ_SHMEM void init_iq_shmem(uvec3 wgsize) @@ -573,6 +857,12 @@ void init_iq_shmem(uvec3 wgsize) iq1s_grid[2*idx+1] = g.y; } } + [[unroll]] for (uint i = 0; i < iq1s_grid_gpu_const.length(); i += wgsize.x) { + uint idx = i + gl_LocalInvocationIndex.x; + if (iq1s_grid_gpu_const.length() % wgsize.x == 0 || idx < iq1s_grid_gpu_const.length()) { + iq1s_grid_gpu[idx] = iq1s_grid_gpu_const[idx]; + } + } barrier(); } #endif diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp index 4a83378374..bbdbf9dcaa 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -552,9 +552,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c for (const auto& tname : type_names) { std::string load_vec_quant = "2"; - if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) + if ((tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) load_vec_quant = "8"; - else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4")) + else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4")) load_vec_quant = "4"; if (tname == "bf16") { @@ -685,7 +685,7 @@ void process_shaders() { // mul mat vec with integer dot product #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) - if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname)) { + if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname) || tname == "iq1_s" || tname == "iq1_m") { string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}})); string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); @@ -944,6 +944,8 @@ void process_shaders() { string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}})); string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("cumsum_multipass1_f32", "cumsum_multipass1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("cumsum_multipass2_f32", "cumsum_multipass2.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("count_experts", "count_experts.comp", merge_maps(base_dict, {{"A_TYPE", "uint"}, {"D_TYPE", "uint"}})); @@ -1123,7 +1125,7 @@ void write_output_files() { for (const std::string& btype : btypes) { for (const auto& tname : type_names) { - if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && !is_k_quant(tname)) { + if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && !is_k_quant(tname) && tname != "iq1_s" && tname != "iq1_m") { continue; } hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n"; diff --git a/ggml/src/ggml-webgpu/ggml-webgpu.cpp b/ggml/src/ggml-webgpu/ggml-webgpu.cpp index d0e99b6fe2..c7afdfb8e9 100644 --- a/ggml/src/ggml-webgpu/ggml-webgpu.cpp +++ b/ggml/src/ggml-webgpu/ggml-webgpu.cpp @@ -2273,6 +2273,16 @@ static void ggml_webgpu_init_unary_pipeline(webgpu_context & webgpu_ctx) { ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_inplace_f32, "xielu_inplace_f32", constants); webgpu_ctx->unary_pipelines[GGML_UNARY_OP_XIELU][GGML_TYPE_F16][1] = ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_xielu_inplace_f16, "xielu_inplace_f16", constants); + + // CEIL + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F32][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_f32, "ceil_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F16][0] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_f16, "ceil_f16", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F32][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_inplace_f32, "ceil_inplace_f32", constants); + webgpu_ctx->unary_pipelines[GGML_UNARY_OP_CEIL][GGML_TYPE_F16][1] = + ggml_webgpu_create_pipeline(webgpu_ctx->device, wgsl_ceil_inplace_f16, "ceil_inplace_f16", constants); } static void ggml_webgpu_init_scale_pipeline(webgpu_context & webgpu_ctx) { @@ -2528,6 +2538,7 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_XIELU: + case GGML_UNARY_OP_CEIL: supports_op = supports_op = (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && (src0->type == op->type); break; diff --git a/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl b/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl index d474ab107b..25fe285451 100644 --- a/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl +++ b/ggml/src/ggml-webgpu/wgsl-shaders/unary_op.wgsl @@ -16,7 +16,8 @@ "HARDSWISH_FUNC": "{{MUTATE}}[dst_i] = src[src_i] * min(1.0, max(0.0, (src[src_i] + 3.0) / 6.0));", "GELU_FUNC": "{{MUTATE}}[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(sqrt(2.0 / 3.14159265) * (src[src_i] + 0.044715 * pow(src[src_i], 3.0)), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", "GELU_QUICK_FUNC": "{{MUTATE}}[dst_i] = src[src_i] * 0.5 * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", - "GELU_ERF_FUNC": "{{MUTATE}}[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458" + "GELU_ERF_FUNC": "{{MUTATE}}[dst_i] = 0.5 * src[src_i] * (1.0 + tanh(clamp(0.79788456 * (src[src_i] + 0.044715 * src[src_i] * src[src_i] * src[src_i]), -9.010913, 9.010913))); // Regarding tanh() domain restrictions in wgsl https://github.com/gpuweb/gpuweb/issues/4458", + "CEIL_FUNC": "{{MUTATE}}[dst_i] = ceil(src[src_i]);" } #end(REPL_TEMPLATES) @@ -357,6 +358,27 @@ "SHADER_NAME": "gelu_erf_inplace_f16", "REPLS": { "TYPE": "f16", "FUNC": "GELU_ERF_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, "DECLS": ["INPLACE"] + }, + + { + "SHADER_NAME": "ceil_f32", + "REPLS": { "TYPE": "f32", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "ceil_f16", + "REPLS": { "TYPE": "f16", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "dst" }, + "DECLS": ["NOT_INPLACE"] + }, + { + "SHADER_NAME": "ceil_inplace_f32", + "REPLS": { "TYPE": "f32", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] + }, + { + "SHADER_NAME": "ceil_inplace_f16", + "REPLS": { "TYPE": "f16", "FUNC": "CEIL_FUNC", "EXT_PARAMS": "", "MUTATE": "src" }, + "DECLS": ["INPLACE"] } ] diff --git a/ggml/src/ggml.c b/ggml/src/ggml.c index eb3ae72eaa..09b8eb466d 100644 --- a/ggml/src/ggml.c +++ b/ggml/src/ggml.c @@ -53,13 +53,15 @@ #define UNUSED GGML_UNUSED +// Needed for ggml_fp32_to_bf16_row() +#if defined(__AVX512BF16__) #if defined(_MSC_VER) -#define m512bh(p) p #define m512i(p) p #else -#define m512bh(p) (__m512bh)(p) +#include #define m512i(p) (__m512i)(p) -#endif +#endif // defined(_MSC_VER) +#endif // defined(__AVX512BF16__) #if defined(__linux__) || \ defined(__FreeBSD__) || defined(__NetBSD__) || defined(__OpenBSD__) || \ diff --git a/gguf-py/gguf/constants.py b/gguf-py/gguf/constants.py index c2a0f41c1b..64c227799f 100644 --- a/gguf-py/gguf/constants.py +++ b/gguf-py/gguf/constants.py @@ -104,6 +104,7 @@ class Keys: VOCAB_SIZE = "{arch}.vocab_size" CONTEXT_LENGTH = "{arch}.context_length" EMBEDDING_LENGTH = "{arch}.embedding_length" + EMBEDDING_LENGTH_OUT = "{arch}.embedding_length_out" FEATURES_LENGTH = "{arch}.features_length" BLOCK_COUNT = "{arch}.block_count" LEADING_DENSE_BLOCK_COUNT = "{arch}.leading_dense_block_count" @@ -294,7 +295,9 @@ class Keys: USE_GELU = "clip.use_gelu" USE_SILU = "clip.use_silu" N_WA_PATTERN = "clip.vision.n_wa_pattern" # used by qwen2.5vl + WA_LAYER_INDEXES = "clip.vision.wa_layer_indexes" # used by youtuvl IS_DEEPSTACK_LAYERS = "clip.vision.is_deepstack_layers" + WINDOW_SIZE = "clip.vision.window_size" class Attention: HEAD_COUNT = "clip.vision.attention.head_count" @@ -452,6 +455,7 @@ class MODEL_ARCH(IntEnum): MISTRAL3 = auto() MIMO2 = auto() LLAMA_EMBED = auto() + MAINCODER = auto() class VISION_PROJECTOR_TYPE(IntEnum): @@ -850,6 +854,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = { MODEL_ARCH.MISTRAL3: "mistral3", MODEL_ARCH.MIMO2: "mimo2", MODEL_ARCH.LLAMA_EMBED: "llama-embed", + MODEL_ARCH.MAINCODER: "maincoder", } VISION_PROJECTOR_TYPE_NAMES: dict[VISION_PROJECTOR_TYPE, str] = { @@ -3034,6 +3039,7 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.ATTN_V, MODEL_TENSOR.ATTN_OUT, MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.DENSE_2_OUT, # LFM2-ColBert-350M ], MODEL_ARCH.LFM2MOE: [ MODEL_TENSOR.TOKEN_EMBD, @@ -3257,6 +3263,22 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = { MODEL_TENSOR.FFN_DOWN_EXP, MODEL_TENSOR.FFN_UP_EXP, ], + MODEL_ARCH.MAINCODER: [ + MODEL_TENSOR.TOKEN_EMBD, + MODEL_TENSOR.OUTPUT_NORM, + MODEL_TENSOR.OUTPUT, + MODEL_TENSOR.ATTN_NORM, + MODEL_TENSOR.ATTN_Q, + MODEL_TENSOR.ATTN_Q_NORM, + MODEL_TENSOR.ATTN_K, + MODEL_TENSOR.ATTN_K_NORM, + MODEL_TENSOR.ATTN_V, + MODEL_TENSOR.ATTN_OUT, + MODEL_TENSOR.FFN_NORM, + MODEL_TENSOR.FFN_GATE, + MODEL_TENSOR.FFN_DOWN, + MODEL_TENSOR.FFN_UP, + ], # TODO } @@ -3494,6 +3516,7 @@ class VisionProjectorType: LFM2A = "lfm2a" # audio MUSIC_FLAMINGO = "musicflamingo" # audio GLM4V = "glm4v" + YOUTUVL = "youtuvl" # Items here are (block size, type size) diff --git a/gguf-py/gguf/gguf_writer.py b/gguf-py/gguf/gguf_writer.py index 6a4a504f8d..a7506aa793 100644 --- a/gguf-py/gguf/gguf_writer.py +++ b/gguf-py/gguf/gguf_writer.py @@ -681,6 +681,9 @@ class GGUFWriter: def add_embedding_length(self, length: int) -> None: self.add_uint32(Keys.LLM.EMBEDDING_LENGTH.format(arch=self.arch), length) + def add_embedding_length_out(self, length: int) -> None: + self.add_uint32(Keys.LLM.EMBEDDING_LENGTH_OUT.format(arch=self.arch), length) + def add_features_length(self, length: int) -> None: self.add_uint32(Keys.LLM.FEATURES_LENGTH.format(arch=self.arch), length) @@ -1129,11 +1132,40 @@ class GGUFWriter: self.add_uint32(Keys.ClipVision.Projector.SCALE_FACTOR, value) def add_vision_n_wa_pattern(self, value: int) -> None: + """Add window attention pattern interval for vision models. + + This defines the pattern interval for window attention vs full attention layers. + For example, if n_wa_pattern=4, then layers 3, 7, 11, ... use full attention, + while other layers use window attention. + + Used by models like Qwen2.5-VL where full attention layers follow a regular pattern. + """ self.add_uint32(Keys.ClipVision.N_WA_PATTERN, value) + def add_vision_wa_layer_indexes(self, layers: Sequence[int]) -> None: + """Add explicit layer indexes that use full attention in vision models. + + This specifies the exact layer indices (0-based) that should use full attention + instead of window attention. All other layers will use window attention. + + Args: + layers: List of layer indices that use full attention (e.g., [3, 7, 11, 15]) + + Used by models like YoutuVL where full attention layers are explicitly specified + rather than following a regular pattern. + + Difference from add_vision_n_wa_pattern: + - n_wa_pattern: Defines a regular interval pattern (every Nth layer uses full attention) + - wa_layer_indexes: Explicitly lists which layers use full attention (irregular pattern) + """ + self.add_array(Keys.ClipVision.WA_LAYER_INDEXES, layers) + def add_vision_is_deepstack_layers(self, layers: Sequence[bool]) -> None: self.add_array(Keys.ClipVision.IS_DEEPSTACK_LAYERS, layers) + def add_vision_window_size(self, value: int) -> None: + self.add_uint32(Keys.ClipVision.WINDOW_SIZE, value) + # audio models def add_audio_projection_dim(self, value: int) -> None: diff --git a/gguf-py/gguf/tensor_mapping.py b/gguf-py/gguf/tensor_mapping.py index 115df6c7c3..64dd4ddca5 100644 --- a/gguf-py/gguf/tensor_mapping.py +++ b/gguf-py/gguf/tensor_mapping.py @@ -1221,6 +1221,7 @@ class TensorNameMap: MODEL_TENSOR.V_MMPROJ: ( "multi_modal_projector.linear_{bid}", "visual.merger.mlp.{bid}", # qwen2vl + "merger.mlp.{bid}", ), MODEL_TENSOR.V_MMPROJ_FC: ( @@ -1258,6 +1259,7 @@ class TensorNameMap: "visual.patch_embed.proj", # qwen2vl "vision_tower.patch_embed.proj", # kimi-vl "model.vision.patch_embedding.proj", # cogvlm + "siglip2.vision_model.embeddings.patch_embedding", ), MODEL_TENSOR.V_ENC_EMBD_NORM: ( @@ -1291,6 +1293,7 @@ class TensorNameMap: "vision_encoder.transformer.layers.{bid}.attention.wq", # pixtral "visual.blocks.{bid}.attn.q", # qwen2vl, generated "vision_tower.encoder.blocks.{bid}.wq", # kimi-vl, generated + "siglip2.vision_model.encoder.layers.{bid}.self_attn.q_proj", # youtuvl ), MODEL_TENSOR.V_ENC_ATTN_Q_NORM: ( @@ -1308,6 +1311,7 @@ class TensorNameMap: "vision_encoder.transformer.layers.{bid}.attention.wk", # pixtral "visual.blocks.{bid}.attn.k", # qwen2vl, generated "vision_tower.encoder.blocks.{bid}.wk", # kimi-vl, generated + "siglip2.vision_model.encoder.layers.{bid}.self_attn.k_proj", ), MODEL_TENSOR.V_ENC_ATTN_K_NORM: ( @@ -1325,6 +1329,7 @@ class TensorNameMap: "vision_encoder.transformer.layers.{bid}.attention.wv", # pixtral "visual.blocks.{bid}.attn.v", # qwen2vl, generated "vision_tower.encoder.blocks.{bid}.wv", # kimi-vl, generated + "siglip2.vision_model.encoder.layers.{bid}.self_attn.v_proj", ), MODEL_TENSOR.V_ENC_INPUT_NORM: ( @@ -1339,6 +1344,7 @@ class TensorNameMap: "visual.blocks.{bid}.norm1", # qwen2vl "vision_tower.encoder.blocks.{bid}.norm0", # kimi-vl (norm0/norm1) "model.vision.transformer.layers.{bid}.input_layernorm", # cogvlm + "siglip2.vision_model.encoder.layers.{bid}.layer_norm1", ), MODEL_TENSOR.V_ENC_ATTN_O: ( @@ -1354,6 +1360,7 @@ class TensorNameMap: "visual.blocks.{bid}.attn.proj", # qwen2vl "vision_tower.encoder.blocks.{bid}.wo", # kimi-vl "model.vision.transformer.layers.{bid}.attention.dense", # cogvlm + "siglip2.vision_model.encoder.layers.{bid}.self_attn.out_proj", # youtuvl ), MODEL_TENSOR.V_ENC_POST_ATTN_NORM: ( @@ -1368,6 +1375,7 @@ class TensorNameMap: "visual.blocks.{bid}.norm2", # qwen2vl "vision_tower.encoder.blocks.{bid}.norm1", # kimi-vl (norm0/norm1) "model.vision.transformer.layers.{bid}.post_attention_layernorm", # cogvlm + "siglip2.vision_model.encoder.layers.{bid}.layer_norm2", ), MODEL_TENSOR.V_ENC_FFN_UP: ( @@ -1383,6 +1391,7 @@ class TensorNameMap: "visual.blocks.{bid}.mlp.linear_fc1", # qwen3vl "vision_tower.encoder.blocks.{bid}.mlp.fc0", # kimi-vl (fc0/fc1) "model.vision.transformer.layers.{bid}.mlp.fc1", # cogvlm + "siglip2.vision_model.encoder.layers.{bid}.mlp.fc1", ), MODEL_TENSOR.V_ENC_FFN_GATE: ( @@ -1404,6 +1413,7 @@ class TensorNameMap: "visual.blocks.{bid}.mlp.linear_fc2", # qwen3vl "vision_tower.encoder.blocks.{bid}.mlp.fc1", # kimi-vl (fc0/fc1) "model.vision.transformer.layers.{bid}.mlp.fc2", # cogvlm + "siglip2.vision_model.encoder.layers.{bid}.mlp.fc2", ), MODEL_TENSOR.V_LAYER_SCALE_1: ( @@ -1430,6 +1440,7 @@ class TensorNameMap: "visual.merger.ln_q", # qwen2vl "vision_tower.encoder.final_layernorm", # kimi-vl "visual.post_layernorm", # glm4v + "siglip2.vision_model.post_layernorm", ), MODEL_TENSOR.V_MM_POST_NORM: ( @@ -1446,6 +1457,7 @@ class TensorNameMap: "multi_modal_projector.pre_norm", "pre_mm_projector_norm", "model.vision.linear_proj.norm1", # cogvlm + "merger.ln_q", ), MODEL_TENSOR.V_MM_SOFT_EMB_NORM: ( diff --git a/gguf-py/pyproject.toml b/gguf-py/pyproject.toml index 0f3a1eeee8..f6c4cd14e7 100644 --- a/gguf-py/pyproject.toml +++ b/gguf-py/pyproject.toml @@ -22,6 +22,7 @@ python = ">=3.8" numpy = ">=1.17" tqdm = ">=4.27" pyyaml = ">=5.1" +requests = ">=2.25" sentencepiece = { version = ">=0.1.98,<=0.2.0", optional = true } PySide6 = { version = "^6.9", python = ">=3.9,<3.14", optional = true } diff --git a/include/llama.h b/include/llama.h index 8b3c8a7b10..05cb653254 100644 --- a/include/llama.h +++ b/include/llama.h @@ -316,6 +316,11 @@ extern "C" { bool no_alloc; // only load metadata and simulate memory allocations }; + struct llama_sampler_seq_config { + llama_seq_id seq_id; + struct llama_sampler * sampler; + }; + // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations // https://github.com/ggml-org/llama.cpp/pull/7544 struct llama_context_params { @@ -364,6 +369,12 @@ extern "C" { bool kv_unified; // use a unified buffer across the input sequences when computing the attention // try to disable when n_seq_max > 1 for improved performance when the sequences do not share a large prefix // ref: https://github.com/ggml-org/llama.cpp/pull/14363 + + // [EXPERIMENTAL] + // backend sampler chain configuration (make sure the caller keeps the sampler chains alive) + // note: the samplers must be sampler chains (i.e. use llama_sampler_chain_init) + struct llama_sampler_seq_config * samplers; + size_t n_samplers; }; // model quantization parameters @@ -524,6 +535,7 @@ extern "C" { LLAMA_API int32_t llama_model_n_ctx_train(const struct llama_model * model); LLAMA_API int32_t llama_model_n_embd (const struct llama_model * model); LLAMA_API int32_t llama_model_n_embd_inp (const struct llama_model * model); + LLAMA_API int32_t llama_model_n_embd_out (const struct llama_model * model); LLAMA_API int32_t llama_model_n_layer (const struct llama_model * model); LLAMA_API int32_t llama_model_n_head (const struct llama_model * model); LLAMA_API int32_t llama_model_n_head_kv (const struct llama_model * model); @@ -992,6 +1004,32 @@ extern "C" { // otherwise: float[n_embd] (1-dimensional) LLAMA_API float * llama_get_embeddings_seq(struct llama_context * ctx, llama_seq_id seq_id); + // + // backend sampling API [EXPERIMENTAL] + // note: use only if the llama_context was created with at least one llama_sampler_seq_config + // + + // Get the backend sampled token for the ith token. + // Returns LLAMA_TOKEN_NULL if no token was sampled. + LLAMA_API llama_token llama_get_sampled_token_ith(struct llama_context * ctx, int32_t i); + + // Get the backend sampled probabilites for the ith token + // The index matches llama_get_sampled_token_ith(). + // Returns NULL if no probabilites were generated. + LLAMA_API float * llama_get_sampled_probs_ith (struct llama_context * ctx, int32_t i); + LLAMA_API uint32_t llama_get_sampled_probs_count_ith(struct llama_context * ctx, int32_t i); + + // Get the backend sampled logits for the ith token + // Returns NULL if no logits were sampled. + LLAMA_API float * llama_get_sampled_logits_ith (struct llama_context * ctx, int32_t i); + LLAMA_API uint32_t llama_get_sampled_logits_count_ith(struct llama_context * ctx, int32_t i); + + // Get the backend sampled candidates (token ids) for the ith token + // These are needed to map probability/logit indices to vocab token ids. + // Returns NULL if no candidates were sampled. + LLAMA_API llama_token * llama_get_sampled_candidates_ith (struct llama_context * ctx, int32_t i); + LLAMA_API uint32_t llama_get_sampled_candidates_count_ith(struct llama_context * ctx, int32_t i); + // // Vocab // @@ -1163,11 +1201,16 @@ extern "C" { // // llama_sampler_free(smpl); // - // TODO: In the future, llama_sampler will be utilized to offload the sampling to the backends (e.g. GPU). - // typedef void * llama_sampler_context_t; + struct llama_sampler_data { + struct ggml_tensor * logits; + struct ggml_tensor * probs; + struct ggml_tensor * sampled; + struct ggml_tensor * candidates; + }; + // user code can implement the interface below in order to create custom llama_sampler struct llama_sampler_i { const char * (*name) (const struct llama_sampler * smpl); // can be NULL @@ -1177,17 +1220,45 @@ extern "C" { struct llama_sampler * (*clone) (const struct llama_sampler * smpl); // can be NULL if ctx is NULL void (*free) ( struct llama_sampler * smpl); // can be NULL if ctx is NULL - // TODO: API for internal libllama usage for appending the sampling to an existing ggml_cgraph - //void (*apply_ggml) (struct llama_sampler * smpl, ...); + // [EXPERIMENTAL] + // backend sampling interface: + + // return true if the backend supports all ops needed by the sampler + // note: call once per sampler + bool (*backend_init)(struct llama_sampler * smpl, ggml_backend_buffer_type_t buft); + + // call after .backend_apply() + void (*backend_accept)( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct ggml_tensor * selected_token); + + // call after .backend_init() + void (*backend_apply)( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data); + + // called before graph execution to set inputs for the current ubatch + void (*backend_set_input)(struct llama_sampler * smpl); }; struct llama_sampler { - const struct llama_sampler_i * iface; - llama_sampler_context_t ctx; + struct llama_sampler_i * iface; + + llama_sampler_context_t ctx; }; + // [EXPERIMENTAL] + // attach a sampler to the context + // note: prefer initializing the context with llama_context_params.samplers when possible + // note: changing the samplers of a context can cause graph reallocations and degraded performance + LLAMA_API bool llama_set_sampler(struct llama_context * ctx, llama_seq_id seq_id, struct llama_sampler * smpl); + // mirror of llama_sampler_i: - LLAMA_API struct llama_sampler * llama_sampler_init (const struct llama_sampler_i * iface, llama_sampler_context_t ctx); + LLAMA_API struct llama_sampler * llama_sampler_init ( struct llama_sampler_i * iface, llama_sampler_context_t ctx); LLAMA_API const char * llama_sampler_name (const struct llama_sampler * smpl); LLAMA_API void llama_sampler_accept( struct llama_sampler * smpl, llama_token token); LLAMA_API void llama_sampler_apply ( struct llama_sampler * smpl, llama_token_data_array * cur_p); @@ -1203,7 +1274,15 @@ extern "C" { // important: takes ownership of the sampler object and will free it when llama_sampler_free is called LLAMA_API void llama_sampler_chain_add( struct llama_sampler * chain, struct llama_sampler * smpl); - LLAMA_API struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i); + + // return NULL if: + // - the sampler is NULL + // - the sampler is not a llama_sampler_chain + // - the index is out of bounds, unless i == -1 + // - if i == -1, returns the chain itself (can be used to check if the sampler is a chain) + LLAMA_API struct llama_sampler * llama_sampler_chain_get( struct llama_sampler * chain, int32_t i); + + // the total number of samplers in the chain LLAMA_API int llama_sampler_chain_n (const struct llama_sampler * chain); // after removing a sampler, the chain will no longer own it, and it will not be freed when the chain is freed diff --git a/licenses/LICENSE-linenoise b/licenses/LICENSE-linenoise deleted file mode 100644 index b006b3b24d..0000000000 --- a/licenses/LICENSE-linenoise +++ /dev/null @@ -1,26 +0,0 @@ -Copyright (c) 2010-2014, Salvatore Sanfilippo -Copyright (c) 2010-2013, Pieter Noordhuis -Copyright (c) 2025, Eric Curtin - -All rights reserved. - -Redistribution and use in source and binary forms, with or without -modification, are permitted provided that the following conditions are met: - -* Redistributions of source code must retain the above copyright notice, - this list of conditions and the following disclaimer. - -* Redistributions in binary form must reproduce the above copyright notice, - this list of conditions and the following disclaimer in the documentation - and/or other materials provided with the distribution. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND -ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED -WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE -DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR -ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES -(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; -LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON -ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT -(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS -SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/scripts/pr2wt.sh b/scripts/pr2wt.sh new file mode 100755 index 0000000000..22251339ac --- /dev/null +++ b/scripts/pr2wt.sh @@ -0,0 +1,65 @@ +#!/usr/bin/env bash + +# intialize a new worktree from a PR number: +# +# - creates a new remote using the fork's clone URL +# - creates a local branch tracking the remote branch +# - creates a new worktree in a parent folder, suffixed with "-pr-${PR}" +# +# sample usage: +# ./scripts/pr2wt.sh 12345 +# ./scripts/pr2wt.sh 12345 opencode + +function usage() { + echo "usage: $0 [cmd]" + exit 1 +} + +# check we are in the right directory +if [[ ! -f "scripts/pr2wt.sh" ]]; then + echo "error: this script must be run from the root of the repository" + exit 1 +fi + +if [[ $# -lt 1 || $# -gt 2 ]]; then + usage +fi + +PR=$1 +[[ "$PR" =~ ^[0-9]+$ ]] || { echo "error: PR number must be numeric"; exit 1; } + +url_origin=$(git config --get remote.origin.url) || { + echo "error: no remote named 'origin' in this repository" + exit 1 +} + +org_repo=$(echo $url_origin | cut -d/ -f4-) + +echo "org/repo: $org_repo" + +meta=$(curl -sSf -H "Accept: application/vnd.github+json" "https://api.github.com/repos/${org_repo}/pulls/${PR}") + +url_remote=$(echo "$meta" | jq -r '.head.repo.clone_url') +head_ref=$(echo "$meta" | jq -r '.head.ref') + +echo "url: $url_remote" +echo "head_ref: $head_ref" + +git remote rm pr/${PR} +git remote add pr/${PR} $url_remote +git fetch pr/${PR} $head_ref + +dir=$(basename $(pwd)) + +git branch -D pr/$PR 2> /dev/null +git worktree add -b pr/$PR ../$dir-pr-$PR pr/$PR/${head_ref} 2> /dev/null + +wt_path=$(cd ../$dir-pr-$PR && pwd) + +echo "git worktree created in $wt_path" + +# if a command was provided, execute it +if [[ $# -eq 2 ]]; then + cd ../$dir-pr-$PR + exec $2 +fi diff --git a/scripts/snapdragon/adb/run-bench.sh b/scripts/snapdragon/adb/run-bench.sh index b2e651e749..1a7d8c9fd6 100755 --- a/scripts/snapdragon/adb/run-bench.sh +++ b/scripts/snapdragon/adb/run-bench.sh @@ -16,8 +16,14 @@ model="Llama-3.2-3B-Instruct-Q4_0.gguf" device="HTP0" [ "$D" != "" ] && device="$D" -verbose="" -[ "$V" != "" ] && verbose="$V" +verbose= +[ "$V" != "" ] && verbose="GGML_HEXAGON_VERBOSE=$V" cli_opts="$cli_opts -v" + +experimental= +[ "$E" != "" ] && experimental="GGML_HEXAGON_EXPERIMENTAL=$E" + +profile= +[ "$PROF" != "" ] && profile="GGML_HEXAGON_PROFILE=$PROF GGML_HEXAGON_OPSYNC=1" cli_opts="$cli_opts -v" opmask= [ "$OPMASK" != "" ] && opmask="GGML_HEXAGON_OPMASK=$OPMASK" @@ -34,7 +40,7 @@ adb $adbserial shell " \ cd $basedir; \ LD_LIBRARY_PATH=$basedir/$branch/lib \ ADSP_LIBRARY_PATH=$basedir/$branch/lib \ - $ndev $nhvx $opmask ./$branch/bin/llama-bench --device $device --mmap 0 -m $basedir/../gguf/$model \ + $ndev $nhvx $opmask $verbose $experimental $profile ./$branch/bin/llama-bench --device $device --mmap 0 -m $basedir/../gguf/$model \ --poll 1000 -t 6 --cpu-mask 0xfc --cpu-strict 1 \ - --batch-size 128 -ngl 99 $@ \ + --batch-size 128 -ngl 99 $cli_opts $@ \ " diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 762ea65c71..b0932794d4 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -87,6 +87,7 @@ add_library(llama models/llada.cpp models/llama-iswa.cpp models/llama.cpp + models/maincoder.cpp models/mamba.cpp models/mimo2-iswa.cpp models/minicpm3.cpp diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp index 94a6807eac..2ead965469 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -118,6 +118,7 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_MISTRAL3, "mistral3" }, { LLM_ARCH_MIMO2, "mimo2" }, { LLM_ARCH_LLAMA_EMBED, "llama-embed" }, + { LLM_ARCH_MAINCODER, "maincoder" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -151,6 +152,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_VOCAB_SIZE, "%s.vocab_size" }, { LLM_KV_CONTEXT_LENGTH, "%s.context_length" }, { LLM_KV_EMBEDDING_LENGTH, "%s.embedding_length" }, + { LLM_KV_EMBEDDING_LENGTH_OUT, "%s.embedding_length_out" }, { LLM_KV_FEATURES_LENGTH, "%s.features_length" }, { LLM_KV_BLOCK_COUNT, "%s.block_count" }, { LLM_KV_LEADING_DENSE_BLOCK_COUNT, "%s.leading_dense_block_count" }, @@ -2074,6 +2076,7 @@ static std::set llm_get_tensor_names(llm_arch arch) { LLM_TENSOR_TOKEN_EMBD, LLM_TENSOR_OUTPUT_NORM_LFM2, LLM_TENSOR_OUTPUT, + LLM_TENSOR_DENSE_2_OUT, }; case LLM_ARCH_LFM2MOE: return { @@ -2234,6 +2237,23 @@ static std::set llm_get_tensor_names(llm_arch arch) { return { LLM_TENSOR_TOKEN_EMBD, }; + case LLM_ARCH_MAINCODER: + return { + LLM_TENSOR_TOKEN_EMBD, + LLM_TENSOR_OUTPUT_NORM, + LLM_TENSOR_OUTPUT, + LLM_TENSOR_ATTN_NORM, + LLM_TENSOR_ATTN_Q, + LLM_TENSOR_ATTN_Q_NORM, + LLM_TENSOR_ATTN_K, + LLM_TENSOR_ATTN_K_NORM, + LLM_TENSOR_ATTN_V, + LLM_TENSOR_ATTN_OUT, + LLM_TENSOR_FFN_NORM, + LLM_TENSOR_FFN_GATE, + LLM_TENSOR_FFN_DOWN, + LLM_TENSOR_FFN_UP, + }; default: GGML_ABORT("unknown architecture for tensor mapping"); } diff --git a/src/llama-arch.h b/src/llama-arch.h index 714ead4025..68ec6a18b1 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -122,6 +122,7 @@ enum llm_arch { LLM_ARCH_MISTRAL3, LLM_ARCH_MIMO2, LLM_ARCH_LLAMA_EMBED, + LLM_ARCH_MAINCODER, LLM_ARCH_UNKNOWN, }; @@ -155,6 +156,7 @@ enum llm_kv { LLM_KV_VOCAB_SIZE, LLM_KV_CONTEXT_LENGTH, LLM_KV_EMBEDDING_LENGTH, + LLM_KV_EMBEDDING_LENGTH_OUT, LLM_KV_FEATURES_LENGTH, LLM_KV_BLOCK_COUNT, LLM_KV_LEADING_DENSE_BLOCK_COUNT, diff --git a/src/llama-chat.cpp b/src/llama-chat.cpp index fc6a6223cf..b54ebbd155 100644 --- a/src/llama-chat.cpp +++ b/src/llama-chat.cpp @@ -74,6 +74,7 @@ static const std::map LLM_CHAT_TEMPLATES = { { "seed_oss", LLM_CHAT_TEMPLATE_SEED_OSS }, { "grok-2", LLM_CHAT_TEMPLATE_GROK_2 }, { "pangu-embedded", LLM_CHAT_TEMPLATE_PANGU_EMBED }, + { "solar-open", LLM_CHAT_TEMPLATE_SOLAR_OPEN }, }; llm_chat_template llm_chat_template_from_str(const std::string & name) { @@ -216,6 +217,8 @@ llm_chat_template llm_chat_detect_template(const std::string & tmpl) { return LLM_CHAT_TEMPLATE_GROK_2; } else if (tmpl_contains(LU8("[unused9]系统:[unused10]"))) { return LLM_CHAT_TEMPLATE_PANGU_EMBED; + } else if (tmpl_contains("<|begin|>") && tmpl_contains("<|end|>") && tmpl_contains("<|content|>")) { + return LLM_CHAT_TEMPLATE_SOLAR_OPEN; } return LLM_CHAT_TEMPLATE_UNKNOWN; } @@ -845,6 +848,14 @@ int32_t llm_chat_apply_template( if (add_ass) { ss << "[unused9]助手:"; } + } else if (tmpl == LLM_CHAT_TEMPLATE_SOLAR_OPEN) { + for (auto message : chat) { + std::string role(message->role); + ss << "<|begin|>" << role << "<|content|>" << message->content << "<|end|>"; + } + if (add_ass) { + ss << "<|begin|>assistant"; + } } else { // template not supported return -1; diff --git a/src/llama-chat.h b/src/llama-chat.h index 684efb4d67..e1f795249c 100644 --- a/src/llama-chat.h +++ b/src/llama-chat.h @@ -54,6 +54,7 @@ enum llm_chat_template { LLM_CHAT_TEMPLATE_SEED_OSS, LLM_CHAT_TEMPLATE_GROK_2, LLM_CHAT_TEMPLATE_PANGU_EMBED, + LLM_CHAT_TEMPLATE_SOLAR_OPEN, LLM_CHAT_TEMPLATE_UNKNOWN, }; diff --git a/src/llama-context.cpp b/src/llama-context.cpp index 34dfcd4724..f220010a1b 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -60,6 +60,25 @@ llama_context::llama_context( cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; + // Initialize backend samplers here so they are part of the sampling graph + // before the reserve passes run later in this function. This avoids a later + // re-reserve when graph nodes change. + if (params.samplers != nullptr && params.n_samplers > 0) { + for (size_t i = 0; i < params.n_samplers; ++i) { + const auto & config = params.samplers[i]; + + if (llama_sampler_chain_get(config.sampler, -1) == nullptr) { + throw std::runtime_error("the backend samplers must be of type llama_sampler_chain"); + } + + if (set_sampler(config.seq_id, config.sampler)) { + const int n_samplers = llama_sampler_chain_n(config.sampler); + + LLAMA_LOG_INFO("%s: setting backend sampler for seq_id %d (n = %d)\n", __func__, config.seq_id, n_samplers); + } + } + } + auto rope_scaling_type = params.rope_scaling_type; if (rope_scaling_type == LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED) { rope_scaling_type = hparams.rope_scaling_type_train; @@ -231,7 +250,10 @@ llama_context::llama_context( // graph outputs buffer { // resized during inference when a batch uses more outputs - if (output_reserve(params.n_seq_max) < params.n_seq_max) { + // Create a dummy batch for initialization. + llama_batch dummy_batch = {}; + dummy_batch.n_tokens = 0; + if (output_reserve(params.n_seq_max, dummy_batch) < params.n_seq_max) { throw std::runtime_error("failed to reserve initial output buffer"); } @@ -456,6 +478,16 @@ llama_context::llama_context( LLAMA_LOG_INFO("%s: graph splits = %d (with bs=%d), %d (with bs=1)\n", __func__, n_splits_pp, n_tokens, n_splits_tg); } } + + // Initialize the full vocabulary token ids for backend samplers. + { + const int n_vocab = model.vocab.n_tokens(); + + sampling.token_ids_full_vocab.resize(n_vocab); + for (int i = 0; i < n_vocab; ++i) { + sampling.token_ids_full_vocab[i] = i; + } + } } llama_context::~llama_context() { @@ -616,6 +648,35 @@ float * llama_context::get_logits() { return logits; } +int64_t llama_context::output_resolve_row(int32_t i) const { + int64_t j = -1; + + // support negative indices (last output row) + if (i < 0) { + j = n_outputs + i; + if (j < 0) { + throw std::runtime_error(format("negative index out of range [0, %d)", n_outputs)); + } + } else if ((size_t) i >= output_ids.size()) { + throw std::runtime_error(format("out of range [0, %zu)", output_ids.size())); + } else { + // use output_ids to translate the batch token index into a row number + // that holds this token's data. + j = output_ids[i]; + } + + if (j < 0) { + // the batch token was not configured to output anything + throw std::runtime_error(format("batch.logits[%d] != true", i)); + } + + if (j >= n_outputs) { + throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); + } + + return j; +} + float * llama_context::get_logits_ith(int32_t i) { int64_t j = -1; @@ -626,6 +687,7 @@ float * llama_context::get_logits_ith(int32_t i) { throw std::runtime_error("no logits"); } + // TODO: use output_resolve_row() if (i < 0) { j = n_outputs + i; if (j < 0) { @@ -662,6 +724,10 @@ float * llama_context::get_embeddings() { return embd; } +llama_token * llama_context::get_sampled_tokens() const{ + return sampling.sampled; +} + float * llama_context::get_embeddings_ith(int32_t i) { int64_t j = -1; @@ -672,6 +738,7 @@ float * llama_context::get_embeddings_ith(int32_t i) { throw std::runtime_error("no embeddings"); } + // TODO: use output_resolve_row() if (i < 0) { j = n_outputs + i; if (j < 0) { @@ -691,7 +758,8 @@ float * llama_context::get_embeddings_ith(int32_t i) { throw std::runtime_error(format("corrupt output buffer (j=%" PRId64 ", n_outputs=%d)", j, n_outputs)); } - return embd + j*model.hparams.n_embd; + const uint32_t n_embd_out = model.hparams.get_n_embd_out(); + return embd + j*n_embd_out; } catch (const std::exception & err) { LLAMA_LOG_ERROR("%s: invalid embeddings id %d, reason: %s\n", __func__, i, err.what()); #ifndef NDEBUG @@ -711,6 +779,136 @@ float * llama_context::get_embeddings_seq(llama_seq_id seq_id) { return it->second.data(); } +llama_token llama_context::get_sampled_token_ith(int32_t idx) { + output_reorder(); + + if (sampling.sampled == nullptr) { + return LLAMA_TOKEN_NULL; + } + + try { + const int64_t row = output_resolve_row(idx); + GGML_ASSERT(row < (int64_t) sampling.sampled_size); + return sampling.sampled[row]; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled token id %d, reason: %s\n", __func__, idx, err.what()); + return LLAMA_TOKEN_NULL; + } +} + +float * llama_context::get_sampled_probs_ith(int32_t idx) { + output_reorder(); + + if (sampling.probs == nullptr) { + return nullptr; + } + + try { + const int64_t row = output_resolve_row(idx); + if ((size_t) row >= sampling.probs_count.size() || sampling.probs_count[row] == 0) { + return nullptr; + } + return sampling.probs + row*model.vocab.n_tokens(); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled probs id %d, reason: %s\n", __func__, idx, err.what()); + return nullptr; + } +} + +float * llama_context::get_sampled_logits_ith(int32_t idx) { + output_reorder(); + + if (sampling.logits == nullptr) { + return nullptr; + } + + try { + const int64_t row = output_resolve_row(idx); + if ((size_t) row >= sampling.logits_count.size() || sampling.logits_count[row] == 0) { + return nullptr; + } + return sampling.logits + row*model.vocab.n_tokens(); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled logits id %d, reason: %s\n", __func__, idx, err.what()); + return nullptr; + } +} + +const llama_token * llama_context::get_sampled_candidates_ith(int32_t idx) { + output_reorder(); + + try { + const int64_t row = output_resolve_row(idx); + if (sampling.candidates != nullptr && + (size_t) row < sampling.candidates_count.size() && + sampling.candidates_count[row] > 0) { + return sampling.candidates + row*model.vocab.n_tokens(); + } + } catch (const std::exception & err) { + // fallback to full vocab list + } + + return sampling.token_ids_full_vocab.data(); +} + +size_t llama_context::get_sampled_candidates_count(int32_t idx) { + output_reorder(); + + if (sampling.candidates == nullptr) { + return 0; + } + + try { + const int64_t row = output_resolve_row(idx); + if ((size_t) row >= sampling.candidates_count.size()) { + return 0; + } + return sampling.candidates_count[row]; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled candidates count id %d, reason: %s\n", __func__, idx, err.what()); + return 0; + } +} + +size_t llama_context::get_sampled_logits_count(int32_t idx) { + output_reorder(); + + if (sampling.logits == nullptr) { + return model.vocab.n_tokens(); + } + + try { + const int64_t row = output_resolve_row(idx); + if ((size_t) row >= sampling.logits_count.size()) { + return 0; + } + return sampling.logits_count[row]; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled logits count id %d, reason: %s\n", __func__, idx, err.what()); + return 0; + } +} + +size_t llama_context::get_sampled_probs_count(int32_t idx) { + output_reorder(); + + if (sampling.probs == nullptr) { + return 0; + } + + try { + const int64_t row = output_resolve_row(idx); + if ((size_t) row >= sampling.probs_count.size()) { + return 0; + } + return sampling.probs_count[row]; + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: invalid backend sampled probs count id %d, reason: %s\n", __func__, idx, err.what()); + return 0; + } +} + + void llama_context::attach_threadpool( ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch) { @@ -767,6 +965,42 @@ void llama_context::set_warmup(bool value) { cparams.warmup = value; } +bool llama_context::set_sampler(llama_seq_id seq_id, llama_sampler * sampler) { + LLAMA_LOG_DEBUG("%s: seq_id = %d, sampler = %p\n", __func__, (int) seq_id, (void *) sampler); + + const bool can_offload = + sampler && + sampler->iface->backend_init && + sampler->iface->backend_apply && + llama_sampler_chain_n(sampler) > 0; + + if (sampler && can_offload) { + ggml_backend_buffer_type_t buft = ggml_backend_dev_buffer_type(model.dev_output()); + auto * host_buft = ggml_backend_dev_host_buffer_type(model.dev_output()); + if (host_buft) { + buft = host_buft; + } + + sampler->iface->backend_init(sampler, buft); + + sampling.samplers[seq_id] = sampler; + + return true; + } + + if (sampler && !can_offload) { + LLAMA_LOG_WARN("%s: sampler '%s' for seq_id = %d, cannot be offloaded to the backend\n", __func__, llama_sampler_name(sampler), seq_id); + + sampling.samplers.erase(seq_id); + + return false; + } + + sampling.samplers.erase(seq_id); + + return true; +} + void llama_context::set_adapter_lora( llama_adapter_lora * adapter, float scale) { @@ -907,7 +1141,7 @@ int llama_context::encode(const llama_batch & batch_inp) { n_queued_tokens += n_tokens; // reserve output buffer - if (output_reserve(n_tokens) < n_tokens) { + if (output_reserve(n_tokens, batch_inp) < n_tokens) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %u outputs\n", __func__, n_tokens); return -2; }; @@ -961,9 +1195,10 @@ int llama_context::encode(const llama_batch & batch_inp) { { // extract token embeddings GGML_ASSERT(embd != nullptr); + const uint32_t n_embd_out = hparams.get_n_embd_out(); - GGML_ASSERT(n_tokens*n_embd <= (int64_t) embd_size); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd*sizeof(float)); + GGML_ASSERT(n_tokens*n_embd_out <= (int64_t) embd_size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd, 0, n_tokens*n_embd_out*sizeof(float)); } break; case LLAMA_POOLING_TYPE_MEAN: case LLAMA_POOLING_TYPE_CLS: @@ -1031,6 +1266,112 @@ int llama_context::encode(const llama_batch & batch_inp) { return 0; } +static std::map build_seq_to_output_row(const llama_ubatch & ubatch, uint32_t row_offset) { + std::map seq_to_row; + // how many output tokens we have seen so far for this ubatch. + uint32_t local = 0; + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { + // skip tokens that are not output. + if (!ubatch.output[i]) { + continue; + } + + const llama_seq_id seq_id = ubatch.seq_id[i][0]; + // row_offset is the number of output tokens before this ubatch. + seq_to_row[seq_id] = row_offset + local; + ++local; + } + return seq_to_row; +} + +static void copy_tensor_async_ints( + const std::map & tensor_map, + llama_token * sampled, + size_t sampled_size, + const std::map & seq_to_row, + ggml_backend_sched_t sched) { + if (sampled == nullptr) { + return; + } + + for (const auto & [seq_id, tensor] : tensor_map) { + auto it = seq_to_row.find(seq_id); + if (it == seq_to_row.end()) { + continue; + } + + const uint32_t row = it->second; + GGML_ASSERT(row < sampled_size); + + GGML_ASSERT(ggml_is_contiguous(tensor) && "sampled tokens tensor must be contiguous for async copy"); + + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); + ggml_backend_tensor_get_async(backend, tensor, sampled + row, 0, sizeof(sampled[row])); + } +} + +static void copy_tensor_async_floats( + const std::map & tensor_map, + float * dst, + size_t stride, + std::vector & counts, + const std::map & seq_to_row, + ggml_backend_sched_t sched) { + if (dst == nullptr) { + return; + } + + for (const auto & [seq_id, tensor] : tensor_map) { + auto it = seq_to_row.find(seq_id); + if (it == seq_to_row.end()) { + continue; + } + + const uint32_t row = it->second; + GGML_ASSERT(row < counts.size()); + + GGML_ASSERT(ggml_is_contiguous(tensor) && "logits/probs tensor must be contiguous for async copy"); + + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); + float * row_ptr = dst + (size_t) row * stride; + ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); + + // Update the actual number of logits/probabilities that were written for this row. + counts[row] = ggml_nelements(tensor); + } +} + +static void copy_tensor_async_candidates( + const std::map & tensor_map, + llama_token * dst, + size_t stride, + std::vector & counts, + const std::map & seq_to_row, + ggml_backend_sched_t sched) { + if (dst == nullptr) { + return; + } + + for (const auto & [seq_id, tensor] : tensor_map) { + auto it = seq_to_row.find(seq_id); + if (it == seq_to_row.end()) { + continue; + } + + const uint32_t row = it->second; + GGML_ASSERT(row < counts.size()); + + GGML_ASSERT(ggml_is_contiguous(tensor) && "candidates tensor must be contiguous for async copy"); + + ggml_backend_t backend = ggml_backend_sched_get_tensor_backend(sched, tensor); + llama_token * row_ptr = dst + (size_t) row * stride; + ggml_backend_tensor_get_async(backend, tensor, row_ptr, 0, ggml_nbytes(tensor)); + + // Update the actual number of candidates that were written. + counts[row] = ggml_nelements(tensor); + } +} + int llama_context::decode(const llama_batch & batch_inp) { GGML_ASSERT((!batch_inp.token && batch_inp.embd) || (batch_inp.token && !batch_inp.embd)); // NOLINT @@ -1051,9 +1392,36 @@ int llama_context::decode(const llama_batch & batch_inp) { const int64_t n_embd = hparams.n_embd_inp(); // when computing embeddings, all tokens are output - const bool output_all = cparams.embeddings; + const bool output_all = cparams.embeddings; + const bool has_samplers = !sampling.samplers.empty(); - if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max, output_all)) { + const uint32_t n_seq_max = cparams.kv_unified ? LLAMA_MAX_SEQ : cparams.n_seq_max; + + // TODO: avoid this workaround in the future + if (has_samplers && batch_inp.logits) { + std::vector seq_output_count(n_seq_max, 0); + + for (int32_t i = 0; i < batch_inp.n_tokens; ++i) { + if (batch_inp.logits[i] == 0) { + continue; + } + + const int ns = batch_inp.n_seq_id ? batch_inp.n_seq_id[i] : 1; + + for (int32_t s = 0; s < ns; ++s) { + const llama_seq_id seq_id = batch_inp.seq_id ? batch_inp.seq_id[i][s] : 0; + + seq_output_count[seq_id]++; + if (seq_output_count[seq_id] > 1) { + LLAMA_LOG_ERROR("%s: backend sampling requires at most one output token per sequence (seq_id %d had %d)\n", + __func__, seq_id, seq_output_count[seq_id]); + return -1; + } + } + } + } + + if (!balloc->init(batch_inp, vocab, memory.get(), n_embd, n_seq_max, output_all)) { LLAMA_LOG_ERROR("%s: failed to initialize batch\n", __func__); return -1; } @@ -1134,7 +1502,7 @@ int llama_context::decode(const llama_batch & batch_inp) { } // reserve output buffer - if (output_reserve(n_outputs_all) < n_outputs_all) { + if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); return -2; }; @@ -1207,7 +1575,10 @@ int llama_context::decode(const llama_batch & batch_inp) { } // extract logits - if (t_logits && n_outputs > 0) { + // For multi-sequence batches that mix backend samplers and CPU sampler + // this is currently inefficient as we copy all logits even for the + // backend sampled tokens. + if (logits && t_logits && n_outputs > 0) { ggml_backend_t backend_res = ggml_backend_sched_get_tensor_backend(sched.get(), t_logits); GGML_ASSERT(backend_res != nullptr); GGML_ASSERT(logits != nullptr); @@ -1222,7 +1593,7 @@ int llama_context::decode(const llama_batch & batch_inp) { } // extract embeddings - if (t_embd && n_outputs > 0) { + if (embd && t_embd && n_outputs > 0) { ggml_backend_t backend_embd = ggml_backend_sched_get_tensor_backend(sched.get(), t_embd); GGML_ASSERT(backend_embd != nullptr); @@ -1231,12 +1602,13 @@ int llama_context::decode(const llama_batch & batch_inp) { { // extract token embeddings GGML_ASSERT(embd != nullptr); - float * embd_out = embd + n_outputs_prev*n_embd; + const uint32_t n_embd_out = hparams.get_n_embd_out(); + float * embd_out = embd + n_outputs_prev*n_embd_out; if (n_outputs) { GGML_ASSERT( n_outputs_prev + n_outputs <= n_outputs_all); - GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd <= (int64_t) embd_size); - ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd*sizeof(float)); + GGML_ASSERT((n_outputs_prev + n_outputs)*n_embd_out <= (int64_t) embd_size); + ggml_backend_tensor_get_async(backend_embd, t_embd, embd_out, 0, n_outputs*n_embd_out*sizeof(float)); } } break; case LLAMA_POOLING_TYPE_MEAN: @@ -1276,6 +1648,22 @@ int llama_context::decode(const llama_batch & batch_inp) { } } + // This flag indicates whether a backend sampler has actually sampled a specific + // token, or if it has produced probabilites. If true, we can skip the normal copying of logits and embeddings. + const bool has_sampled = !res->t_sampled.empty() || !res->t_sampled_probs.empty() || !res->t_sampled_logits.empty(); + + if (has_samplers && has_sampled) { + const auto seq_to_output_row = build_seq_to_output_row(ubatch, n_outputs_prev); + const auto stride = n_vocab; + + // async copy the sampling data from the backend to the host + copy_tensor_async_ints(res->t_sampled, sampling.sampled, sampling.sampled_size, seq_to_output_row, sched.get()); + + copy_tensor_async_floats (res->t_sampled_logits, sampling.logits, stride, sampling.logits_count, seq_to_output_row, sched.get()); + copy_tensor_async_floats (res->t_sampled_probs, sampling.probs, stride, sampling.probs_count, seq_to_output_row, sched.get()); + copy_tensor_async_candidates(res->t_candidates, sampling.candidates, stride, sampling.candidates_count, seq_to_output_row, sched.get()); + } + n_outputs_prev += n_outputs; } while (mctx->next()); @@ -1339,15 +1727,15 @@ int llama_context::decode(const llama_batch & batch_inp) { // output // -uint32_t llama_context::output_reserve(int32_t n_outputs) { +uint32_t llama_context::output_reserve(int32_t n_outputs, const llama_batch & batch) { const auto & hparams = model.hparams; const auto & vocab = model.vocab; const int64_t n_outputs_max = std::max(n_outputs, n_seq_max()); - const auto n_batch = cparams.n_batch; - const auto n_vocab = vocab.n_tokens(); - const auto n_embd = hparams.n_embd; + const auto n_batch = cparams.n_batch; + const auto n_vocab = vocab.n_tokens(); + const auto n_embd_out = hparams.get_n_embd_out(); bool has_logits = true; bool has_embd = cparams.embeddings; @@ -1358,8 +1746,53 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { has_embd = true; } - logits_size = has_logits ? n_vocab*n_outputs_max : 0; - embd_size = has_embd ? n_embd*n_outputs_max : 0; + // Check which sampling modes are needed for the current batch. + // TODO: avoid this branching by working with the worst-case + bool has_sampling = false; + bool cpu_logits = false; + + if (batch.logits) { + for (int32_t i = 0; i < batch.n_tokens; i++) { + if (!batch.logits[i]) { + continue; + } + for (int32_t j = 0; j < batch.n_seq_id[i]; j++) { + llama_seq_id seq_id = batch.seq_id[i][j]; + if (sampling.samplers.find(seq_id) != sampling.samplers.end()) { + has_sampling = true; + } else { + cpu_logits = true; + } + } + } + } else { + // When batch.logits is nullptr (when loading state with a dummy batch), + // allocate CPU logits. + cpu_logits = true; + } + + size_t backend_float_count = 0; + size_t backend_token_count = 0; + + // Allocate CPU logits buffer only if needed by sequences in this batch + logits_size = (has_logits && cpu_logits) ? n_vocab*n_outputs_max : 0; + embd_size = has_embd ? n_embd_out*n_outputs_max : 0; + + // TODO: avoid this branching by working with the worst-case + if (!has_sampling) { + sampling.logits_size = 0; + sampling.probs_size = 0; + sampling.sampled_size = 0; + sampling.candidates_size = 0; + } else { + sampling.logits_size = n_vocab*n_outputs_max; + sampling.probs_size = n_vocab*n_outputs_max; + sampling.sampled_size = n_outputs_max; + sampling.candidates_size = n_vocab*n_outputs_max; + + backend_float_count = sampling.logits_size + sampling.probs_size; + backend_token_count = sampling.sampled_size + sampling.candidates_size; + } if (output_ids.empty()) { // init, never resized afterwards @@ -1367,7 +1800,9 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { } const size_t prev_size = buf_output ? ggml_backend_buffer_get_size(buf_output.get()) : 0; - const size_t new_size = (logits_size + embd_size) * sizeof(float); + const size_t new_size = + (logits_size + embd_size + backend_float_count) * sizeof(float) + + ( backend_token_count) * sizeof(llama_token); // alloc only when more than the current capacity is required // TODO: also consider shrinking the buffer @@ -1375,9 +1810,11 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { if (buf_output) { #ifndef NDEBUG // This doesn't happen often, but may be annoying in some cases (like the HellaSwag benchmark) - LLAMA_LOG_INFO("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); + LLAMA_LOG_DEBUG("%s: reallocating output buffer from size %.02f MiB to %.02f MiB\n", __func__, prev_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); #endif synchronize(); + + // TODO: not needed? buf_output = nullptr; logits = nullptr; embd = nullptr; @@ -1399,8 +1836,49 @@ uint32_t llama_context::output_reserve(int32_t n_outputs) { float * output_base = (float *) ggml_backend_buffer_get_base(buf_output.get()); - logits = has_logits ? output_base : nullptr; - embd = has_embd ? output_base + logits_size : nullptr; + logits = nullptr; + embd = nullptr; + + size_t offset = 0; + uint8_t * base = (uint8_t *) output_base; + + logits = (has_logits && cpu_logits) ? output_base : nullptr; + offset += logits_size * sizeof(float); + + embd = has_embd ? (float *) (base + offset) : nullptr; + offset += embd_size * sizeof(float); + + sampling.logits = nullptr; + sampling.probs = nullptr; + sampling.sampled = nullptr; + sampling.candidates = nullptr; + + if (has_sampling) { + sampling.logits = (float *) (base + offset); + offset += sampling.logits_size * sizeof(float); + + sampling.probs = (float *) (base + offset); + offset += sampling.probs_size * sizeof(float); + + sampling.sampled = (llama_token *) (base + offset); + offset += sampling.sampled_size * sizeof(llama_token); + + sampling.candidates = (llama_token *) (base + offset); + offset += sampling.candidates_size * sizeof(llama_token); + + // The count vectors keep track of the actual number of logits/probs/candidates + // copied from the backend for each output row. + + sampling.logits_count.resize(n_outputs_max); + sampling.probs_count.resize(n_outputs_max); + sampling.candidates_count.resize(n_outputs_max); + + std::fill(sampling.logits_count.begin(), sampling.logits_count.end(), 0); + std::fill(sampling.probs_count.begin(), sampling.probs_count.end(), 0); + std::fill(sampling.candidates_count.begin(), sampling.candidates_count.end(), 0); + + std::fill_n(sampling.sampled, sampling.sampled_size, LLAMA_TOKEN_NULL); + } // set all ids as invalid (negative) std::fill(output_ids.begin(), output_ids.end(), -1); @@ -1429,6 +1907,40 @@ void llama_context::output_reorder() { std::swap(embd[i0*n_embd + k], embd[i1*n_embd + k]); } } + + if (sampling.logits && sampling.logits_size > 0) { + for (uint64_t k = 0; k < n_vocab; ++k) { + std::swap(sampling.logits[i0*n_vocab + k], sampling.logits[i1*n_vocab + k]); + } + } + + if (sampling.probs && sampling.probs_size > 0) { + for (uint64_t k = 0; k < n_vocab; ++k) { + std::swap(sampling.probs[i0*n_vocab + k], sampling.probs[i1*n_vocab + k]); + } + } + + if (sampling.candidates && sampling.candidates_size > 0) { + for (uint64_t k = 0; k < n_vocab; ++k) { + std::swap(sampling.candidates[i0*n_vocab + k], sampling.candidates[i1*n_vocab + k]); + } + } + + if (sampling.sampled && sampling.sampled_size > 0) { + std::swap(sampling.sampled[i0], sampling.sampled[i1]); + } + + if (!sampling.logits_count.empty()) { + std::swap(sampling.logits_count[i0], sampling.logits_count[i1]); + } + + if (!sampling.probs_count.empty()) { + std::swap(sampling.probs_count[i0], sampling.probs_count[i1]); + } + + if (!sampling.candidates_count.empty()) { + std::swap(sampling.candidates_count[i0], sampling.candidates_count[i1]); + } } output_swaps.clear(); @@ -1458,7 +1970,7 @@ ggml_cgraph * llama_context::graph_reserve( if (n_tokens % n_seqs != 0) { n_tokens = ((n_tokens + (n_seqs - 1)) / n_seqs) * n_seqs; // round to next multiple of n_seqs - n_outputs = std::min(n_outputs, n_tokens); + n_outputs = std::max(n_outputs, n_tokens); LLAMA_LOG_DEBUG("%s: making n_tokens a multiple of n_seqs - n_tokens = %u, n_seqs = %u, n_outputs = %u\n", __func__, n_tokens, n_seqs, n_outputs); } @@ -1477,6 +1989,15 @@ ggml_cgraph * llama_context::graph_reserve( llama_batch_allocr balloc(model.hparams.n_pos_per_embd()); llama_ubatch ubatch = balloc.ubatch_reserve(n_tokens/n_seqs, n_seqs); + // set one output token per sequence in order to activate all backend samplers + std::vector seq_ids(n_seqs); + for (uint32_t i = 0; i < n_seqs; ++i) { + seq_ids[i] = i; + ubatch.n_seq_id[i] = 1; + ubatch.seq_id[i] = &seq_ids[i]; + ubatch.output[i] = true; + } + auto * res = gf_res_reserve.get(); const auto gparams = graph_params(res, ubatch, mctx, LLM_GRAPH_TYPE_DEFAULT); @@ -1507,7 +2028,7 @@ llm_graph_params llama_context::graph_params( llm_graph_result * res, const llama_ubatch & ubatch, const llama_memory_context_i * mctx, - llm_graph_type gtype) const { + llm_graph_type gtype) const { return { /*.arch =*/ model.arch, /*.hparams =*/ model.hparams, @@ -1520,6 +2041,7 @@ llm_graph_params llama_context::graph_params( /*.loras =*/ &loras, /*.mctx =*/ mctx, /*.cross =*/ &cross, + /*.samplers =*/ sampling.samplers, /*.n_outputs =*/ n_outputs, /*.cb =*/ graph_get_cb(), /*.res =*/ res, @@ -1975,6 +2497,9 @@ size_t llama_context::state_write_data(llama_io_write_i & io) { } } + // TODO: handle sampling buffers and samplers state ? + // https://github.com/ggml-org/llama.cpp/pull/17004 + if (memory != nullptr) { LLAMA_LOG_DEBUG("%s: - writing memory module\n", __func__); memory->state_write(io); @@ -2007,7 +2532,10 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { auto n_outputs = this->n_outputs; io.read_to(&n_outputs, sizeof(n_outputs)); - if (n_outputs > output_reserve(n_outputs)) { + // Create a dummy batch for state loading. + llama_batch dummy_batch = {}; + dummy_batch.n_tokens = 0; + if (n_outputs > output_reserve(n_outputs, dummy_batch)) { throw std::runtime_error("could not reserve outputs"); } @@ -2061,6 +2589,9 @@ size_t llama_context::state_read_data(llama_io_read_i & io) { } } + // TODO: handle sampling buffers and samplers state ? + // https://github.com/ggml-org/llama.cpp/pull/17004 + if (memory) { LLAMA_LOG_DEBUG("%s: - reading memory module\n", __func__); @@ -2249,7 +2780,7 @@ void llama_context::opt_epoch_iter( } // reserve output buffer - if (output_reserve(n_outputs_all) < n_outputs_all) { + if (output_reserve(n_outputs_all, balloc->get_batch()) < n_outputs_all) { LLAMA_LOG_ERROR("%s: could not reserve space for batch with %d outputs\n", __func__, n_outputs_all); GGML_ABORT("TODO: handle this error"); }; @@ -2394,6 +2925,8 @@ llama_context_params llama_context_default_params() { /*.op_offload =*/ true, /*.swa_full =*/ true, /*.kv_unified =*/ false, + /*.sampler =*/ nullptr, + /*.n_sampler =*/ 0, }; return result; @@ -2553,7 +3086,15 @@ float * llama_get_logits(llama_context * ctx) { float * llama_get_logits_ith(llama_context * ctx, int32_t i) { ctx->synchronize(); - return ctx->get_logits_ith(i); + float * res = nullptr; + + res = ctx->get_sampled_logits_ith(i); + + if (!res) { + res = ctx->get_logits_ith(i); + } + + return res; } float * llama_get_embeddings(llama_context * ctx) { @@ -2574,6 +3115,52 @@ float * llama_get_embeddings_seq(llama_context * ctx, llama_seq_id seq_id) { return ctx->get_embeddings_seq(seq_id); } +bool llama_set_sampler(llama_context * ctx, llama_seq_id seq_id, llama_sampler * smpl) { + return ctx->set_sampler(seq_id, smpl); +} + +llama_token llama_get_sampled_token_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return ctx->get_sampled_token_ith(i); +} + +float * llama_get_sampled_probs_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return ctx->get_sampled_probs_ith(i); +} + +float * llama_get_sampled_logits_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return ctx->get_sampled_logits_ith(i); +} + +llama_token * llama_get_sampled_candidates_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return const_cast(ctx->get_sampled_candidates_ith(i)); +} + +uint32_t llama_get_sampled_candidates_count_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return static_cast(ctx->get_sampled_candidates_count(i)); +} + +uint32_t llama_get_sampled_logits_count_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return static_cast(ctx->get_sampled_logits_count(i)); +} + +uint32_t llama_get_sampled_probs_count_ith(llama_context * ctx, int32_t i) { + ctx->synchronize(); + + return static_cast(ctx->get_sampled_probs_count(i)); +} + // llama adapter API int32_t llama_set_adapter_lora( diff --git a/src/llama-context.h b/src/llama-context.h index c31101330e..b29edf4db2 100644 --- a/src/llama-context.h +++ b/src/llama-context.h @@ -70,6 +70,18 @@ struct llama_context { float * get_embeddings_ith(int32_t i); float * get_embeddings_seq(llama_seq_id seq_id); + llama_token * get_sampled_tokens() const; + llama_token get_sampled_token_ith(int32_t idx); + + float * get_sampled_logits_ith(int32_t idx); + size_t get_sampled_logits_count(int32_t idx); + + float * get_sampled_probs_ith(int32_t idx); + size_t get_sampled_probs_count(int32_t idx); + + const llama_token * get_sampled_candidates_ith(int32_t idx); + size_t get_sampled_candidates_count(int32_t idx); + void attach_threadpool( ggml_threadpool_t threadpool, ggml_threadpool_t threadpool_batch); @@ -192,10 +204,13 @@ private: // Make sure enough space is available for outputs. // Returns max number of outputs for which space was reserved. - uint32_t output_reserve(int32_t n_outputs); + uint32_t output_reserve(int32_t n_outputs, const llama_batch & batch); void output_reorder(); + // map the output row index `i` to batch index + int64_t output_resolve_row(int32_t i) const; + // // graph // @@ -213,6 +228,8 @@ public: ggml_cgraph * graph_reserve( uint32_t n_tokens, uint32_t n_seqs, uint32_t n_outputs, const llama_memory_context_i * mctx, bool split_only = false, size_t * sizes = nullptr); + bool set_sampler(llama_seq_id seq_id, llama_sampler * sampler); + private: llm_graph_params graph_params( llm_graph_result * res, @@ -252,6 +269,31 @@ private: size_t embd_size = 0; // capacity (of floats) for embeddings float * embd = nullptr; + // TODO: simplify + struct sampling_info { + std::map samplers; + + float * logits = nullptr; + size_t logits_size = 0; + + llama_token * sampled = nullptr; + size_t sampled_size = 0; + + float * probs = nullptr; + size_t probs_size = 0; + + llama_token * candidates = nullptr; + size_t candidates_size = 0; + + std::vector logits_count; + std::vector probs_count; + std::vector candidates_count; + + std::vector token_ids_full_vocab; + }; + + sampling_info sampling; + // sequence embeddings output (map of [n_embd] vectors) // populated only when pooling_type != LLAMA_POOLING_TYPE_NONE std::map> embd_seq; diff --git a/src/llama-grammar.cpp b/src/llama-grammar.cpp index 75d5d750c3..64ea2fd00a 100644 --- a/src/llama-grammar.cpp +++ b/src/llama-grammar.cpp @@ -369,6 +369,44 @@ static void print_rule( fprintf(file, "\n"); } +// +// Regex utilities +// + +size_t llama_grammar_trigger_pattern::find(const std::string & input) const { + auto find_start_pos = [](const std::smatch & match) { + // get from the first matched capturing group to the end of the string + size_t start = std::string::npos; + for (auto i = 1u; i < match.size(); i++) { + if (match.length(i) > 0) { + start = match.position(i); + break; + } + } + if (start == std::string::npos) { + start = match.position(0); + } + return start; + }; + + if (!pattern.empty() && pattern.front() == '^' && pattern.back() == '$') { + // match against the entire input + std::smatch match; + if (std::regex_match(input, match, regex)) { + return find_start_pos(match); + } + } + + // search anywhere + std::smatch match; + if (std::regex_search(input, match, regex)) { + return find_start_pos(match); + } + + return std::string::npos; +} + + // // implementation // @@ -1312,21 +1350,10 @@ void llama_grammar_accept_impl(struct llama_grammar & grammar, llama_token token grammar.trigger_buffer_positions.push_back(std::make_pair(token, position)); grammar.trigger_buffer += piece; - std::smatch match; for (const auto & trigger_pattern : grammar.trigger_patterns) { - if (std::regex_match(grammar.trigger_buffer, match, trigger_pattern.regex)) { + auto start = trigger_pattern.find(grammar.trigger_buffer); + if (start != std::string::npos) { grammar.awaiting_trigger = false; - // get from the first matched capturing group to the end of the string - size_t start = std::string::npos; - for (auto i = 1u; i < match.size(); i++) { - if (match.length(i) > 0) { - start = match.position(i); - break; - } - } - if (start == std::string::npos) { - start = match.position(0); - } // replay tokens that overlap with [start, end) for (const auto & [tok, tok_pos] : grammar.trigger_buffer_positions) { diff --git a/src/llama-grammar.h b/src/llama-grammar.h index a4c978ac11..b5a0e588e9 100644 --- a/src/llama-grammar.h +++ b/src/llama-grammar.h @@ -119,6 +119,8 @@ struct llama_grammar_parser { struct llama_grammar_trigger_pattern { std::string pattern; std::regex regex; + + size_t find(const std::string & input) const; }; struct llama_grammar { diff --git a/src/llama-graph.cpp b/src/llama-graph.cpp index 1d0d7197e1..374ff1ebf3 100644 --- a/src/llama-graph.cpp +++ b/src/llama-graph.cpp @@ -12,6 +12,7 @@ #include #include #include +#include void llm_graph_input_embd::set_input(const llama_ubatch * ubatch) { if (ubatch->token) { @@ -32,7 +33,7 @@ bool llm_graph_input_embd::can_reuse(const llm_graph_params & params) { bool res = true; res &= (!tokens && !params.ubatch.token) || (tokens && tokens->ne[0] == params.ubatch.n_tokens); - res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[0] == params.ubatch.n_tokens); + res &= (!embd && !params.ubatch.embd) || (embd && embd->ne[1] == params.ubatch.n_tokens); return res; } @@ -62,7 +63,7 @@ void llm_graph_input_pos::set_input(const llama_ubatch * ubatch) { bool llm_graph_input_pos::can_reuse(const llm_graph_params & params) { bool res = true; - res &= pos->ne[0] == params.ubatch.n_tokens; + res &= pos->ne[0] == params.ubatch.n_tokens*n_pos_per_embd; return res; } @@ -521,6 +522,43 @@ bool llm_graph_input_mem_hybrid::can_reuse(const llm_graph_params & params) { return res; } +void llm_graph_input_sampling::set_input(const llama_ubatch * ubatch) { + // set the inputs only for the active samplers in the current ubatch + std::unordered_set active_samplers; + for (uint32_t i = 0; i < ubatch->n_tokens; i++) { + if (ubatch->output[i]) { + llama_seq_id seq_id = ubatch->seq_id[i][0]; + active_samplers.insert(seq_id); + } + } + + for (auto seq_id : active_samplers) { + if (samplers.find(seq_id) == samplers.end()) { + continue; + } + + auto & sampler = samplers[seq_id]; + + if (sampler->iface->backend_set_input) { + sampler->iface->backend_set_input(sampler); + } + } +} + +bool llm_graph_input_sampling::can_reuse(const llm_graph_params & params) { + if (samplers.size() != params.samplers.size()) { + return false; + } + + for (const auto & [seq_id, sampler] : params.samplers) { + if (samplers[seq_id] != sampler) { + return false; + } + } + + return true; +} + // // llm_graph_result // @@ -541,6 +579,10 @@ void llm_graph_result::reset() { t_logits = nullptr; t_embd = nullptr; t_embd_pooled = nullptr; + t_sampled.clear(); + t_sampled_probs.clear(); + t_sampled_logits.clear(); + t_candidates.clear(); params = {}; @@ -565,6 +607,38 @@ void llm_graph_result::set_inputs(const llama_ubatch * ubatch) { } } +void llm_graph_result::set_outputs() { + if (t_logits != nullptr) { + ggml_set_output(t_logits); + } + if (t_embd != nullptr) { + ggml_set_output(t_embd); + } + if (t_embd_pooled != nullptr) { + ggml_set_output(t_embd_pooled); + } + for (auto & [seq_id, t] : t_sampled) { + if (t != nullptr) { + ggml_set_output(t); + } + } + for (auto & [seq_id, t] : t_sampled_probs) { + if (t != nullptr) { + ggml_set_output(t); + } + } + for (auto & [seq_id, t] : t_sampled_logits) { + if (t != nullptr) { + ggml_set_output(t); + } + } + for (auto & [seq_id, t] : t_candidates) { + if (t != nullptr) { + ggml_set_output(t); + } + } +} + bool llm_graph_result::can_reuse(const llm_graph_params & params) { if (!this->params.allow_reuse(params)) { if (debug > 1) { @@ -646,6 +720,7 @@ llm_graph_context::llm_graph_context(const llm_graph_params & params) : loras (params.loras), mctx (params.mctx), cross (params.cross), + samplers (params.samplers), cb_func (params.cb), res (params.res), ctx0 (res->get_ctx()), @@ -1251,6 +1326,10 @@ ggml_tensor * llm_graph_context::build_inp_embd(ggml_tensor * tok_embd) const { res->add_input(std::move(inp)); + // make sure the produced embeddings are immediately materialized in the ggml graph + // ref: https://github.com/ggml-org/llama.cpp/pull/18599 + ggml_build_forward_expand(gf, cur); + return cur; } @@ -1834,8 +1913,10 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_kq_mask = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); ggml_set_input(inp->self_kq_mask); + ggml_set_name(inp->self_kq_mask, "self_kq_mask"); inp->self_kq_mask_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask, GGML_TYPE_F16) : inp->self_kq_mask; + ggml_set_name(inp->self_kq_mask_cnv, "self_kq_mask_cnv"); } { @@ -1848,8 +1929,10 @@ llm_graph_input_attn_kv_iswa * llm_graph_context::build_attn_inp_kv_iswa() const inp->self_kq_mask_swa = ggml_new_tensor_4d(ctx0, GGML_TYPE_F32, n_kv, n_tokens/n_stream, 1, n_stream); ggml_set_input(inp->self_kq_mask_swa); + ggml_set_name(inp->self_kq_mask_swa, "self_kq_mask_swa"); inp->self_kq_mask_swa_cnv = cparams.flash_attn ? ggml_cast(ctx0, inp->self_kq_mask_swa, GGML_TYPE_F16) : inp->self_kq_mask_swa; + ggml_set_name(inp->self_kq_mask_swa_cnv, "self_kq_mask_swa_cnv"); } return (llm_graph_input_attn_kv_iswa *) res->add_input(std::move(inp)); @@ -1988,14 +2071,18 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { void llm_graph_context::build_dense_out( ggml_tensor * dense_2, ggml_tensor * dense_3) const { - if (!cparams.embeddings || dense_2 == nullptr || dense_3 == nullptr) { + if (!cparams.embeddings || !(dense_2 || dense_3)) { return; } ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd; GGML_ASSERT(cur != nullptr && "missing t_embd_pooled/t_embd"); - cur = ggml_mul_mat(ctx0, dense_2, cur); - cur = ggml_mul_mat(ctx0, dense_3, cur); + if (dense_2) { + cur = ggml_mul_mat(ctx0, dense_2, cur); + } + if (dense_3) { + cur = ggml_mul_mat(ctx0, dense_3, cur); + } cb(cur, "result_embd_pooled", -1); res->t_embd_pooled = cur; ggml_build_forward_expand(gf, cur); @@ -2086,6 +2173,87 @@ void llm_graph_context::build_pooling( ggml_build_forward_expand(gf, cur); } +void llm_graph_context::build_sampling() const { + if (samplers.empty() || !res->t_logits) { + return; + } + + auto inp_sampling = std::make_unique(samplers); + res->add_input(std::move(inp_sampling)); + + std::map seq_to_logit_row; + int32_t logit_row_idx = 0; + + for (uint32_t i = 0; i < ubatch.n_tokens; i++) { + if (ubatch.output[i]) { + llama_seq_id seq_id = ubatch.seq_id[i][0]; + seq_to_logit_row[seq_id] = logit_row_idx; + logit_row_idx++; + } + } + + // res->t_logits will contain logits for all tokens that want the logits calculated (logits=1 or output=1) + GGML_ASSERT(res->t_logits != nullptr && "missing t_logits tensor"); + + // add a dummy row of logits + // this trick makes the graph static, regardless of which samplers are activated + // this is important in order to minimize graph reallocations + // TODO: use `ggml_build_forward_select()` when available (https://github.com/ggml-org/llama.cpp/pull/18550) + ggml_tensor * logits_t = ggml_pad(ctx0, res->t_logits, 0, 1, 0, 0); + + for (const auto & [seq_id, sampler] : samplers) { + const auto it = seq_to_logit_row.find(seq_id); + + // inactive samplers always work on the first row + const auto row_idx = seq_to_logit_row.find(seq_id) != seq_to_logit_row.end() ? it->second : 0; + + ggml_tensor * logits_seq = ggml_view_1d(ctx0, logits_t, logits_t->ne[0], row_idx * logits_t->nb[1]); + ggml_format_name(logits_seq, "logits_seq_%d", seq_id); + + struct llama_sampler_data data = { + /*.logits =*/ logits_seq, + /*.probs =*/ nullptr, + /*.sampled =*/ nullptr, + /*.candidates =*/ nullptr, + }; + + assert(sampler->iface->backend_apply); + sampler->iface->backend_apply(sampler, ctx0, gf, &data); + + if (data.sampled != nullptr) { + res->t_sampled[seq_id] = data.sampled; + ggml_build_forward_expand(gf, data.sampled); + } + + if (data.probs != nullptr) { + res->t_sampled_probs[seq_id] = data.probs; + ggml_build_forward_expand(gf, data.probs); + } + + if (data.logits != nullptr) { + res->t_sampled_logits[seq_id] = data.logits; + ggml_build_forward_expand(gf, data.logits); + } + + if (data.candidates != nullptr) { + res->t_candidates[seq_id] = data.candidates; + ggml_build_forward_expand(gf, data.candidates); + } + } + + // TODO: Call llama_sampler_accept_ggml after all samplers have been applied. + /* + for (const auto & [seq_id, sampler] : samplers) { + if (auto it = res->t_sampled.find(seq_id); it != res->t_sampled.end()) { + ggml_tensor * selected_token = it->second; + if (selected_token != nullptr) { + llama_sampler_accept_ggml(sampler, ctx0, gf, selected_token); + } + } + } + */ +} + int32_t llama_relative_position_bucket(llama_pos x, llama_pos y, uint64_t n_buckets, bool bidirectional) { // TODO move to hparams if a T5 variant appears that uses a different value const int64_t max_distance = 128; diff --git a/src/llama-graph.h b/src/llama-graph.h index 81ac329cc3..503ffd695a 100644 --- a/src/llama-graph.h +++ b/src/llama-graph.h @@ -10,6 +10,7 @@ #include #include #include +#include struct ggml_cgraph; struct ggml_context; @@ -396,6 +397,18 @@ public: const llama_memory_hybrid_context * mctx; }; +class llm_graph_input_sampling : public llm_graph_input_i { +public: + llm_graph_input_sampling(std::map samplers) : + samplers(std::move(samplers)) { } + virtual ~llm_graph_input_sampling() = default; + + void set_input(const llama_ubatch * ubatch) override; + bool can_reuse(const llm_graph_params & params) override; + + std::map samplers; +}; + // // llm_graph_result // @@ -429,6 +442,23 @@ struct llm_graph_params { const llama_memory_context_i * mctx; const llama_cross * cross; + std::map samplers; + + static bool samplers_equal( + const std::map & lhs, + const std::map & rhs) { + if (lhs.size() != rhs.size()) { + return false; + } + for (const auto & [seq_id, sampler] : lhs) { + auto it = rhs.find(seq_id); + if (it == rhs.end() || it->second != sampler) { + return false; + } + } + return true; + } + uint32_t n_outputs; llm_graph_cb cb; @@ -468,15 +498,36 @@ struct llm_graph_params { return false; } + if (n_outputs != other.n_outputs) { + return false; + } + + if (!samplers_equal(samplers, other.samplers)) { + return false; + } + + if (samplers.size() > 0) { + if (!ubatch.data || !other.ubatch.data) { + return false; + } + + // check that the outputs are the same for all samplers + for (uint32_t i = 0; i < ubatch.n_tokens; ++i) { + if (ubatch.output[i] != other.ubatch.output[i] || + ubatch.seq_id[i][0] != other.ubatch.seq_id[i][0]) { + return false; + } + } + } + return cparams.embeddings == other.cparams.embeddings && cparams.causal_attn == other.cparams.causal_attn && - arch == other.arch && - gtype == other.gtype && - cvec == other.cvec && - loras == other.loras && - cross == other.cross && - n_outputs == other.n_outputs; + arch == other.arch && + gtype == other.gtype && + cvec == other.cvec && + loras == other.loras && + cross == other.cross; } }; @@ -499,6 +550,7 @@ public: void reset(); void set_inputs(const llama_ubatch * ubatch); + void set_outputs(); // try to update the existing graph result using the new graph parameters in order to reuse it // this can only be done if we determine that the resulting graph using the new graph parameters @@ -517,6 +569,11 @@ public: ggml_tensor * t_embd = nullptr; ggml_tensor * t_embd_pooled = nullptr; + std::map t_sampled_logits; + std::map t_candidates; + std::map t_sampled; + std::map t_sampled_probs; + std::vector inputs; ggml_context_ptr ctx_compute; @@ -592,6 +649,8 @@ struct llm_graph_context { const llama_memory_context_i * mctx; const llama_cross * cross; + std::map samplers; + const llm_graph_cb & cb_func; llm_graph_result * res; @@ -832,6 +891,12 @@ struct llm_graph_context { ggml_tensor * cls_out, ggml_tensor * cls_out_b) const; + // + // sampling (backend sampling) + // + + void build_sampling() const; + // // dense (out) // diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp index fe1fa4341d..c847ef91b7 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp @@ -72,6 +72,10 @@ uint32_t llama_hparams::n_embd_inp() const { return n_embd_inp; } +uint32_t llama_hparams::get_n_embd_out() const { + return n_embd_out > 0 ? n_embd_out : n_embd; +} + uint32_t llama_hparams::n_embd_k_gqa(uint32_t il) const { const uint32_t n_head_kv = this->n_head_kv(il); diff --git a/src/llama-hparams.h b/src/llama-hparams.h index 42def73f06..7ae3ec292e 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -105,9 +105,9 @@ struct llama_hparams { float rope_attn_factor = 1.0f; float rope_freq_base_train; - float rope_freq_base_train_swa; + float rope_freq_base_train_swa = 10000.0f; float rope_freq_scale_train; - float rope_freq_scale_train_swa; + float rope_freq_scale_train_swa = 1.0f; uint32_t n_ctx_orig_yarn; float rope_yarn_log_mul = 0.0f; @@ -162,6 +162,9 @@ struct llama_hparams { // for Classifiers uint32_t n_cls_out = 1; + // output embedding dimension (0 = use n_embd) + uint32_t n_embd_out = 0; + // llama4 smallthinker uint32_t n_moe_layer_step = 0; uint32_t n_no_rope_layer_step = 4; @@ -234,6 +237,9 @@ struct llama_hparams { // dimension of main + auxiliary input embeddings uint32_t n_embd_inp() const; + // dimension of output embeddings + uint32_t get_n_embd_out() const; + // dimension of key embeddings across all k-v heads uint32_t n_embd_k_gqa(uint32_t il = 0) const; diff --git a/src/llama-model-saver.cpp b/src/llama-model-saver.cpp index 563823dc35..ae27c71ce2 100644 --- a/src/llama-model-saver.cpp +++ b/src/llama-model-saver.cpp @@ -146,6 +146,9 @@ void llama_model_saver::add_kv_from_model() { add_kv(LLM_KV_VOCAB_SIZE, vocab.n_tokens()); add_kv(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); add_kv(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); + if (hparams.n_embd_out > 0) { + add_kv(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out); + } add_kv(LLM_KV_BLOCK_COUNT, hparams.n_layer); add_kv(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); add_kv(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, true); diff --git a/src/llama-model.cpp b/src/llama-model.cpp index 5e664c8c57..04c48b5fd3 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp @@ -126,6 +126,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_31B_A3_5B: return "31B.A3.5B"; case LLM_TYPE_80B_A3B: return "80B.A3B"; case LLM_TYPE_100B_A6B: return "100B.A6B"; + case LLM_TYPE_102B_A12B: return "102B.A12B"; case LLM_TYPE_106B_A12B: return "106B.A12B"; case LLM_TYPE_230B_A10B: return "230B.A10B"; case LLM_TYPE_235B_A22B: return "235B.A22B"; @@ -506,6 +507,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_CONTEXT_LENGTH, hparams.n_ctx_train); ml.get_key(LLM_KV_EMBEDDING_LENGTH, hparams.n_embd); + ml.get_key(LLM_KV_EMBEDDING_LENGTH_OUT, hparams.n_embd_out, false); ml.get_key(LLM_KV_BLOCK_COUNT, hparams.n_layer); ml.get_key(LLM_KV_EXPERT_COUNT, hparams.n_expert, false); ml.get_key(LLM_KV_EXPERT_USED_COUNT, hparams.n_expert_used, false); @@ -577,6 +579,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.rope_scaling_type_train = llama_rope_scaling_type_from_string(rope_scaling); GGML_ASSERT(hparams.rope_scaling_type_train != LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED); + // TODO: Handle SWA metadata similarly when models start implementing it // rope_freq_scale (inverse of the kv) is optional float ropescale = 0.0f; if (!ml.get_key(LLM_KV_ROPE_SCALING_FACTOR, ropescale, false)) { @@ -585,10 +588,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { } hparams.rope_freq_scale_train = ropescale == 0.0f ? 1.0f : 1.0f/ropescale; - // by default assume that the sliding-window layers use the same scaling type as the non-sliding-window layers - hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; - hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; - ml.get_key(LLM_KV_ROPE_SCALING_ATTN_FACTOR, hparams.rope_attn_factor, false); // non-transformer models do not have attention heads @@ -676,6 +675,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.f_attn_temp_scale = 0.1f; hparams.f_attn_temp_offset = 1.0f; hparams.set_swa_pattern(4); // pattern: 3 chunked - 1 full + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); } switch (hparams.n_expert) { @@ -721,6 +724,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { if (hparams.n_swa > 0) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.set_swa_pattern(4); + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); } else { hparams.swa_type = LLAMA_SWA_TYPE_NONE; } @@ -1109,6 +1116,14 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_MAINCODER: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_1B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; case LLM_ARCH_QWEN3VL: { ml.get_key(LLM_KV_NUM_DEEPSTACK_LAYERS, hparams.n_deepstack_layers, false); @@ -1234,7 +1249,6 @@ void llama_model::load_hparams(llama_model_loader & ml) { if (found_swa && hparams.n_swa > 0) { uint32_t swa_period = 8; hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; - hparams.rope_freq_scale_train_swa = 1.0f; ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa); ml.get_key_or_arr(LLM_KV_ATTENTION_SLIDING_WINDOW_PATTERN, swa_period, false); hparams.set_swa_pattern(swa_period); @@ -1300,7 +1314,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.n_swa = 4096; // default value of gemma 2 hparams.set_swa_pattern(2); hparams.attn_soft_cap = true; + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_ATTN_LOGIT_SOFTCAPPING, hparams.f_attn_logit_softcapping, false); @@ -1325,8 +1342,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.set_swa_pattern(6); - hparams.rope_freq_base_train_swa = 10000.0f; - hparams.rope_freq_scale_train_swa = 1.0f; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); } else { hparams.swa_type = LLAMA_SWA_TYPE_NONE; } @@ -1356,10 +1372,9 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.set_swa_pattern(5); hparams.n_layer_kv_from_start = 20; - hparams.rope_freq_base_train_swa = 10000.0f; - hparams.rope_freq_scale_train_swa = 1.0f; hparams.f_attention_scale = 1.0f; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); @@ -1375,9 +1390,8 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.set_swa_pattern(6); hparams.causal_attn = false; // embeddings do not use causal attention - hparams.rope_freq_base_train_swa = 10000.0f; - hparams.rope_freq_scale_train_swa = 1.0f; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); @@ -1516,7 +1530,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.set_swa_pattern(4); + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_LOGIT_SCALE, hparams.f_logit_scale); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); @@ -1555,6 +1572,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { if (found_swa && hparams.n_swa > 0) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.set_swa_pattern(4); + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = 1.0; // See olmo2.cpp + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); } else { hparams.swa_type = LLAMA_SWA_TYPE_NONE; } @@ -1682,7 +1703,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH_MLA, hparams.n_embd_head_v_mla, false); ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); ml.get_key(LLM_KV_EXPERT_SHARED_COUNT, hparams.n_expert_shared); - ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale); + ml.get_key(LLM_KV_EXPERT_WEIGHTS_SCALE, hparams.expert_weights_scale, false); ml.get_key(LLM_KV_EXPERT_WEIGHTS_NORM, hparams.expert_weights_norm, false); ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func, false); if (hparams.expert_gating_func == LLAMA_EXPERT_GATING_FUNC_TYPE_NONE) { @@ -1778,6 +1799,7 @@ void llama_model::load_hparams(llama_model_loader & ml) { switch (hparams.n_layer) { case 47: type = LLM_TYPE_106B_A12B; break; // GLM-4.5-Air (46 layers + 1 NextN layer) + case 48: type = LLM_TYPE_102B_A12B; break; // Solar Open case 93: type = LLM_TYPE_355B_A32B; break; // GLM-4.5 (92 layers + 1 NextN layer) default: type = LLM_TYPE_UNKNOWN; } @@ -1896,6 +1918,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.n_swa = 4096; hparams.set_swa_pattern(4); + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); } ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); @@ -2198,6 +2224,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.set_swa_pattern(2); + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); + switch (hparams.n_layer) { case 24: type = LLM_TYPE_20B; break; case 36: type = LLM_TYPE_120B; break; @@ -2242,6 +2272,10 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.swa_type = LLAMA_SWA_TYPE_STANDARD; hparams.n_swa = 4096; hparams.set_swa_pattern(4, true); + + hparams.rope_freq_base_train_swa = hparams.rope_freq_base_train; + hparams.rope_freq_scale_train_swa = hparams.rope_freq_scale_train; + ml.get_key(LLM_KV_ROPE_FREQ_BASE_SWA, hparams.rope_freq_base_train_swa, false); } else { hparams.swa_type = LLAMA_SWA_TYPE_NONE; hparams.n_no_rope_layer_step = hparams.n_layer; @@ -3320,7 +3354,14 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.attn_norm_2_b = create_tensor(tn(LLM_TENSOR_ATTN_NORM_2, "bias", i), {n_embd}, TENSOR_NOT_REQUIRED); layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, TENSOR_NOT_REQUIRED); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, layer.ffn_gate ? n_ff : n_ff * 2}, 0); + + const auto tn_ffn_up_weight = tn(LLM_TENSOR_FFN_UP, "weight", i); + ggml_tensor * t_ffn_up = ml.get_tensor_meta(tn_ffn_up_weight.str().c_str()); + const int64_t n_ffn_up = t_ffn_up ? t_ffn_up->ne[1] : n_ff; + + GGML_ASSERT(n_ffn_up == n_ff || n_ffn_up == n_ff * 2); + layer.ffn_up = create_tensor(tn_ffn_up_weight, {n_embd, n_ffn_up}, 0); + layer.ffn_up_b = create_tensor(tn(LLM_TENSOR_FFN_UP, "bias", i), {n_ffn_up}, TENSOR_NOT_REQUIRED); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), {n_ff, n_embd}, 0); layer.ffn_down_b = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "bias", i), {n_embd}, 0); @@ -4776,7 +4817,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -4839,7 +4884,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // output output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); - output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, 0); + // try to load output.weight, if not found, use token_embd (tied embeddings) + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + if (!output) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -5206,9 +5255,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, flags); layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, flags); layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, flags); - layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, flags); - layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, flags); - layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, flags); + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd_head_k * n_head }, TENSOR_NOT_REQUIRED | flags); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_k_gqa }, TENSOR_NOT_REQUIRED | flags); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_v_gqa }, TENSOR_NOT_REQUIRED | flags); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, flags); @@ -6421,6 +6470,9 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.shortconv.out_proj = create_tensor(tn(LLM_TENSOR_SHORTCONV_OUTPROJ, "weight", i), {n_embd, n_embd}, 0); } } + + // for LFM2-ColBert-350M + dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.get_n_embd_out()}, TENSOR_NOT_REQUIRED); } break; case LLM_ARCH_SMALLTHINKER: { @@ -6761,6 +6813,37 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, TENSOR_NOT_REQUIRED); } } break; + case LLM_ARCH_MAINCODER: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED); + // if output is NULL, init from the input tok embed + if (output == NULL) { + output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); + } + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), {n_embd, n_embd_head_k * n_head}, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), {n_embd, n_embd_gqa}, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {n_embd_head_k * n_head, n_embd}, 0); + + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {n_embd_head_k}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {n_embd_head_k}, 0); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -7042,6 +7125,10 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: rope scaling = %s\n", __func__, rope_scaling_type.c_str()); LLAMA_LOG_INFO("%s: freq_base_train = %.1f\n", __func__, hparams.rope_freq_base_train); LLAMA_LOG_INFO("%s: freq_scale_train = %g\n", __func__, hparams.rope_freq_scale_train); + if (hparams.swa_type != LLAMA_SWA_TYPE_NONE) { + LLAMA_LOG_INFO("%s: freq_base_swa = %.1f\n", __func__, hparams.rope_freq_base_train_swa); + LLAMA_LOG_INFO("%s: freq_scale_swa = %g\n", __func__, hparams.rope_freq_scale_train_swa); + } LLAMA_LOG_INFO("%s: n_ctx_orig_yarn = %u\n", __func__, hparams.n_ctx_orig_yarn); LLAMA_LOG_INFO("%s: rope_yarn_log_mul= %.4f\n", __func__, hparams.rope_yarn_log_mul); LLAMA_LOG_INFO("%s: rope_finetuned = %s\n", __func__, hparams.rope_finetuned ? "yes" : "unknown"); @@ -7406,6 +7493,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique>(*this, params); } break; + case LLM_ARCH_MAINCODER: + { + llm = std::make_unique(*this, params); + } break; case LLM_ARCH_DECI: { llm = std::make_unique(*this, params); @@ -7440,7 +7531,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { } break; case LLM_ARCH_MODERN_BERT: { - llm = std::make_unique>(*this, params); + llm = std::make_unique(*this, params); } break; case LLM_ARCH_NEO_BERT: { @@ -7850,12 +7941,17 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { // add on pooling layer llm->build_pooling(cls, cls_b, cls_out, cls_out_b); + // add backend sampling layers (if any) + llm->build_sampling(); + // if the gguf model was converted with --sentence-transformers-dense-modules // there will be two additional dense projection layers // dense linear projections are applied after pooling // TODO: move reranking logic here and generalize llm->build_dense_out(dense_2_out_layers, dense_3_out_layers); + llm->res->set_outputs(); + return llm->res->get_gf(); } @@ -7911,6 +8007,10 @@ int32_t llama_model_n_embd_inp(const llama_model * model) { return model->hparams.n_embd_inp(); } +int32_t llama_model_n_embd_out(const llama_model * model) { + return model->hparams.get_n_embd_out(); +} + int32_t llama_model_n_layer(const llama_model * model) { return model->hparams.n_layer; } @@ -8014,6 +8114,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_ERNIE4_5_MOE: case LLM_ARCH_MISTRAL3: case LLM_ARCH_LLAMA_EMBED: + case LLM_ARCH_MAINCODER: return LLAMA_ROPE_TYPE_NORM; // the pairs of head values are offset by n_rot/2 diff --git a/src/llama-model.h b/src/llama-model.h index f4f44a92b6..79200a0d97 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -119,6 +119,7 @@ enum llm_type { LLM_TYPE_31B_A3_5B, LLM_TYPE_80B_A3B, // Qwen3 Next LLM_TYPE_100B_A6B, + LLM_TYPE_102B_A12B, // Solar-Open LLM_TYPE_106B_A12B, // GLM-4.5-Air LLM_TYPE_230B_A10B, // Minimax M2 LLM_TYPE_235B_A22B, diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp index f3891453e4..48291a3a7c 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -4,6 +4,8 @@ #include "llama-vocab.h" #include "llama-grammar.h" +#include "ggml-cpp.h" + #include #include #include @@ -346,7 +348,9 @@ static uint32_t get_rng_seed(uint32_t seed) { // llama_sampler API -struct llama_sampler * llama_sampler_init(const struct llama_sampler_i * iface, llama_sampler_context_t ctx) { +struct llama_sampler * llama_sampler_init( + struct llama_sampler_i * iface, + llama_sampler_context_t ctx) { return new llama_sampler { /* .iface = */ iface, /* .ctx = */ ctx, @@ -421,6 +425,202 @@ void llama_sampler_free(struct llama_sampler * smpl) { delete smpl; } +// empty sampler + +struct llama_sampler_empty { + const char * name; +}; + +static struct llama_sampler * llama_sampler_init_empty(const char * name); + +static const char * llama_sampler_empty_name(const struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_empty *) smpl->ctx; + return ctx->name; +} + +static void llama_sampler_empty_accept(struct llama_sampler * smpl, llama_token token) { + GGML_UNUSED(smpl); + GGML_UNUSED(token); +} + +static void llama_sampler_empty_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { + GGML_UNUSED(smpl); + GGML_UNUSED(cur_p); +} + +static void llama_sampler_empty_reset(struct llama_sampler * smpl) { + GGML_UNUSED(smpl); +} + +static struct llama_sampler * llama_sampler_empty_clone(const struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_empty *) smpl->ctx; + return llama_sampler_init_empty(ctx->name); +} + +static void llama_sampler_empty_free(struct llama_sampler * smpl) { + delete (llama_sampler_empty *) smpl->ctx; +} + +static bool llama_sampler_empty_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + GGML_UNUSED(smpl); + GGML_UNUSED(buft); + + return true; +} + +static void llama_sampler_empty_backend_accept( + struct llama_sampler * smpl, + ggml_context * ctx, + ggml_cgraph * gf, + struct ggml_tensor * selected_token) { + GGML_UNUSED(smpl); + GGML_UNUSED(ctx); + GGML_UNUSED(gf); + GGML_UNUSED(selected_token); +} + +static void llama_sampler_empty_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + GGML_UNUSED(smpl); + GGML_UNUSED(ctx); + GGML_UNUSED(gf); + GGML_UNUSED(data); +} + +static void llama_sampler_empty_backend_set_input(struct llama_sampler * smpl) { + GGML_UNUSED(smpl); +} + +static struct llama_sampler_i llama_sampler_empty_i = { + /* .name = */ llama_sampler_empty_name, + /* .accept = */ llama_sampler_empty_accept, + /* .apply = */ llama_sampler_empty_apply, + /* .reset = */ llama_sampler_empty_reset, + /* .clone = */ llama_sampler_empty_clone, + /* .free = */ llama_sampler_empty_free, + /* .backend_init = */ llama_sampler_empty_backend_init, + /* .backend_accept = */ llama_sampler_empty_backend_accept, + /* .backend_apply = */ llama_sampler_empty_backend_apply, + /* .backend_set_input = */ llama_sampler_empty_backend_set_input, +}; + +struct llama_sampler * llama_sampler_init_empty(const char * name) { + return llama_sampler_init( + /* .iface = */ &llama_sampler_empty_i, + /* .ctx = */ new llama_sampler_empty { + /* .name = */ name, + } + ); +} + +// common backend sampler functionality +// +// +name : means that the sampler is support and will run on the backend +// -name : means that a ggml operator is not supported by the backend +// +struct llama_sampler_backend { + llama_sampler_backend(const char * name) : name(name), name_ext(name), is_init(false), support(false) {} + + const char * get_name() { + if (!is_init) { + return name.c_str(); + } + + if (support) { + name_ext = "+" + name; + } else { + name_ext = "-" + name; + } + + return name_ext.c_str(); + } + + void init(bool support) { + GGML_ASSERT(this->is_init == false); + + this->is_init = true; + this->support = support; + } + +private: + std::string name; + std::string name_ext; + + bool is_init; + bool support; +}; + +// check if all ggml ops used by the sampler are supported by the backend +static bool llama_sampler_backend_support( + llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * device = ggml_backend_buft_get_device(buft); + if (!device) { + // CPU backend always supported + return true; + } + + ggml_init_params params = { + /*.mem_size =*/ 128*ggml_tensor_overhead() + ggml_graph_overhead(), + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context_ptr ctx_ptr { ggml_init(params) }; + if (!ctx_ptr) { + throw std::runtime_error(format("failed to create ggml context")); + } + + ggml_context * ctx = ctx_ptr.get(); + + const int64_t n = 1024*1024; + + llama_sampler_data data = { + /*.logits = */ ggml_new_tensor_1d(ctx, GGML_TYPE_F32, n), + /*.probs = */ nullptr, + /*.sampled = */ nullptr, + /*.candidates = */ ggml_new_tensor_1d(ctx, GGML_TYPE_I32, n), + }; + + ggml_cgraph * gf = ggml_new_graph(ctx); + + smpl->iface->backend_apply(smpl, ctx, gf, &data); + + if (data.logits) { + ggml_build_forward_expand(gf, data.logits); + } + + if (data.probs) { + ggml_build_forward_expand(gf, data.probs); + } + + if (data.sampled) { + ggml_build_forward_expand(gf, data.sampled); + } + + if (data.candidates) { + ggml_build_forward_expand(gf, data.candidates); + } + + for (int i = 0; i < ggml_graph_n_nodes(gf); i++) { + struct ggml_tensor * op = ggml_graph_node(gf, i); + + if (!ggml_backend_dev_supports_op(device, op)) { + LLAMA_LOG_WARN("%s: device '%s' does not have support for op %s needed for sampler '%s'\n", + __func__, ggml_backend_dev_name(device), ggml_op_name(op->op), smpl->iface->name(smpl)); + + return false; + } + } + + return true; +} + // sampler chain static const char * llama_sampler_chain_name(const struct llama_sampler * /*smpl*/) { @@ -432,8 +632,8 @@ static void llama_sampler_chain_accept(struct llama_sampler * smpl, llama_token time_meas tm(chain->t_sample_us, chain->params.no_perf); - for (auto * smpl : chain->samplers) { - llama_sampler_accept(smpl, token); + for (auto & smpl : chain->samplers) { + llama_sampler_accept(smpl.ptr, token); } chain->n_sample++; @@ -444,16 +644,28 @@ static void llama_sampler_chain_apply(struct llama_sampler * smpl, llama_token_d time_meas tm(chain->t_sample_us, chain->params.no_perf); - for (auto * smpl : chain->samplers) { - llama_sampler_apply(smpl, cur_p); + bool is_backend = chain->is_init; + + for (auto & smpl : chain->samplers) { + if (is_backend && smpl.is_backend) { + continue; + } + + is_backend = false; + + if (smpl.ptr->iface->apply == nullptr) { + continue; + } + + llama_sampler_apply(smpl.ptr, cur_p); } } static void llama_sampler_chain_reset(struct llama_sampler * smpl) { auto * chain = (llama_sampler_chain *) smpl->ctx; - for (auto * smpl : chain->samplers) { - llama_sampler_reset(smpl); + for (auto & smpl : chain->samplers) { + llama_sampler_reset(smpl.ptr); } } @@ -462,8 +674,8 @@ static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampl auto * result = llama_sampler_chain_init(chain_src->params); - for (auto * smpl : chain_src->samplers) { - llama_sampler_chain_add(result, llama_sampler_clone(smpl)); + for (const auto & smpl : chain_src->samplers) { + llama_sampler_chain_add(result, llama_sampler_clone(smpl.ptr)); } return result; @@ -472,20 +684,109 @@ static struct llama_sampler * llama_sampler_chain_clone(const struct llama_sampl static void llama_sampler_chain_free(struct llama_sampler * smpl) { auto * chain = (llama_sampler_chain *) smpl->ctx; - for (auto * smpl : chain->samplers) { - llama_sampler_free(smpl); + for (auto & smpl : chain->samplers) { + llama_sampler_free(smpl.ptr); } delete chain; } +static bool llama_sampler_chain_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + GGML_ASSERT(chain->is_init == false && "llama_sampler_chain_backend_init() called twice"); + + chain->is_init = true; + + bool res = true; + + for (auto & smpl : chain->samplers) { + bool res_cur = true; + + // to be able to run a sampler on the backend, it has to: + // - have the .backend_init() API implemented + // - return true during .backend_init() + if (smpl.ptr->iface->backend_init) { + if (!smpl.ptr->iface->backend_init(smpl.ptr, buft)) { + res_cur = false; + } + } else { + res_cur = false; + } + + smpl.is_backend = res_cur; + + res = res && res_cur; + } + + return res; +} + +static void llama_sampler_chain_backend_accept( + struct llama_sampler * smpl, + ggml_context * ctx, + ggml_cgraph * gf, + struct ggml_tensor * selected_token) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto & smpl : chain->samplers) { + if (!smpl.is_backend) { + break; + } + + if (smpl.ptr->iface->backend_accept) { + smpl.ptr->iface->backend_accept(smpl.ptr, ctx, gf, selected_token); + } + } +} + +static void llama_sampler_chain_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + GGML_ASSERT(chain->is_init && "llama_sampler_chain_backend_init() not called"); + + for (auto & smpl : chain->samplers) { + if (!smpl.is_backend) { + break; + } + + if (smpl.ptr->iface->backend_apply) { + smpl.ptr->iface->backend_apply(smpl.ptr, ctx, gf, data); + } + } +} + +static void llama_sampler_chain_backend_set_input(struct llama_sampler * smpl) { + auto * chain = (llama_sampler_chain *) smpl->ctx; + + for (auto & smpl : chain->samplers) { + if (!smpl.is_backend) { + break; + } + + if (smpl.ptr->iface->backend_set_input) { + smpl.ptr->iface->backend_set_input(smpl.ptr); + } + } +} + static struct llama_sampler_i llama_sampler_chain_i = { - /* .name = */ llama_sampler_chain_name, - /* .accept = */ llama_sampler_chain_accept, - /* .apply = */ llama_sampler_chain_apply, - /* .reset = */ llama_sampler_chain_reset, - /* .clone = */ llama_sampler_chain_clone, - /* .free = */ llama_sampler_chain_free, + /* .name = */ llama_sampler_chain_name, + /* .accept = */ llama_sampler_chain_accept, + /* .apply = */ llama_sampler_chain_apply, + /* .reset = */ llama_sampler_chain_reset, + /* .clone = */ llama_sampler_chain_clone, + /* .free = */ llama_sampler_chain_free, + /* .backend_init = */ llama_sampler_chain_backend_init, + /* .backend_accept = */ llama_sampler_chain_backend_accept, + /* .backend_apply = */ llama_sampler_chain_backend_apply, + /* .backend_set_input = */ llama_sampler_chain_backend_set_input, }; struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_params params) { @@ -493,6 +794,7 @@ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_param /* .iface = */ &llama_sampler_chain_i, /* .ctx = */ new llama_sampler_chain { /* .params = */ params, + /* .is_init = */ false, /* .samplers = */ {}, /* .cur = */ {}, /* .t_sample_us = */ 0, @@ -502,7 +804,16 @@ struct llama_sampler * llama_sampler_chain_init(struct llama_sampler_chain_param } llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_context * ctx, int32_t idx) { - const auto * logits = llama_get_logits_ith(ctx, idx); + const llama_token sampled_token = llama_get_sampled_token_ith (ctx, idx); + const float * sampled_probs = llama_get_sampled_probs_ith (ctx, idx); + const float * sampled_logits = llama_get_sampled_logits_ith (ctx, idx); + const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx); + + // If a backend sampler has already sampled a token, return it. + if (sampled_token != LLAMA_TOKEN_NULL) { + LLAMA_LOG_DEBUG("%s: Backend sampler selected token for idx %d. Skipping CPU samplers\n", __func__, idx); + return sampled_token; + } const llama_model * model = llama_get_model(ctx); const llama_vocab * vocab = llama_model_get_vocab(model); @@ -521,9 +832,26 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte } auto & cur = *cur_ptr; - cur.resize(n_vocab); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + + if (sampled_probs) { + const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith(ctx, idx); + cur.resize(sampled_probs_count); + for (uint32_t i = 0; i < sampled_probs_count; ++i) { + cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]}; + } + } else if (sampled_logits) { + const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith(ctx, idx); + cur.resize(sampled_logits_count); + for (llama_token i = 0; i < (int)sampled_logits_count; i++) { + cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f}; + } + } else { + const auto * logits = llama_get_logits_ith(ctx, idx); + GGML_ASSERT(logits != nullptr); + cur.resize(n_vocab); + for (llama_token token_id = 0; token_id < n_vocab; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } } llama_token_data_array cur_p = { @@ -544,19 +872,35 @@ llama_token llama_sampler_sample(struct llama_sampler * smpl, struct llama_conte return token; } + void llama_sampler_chain_add(struct llama_sampler * chain, struct llama_sampler * smpl) { auto * p = (llama_sampler_chain *) chain->ctx; - p->samplers.push_back(smpl); + p->samplers.push_back({ + /* .is_backend = */ false, + /* .ptr = */ smpl, + }); } -struct llama_sampler * llama_sampler_chain_get(const struct llama_sampler * chain, int32_t i) { +struct llama_sampler * llama_sampler_chain_get(struct llama_sampler * chain, int32_t i) { + if (chain == nullptr) { + return nullptr; + } + + if (chain->iface != &llama_sampler_chain_i) { + return nullptr; + } + + if (i == -1) { + return chain; + } + const auto * p = (const llama_sampler_chain *) chain->ctx; if (i < 0 || (size_t) i >= p->samplers.size()) { return nullptr; } - return p->samplers[i]; + return p->samplers[i].ptr; } struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, int32_t i) { @@ -566,7 +910,7 @@ struct llama_sampler * llama_sampler_chain_remove(struct llama_sampler * chain, return nullptr; } - auto * result = p->samplers[i]; + auto * result = p->samplers[i].ptr; p->samplers.erase(p->samplers.begin() + i); return result; @@ -584,8 +928,36 @@ int llama_sampler_chain_n(const struct llama_sampler * chain) { // greedy -static const char * llama_sampler_greedy_name(const struct llama_sampler * /*smpl*/) { - return "greedy"; +struct llama_sampler_greedy : public llama_sampler_backend { +}; + +static const char * llama_sampler_greedy_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_greedy *) smpl->ctx; + return sctx->get_name(); +} + +static void llama_sampler_greedy_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_greedy *) smpl->ctx; + GGML_UNUSED(ctx); +} + +static struct llama_sampler * llama_sampler_greedy_clone(const struct llama_sampler * smpl) { + const auto * ctx = (const llama_sampler_greedy *) smpl->ctx; + auto * result = llama_sampler_init_greedy(); + + // copy the state + { + auto * result_ctx = (llama_sampler_greedy *) result->ctx; + + GGML_UNUSED(ctx); + GGML_UNUSED(result_ctx); + } + + return result; +} + +static void llama_sampler_greedy_free(struct llama_sampler * smpl) { + delete (llama_sampler_greedy *) smpl->ctx; } static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_token_data_array * cur_p) { @@ -597,33 +969,72 @@ static void llama_sampler_greedy_apply(struct llama_sampler * /*smpl*/, llama_to } } +static bool llama_sampler_greedy_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_greedy *) smpl->ctx; + + const bool res = llama_sampler_backend_support(smpl, buft); + + sctx->init(res); + + return res; +} + +static void llama_sampler_greedy_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + GGML_UNUSED(gf); + GGML_UNUSED(smpl); + + struct ggml_tensor * curl = ggml_argmax(ctx, data->logits); + ggml_set_name(curl, "greedy_argmax"); + + data->sampled = curl; +} + static struct llama_sampler_i llama_sampler_greedy_i = { - /* .name = */ llama_sampler_greedy_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_greedy_apply, - /* .reset = */ nullptr, - /* .clone = */ nullptr, - /* .free = */ nullptr, + /* .name = */ llama_sampler_greedy_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_greedy_apply, + /* .reset = */ llama_sampler_greedy_reset, + /* .clone = */ llama_sampler_greedy_clone, + /* .free = */ llama_sampler_greedy_free, + /* .backend_init = */ llama_sampler_greedy_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_greedy_backend_apply, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_greedy() { return llama_sampler_init( /* .iface = */ &llama_sampler_greedy_i, - /* .ctx = */ nullptr + /* .ctx = */ new llama_sampler_greedy { + ("greedy"), + } ); } // dist -struct llama_sampler_dist { +struct llama_sampler_dist : public llama_sampler_backend { const uint32_t seed; uint32_t seed_cur; std::mt19937 rng; + + // backend input + struct ggml_tensor * inp_uniform; + + ggml_context_ptr inp_ctx; + ggml_backend_buffer_ptr inp_buf; }; -static const char * llama_sampler_dist_name(const struct llama_sampler * /*smpl*/) { - return "dist"; +static const char * llama_sampler_dist_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_dist *) smpl->ctx; + return sctx->get_name(); } static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -698,6 +1109,12 @@ static void llama_sampler_dist_apply(struct llama_sampler * smpl, llama_token_da #endif } +static void llama_sampler_dist_reset(struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_dist *) smpl->ctx; + ctx->seed_cur = get_rng_seed(ctx->seed); + ctx->rng.seed(ctx->seed_cur); +} + static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sampler * smpl) { const auto * ctx = (const llama_sampler_dist *) smpl->ctx; auto * result = llama_sampler_init_dist(ctx->seed); @@ -712,23 +1129,127 @@ static struct llama_sampler * llama_sampler_dist_clone(const struct llama_sample return result; } -static void llama_sampler_dist_reset(struct llama_sampler * smpl) { - auto * ctx = (llama_sampler_dist *) smpl->ctx; - ctx->seed_cur = get_rng_seed(ctx->seed); - ctx->rng.seed(ctx->seed_cur); -} - static void llama_sampler_dist_free(struct llama_sampler * smpl) { delete (llama_sampler_dist *) smpl->ctx; } +static bool llama_sampler_dist_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_dist *) smpl->ctx; + + // allocate inputs + { + ggml_init_params params = { + /*.mem_size =*/ ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + sctx->inp_ctx.reset(ggml_init(params)); + + // Create the uniform random scalar input tensor. This will be set by + // llama_sampler_dist_backend_set_input after this graph is built. + sctx->inp_uniform = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1); + ggml_set_name (sctx->inp_uniform, "uniform"); + ggml_set_input(sctx->inp_uniform); + + // Allocate all tensors from our context to the backend + sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft)); + + ggml_backend_buffer_clear(sctx->inp_buf.get(), 0); + } + + const bool res = llama_sampler_backend_support(smpl, buft); + + sctx->init(res); + + if (!res) { + sctx->inp_ctx.reset(nullptr); + sctx->inp_buf.reset(nullptr); + } + + return res; +} + +static void llama_sampler_dist_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + GGML_UNUSED(gf); + auto * sctx = (llama_sampler_dist *) smpl->ctx; + + struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits); + ggml_set_name(probs, "dist_probs"); + + struct ggml_tensor * cumsum = ggml_cumsum(ctx, probs); + ggml_set_name(cumsum, "dist_cumsum"); + + // The uniform tensor has a random value and we subtract this tensor with + // the cumsum tensor (the uniform tensor will be broadcasted by ggml_sub). + // Recall that each entry in cumsum is the cumulative probability up to that + // index so values stay negative while the cumulative total is below the + // random value, and become zero/positive once the threshold is crossed. + struct ggml_tensor * diff = ggml_sub(ctx, cumsum, sctx->inp_uniform); + ggml_set_name(diff, "dist_cumsum"); + + // The ggml_step function produces a tensor where entries are 1 if the + // corresponding entry in diff is > 0, and 0 otherwise. So all values up to + // the index where the cumulative probability exceeds the random value are 0, + // and all entries after that are 1. + struct ggml_tensor * mask = ggml_step(ctx, diff); + ggml_set_name(mask, "dist_mask"); + + // Taking the sum of the mask gives us the sum of elements after the threshold + // we are interested in. + struct ggml_tensor * idxf = ggml_sum(ctx, mask); + ggml_set_name(idxf, "dist_index_f32"); + + // Use ggml_scale_bias to scale the index value by -1 and then add the size + // of the mask to that value so we get the correct index ((-1 * idxf) + n). + struct ggml_tensor * idx = ggml_cast(ctx, ggml_scale_bias(ctx, idxf, -1.0f, mask->ne[0]), GGML_TYPE_I32); + ggml_set_name(idx, "dist_index_i32"); + + // Map back to original vocab ids if a candidates tensor is available. + struct ggml_tensor * sampled_token = idx; + if (data->candidates != nullptr) { + struct ggml_tensor * candidates = ggml_reshape_2d(ctx, data->candidates, 1, ggml_nelements(data->candidates)); + + sampled_token = ggml_get_rows(ctx, candidates, idx); + ggml_set_name(sampled_token, "dist_sampled_token"); + } + + data->sampled = sampled_token; + data->probs = probs; +} + +static void llama_sampler_dist_backend_set_input(struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_dist *) smpl->ctx; + GGML_ASSERT(sctx->inp_uniform != nullptr); + + // We sample in double precision and cast to float to match rnd numbers of + // llama_dampler_dist which uses double precision (sampling from + // std::uniform_real_distribution and + // std::uniform_real_distribution with same rng will produce + // different sequences). + std::uniform_real_distribution dist(0.0f, 1.0f); + const float rnd = dist(sctx->rng); + + ggml_backend_tensor_set(sctx->inp_uniform, &rnd, 0, sizeof(float)); +} + static struct llama_sampler_i llama_sampler_dist_i = { - /* .name = */ llama_sampler_dist_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_dist_apply, - /* .reset = */ llama_sampler_dist_reset, - /* .clone = */ llama_sampler_dist_clone, - /* .free = */ llama_sampler_dist_free, + /* .name = */ llama_sampler_dist_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_dist_apply, + /* .reset = */ llama_sampler_dist_reset, + /* .clone = */ llama_sampler_dist_clone, + /* .free = */ llama_sampler_dist_free, + /* .backend_init = */ llama_sampler_dist_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_dist_backend_apply, + /* .backend_set_input = */ llama_sampler_dist_backend_set_input, }; struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { @@ -736,21 +1257,26 @@ struct llama_sampler * llama_sampler_init_dist(uint32_t seed) { return llama_sampler_init( /* .iface = */ &llama_sampler_dist_i, /* .ctx = */ new llama_sampler_dist { - /* .seed = */ seed, - /* .seed_cur = */ seed_cur, - /* .rng = */ std::mt19937(seed_cur), + ("dist"), + /* .seed = */ seed, + /* .seed_cur = */ seed_cur, + /* .rng = */ std::mt19937(seed_cur), + /* .inp_uniform = */ nullptr, + /* .inp_ctx = */ nullptr, + /* .inp_buf = */ nullptr, } ); } // top-k -struct llama_sampler_top_k { +struct llama_sampler_top_k : public llama_sampler_backend { const int32_t k; }; -static const char * llama_sampler_top_k_name(const struct llama_sampler * /*smpl*/) { - return "top-k"; +static const char * llama_sampler_top_k_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_top_k *) smpl->ctx; + return sctx->get_name(); } static void llama_sampler_top_k_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -767,19 +1293,69 @@ static void llama_sampler_top_k_free(struct llama_sampler * smpl) { delete (llama_sampler_top_k *) smpl->ctx; } +static bool llama_sampler_top_k_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_top_k *) smpl->ctx; + + const bool res = llama_sampler_backend_support(smpl, buft); + + sctx->init(res); + + return res; +} + +static void llama_sampler_top_k_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + auto * sctx = (llama_sampler_top_k *) smpl->ctx; + + struct ggml_tensor * top_k = ggml_top_k(ctx, data->logits, sctx->k); + ggml_set_name(top_k, "top_k"); + + if (data->candidates) { + struct ggml_tensor * candidates_rows = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]); + data->candidates = ggml_get_rows(ctx, candidates_rows, top_k); + data->candidates = ggml_reshape_1d(ctx, data->candidates, sctx->k); + ggml_set_name(data->candidates, "top_k_candidates"); + } else { + data->candidates = top_k; + } + + struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]); + struct ggml_tensor * top_k_rows = ggml_get_rows(ctx, logits_rows, top_k); + data->logits = ggml_reshape_1d(ctx, top_k_rows, sctx->k); + ggml_set_name(top_k_rows, "top_k_rows"); + + GGML_UNUSED(gf); +} + static struct llama_sampler_i llama_sampler_top_k_i = { - /* .name = */ llama_sampler_top_k_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_top_k_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_top_k_clone, - /* .free = */ llama_sampler_top_k_free, + /* .name = */ llama_sampler_top_k_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_k_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_k_clone, + /* .free = */ llama_sampler_top_k_free, + /* .backend_init = */ llama_sampler_top_k_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_top_k_backend_apply, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_top_k(int32_t k) { + const bool is_empty = (k <= 0); + + if (is_empty) { + return llama_sampler_init_empty("?top-k"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_top_k_i, /* .ctx = */ new llama_sampler_top_k { + ("top-k"), /* .k = */ k, } ); @@ -787,15 +1363,16 @@ struct llama_sampler * llama_sampler_init_top_k(int32_t k) { // top-p -struct llama_sampler_top_p { +struct llama_sampler_top_p : public llama_sampler_backend { const float p; const size_t min_keep; std::vector buf_sort; }; -static const char * llama_sampler_top_p_name(const struct llama_sampler * /*smpl*/) { - return "top-p"; +static const char * llama_sampler_top_p_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_top_p *) smpl->ctx; + return sctx->get_name(); } static void llama_sampler_top_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -862,19 +1439,118 @@ static void llama_sampler_top_p_free(struct llama_sampler * smpl) { delete (llama_sampler_top_p *) smpl->ctx; } +static bool llama_sampler_top_p_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_top_p *) smpl->ctx; + + const bool res = llama_sampler_backend_support(smpl, buft); + + sctx->init(res); + + return res; +} + +static void llama_sampler_top_p_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + auto * sctx = (llama_sampler_top_p *) smpl->ctx; + + auto ggml_sort = [ctx](struct ggml_tensor * a, struct ggml_tensor * b) { + GGML_ASSERT(ggml_nrows(a) == 1); + struct ggml_tensor * a_reshaped = ggml_reshape_2d(ctx, a, 1, a->ne[0]); + struct ggml_tensor * a_sorted = ggml_get_rows(ctx, a_reshaped, b); + return ggml_reshape_1d(ctx, a_sorted, a->ne[0]); + }; + + // Get the sorted logits in descending order. + struct ggml_tensor * sorted_idx = ggml_argsort(ctx, data->logits, GGML_SORT_ORDER_DESC); + ggml_set_name(sorted_idx, "top_p_sorted_idx"); + + // Do the sorting via reshape + get_rows + struct ggml_tensor * sorted_logits = ggml_sort(data->logits, sorted_idx); + ggml_set_name(sorted_logits, "top_p_sorted_logits"); + + struct ggml_tensor * softmax = ggml_soft_max(ctx, sorted_logits); + ggml_set_name(softmax, "top_p_softmax"); + + // If candidates are provided, sort them as well. Otherwise, set sorted indices as candidates. + if (data->candidates) { + data->candidates = ggml_sort(data->candidates, sorted_idx); + } else { + data->candidates = sorted_idx; + } + ggml_set_name(data->candidates, "top_p_candidates"); + + // Compute Cumulative Distribution Function (CDF) by means of GGML_OP_CUMSUM. + struct ggml_tensor * cdf = ggml_cumsum(ctx, softmax); + ggml_set_name(cdf, "top_p_cdf"); + + // Invert CDF and add top-p value so that ggml_step yields 1 for values we want to keep + struct ggml_tensor * cdf_scaled = ggml_scale_bias(ctx, cdf, -1.0f, sctx->p); + ggml_set_name(cdf_scaled, "top_p_cdf_scaled"); + + struct ggml_tensor * mask = ggml_step(ctx, cdf_scaled); + ggml_set_name(mask, "top_p_mask"); + + // Taking the sum of the mask gives us the sum of elements after the threshold + // we are interested in. + struct ggml_tensor * idxf = ggml_sum(ctx, mask); + ggml_set_name(idxf, "top_p_index_f32"); + + // prevent out-of-bounds access + idxf = ggml_clamp(ctx, idxf, 0.0f, mask->ne[0] - 1); + + // construct ones tensor to set the value in the mask + struct ggml_tensor * ones = ggml_scale_bias(ctx, idxf, 0.0f, 1.0f); + ggml_set_name(ones, "top_p_ones"); + + // Make top-p inclusive (i.e. return all values such that cum_sum/cdf >= p) + struct ggml_tensor * mask_reshaped = ggml_reshape_2d(ctx, mask, 1, mask->ne[0]); + + mask_reshaped = ggml_set_rows(ctx, mask_reshaped, ones, ggml_cast(ctx, idxf, GGML_TYPE_I32)); + mask = ggml_reshape_1d(ctx, mask_reshaped, mask->ne[0]); + + // Use ggml_scale_bias (output = (a * s) + b) which in this case becomes: + // top_p_bias = (mask * 1e9f) - 1e9f. + // So entries in the mask that we want to discard will become -1e9f, and + // others will be 0 (meaning that will not effect the logits). + const float large_val = 1e9f; + struct ggml_tensor * top_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val); + ggml_set_name(top_p_bias, "top_p_bias"); + + data->logits = ggml_add(ctx, sorted_logits, top_p_bias); + ggml_set_name(data->logits, "top_p_logits"); + + GGML_UNUSED(gf); +} + static struct llama_sampler_i llama_sampler_top_p_i = { - /* .name = */ llama_sampler_top_p_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_top_p_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_top_p_clone, - /* .free = */ llama_sampler_top_p_free, + /* .name = */ llama_sampler_top_p_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_p_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_p_clone, + /* .free = */ llama_sampler_top_p_free, + /* .backend_init = */ llama_sampler_top_p_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_top_p_backend_apply, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { + const bool is_empty = p >= 1.0f; + + if (is_empty) { + return llama_sampler_init_empty("?top-p"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_top_p_i, /* .ctx = */ new llama_sampler_top_p { + ("top-p"), /* .p = */ p, /* .min_keep = */ min_keep, /* .buf_sort = */ {}, @@ -884,13 +1560,14 @@ struct llama_sampler * llama_sampler_init_top_p(float p, size_t min_keep) { // min-p -struct llama_sampler_min_p { +struct llama_sampler_min_p : public llama_sampler_backend { const float p; const size_t min_keep; }; -static const char * llama_sampler_min_p_name(const struct llama_sampler * /*smpl*/) { - return "min-p"; +static const char * llama_sampler_min_p_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_min_p *) smpl->ctx; + return sctx->get_name(); } static void llama_sampler_min_p_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -956,19 +1633,85 @@ static void llama_sampler_min_p_free(struct llama_sampler * smpl) { delete (llama_sampler_min_p *) smpl->ctx; } +static bool llama_sampler_min_p_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_min_p *) smpl->ctx; + + const bool res = llama_sampler_backend_support(smpl, buft); + + sctx->init(res); + + return res; +} + +static void llama_sampler_min_p_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + auto * sctx = (llama_sampler_min_p *) smpl->ctx; + + struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits); + ggml_set_name(max_idx, "max_idx"); + + struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]); + ggml_set_name(logits_rows, "logits_rows"); + + struct ggml_tensor * max_logit = ggml_get_rows(ctx, logits_rows, max_idx); + ggml_set_name(max_logit, "max_logit"); + + // Calculate the threshold value. + struct ggml_tensor * threshold = ggml_scale_bias(ctx, max_logit, 1.0f, logf(sctx->p)); + ggml_set_name(threshold, "min_p_threshold"); + + // Subtract the threshold from logits. + struct ggml_tensor * sub = ggml_sub(ctx, data->logits, threshold); + + // Create a mask where logits below the threshold are 0 (discard), + // and others are 1 (keep). + struct ggml_tensor * mask = ggml_step(ctx, sub); + ggml_set_name(mask, "min_p_mask"); + + // Use ggml_scale_bias (output = (a * s) + b) which in this case becomes: + // min_p_bias = (mask * 1e9f) - 1e9f. + // So entries in the mask that we want to discard will become -1e9f, and + // others will be 0 (meaning that will not effect the logits). + const float large_val = 1e9f; + struct ggml_tensor * min_p_bias = ggml_scale_bias(ctx, mask, large_val, -large_val); + ggml_set_name(min_p_bias, "min_p_bias"); + + // Add the min_p bias to the logits. + data->logits = ggml_add(ctx, data->logits, min_p_bias); + ggml_set_name(data->logits, "min_p_logits"); + + GGML_UNUSED(gf); +} + static struct llama_sampler_i llama_sampler_min_p_i = { - /* .name = */ llama_sampler_min_p_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_min_p_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_min_p_clone, - /* .free = */ llama_sampler_min_p_free, + /* .name = */ llama_sampler_min_p_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_min_p_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_min_p_clone, + /* .free = */ llama_sampler_min_p_free, + /* .backend_init = */ llama_sampler_min_p_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_min_p_backend_apply, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_min_p(float p, size_t min_keep) { + const bool is_empty = (p <= 0.0f); + + if (is_empty) { + return llama_sampler_init_empty("?min-p"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_min_p_i, /* .ctx = */ new llama_sampler_min_p { + ("min-p"), /* .p = */ p, /* .min_keep = */ min_keep, } @@ -1056,15 +1799,25 @@ static void llama_sampler_typical_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_typical_i = { - /* .name = */ llama_sampler_typical_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_typical_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_typical_clone, - /* .free = */ llama_sampler_typical_free, + /* .name = */ llama_sampler_typical_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_typical_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_typical_clone, + /* .free = */ llama_sampler_typical_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) { + const bool is_empty = (p >= 1.0f); + + if (is_empty) { + return llama_sampler_init_empty("?typical"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_typical_i, /* .ctx = */ new llama_sampler_typical { @@ -1076,12 +1829,13 @@ struct llama_sampler * llama_sampler_init_typical(float p, size_t min_keep) { // temp -struct llama_sampler_temp { +struct llama_sampler_temp : public llama_sampler_backend { const float temp; }; -static const char * llama_sampler_temp_name(const struct llama_sampler * /*smpl*/) { - return "temp"; +static const char * llama_sampler_temp_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_temp *) smpl->ctx; + return sctx->get_name(); } static void llama_sampler_temp_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -1099,19 +1853,79 @@ static void llama_sampler_temp_free(struct llama_sampler * smpl) { delete (llama_sampler_temp *) smpl->ctx; } +static void llama_sampler_backend_temp_sampling( + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data, + float temp) { + if (temp <= 0.0f) { + // Find the most probable token index. + struct ggml_tensor * max_idx = ggml_argmax(ctx, data->logits); + ggml_set_name(max_idx, "temp_max_idx"); + + if (data->candidates) { + struct ggml_tensor * candidates_rows = ggml_reshape_2d(ctx, data->candidates, 1, data->candidates->ne[0]); + data->candidates = ggml_get_rows(ctx, candidates_rows, max_idx); + } else { + data->candidates = max_idx; + } + + struct ggml_tensor * logits_rows = ggml_reshape_2d(ctx, data->logits, 1, data->logits->ne[0]); + data->logits = ggml_get_rows(ctx, logits_rows, max_idx); + + return; + } + + data->logits = ggml_scale(ctx, data->logits, 1.0f / temp); + + GGML_UNUSED(gf); +} + +static bool llama_sampler_temp_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_temp *) smpl->ctx; + + const bool res = llama_sampler_backend_support(smpl, buft); + + sctx->init(res); + + return res; +} + +static void llama_sampler_temp_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + auto * sctx = (llama_sampler_temp *) smpl->ctx; + llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp); +} + static struct llama_sampler_i llama_sampler_temp_i = { - /* .name = */ llama_sampler_temp_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_temp_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_temp_clone, - /* .free = */ llama_sampler_temp_free, + /* .name = */ llama_sampler_temp_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_temp_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_temp_clone, + /* .free = */ llama_sampler_temp_free, + /* .backend_init = */ llama_sampler_temp_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_temp_backend_apply, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_temp(float temp) { + const bool is_empty = temp == 1.0f; + + if (is_empty) { + return llama_sampler_init_empty("?temp"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_temp_i, /* .ctx = */ new llama_sampler_temp { + ("temp"), /*.temp = */ temp, } ); @@ -1119,14 +1933,15 @@ struct llama_sampler * llama_sampler_init_temp(float temp) { // temp-ext -struct llama_sampler_temp_ext { +struct llama_sampler_temp_ext : public llama_sampler_backend { const float temp; const float delta; const float exponent; }; -static const char * llama_sampler_temp_ext_name(const struct llama_sampler * /*smpl*/) { - return "temp-ext"; +static const char * llama_sampler_temp_ext_name(const struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_temp_ext *) smpl->ctx; + return sctx->get_name(); } static void llama_sampler_temp_ext_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -1209,24 +2024,112 @@ static void llama_sampler_temp_ext_free(struct llama_sampler * smpl) { delete (llama_sampler_temp_ext *) smpl->ctx; } +static bool llama_sampler_temp_ext_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_temp_ext *) smpl->ctx; + + const bool res = llama_sampler_backend_support(smpl, buft); + + sctx->init(res); + + return res; +} + +static void llama_sampler_temp_ext_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + auto * sctx = (llama_sampler_temp_ext *) smpl->ctx; + + // Revert to standard temperature scaling if delta or temp are non-positive. + if (sctx->delta <= 0.0f || sctx->temp <= 0.0f) { + llama_sampler_backend_temp_sampling(ctx, gf, data, sctx->temp); + return; + } + + // Calculate min_temp, max_temp, and max_entropy. + const float min_temp = std::max(0.0f, sctx->temp - sctx->delta); + const float max_temp = sctx->temp + sctx->delta; + const float max_entropy = logf(data->logits->ne[0]); + + // Calculate the probabilities. + struct ggml_tensor * probs = ggml_soft_max(ctx, data->logits); + ggml_set_name(probs, "temp_ext_softmax_probs"); + + // Clamp probabilities to avoid log(0) which would give -inf + struct ggml_tensor * probs_clamped = ggml_clamp(ctx, probs, 1e-10f, 1.0f); + ggml_set_name(probs_clamped, "temp_ext_probs_clamped"); + + // Calculate the entropy, entropy = -Σ(p * log(p)). + struct ggml_tensor * log_probs = ggml_log(ctx, probs_clamped); + struct ggml_tensor * p_log_p = ggml_mul(ctx, probs_clamped, log_probs); + struct ggml_tensor * sum_p_log_p = ggml_sum(ctx, p_log_p); + struct ggml_tensor * entropy = ggml_scale(ctx, sum_p_log_p, -1.0f); + ggml_set_name(log_probs, "temp_ext_log_probs"); + ggml_set_name(p_log_p, "temp_ext_p_log_p"); + ggml_set_name(sum_p_log_p, "temp_ext_sum_p_log_p"); + ggml_set_name(entropy, "temp_ext_entropy"); + + // Normalize the entropy, norm_entropy = entropy / max_entropy + struct ggml_tensor * norm_entropy = ggml_scale(ctx, entropy, 1.0f / max_entropy); + ggml_set_name(norm_entropy, "temp_ext_norm_entropy"); + + // Calculate the dynamic temperature: + // dyn_temp = min_temp + (max_temp - min_temp) * powf(normalized_entropy, exponent); + // + // Calculate powf(normalized_entropy, exponent) as + // norm_entropy^exponent = exp(exponent * log(norm_entropy)) + struct ggml_tensor * log_norm_entropy = ggml_log(ctx, norm_entropy); + struct ggml_tensor * scaled_log = ggml_scale(ctx, log_norm_entropy, sctx->exponent); + struct ggml_tensor * pow_entropy = ggml_exp(ctx, scaled_log); + // With pow_entropy computed we can now compute dyn_temp, scaling by + // (max_temp - min_temp) and then adding min_temp. + struct ggml_tensor * dyn_temp = ggml_scale_bias(ctx, pow_entropy, max_temp - min_temp, min_temp); + ggml_set_name(log_norm_entropy, "temp_ext_log_norm_entropy"); + ggml_set_name(scaled_log, "temp_ext_scaled_log"); + ggml_set_name(pow_entropy, "temp_ext_pow_entropy"); + ggml_set_name(dyn_temp, "temp_ext_dyn_temp"); + + // Scale the logits by the dynamic temperature + struct ggml_tensor * scaled_logits = ggml_div(ctx, data->logits, dyn_temp); + ggml_set_name(scaled_logits, "temp_ext_scaled_logits"); + + data->logits = scaled_logits; +} + static struct llama_sampler_i llama_sampler_temp_ext_i = { - /* .name = */ llama_sampler_temp_ext_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_temp_ext_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_temp_ext_clone, - /* .free = */ llama_sampler_temp_ext_free, + /* .name = */ llama_sampler_temp_ext_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_temp_ext_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_temp_ext_clone, + /* .free = */ llama_sampler_temp_ext_free, + /* .backend_init = */ llama_sampler_temp_ext_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_temp_ext_backend_apply, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_temp_ext(float temp, float delta, float exponent) { - return llama_sampler_init( + const bool is_empty = temp == 1.0f && delta <= 0.0f; + + if (is_empty) { + return llama_sampler_init_empty("?temp-ext"); + } + + auto * res = llama_sampler_init( /* .iface = */ &llama_sampler_temp_ext_i, /* .ctx = */ new llama_sampler_temp_ext { + ("temp-ext"), /* .temp = */ temp, /* .delta = */ delta, /* .exponent = */ exponent, } ); + + return res; } // xtc @@ -1304,16 +2207,27 @@ static void llama_sampler_xtc_reset(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_xtc_i = { - /* .name = */ llama_sampler_xtc_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sample_xtc_apply, - /* .reset = */ llama_sampler_xtc_reset, - /* .clone = */ llama_sampler_xtc_clone, - /* .free = */ llama_sampler_xtc_free, + /* .name = */ llama_sampler_xtc_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sample_xtc_apply, + /* .reset = */ llama_sampler_xtc_reset, + /* .clone = */ llama_sampler_xtc_clone, + /* .free = */ llama_sampler_xtc_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_xtc(float p, float t, size_t min_keep, uint32_t seed) { - auto seed_cur = get_rng_seed(seed); + const bool is_empty = (p <= 0.0f || t > 0.5f); + + if (is_empty) { + return llama_sampler_init_empty("?xtc"); + } + + const auto seed_cur = get_rng_seed(seed); + return llama_sampler_init( /* .iface = */ &llama_sampler_xtc_i, /* .ctx = */ new llama_sampler_xtc { @@ -1412,16 +2326,21 @@ static void llama_sampler_mirostat_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_mirostat_i = { - /* .name = */ llama_sampler_mirostat_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_mirostat_apply, - /* .reset = */ llama_sampler_mirostat_reset, - /* .clone = */ llama_sampler_mirostat_clone, - /* .free = */ llama_sampler_mirostat_free, + /* .name = */ llama_sampler_mirostat_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_mirostat_apply, + /* .reset = */ llama_sampler_mirostat_reset, + /* .clone = */ llama_sampler_mirostat_clone, + /* .free = */ llama_sampler_mirostat_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_mirostat(int32_t n_vocab, uint32_t seed, float tau, float eta, int32_t m) { - auto seed_cur = get_rng_seed(seed); + const auto seed_cur = get_rng_seed(seed); + return llama_sampler_init( /* .iface = */ &llama_sampler_mirostat_i, /* .ctx = */ new llama_sampler_mirostat { @@ -1511,12 +2430,16 @@ static void llama_sampler_mirostat_v2_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_mirostat_v2_i = { - /* .name = */ llama_sampler_mirostat_v2_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_mirostat_v2_apply, - /* .reset = */ llama_sampler_mirostat_v2_reset, - /* .clone = */ llama_sampler_mirostat_v2_clone, - /* .free = */ llama_sampler_mirostat_v2_free, + /* .name = */ llama_sampler_mirostat_v2_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_mirostat_v2_apply, + /* .reset = */ llama_sampler_mirostat_v2_reset, + /* .clone = */ llama_sampler_mirostat_v2_clone, + /* .free = */ llama_sampler_mirostat_v2_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_mirostat_v2(uint32_t seed, float tau, float eta) { @@ -1628,12 +2551,16 @@ static void llama_sampler_grammar_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_grammar_i = { - /* .name = */ llama_sampler_grammar_name, - /* .accept = */ llama_sampler_grammar_accept_impl, - /* .apply = */ llama_sampler_grammar_apply, - /* .reset = */ llama_sampler_grammar_reset, - /* .clone = */ llama_sampler_grammar_clone, - /* .free = */ llama_sampler_grammar_free, + /* .name = */ llama_sampler_grammar_name, + /* .accept = */ llama_sampler_grammar_accept_impl, + /* .apply = */ llama_sampler_grammar_apply, + /* .reset = */ llama_sampler_grammar_reset, + /* .clone = */ llama_sampler_grammar_clone, + /* .free = */ llama_sampler_grammar_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; static struct llama_sampler * llama_sampler_init_grammar_impl( @@ -1835,12 +2762,16 @@ static void llama_sampler_penalties_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_penalties_i = { - /* .name = */ llama_sampler_penalties_name, - /* .accept = */ llama_sampler_penalties_accept, - /* .apply = */ llama_sampler_penalties_apply, - /* .reset = */ llama_sampler_penalties_reset, - /* .clone = */ llama_sampler_penalties_clone, - /* .free = */ llama_sampler_penalties_free, + /* .name = */ llama_sampler_penalties_name, + /* .accept = */ llama_sampler_penalties_accept, + /* .apply = */ llama_sampler_penalties_apply, + /* .reset = */ llama_sampler_penalties_reset, + /* .clone = */ llama_sampler_penalties_clone, + /* .free = */ llama_sampler_penalties_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_penalties( @@ -1850,6 +2781,12 @@ struct llama_sampler * llama_sampler_init_penalties( float penalty_present) { penalty_last_n = std::max(penalty_last_n, 0); + const bool is_empty = (penalty_last_n == 0 || (penalty_repeat == 1.0f && penalty_freq == 0.0f && penalty_present == 0.0f)); + + if (is_empty) { + return llama_sampler_init_empty("?penalties"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_penalties_i, /* .ctx = */ new llama_sampler_penalties { @@ -1887,9 +2824,7 @@ static void llama_sampler_top_n_sigma_apply(struct llama_sampler * smpl, llama_t for (size_t i = 0; i < cur_p->size; ++i) { // Only count non-negative infinity values if (cur_p->data[i].logit != -INFINITY) { - if (cur_p->data[i].logit > max) { - max = cur_p->data[i].logit; - } + max = std::max(max, cur_p->data[i].logit); logits_sum += cur_p->data[i].logit; valid_count++; } @@ -1926,15 +2861,25 @@ static void llama_sampler_top_n_sigma_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_top_n_sigma_i = { - /* .name = */ llama_sampler_top_n_sigma_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_top_n_sigma_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_top_n_sigma_clone, - /* .free = */ llama_sampler_top_n_sigma_free, + /* .name = */ llama_sampler_top_n_sigma_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_top_n_sigma_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_top_n_sigma_clone, + /* .free = */ llama_sampler_top_n_sigma_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_top_n_sigma(float n) { + const bool is_empty = (n <= 0.0f); + + if (is_empty) { + return llama_sampler_init_empty("?top-n-sigma"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_top_n_sigma_i, /* .ctx = */ new llama_sampler_top_n_sigma { @@ -2256,12 +3201,16 @@ static void llama_sampler_dry_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_dry_i = { - /* .name = */ llama_sampler_dry_name, - /* .accept = */ llama_sampler_dry_accept, - /* .apply = */ llama_sampler_dry_apply, - /* .reset = */ llama_sampler_dry_reset, - /* .clone = */ llama_sampler_dry_clone, - /* .free = */ llama_sampler_dry_free, + /* .name = */ llama_sampler_dry_name, + /* .accept = */ llama_sampler_dry_accept, + /* .apply = */ llama_sampler_dry_apply, + /* .reset = */ llama_sampler_dry_reset, + /* .clone = */ llama_sampler_dry_clone, + /* .free = */ llama_sampler_dry_free, + /* .backend_init = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ nullptr, + /* .backend_set_input = */ nullptr, }; struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, int32_t n_ctx_train, float dry_multiplier, float dry_base, int32_t dry_allowed_length, int32_t dry_penalty_last_n, const char** seq_breakers, size_t num_breakers) { @@ -2272,6 +3221,10 @@ struct llama_sampler * llama_sampler_init_dry(const struct llama_vocab * vocab, const bool dry_enabled = (dry_multiplier != 0.0f && dry_base >= 1.0f && dry_penalty_last_n != 0); + if (!dry_enabled) { + return llama_sampler_init_empty("?dry"); + } + if (dry_enabled && seq_breakers != nullptr && num_breakers > 0) { // Process sequence breakers for (size_t i = 0; i < num_breakers; ++i) { @@ -2342,16 +3295,23 @@ struct llama_sampler * llama_sampler_init_dry_testing(int32_t context_size, floa // logit-bias -struct llama_sampler_logit_bias { +struct llama_sampler_logit_bias : public llama_sampler_backend { const int32_t n_vocab; const std::vector logit_bias; std::vector to_search; + + struct ggml_tensor * inp_logit_bias; + struct ggml_tensor * inp_logit_idxs; + + ggml_context_ptr inp_ctx; + ggml_backend_buffer_ptr inp_buf; }; -static const char * llama_sampler_logit_bias_name(const struct llama_sampler * /*smpl*/) { - return "logit-bias"; +static const char * llama_sampler_logit_bias_name(const struct llama_sampler * smpl) { + auto * ctx = (llama_sampler_logit_bias *) smpl->ctx; + return ctx->get_name(); } static void llama_sampler_logit_bias_apply(struct llama_sampler * smpl, llama_token_data_array * cur_p) { @@ -2396,25 +3356,123 @@ static void llama_sampler_logit_bias_free(struct llama_sampler * smpl) { delete (llama_sampler_logit_bias *) smpl->ctx; } +static void llama_sampler_logit_bias_backend_apply( + struct llama_sampler * smpl, + struct ggml_context * ctx, + struct ggml_cgraph * gf, + struct llama_sampler_data * data) { + GGML_UNUSED(gf); + GGML_UNUSED(ctx); + + auto * sctx = (llama_sampler_logit_bias *) smpl->ctx; + if (sctx->logit_bias.empty()) { + return; + } + + ggml_tensor * cur = ggml_fill(ctx, data->logits, 0.0f); + + cur = ggml_reshape_2d(ctx, cur, 1, ggml_nelements(cur)); + cur = ggml_set_rows(ctx, cur, sctx->inp_logit_bias, sctx->inp_logit_idxs); + cur = ggml_reshape_1d(ctx, cur, ggml_nelements(cur)); + + data->logits = ggml_add(ctx, data->logits, cur); +} + +static void llama_sampler_logit_bias_backend_set_input(struct llama_sampler * smpl) { + auto * sctx = (llama_sampler_logit_bias *) smpl->ctx; + if (sctx->logit_bias.empty()) { + return; + } + + GGML_ASSERT(sctx->inp_logit_bias != nullptr); + GGML_ASSERT(sctx->inp_logit_idxs != nullptr); + + const size_t n = sctx->logit_bias.size(); + + std::vector data_logit_bias(n, 0.0f); + std::vector data_logit_idxs(n, 0); + for (size_t i = 0; i < n; ++i) { + const auto & lb = sctx->logit_bias[i]; + GGML_ASSERT(lb.token >= 0 && lb.token < (int32_t) sctx->n_vocab); + data_logit_bias[i] = lb.bias; + data_logit_idxs[i] = lb.token; + } + + ggml_backend_tensor_set(sctx->inp_logit_bias, data_logit_bias.data(), 0, ggml_nbytes(sctx->inp_logit_bias)); + ggml_backend_tensor_set(sctx->inp_logit_idxs, data_logit_idxs.data(), 0, ggml_nbytes(sctx->inp_logit_idxs)); +} + +static bool llama_sampler_logit_bias_backend_init( + struct llama_sampler * smpl, + ggml_backend_buffer_type_t buft) { + auto * sctx = (llama_sampler_logit_bias *) smpl->ctx; + + sctx->init(true); + + if (sctx->logit_bias.empty()) { + return true; + } + + ggml_init_params params = { + /*.mem_size =*/ 2*ggml_tensor_overhead(), + /*.mem_buffer =*/ nullptr, + /*.no_alloc =*/ true, + }; + + sctx->inp_ctx.reset(ggml_init(params)); + + const size_t n = sctx->logit_bias.size(); + + sctx->inp_logit_bias = ggml_new_tensor_2d(sctx->inp_ctx.get(), GGML_TYPE_F32, 1, n); + ggml_set_name(sctx->inp_logit_bias, "logit_bias"); + ggml_set_input(sctx->inp_logit_bias); + + sctx->inp_logit_idxs = ggml_new_tensor_1d(sctx->inp_ctx.get(), GGML_TYPE_I32, n); + ggml_set_name(sctx->inp_logit_idxs, "logit_idxs"); + ggml_set_input(sctx->inp_logit_idxs); + + // Allocate all tensors from our context to the backend + sctx->inp_buf.reset(ggml_backend_alloc_ctx_tensors_from_buft(sctx->inp_ctx.get(), buft)); + + ggml_backend_buffer_clear(sctx->inp_buf.get(), 0); + + return true; +} + static struct llama_sampler_i llama_sampler_logit_bias_i = { - /* .name = */ llama_sampler_logit_bias_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_logit_bias_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_logit_bias_clone, - /* .free = */ llama_sampler_logit_bias_free, + /* .name = */ llama_sampler_logit_bias_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_logit_bias_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_logit_bias_clone, + /* .free = */ llama_sampler_logit_bias_free, + /* .backend_init = */ llama_sampler_logit_bias_backend_init, + /* .backend_accept = */ nullptr, + /* .backend_apply = */ llama_sampler_logit_bias_backend_apply, + /* .backend_set_input = */ llama_sampler_logit_bias_backend_set_input, }; struct llama_sampler * llama_sampler_init_logit_bias( int32_t n_vocab, int32_t n_logit_bias, const llama_logit_bias * logit_bias) { + const bool is_empty = n_logit_bias <= 0; + + if (is_empty) { + return llama_sampler_init_empty("?logit-bias"); + } + return llama_sampler_init( /* .iface = */ &llama_sampler_logit_bias_i, /* .ctx = */ new llama_sampler_logit_bias { - /* .n_vocab = */ n_vocab, - /* .logit_bias = */ std::vector(logit_bias, logit_bias + n_logit_bias), - /* .to_search = */ {}, + ("logit-bias"), + /* .n_vocab = */ n_vocab, + /* .logit_bias = */ std::vector(logit_bias, logit_bias + n_logit_bias), + /* .to_search = */ {}, + /* .inp_logit_bias = */ nullptr, + /* .inp_logit_idxs = */ nullptr, + /* .inp_ctx = */ nullptr, + /* .inp_buf = */ nullptr, } ); } @@ -2627,12 +3685,16 @@ static void llama_sampler_infill_free(struct llama_sampler * smpl) { } static struct llama_sampler_i llama_sampler_infill_i = { - /* .name = */ llama_sampler_infill_name, - /* .accept = */ nullptr, - /* .apply = */ llama_sampler_infill_apply, - /* .reset = */ nullptr, - /* .clone = */ llama_sampler_infill_clone, - /* .free = */ llama_sampler_infill_free, + /* .name = */ llama_sampler_infill_name, + /* .accept = */ nullptr, + /* .apply = */ llama_sampler_infill_apply, + /* .reset = */ nullptr, + /* .clone = */ llama_sampler_infill_clone, + /* .free = */ llama_sampler_infill_free, + /* .backend_apply = */ nullptr, + /* .backend_accept = */ nullptr, + /* .backend_set_input = */ nullptr, + /* .backend_init = */ nullptr, }; struct llama_sampler * llama_sampler_init_infill(const struct llama_vocab * vocab) { @@ -2664,7 +3726,7 @@ uint32_t llama_sampler_get_seed(const struct llama_sampler * smpl) { if (smpl->iface == &llama_sampler_chain_i) { const auto * ctx = (const llama_sampler_chain *) smpl->ctx; for (auto it = ctx->samplers.rbegin(); it != ctx->samplers.rend(); ++it) { - const uint32_t seed = llama_sampler_get_seed(*it); + const uint32_t seed = llama_sampler_get_seed(it->ptr); if (seed != LLAMA_DEFAULT_SEED) { return seed; } diff --git a/src/llama-sampling.h b/src/llama-sampling.h index 1e3de4e2ec..6a963c0bb7 100644 --- a/src/llama-sampling.h +++ b/src/llama-sampling.h @@ -14,7 +14,16 @@ struct llama_grammar; struct llama_sampler_chain { llama_sampler_chain_params params; - std::vector samplers; + // has .backend_init() been called? + bool is_init = false; + + struct info { + bool is_backend; + + llama_sampler * ptr; + }; + + std::vector samplers; // pre-allocated buffer for llama_sampler_sample to avoid repeated allocations std::vector cur; @@ -27,9 +36,9 @@ struct llama_sampler_chain { }; struct llama_sampler * llama_sampler_init_dry_testing( - int32_t context_size, - float dry_multiplier, - float dry_base, - int32_t dry_allowed_length, - int32_t dry_penalty_last_n, - const std::vector>& seq_breakers); + int32_t context_size, + float dry_multiplier, + float dry_base, + int32_t dry_allowed_length, + int32_t dry_penalty_last_n, + const std::vector> & seq_breakers); diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp index cd4092ca07..a20c6525e4 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -314,6 +314,12 @@ struct llm_tokenizer_bpe : llm_tokenizer { "[!\"#$%&'()*+,\\-./:;<=>?@\\[\\\\\\]^_`{|}~][A-Za-z]+|[^\r\n\\p{L}\\p{P}\\p{S}]?[\\p{L}\\p{M}]+| ?[\\p{P}\\p{S}]+[\r\n]*|\\s*[\r\n]+|\\s+(?!\\S)|\\s+", }; break; + case LLAMA_VOCAB_PRE_TYPE_YOUTU: + regex_exprs = { + "[가-힣ㄱ-ㆎ]+|[!…“”‘’—:;,、-〿︰-﹏]+|[ㄅ-ㄯ]+|[一-龥぀-ゟ゠-ヿ]+", + "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*(?:'[sS]|'[tT]|'[rR][eE]|'[vV][eE]|'[mM]|'[lL][lL]|'[dD])?|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+", + }; + break; case LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER: regex_exprs = { "[\r\n]", @@ -355,6 +361,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { case LLAMA_VOCAB_PRE_TYPE_STABLELM2: case LLAMA_VOCAB_PRE_TYPE_QWEN2: case LLAMA_VOCAB_PRE_TYPE_HUNYUAN: + case LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN: regex_exprs = { // original regex from tokenizer.json // "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" @@ -1860,6 +1867,11 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "deepseek-v3") { pre_type = LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM; clean_spaces = false; + } else if ( + tokenizer_pre == "youtu") { + pre_type = LLAMA_VOCAB_PRE_TYPE_YOUTU; + clean_spaces = false; + ignore_merges = true; } else if ( tokenizer_pre == "falcon") { pre_type = LLAMA_VOCAB_PRE_TYPE_FALCON; @@ -2015,6 +2027,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "minimax-m2") { pre_type = LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2; clean_spaces = false; + } else if ( + tokenizer_pre == "solar-open") { + pre_type = LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN; + clean_spaces = false; } else { throw std::runtime_error(format("unknown pre-tokenizer type: '%s'", tokenizer_pre.c_str())); } @@ -2187,6 +2203,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { // for now, we apply this workaround to find the tokens based on their text for (const auto & t : token_to_id) { + auto & attr = id_to_token[t.second].attr; + // find EOT token: "<|eot_id|>", "<|im_end|>", "", etc. if (special_eot_id == LLAMA_TOKEN_NULL) { if (false @@ -2202,10 +2220,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "" // smoldocling ) { special_eot_id = t.second; - if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", __func__, t.second, t.first.c_str()); - id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL); } } } @@ -2216,10 +2234,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|eom_id|>" ) { special_eom_id = t.second; - if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", __func__, t.second, t.first.c_str()); - id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL); } } } @@ -2236,10 +2254,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|code_prefix|>" // GLM-4.5 ) { special_fim_pre_id = t.second; - if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", __func__, t.second, t.first.c_str()); - id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL); } } } @@ -2256,10 +2274,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|code_suffix|>" // GLM-4.5 ) { special_fim_suf_id = t.second; - if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", __func__, t.second, t.first.c_str()); - id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL); } } } @@ -2276,10 +2294,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|code_middle|>" // GLM-4.5 ) { special_fim_mid_id = t.second; - if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", __func__, t.second, t.first.c_str()); - id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL); } } } @@ -2293,10 +2311,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "" ) { special_fim_pad_id = t.second; - if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", __func__, t.second, t.first.c_str()); - id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL); } } } @@ -2311,10 +2329,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "" // Granite ) { special_fim_rep_id = t.second; - if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", __func__, t.second, t.first.c_str()); - id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL); } } } @@ -2325,15 +2343,41 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|file_sep|>" // Qwen ) { special_fim_sep_id = t.second; - if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", __func__, t.second, t.first.c_str()); - id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL); } } } } + // auto-detect unused tokens: e.g. control tokens with the word "unused" + // ideally, these tokens should be marked as unused during conversion + { + uint32_t n_unused = 0; + + for (const auto & t : token_to_id) { + auto & attr = id_to_token[t.second].attr; + + if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + continue; + } + + if ((attr & LLAMA_TOKEN_ATTR_UNUSED) == 0) { + if (strstr(t.first.c_str(), "unused") != NULL) { + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_UNUSED); + } + } + + if (attr & LLAMA_TOKEN_ATTR_UNUSED) { + n_unused++; + } + } + + LLAMA_LOG_INFO("%s: %u unused tokens\n", __func__, n_unused); + } + // maintain a list of tokens that cause end-of-generation // this is currently determined based on the token text, which is obviously not ideal // ref: https://github.com/ggerganov/llama.cpp/issues/9606 @@ -2352,12 +2396,16 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { } for (const auto & t : token_to_id) { + auto & attr = id_to_token[t.second].attr; + if (false || t.first == "<|eot_id|>" || t.first == "<|im_end|>" || t.first == "<|end|>" || t.first == "<|return|>" // o200k_harmony || t.first == "<|call|>" // o200k_harmony + || t.first == "<|flush|>" // solar-open + || t.first == "<|calls|>" // solar-open || t.first == "" || t.first == "<|endoftext|>" || t.first == "<|eom_id|>" @@ -2367,24 +2415,28 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "" // smoldocling ) { special_eog_ids.insert(t.second); - if ((id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { + if ((attr & LLAMA_TOKEN_ATTR_CONTROL) == 0) { LLAMA_LOG_WARN("%s: control-looking token: %6d '%s' was not control-type; this is probably a bug in the model. its type will be overridden\n", __func__, t.second, t.first.c_str()); - id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_CONTROL; + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_CONTROL); } } else { - // token is control, but not marked as EOG -> print a debug log - if (id_to_token[t.second].attr & LLAMA_TOKEN_ATTR_CONTROL && special_eog_ids.count(t.second) == 0) { - LLAMA_LOG_DEBUG("%s: control token: %6d '%s' is not marked as EOG\n", - __func__, t.second, t.first.c_str()); + if (attr & LLAMA_TOKEN_ATTR_CONTROL && !(attr & LLAMA_TOKEN_ATTR_UNUSED)) { + // token is control, but not marked as EOG -> print a debug log + if (special_eog_ids.count(t.second) == 0) { + LLAMA_LOG_DEBUG("%s: control token: %6d '%s' is not marked as EOG\n", + __func__, t.second, t.first.c_str()); + } } } } // @ngxson : quick hack for gpt-oss, always render these tokens for (const auto & t : token_to_id) { + auto & attr = id_to_token[t.second].attr; + if (t.first == "<|channel|>" || t.first == "<|message|>" || t.first == "<|start|>" || t.first == "<|constrain|>") { - id_to_token[t.second].attr = LLAMA_TOKEN_ATTR_USER_DEFINED; + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_USER_DEFINED); } } @@ -2404,34 +2456,42 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { LLAMA_LOG_WARN("%s: special_eom_id is not in special_eog_ids - the tokenizer config may be incorrect\n", __func__); } - // TODO: workaround for o200k_harmony tokenizer: the "<|end|>" token should not be EOG - // we don't have a good way to detect this, so for now, if we have "<|return|>" and "<|call|>" tokens, + // TODO: workaround for o200k_harmony and solar-open tokenizer: the "<|end|>" token should not be EOG + // we don't have a good way to detect this, so for now, if we have "<|return|>" and "<|call|>" tokens ("<|calls|>" and "<|flush|>" for solar-open), // we remove the "<|end|>" token from the EOG list { bool has_return = false; bool has_call = false; bool has_end = false; + bool has_flush = false; llama_token end_id = LLAMA_TOKEN_NULL; LLAMA_LOG_INFO("%s: printing all EOG tokens:\n", __func__); for (auto tid : special_eog_ids) { - LLAMA_LOG_INFO("%s: - %d ('%s')\n", __func__, tid, id_to_token[tid].text.c_str()); + auto & text = id_to_token[tid].text; - if (id_to_token[tid].text == "<|return|>") { + LLAMA_LOG_INFO("%s: - %d ('%s')\n", __func__, tid, text.c_str()); + + if (text == "<|return|>") { has_return = true; - } else if (id_to_token[tid].text == "<|call|>") { + } else if (text == "<|call|>" || text == "<|calls|>") { has_call = true; - } else if (id_to_token[tid].text == "<|end|>") { + } else if (text == "<|flush|>") { + has_flush = true; + } else if (text == "<|end|>") { has_end = true; end_id = tid; } } - if (has_return && has_call && has_end) { + if ((has_return && has_call && has_end) || (has_call && has_flush && has_end)) { special_eog_ids.erase(end_id); - id_to_token[end_id].attr = LLAMA_TOKEN_ATTR_USER_DEFINED; - LLAMA_LOG_WARN("%s: special_eog_ids contains both '<|return|>' and '<|call|>' tokens, removing '<|end|>' token from EOG list\n", __func__); + + auto & attr = id_to_token[end_id].attr; + attr = (llama_token_attr) (attr | LLAMA_TOKEN_ATTR_USER_DEFINED); + + LLAMA_LOG_WARN("%s: special_eog_ids contains both '<|return|>' and '<|call|>', or '<|calls|>' and '<|flush|>' tokens, removing '<|end|>' token from EOG list\n", __func__); } } } diff --git a/src/llama-vocab.h b/src/llama-vocab.h index 55f8f3923c..2b240a5491 100644 --- a/src/llama-vocab.h +++ b/src/llama-vocab.h @@ -51,6 +51,8 @@ enum llama_vocab_pre_type { LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40, LLAMA_VOCAB_PRE_TYPE_MINIMAX_M2 = 41, LLAMA_VOCAB_PRE_TYPE_AFMOE = 42, + LLAMA_VOCAB_PRE_TYPE_SOLAR_OPEN = 43, + LLAMA_VOCAB_PRE_TYPE_YOUTU = 44, }; struct LLM_KV; diff --git a/src/llama.cpp b/src/llama.cpp index 76b3acbadb..0162ae8d58 100644 --- a/src/llama.cpp +++ b/src/llama.cpp @@ -359,6 +359,11 @@ static void llama_params_fit_impl( // for the first partial layer varying parts can overflow, all further layers use LAYER_FRACTION_MOE: layer_fraction_t overflow_type = LAYER_FRACTION_MOE; + + uint32_t n_full() const { + assert(n_layer >= n_part); + return n_layer - n_part; + } }; const size_t ntbo = llama_max_tensor_buft_overrides(); @@ -382,7 +387,7 @@ static void llama_params_fit_impl( size_t itbo = 0; for (size_t id = 0; id < nd; id++) { - il0 += ngl_per_device[id].n_layer - ngl_per_device[id].n_part; + il0 += ngl_per_device[id].n_full(); for (uint32_t il = il0; il < il0 + ngl_per_device[id].n_part; il++) { if (itbo + 1 >= ntbo) { tensor_buft_overrides[itbo].pattern = nullptr; @@ -393,7 +398,7 @@ static void llama_params_fit_impl( + std::to_string(ntbo) + " is insufficient for model"); } tensor_buft_overrides[itbo].pattern = get_overflow_pattern(il, il == il0 ? ngl_per_device[id].overflow_type : LAYER_FRACTION_MOE); - tensor_buft_overrides[itbo].buft = overflow_bufts[id]; + tensor_buft_overrides[itbo].buft = il == il0 ? overflow_bufts[id] : ggml_backend_cpu_buffer_type(); itbo++; } il0 += ngl_per_device[id].n_part; @@ -468,20 +473,14 @@ static void llama_params_fit_impl( LLAMA_LOG_DEBUG("%s: id=%zu, target=%" PRId64 " MiB\n", __func__, id, targets[id]/MiB); } - std::vector overflow_bufts; // which bufts the partial layers of a device overflow to: + std::vector overflow_bufts; // which bufts the first partial layer of a device overflows to: overflow_bufts.reserve(nd); - for (size_t id = 0; id < nd - 1; ++id) { - overflow_bufts.push_back(ggml_backend_dev_buffer_type(devs[id + 1])); + for (size_t id = 0; id < nd; id++) { + overflow_bufts.push_back(ggml_backend_cpu_buffer_type()); } - overflow_bufts.push_back(ggml_backend_cpu_buffer_type()); std::vector ngl_per_device(nd); std::vector mem = get_memory_for_layers(__func__, ngl_per_device, overflow_bufts); - if (hp_nex > 0) { - for (size_t id = 0; id < nd; id++) { - ngl_per_device[id].overflow_type = LAYER_FRACTION_MOE; - } - } // optimize the number of layers per device using the method of false position: // - ngl_per_device has 0 layers for each device, lower bound @@ -512,9 +511,6 @@ static void llama_params_fit_impl( if (mem_high[id] > targets[id]) { assert(ngl_per_device_high[id].n_layer > ngl_per_device[id].n_layer); uint32_t delta = ngl_per_device_high[id].n_layer - ngl_per_device[id].n_layer; - if (hp_nex > 0 && size_t(id) == nd - 1) { - delta--; - } LLAMA_LOG_DEBUG("%s: start filling device %" PRIu32 ", delta=%" PRIu32 "\n", __func__, id, delta); while (delta > 1) { uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]); @@ -524,7 +520,8 @@ static void llama_params_fit_impl( std::vector ngl_per_device_test = ngl_per_device; ngl_per_device_test[id].n_layer += step_size; if (hp_nex) { - ngl_per_device_test[id].n_part += step_size; + ngl_per_device_test[id].n_part += size_t(id) == nd - 1 && ngl_per_device_test[id].n_part == 0 ? + step_size - 1 : step_size; // the first layer is the output layer which must always be full } const std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); @@ -573,7 +570,7 @@ static void llama_params_fit_impl( assert(id_dense_start < nd); LLAMA_LOG_INFO("%s: converting dense-only layers to full layers and filling them front-to-back with overflow to next device/system memory:\n", __func__); - for (size_t id = 0; id <= id_dense_start; id++) { + for (size_t id = 0; id <= id_dense_start && id_dense_start < nd; id++) { std::vector ngl_per_device_high = ngl_per_device; for (size_t jd = id_dense_start; jd < nd; jd++) { const uint32_t n_layer_move = jd < nd - 1 ? ngl_per_device_high[jd].n_layer : ngl_per_device_high[jd].n_layer - 1; @@ -585,12 +582,8 @@ static void llama_params_fit_impl( std::vector mem_high = get_memory_for_layers(__func__, ngl_per_device_high, overflow_bufts); if (mem_high[id] > targets[id]) { - assert(ngl_per_device_high[id].n_layer >= ngl_per_device_high[id].n_part); - assert(ngl_per_device[id].n_layer >= ngl_per_device[id].n_part); - assert((ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part) - >= ngl_per_device[id].n_layer - ngl_per_device[id].n_part); - uint32_t delta = (ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part) - - (ngl_per_device[id].n_layer - ngl_per_device[id].n_part); + assert(ngl_per_device_high[id].n_full() >= ngl_per_device[id].n_full()); + uint32_t delta = ngl_per_device_high[id].n_full() - ngl_per_device[id].n_full(); while (delta > 1) { uint32_t step_size = int64_t(delta) * (targets[id] - mem[id]) / (mem_high[id] - mem[id]); step_size = std::max(step_size, uint32_t(1)); @@ -606,7 +599,7 @@ static void llama_params_fit_impl( ngl_per_device_test[id].n_layer += n_convert_jd; n_converted_test += n_convert_jd; - if (ngl_per_device_test[id_dense_start_test].n_layer > 0) { + if (ngl_per_device_test[id_dense_start_test].n_part > 0) { break; } } @@ -625,8 +618,8 @@ static void llama_params_fit_impl( LLAMA_LOG_DEBUG("%s: set ngl_per_device_high[%zu].(n_layer, n_part)=(%" PRIu32 ", %" PRIu32 "), id_dense_start_high=%zu\n", __func__, id, ngl_per_device_high[id].n_layer, ngl_per_device_high[id].n_part, id_dense_start_high); } - delta = (ngl_per_device_high[id].n_layer - ngl_per_device_high[id].n_part) - - (ngl_per_device[id].n_layer - ngl_per_device[id].n_part); + assert(ngl_per_device_high[id].n_full() >= ngl_per_device[id].n_full()); + delta = ngl_per_device_high[id].n_full() - ngl_per_device[id].n_full(); } } else { ngl_per_device = ngl_per_device_high; @@ -644,14 +637,19 @@ static void llama_params_fit_impl( ngl_per_device_test[id_dense_start_test].n_part--; ngl_per_device_test[id].n_layer++; ngl_per_device_test[id].n_part++; - if (ngl_per_device_test[id_dense_start_test].n_layer == 0) { + if (ngl_per_device_test[id_dense_start_test].n_part == 0) { id_dense_start_test++; } ngl_per_device_test[id].overflow_type = LAYER_FRACTION_UP; + std::vector overflow_bufts_test = overflow_bufts; + if (id < nd - 1) { + overflow_bufts_test[id] = ggml_backend_dev_buffer_type(devs[id + 1]); + } LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_UP\n", __func__); - std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); + std::vector mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test); if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) { ngl_per_device = ngl_per_device_test; + overflow_bufts = overflow_bufts_test; mem = mem_test; id_dense_start = id_dense_start_test; LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", UP), id_dense_start=%zu\n", @@ -659,9 +657,10 @@ static void llama_params_fit_impl( ngl_per_device_test[id].overflow_type = LAYER_FRACTION_GATE; LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_GATE\n", __func__); - mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); + mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test); if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) { ngl_per_device = ngl_per_device_test; + overflow_bufts = overflow_bufts_test; mem = mem_test; id_dense_start = id_dense_start_test; LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", GATE), id_dense_start=%zu\n", @@ -670,9 +669,10 @@ static void llama_params_fit_impl( } else { ngl_per_device_test[id].overflow_type = LAYER_FRACTION_ATTN; LLAMA_LOG_DEBUG("%s: trying to fit one extra layer with overflow_type=LAYER_FRACTION_ATTN\n", __func__); - mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts); + mem_test = get_memory_for_layers(__func__, ngl_per_device_test, overflow_bufts_test); if (mem_test[id] < targets[id] && (id + 1 == nd || mem_test[id + 1] < targets[id + 1])) { ngl_per_device = ngl_per_device_test; + overflow_bufts = overflow_bufts_test; mem = mem_test; id_dense_start = id_dense_start_test; LLAMA_LOG_DEBUG("%s: set ngl_per_device[%zu].(n_layer, n_part, overflow_type)=(%" PRIu32 ", %" PRIu32 ", ATTN), id_dense_start=%zu\n", @@ -687,6 +687,14 @@ static void llama_params_fit_impl( __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB); } + // print info for devices that were not changed during the conversion from dense only to full layers: + for (size_t id = id_dense_start + 1; id < nd; id++) { + const int64_t projected_margin = dmds_full[id].free - mem[id]; + LLAMA_LOG_INFO( + "%s: - %s: %2" PRIu32 " layers (%2" PRIu32 " overflowing), %6" PRId64 " MiB used, %6" PRId64 " MiB free\n", + __func__, dev_names[id].c_str(), ngl_per_device[id].n_layer, ngl_per_device[id].n_part, mem[id]/MiB, projected_margin/MiB); + } + set_ngl_tensor_split_tbo(ngl_per_device, overflow_bufts, *mparams); } @@ -713,7 +721,7 @@ enum llama_params_fit_status llama_params_fit( struct llama_sampler_chain_params llama_sampler_chain_default_params() { struct llama_sampler_chain_params result = { - /*.no_perf =*/ true, + /*.no_perf =*/ true, }; return result; diff --git a/src/models/afmoe.cpp b/src/models/afmoe.cpp index 0192e344ca..6a752a403f 100644 --- a/src/models/afmoe.cpp +++ b/src/models/afmoe.cpp @@ -22,8 +22,15 @@ llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_para const float kq_scale = 1.0f/sqrtf(float(n_embd_head)); for (int il = 0; il < n_layer; ++il) { + const float freq_base_l = model.get_rope_freq_base (cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + ggml_tensor * inpSA = inpL; + // This overlaps with SWA layers in current models, so get_rope_freq_base/scale may be superfluous + const bool use_rope = hparams.n_no_rope_layer_step > 0 && + (il + 1) % hparams.n_no_rope_layer_step != 0; + // dual attention normalization (pre) cur = build_norm(inpL, model.layers[il].attn_norm, NULL, @@ -56,19 +63,16 @@ llm_build_afmoe::llm_build_afmoe(const llama_model & model, const llm_graph_para cb(Qcur, "Qcur_normed", il); cb(Kcur, "Kcur_normed", il); - // RoPE only for sliding_attention layers - const bool use_rope = hparams.n_no_rope_layer_step > 0 && - ((il + 1) % hparams.n_no_rope_layer_step) != 0; if (use_rope) { Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow); cb(Qcur, "Qcur_rope", il); Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow); cb(Kcur, "Kcur_rope", il); } diff --git a/src/models/bert.cpp b/src/models/bert.cpp index 3274fa3b99..bca0e254fc 100644 --- a/src/models/bert.cpp +++ b/src/models/bert.cpp @@ -142,11 +142,13 @@ llm_build_bert::llm_build_bert(const llama_model & model, const llm_graph_params LLM_FFN_GELU, LLM_FFN_SEQ, il); cb(cur, "ffn_out", il); } else if (model.arch == LLM_ARCH_JINA_BERT_V2) { + const bool up_contains_gate = !model.layers[il].ffn_gate && model.layers[il].ffn_up->ne[1] != hparams.n_ff(); + auto type_op = up_contains_gate ? LLM_FFN_GEGLU : LLM_FFN_GELU; cur = build_ffn(cur, - model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_up, model.layers[il].ffn_up_b, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, model.layers[il].ffn_down_b, NULL, NULL, - model.layers[il].ffn_gate ? LLM_FFN_GELU : LLM_FFN_GEGLU, LLM_FFN_PAR, il); + type_op, LLM_FFN_PAR, il); cb(cur, "ffn_out", il); } else { cur = build_ffn(cur, diff --git a/src/models/cogvlm.cpp b/src/models/cogvlm.cpp index edf0d1424c..0ceae3aaeb 100644 --- a/src/models/cogvlm.cpp +++ b/src/models/cogvlm.cpp @@ -3,12 +3,14 @@ llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; - float kq_scale = 1.0f / sqrtf(float(n_embd_head)); + const float kq_scale = 1.0f / sqrtf(float(n_embd_head)); GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); GGML_ASSERT(n_embd_head == hparams.n_rot); - ggml_tensor *inpL, *cur; + ggml_tensor * inpL; + ggml_tensor * cur; + inpL = build_inp_embd(model.tok_embd); ggml_tensor * inp_pos = build_inp_pos(); @@ -44,7 +46,7 @@ llm_build_cogvlm::llm_build_cogvlm(const llama_model & model, const llm_graph_pa } ggml_tensor * inpSA = inpL; - cur = build_norm(inpSA, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); + cur = build_norm(inpSA, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); // build self attention { diff --git a/src/models/cohere2-iswa.cpp b/src/models/cohere2-iswa.cpp index b18aa8c4e6..9334b5e426 100644 --- a/src/models/cohere2-iswa.cpp +++ b/src/models/cohere2-iswa.cpp @@ -21,6 +21,9 @@ llm_build_cohere2_iswa::llm_build_cohere2_iswa(const llama_model & model, const for (int il = 0; il < n_layer; ++il) { const bool is_swa = hparams.is_swa(il); + // UNUSED: + // const float freq_base_l = model.get_rope_freq_base (cparams, il); + // const float freq_scale_l = model.get_rope_freq_scale(cparams, il); // norm cur = build_norm(inpL, model.layers[il].attn_norm, NULL, LLM_NORM, il); diff --git a/src/models/deepseek2.cpp b/src/models/deepseek2.cpp index 49382874ba..ca63a62ad1 100644 --- a/src/models/deepseek2.cpp +++ b/src/models/deepseek2.cpp @@ -215,7 +215,7 @@ llm_build_deepseek2::llm_build_deepseek2(const llama_model & model, const llm_gr model.layers[il].ffn_exp_probs_b, n_expert, n_expert_used, LLM_FFN_SILU, hparams.expert_weights_norm, - true, hparams.expert_weights_scale, + hparams.expert_weights_scale, hparams.expert_weights_scale, (llama_expert_gating_func_type) hparams.expert_gating_func, il); cb(moe_out, "ffn_moe_out", il); diff --git a/src/models/gemma-embedding.cpp b/src/models/gemma-embedding.cpp index 90a98f7abf..944c198bf9 100644 --- a/src/models/gemma-embedding.cpp +++ b/src/models/gemma-embedding.cpp @@ -1,7 +1,5 @@ #include "models.h" - - llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_k; @@ -12,10 +10,8 @@ llm_build_gemma_embedding::llm_build_gemma_embedding(const llama_model & model, inpL = build_inp_embd(model.tok_embd); // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) - if (ubatch.token) { - inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); - cb(inpL, "inp_scaled", -1); - } + inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f); + cb(inpL, "inp_scaled", -1); // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); diff --git a/src/models/gemma2-iswa.cpp b/src/models/gemma2-iswa.cpp index 9cc59a53ee..7a9198193a 100644 --- a/src/models/gemma2-iswa.cpp +++ b/src/models/gemma2-iswa.cpp @@ -19,6 +19,9 @@ llm_build_gemma2_iswa::llm_build_gemma2_iswa(const llama_model & model, const ll ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + const float freq_base_l = model.get_rope_freq_base (cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + // norm cur = build_norm(inpL, model.layers[il].attn_norm, NULL, @@ -43,12 +46,12 @@ llm_build_gemma2_iswa::llm_build_gemma2_iswa(const llama_model & model, const ll Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow); Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow); cb(Qcur, "Qcur", il); diff --git a/src/models/gemma3.cpp b/src/models/gemma3.cpp index ae60ef4790..dec3fc4b8b 100644 --- a/src/models/gemma3.cpp +++ b/src/models/gemma3.cpp @@ -10,10 +10,9 @@ llm_build_gemma3::llm_build_gemma3(const llama_model & model, const llm_gr inpL = build_inp_embd(model.tok_embd); // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) - if (ubatch.token) { - inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); - cb(inpL, "inp_scaled", -1); - } + inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f); + cb(inpL, "inp_scaled", -1); + // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); diff --git a/src/models/gemma3n-iswa.cpp b/src/models/gemma3n-iswa.cpp index a0bdd6a15a..9c7b3ba0bb 100644 --- a/src/models/gemma3n-iswa.cpp +++ b/src/models/gemma3n-iswa.cpp @@ -1,7 +1,5 @@ #include "models.h" - - llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params), model(model), @@ -15,10 +13,9 @@ llm_build_gemma3n_iswa::llm_build_gemma3n_iswa(const llama_model & model, const inpL = build_inp_embd(model.tok_embd); // important: do not normalize weights for raw embeddings input (i.e. encoded image emdeddings) - if (ubatch.token) { - inpL = ggml_scale(ctx0, inpL, sqrtf(n_embd)); - cb(inpL, "inp_scaled", -1); - } + inpL = ggml_scale(ctx0, inpL, ubatch.token ? sqrtf(n_embd) : 1.0f); + cb(inpL, "inp_scaled", -1); + // inp_pos - contains the positions ggml_tensor * inp_pos = build_inp_pos(); @@ -248,7 +245,7 @@ ggml_tensor * llm_build_gemma3n_iswa::view_2d_slice(ggml_tensor * x, int idx) { // equivalent to get_per_layer_inputs() in python code // output shape: [n_embd_altup, n_layer, n_tokens] ggml_tensor * llm_build_gemma3n_iswa::get_per_layer_inputs() { - auto inp = std::make_unique(); + auto inp = std::make_unique(); ggml_tensor * inp_per_layer; if (ubatch.token) { inp->tokens = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, ubatch.n_tokens); diff --git a/src/models/llama-iswa.cpp b/src/models/llama-iswa.cpp index 03f8061682..61dd2c179f 100644 --- a/src/models/llama-iswa.cpp +++ b/src/models/llama-iswa.cpp @@ -25,8 +25,12 @@ llm_build_llama_iswa::llm_build_llama_iswa(const llama_model & model, const llm_ ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + const float freq_base_l = model.get_rope_freq_base (cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + ggml_tensor * inpSA = inpL; + // This overlaps with SWA layers in current models, so get_rope_freq_base/scale may be superfluous const bool use_rope = hparams.n_no_rope_layer_step > 0 && (il + 1) % hparams.n_no_rope_layer_step != 0; @@ -67,13 +71,13 @@ llm_build_llama_iswa::llm_build_llama_iswa(const llama_model & model, const llm_ if (use_rope) { Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow ); Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, rope_factors, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow ); } else if (inp_attn_scale) { diff --git a/src/models/maincoder.cpp b/src/models/maincoder.cpp new file mode 100644 index 0000000000..da57308167 --- /dev/null +++ b/src/models/maincoder.cpp @@ -0,0 +1,117 @@ +#include "models.h" + +llm_build_maincoder::llm_build_maincoder(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + // inp_pos - contains the positions + ggml_tensor * inp_pos = build_inp_pos(); + + auto * inp_attn = build_attn_inp_kv(); + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + // norm + cur = build_norm(inpL, + model.layers[il].attn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, nullptr, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + cb(Qcur, "Qcur", il); + cb(Kcur, "Kcur", il); + cb(Vcur, "Vcur", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, 1.0f/sqrtf(float(n_embd_head)), il); + } + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, NULL, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + cur = build_ffn(cur, + model.layers[il].ffn_up, NULL, NULL, + model.layers[il].ffn_gate, NULL, NULL, + model.layers[il].ffn_down, NULL, NULL, + NULL, + LLM_FFN_SILU, LLM_FFN_PAR, il); + cb(cur, "ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_inp); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + cur = inpL; + + cur = build_norm(cur, + model.output_norm, NULL, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); +} diff --git a/src/models/models.h b/src/models/models.h index e2cd4e484f..72b2b760c6 100644 --- a/src/models/models.h +++ b/src/models/models.h @@ -312,6 +312,10 @@ struct llm_build_llama_iswa : public llm_graph_context { llm_build_llama_iswa(const llama_model & model, const llm_graph_params & params); }; +struct llm_build_maincoder : public llm_graph_context { + llm_build_maincoder(const llama_model & model, const llm_graph_params & params); +}; + struct llm_build_mamba : public llm_graph_context_mamba { llm_build_mamba(const llama_model & model, const llm_graph_params & params); }; @@ -332,7 +336,6 @@ struct llm_build_mistral3 : public llm_graph_context { llm_build_mistral3(const llama_model & model, const llm_graph_params & params); }; -template struct llm_build_modern_bert : public llm_graph_context { llm_build_modern_bert(const llama_model & model, const llm_graph_params & params); }; diff --git a/src/models/modern-bert.cpp b/src/models/modern-bert.cpp index c7809bdedf..bb12ed819f 100644 --- a/src/models/modern-bert.cpp +++ b/src/models/modern-bert.cpp @@ -1,7 +1,6 @@ #include "models.h" -template -llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { +llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { const int64_t n_embd_head = hparams.n_embd_head_v; const int64_t n_embd_gqa = hparams.n_embd_v_gqa(); @@ -24,13 +23,8 @@ llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, co auto * inp_attn = build_attn_inp_no_cache(); for (int il = 0; il < n_layer; ++il) { - float freq_base_l = 0.0f; - - if constexpr (iswa) { - freq_base_l = model.get_rope_freq_base(cparams, il); - } else { - freq_base_l = freq_base; - } + const float freq_base_l = model.get_rope_freq_base(cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); cur = inpL; @@ -55,13 +49,13 @@ llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, co // RoPE Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow ); Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow ); @@ -120,7 +114,3 @@ llm_build_modern_bert::llm_build_modern_bert(const llama_model & model, co res->t_embd = cur; ggml_build_forward_expand(gf, cur); } - -// Explicit template instantiations -template struct llm_build_modern_bert; -template struct llm_build_modern_bert; diff --git a/src/models/openai-moe-iswa.cpp b/src/models/openai-moe-iswa.cpp index 96596709ee..dbe3ca1851 100644 --- a/src/models/openai-moe-iswa.cpp +++ b/src/models/openai-moe-iswa.cpp @@ -14,6 +14,9 @@ llm_build_openai_moe_iswa::llm_build_openai_moe_iswa(const llama_model & model, ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + const float freq_base_l = model.get_rope_freq_base (cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); + ggml_tensor * inpSA = inpL; // norm @@ -49,13 +52,13 @@ llm_build_openai_moe_iswa::llm_build_openai_moe_iswa(const llama_model & model, Qcur = ggml_rope_ext( ctx0, Qcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow ); Kcur = ggml_rope_ext( ctx0, Kcur, inp_pos, nullptr, - n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow ); diff --git a/src/models/smallthinker.cpp b/src/models/smallthinker.cpp index 277eec2955..4c497ca76f 100644 --- a/src/models/smallthinker.cpp +++ b/src/models/smallthinker.cpp @@ -26,10 +26,16 @@ llm_build_smallthinker::llm_build_smallthinker(const llama_model & model, ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { - ggml_tensor * inpSA = inpL; - ggml_tensor * probs = nullptr; + const float freq_base_l = model.get_rope_freq_base (cparams, il); + const float freq_scale_l = model.get_rope_freq_scale(cparams, il); - probs = build_lora_mm(model.layers[il].ffn_gate_inp, inpL); // [n_expert, n_tokens] + ggml_tensor * inpSA = inpL; + + // This overlaps with SWA layers in current models, so get_rope_freq_base/scale may be superfluous + const bool use_rope = hparams.n_no_rope_layer_step == n_layer || + il % hparams.n_no_rope_layer_step != 0; + + ggml_tensor * probs = build_lora_mm(model.layers[il].ffn_gate_inp, inpL); // [n_expert, n_tokens] cb(probs, "ffn_moe_logits", il); // norm @@ -52,11 +58,11 @@ llm_build_smallthinker::llm_build_smallthinker(const llama_model & model, Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); - if (hparams.n_no_rope_layer_step == n_layer || il % hparams.n_no_rope_layer_step != 0) { - Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + if (use_rope) { + Qcur = ggml_rope_ext(ctx0, Qcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow); - Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + Kcur = ggml_rope_ext(ctx0, Kcur, inp_pos, nullptr, n_rot, rope_type, n_ctx_orig, freq_base_l, freq_scale_l, ext_factor, attn_factor, beta_fast, beta_slow); } cb(Qcur, "Qcur", il); diff --git a/src/unicode.cpp b/src/unicode.cpp index bb44edfadd..b47dcbe619 100644 --- a/src/unicode.cpp +++ b/src/unicode.cpp @@ -964,6 +964,11 @@ std::vector unicode_regex_split(const std::string & text, const std { "\\p{P}", unicode_cpt_flags::PUNCTUATION }, { "\\p{M}", unicode_cpt_flags::ACCENT_MARK }, { "\\p{S}", unicode_cpt_flags::SYMBOL }, + { "\\p{Lu}", unicode_cpt_flags::LETTER }, // Uppercase letter + { "\\p{Ll}", unicode_cpt_flags::LETTER }, // Lowercase letter + { "\\p{Lt}", unicode_cpt_flags::LETTER }, // Titlecase letter + { "\\p{Lm}", unicode_cpt_flags::LETTER }, // Modifier letter + { "\\p{Lo}", unicode_cpt_flags::LETTER }, // Other letter }; static const std::map k_ucat_cpt = { @@ -1074,22 +1079,26 @@ std::vector unicode_regex_split(const std::string & text, const std continue; } - if (regex_expr[i + 0] == '\\' && i + 4 < regex_expr.size() && + // Match \p{...} Unicode properties of varying lengths + if (regex_expr[i + 0] == '\\' && i + 3 < regex_expr.size() && regex_expr[i + 1] == 'p' && - regex_expr[i + 2] == '{' && - regex_expr[i + 4] == '}') { - const std::string pat = regex_expr.substr(i, 5); - if (k_ucat_enum.find(pat) != k_ucat_enum.end()) { - if (!inside) { - regex_expr_collapsed += '['; + regex_expr[i + 2] == '{') { + // Find the closing brace + size_t closing_brace = regex_expr.find('}', i + 3); + if (closing_brace != std::string::npos && closing_brace <= i + 10) { // reasonable limit + const std::string pat = regex_expr.substr(i, closing_brace - i + 1); + if (k_ucat_enum.find(pat) != k_ucat_enum.end()) { + if (!inside) { + regex_expr_collapsed += '['; + } + regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat)); + regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat)); + if (!inside) { + regex_expr_collapsed += ']'; + } + i = closing_brace; + continue; } - regex_expr_collapsed += k_ucat_cpt.at(k_ucat_enum.at(pat)); - regex_expr_collapsed += k_ucat_map.at(k_ucat_enum.at(pat)); - if (!inside) { - regex_expr_collapsed += ']'; - } - i += 4; - continue; } } diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index c3d9f9c324..6245cd967a 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -219,8 +219,18 @@ endif() llama_build_and_test(test-gguf.cpp) llama_build_and_test(test-backend-ops.cpp) -llama_build_and_test(test-model-load-cancel.cpp LABEL "model") -llama_build_and_test(test-autorelease.cpp LABEL "model") +llama_build_and_test(test-model-load-cancel.cpp LABEL "model") +llama_build_and_test(test-autorelease.cpp LABEL "model") +llama_build_and_test(test-backend-sampler.cpp LABEL "model") + +llama_test(test-backend-sampler NAME test-backend-sampler-greedy ARGS --test greedy) +llama_test(test-backend-sampler NAME test-backend-sampler-temp ARGS --test temp) +llama_test(test-backend-sampler NAME test-backend-sampler-top_k ARGS --test top_k) +llama_test(test-backend-sampler NAME test-backend-sampler-dist ARGS --test dist) +llama_test(test-backend-sampler NAME test-backend-sampler-dist-and-cpu ARGS --test dist_and_cpu) +llama_test(test-backend-sampler NAME test-backend-sampler-logit-bias ARGS --test logit_bias) +llama_test(test-backend-sampler NAME test-backend-sampler-mul_seq ARGS --test multi_sequence) +llama_test(test-backend-sampler NAME test-backend-sampler-set-sampler ARGS --test set_sampler) # Test for state restore with fragmented KV cache # Requires a model, uses same args pattern as test-thread-safety diff --git a/tests/test-arg-parser.cpp b/tests/test-arg-parser.cpp index 1bbb745e78..e995974a2e 100644 --- a/tests/test-arg-parser.cpp +++ b/tests/test-arg-parser.cpp @@ -127,6 +127,15 @@ int main(void) { assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_SPECULATIVE)); assert(params.speculative.n_max == 123); + // multi-value args (CSV) + argv = {"binary_name", "--lora", "file1.gguf,\"file2,2.gguf\",\"file3\"\"3\"\".gguf\",file4\".gguf"}; + assert(true == common_params_parse(argv.size(), list_str_to_char(argv).data(), params, LLAMA_EXAMPLE_COMMON)); + assert(params.lora_adapters.size() == 4); + assert(params.lora_adapters[0].path == "file1.gguf"); + assert(params.lora_adapters[1].path == "file2,2.gguf"); + assert(params.lora_adapters[2].path == "file3\"3\".gguf"); + assert(params.lora_adapters[3].path == "file4\".gguf"); + // skip this part on windows, because setenv is not supported #ifdef _WIN32 printf("test-arg-parser: skip on windows build\n"); diff --git a/tests/test-backend-ops.cpp b/tests/test-backend-ops.cpp index 6dedd8de58..15567abedc 100644 --- a/tests/test-backend-ops.cpp +++ b/tests/test-backend-ops.cpp @@ -3431,6 +3431,65 @@ struct test_rms_norm_mul_add : public test_case { } }; +// GGML_OP_ADD + GGML_OP_RMS_NORM (fused operation) +struct test_add_rms_norm : public test_case { + const ggml_type type; + const std::array ne; + const float eps; + const bool broadcast; + + std::string op_desc(ggml_tensor * t) override { + GGML_UNUSED(t); + return "ADD_RMS_NORM"; + } + + bool run_whole_graph() override { return true; } + + std::string vars() override { + return VARS_TO_STR4(type, ne, eps, broadcast); + } + + test_add_rms_norm(ggml_type type = GGML_TYPE_F32, + std::array ne = {64, 5, 4, 3}, + float eps = 1e-6f, bool broadcast = false) + : type(type), ne(ne), eps(eps), broadcast(broadcast) {} + + ggml_tensor * build_graph(ggml_context * ctx) override { + std::array broadcast_dims = {ne[0]*2, ne[1]*3, ne[2]*3, ne[3]*4}; + + ggml_tensor * a = ggml_new_tensor(ctx, type, 4, broadcast ? broadcast_dims.data() : ne.data()); + ggml_tensor * b = ggml_new_tensor(ctx, type, 4, ne.data()); + + ggml_set_param(a); + ggml_set_name(a, "a"); + ggml_set_param(b); + ggml_set_name(b, "b"); + + // ADD operation followed by RMS_NORM + ggml_tensor * add_result = ggml_add(ctx, a, b); + ggml_set_name(add_result, "add_result"); + + ggml_tensor * out = ggml_rms_norm(ctx, add_result, eps); + ggml_set_name(out, "out"); + + return out; + } + + void initialize_tensors(ggml_context * ctx) override { + for (ggml_tensor * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { + init_tensor_uniform(t, -10.f, 10.f); + } + } + + float grad_eps() override { + return 1.0f; + } + + bool grad_precise() override { + return true; + } +}; + // GGML_OP_SSM_CONV struct test_ssm_conv : public test_case { const ggml_type type; @@ -7393,11 +7452,14 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true)); test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false)); test_cases.emplace_back(new test_norm_mul_add(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true)); + test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps, false)); + test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, {64, 5, 4, 3}, eps, true)); } for (uint32_t n : {1, 511, 1025, 8192, 33*512}) { for (bool multi_add : {false, true}) { test_cases.emplace_back(new test_rms_norm_mul_add(GGML_TYPE_F32, {n, 1, 1, 1}, 1e-6f, false, multi_add)); } + test_cases.emplace_back(new test_add_rms_norm(GGML_TYPE_F32, {n, 1, 1, 1}, 1e-6f, false)); } for (auto multi_add : {false, true}) { @@ -7563,6 +7625,10 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F32, GGML_TYPE_F32, 64, 77, 77, {12,1}, {1,1})); test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 576, 512, 576, {1,1}, {1,1})); + test_cases.emplace_back(new test_mul_mat(GGML_TYPE_Q4_0, GGML_TYPE_F32, 1, 2048, 8192, {1, 1}, {1, 1})); + for (ggml_type type_a : all_types) { + test_cases.emplace_back(new test_mul_mat(type_a, GGML_TYPE_F32, 1, 64, 256, {1, 1}, {1, 1})); + } #if 0 // test the mat-mat path for Metal @@ -7775,8 +7841,11 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f)); test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {32, 2, 32, 1}, true, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200001, 2, 3, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f)); - test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200001, 2, 3, 1}, true, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200001, 2, 3, 1}, true, true, GGML_TYPE_F32, {1, 1}, 0.1f, 8.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200001, 2, 3, 1}, true, true, GGML_TYPE_F16, {1, 1}, 0.1f, 8.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200000, 1, 1, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {200000, 4, 1, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {643251, 3, 1, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); for (float max_bias : {0.0f, 8.0f}) { for (float scale : {1.0f, 0.1f}) { @@ -7880,6 +7949,11 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {2, 8, 8192, 1}, order)); // bailingmoe2 (group selection) } + for (int n = 1; n < 5; ++n) { + for (int k = 1; k <= n; ++k) { + test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {n, 2, 1, 3}, k, true)); + } + } for (int i = 0; i < 20; ++i) { for (int k : {1, 2, 3, 7, 15, 100, 500, 1023, 9999}) { if (k <= 1<> make_test_cases_eval() { test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 2048, 5, 4, 3 })); test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 201*1204, 1, 1, 1 })); test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 312*1205, 1, 1, 1 })); + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, { 20481, 4, 1, 1 })); test_cases.emplace_back(new test_xielu()); @@ -8109,6 +8184,7 @@ static std::vector> make_test_cases_eval() { test_cases.emplace_back(new test_topk_moe({71, 22, 1, 1}, 8, with_norm, bias_probs, gate, scale_w)); test_cases.emplace_back(new test_topk_moe({128, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w)); test_cases.emplace_back(new test_topk_moe({129, 1, 1, 1}, 128, with_norm, bias_probs, gate, scale_w)); + test_cases.emplace_back(new test_topk_moe({160, 4, 1, 1}, 160, with_norm, bias_probs, gate, scale_w)); } } } @@ -8294,6 +8370,12 @@ static std::vector> make_test_cases_perf() { } } + for (int col : {8192, 16384, 32768, 65536, 131072, 262144, 524288}) { + for (int rows : {1, 4, 16}){ + test_cases.emplace_back(new test_soft_max(GGML_TYPE_F32, {col, rows, 1, 1}, false, false, GGML_TYPE_F32, {1, 1}, 1.0f, 0.0f)); + } + } + test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, false)); test_cases.emplace_back(new test_conv_2d_dw({512, 512, 256, 1}, {3, 3, 1, 256}, 1, 1, 1, true)); @@ -8337,7 +8419,9 @@ static std::vector> make_test_cases_perf() { test_cases.emplace_back(new test_sum(GGML_TYPE_F32, it)); } - test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1})); + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1})); + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {200000, 1, 1, 1})); + test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {200000, 16, 1, 1})); test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2, 1, 1, 1}, 1)); for (auto k : {1, 10, 40, 400}) { @@ -8348,13 +8432,18 @@ static std::vector> make_test_cases_perf() { } } + for (auto nrows : {1, 4, 8, 16}) { + for (auto cols : {128, 1024, 4096, 8192, 16384, 32768, 65536, 131072, 200000, 2000000}) { + test_cases.emplace_back(new test_cumsum(GGML_TYPE_F32, {cols, nrows, 1, 1})); + } + } + // Examples from granite-4.0-h-1b/ggml-model-Q8_0.gguf test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {515, 3328, 1, 1}, {4, 3328, 1, 1})); // prefill test_cases.emplace_back(new test_ssm_conv(GGML_TYPE_F32, {4, 3328, 1, 1}, {4, 3328, 1, 1})); // generate test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 512, 1)); // prefill test_cases.emplace_back(new test_ssm_scan(GGML_TYPE_F32, 128, 64, 48, 1, 1, 1)); // generate - return test_cases; } diff --git a/tests/test-backend-sampler.cpp b/tests/test-backend-sampler.cpp new file mode 100644 index 0000000000..24ece9d4b1 --- /dev/null +++ b/tests/test-backend-sampler.cpp @@ -0,0 +1,1237 @@ +#include "ggml.h" +#include "llama.h" +#include "llama-cpp.h" +#include "get-model.h" +#include "common.h" + +#ifdef NDEBUG +#undef NDEBUG +#endif + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +struct backend_cli_args { + const char * model = nullptr; + const char * test = nullptr; + const char * device = "cpu"; +}; + +struct test_model_context { + llama_model_ptr model; + llama_context_ptr ctx; + int n_vocab = 0; + + std::unordered_map seq_positions; + std::unordered_map last_batch_info; + + bool load_model(const backend_cli_args & args) { + if (model) { + return true; + } + + llama_backend_init(); + + auto mparams = llama_model_default_params(); + + ggml_backend_dev_t devs[2]; + if (std::string_view(args.device) == "gpu") { + ggml_backend_dev_t gpu = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU); + if (gpu == nullptr) { + fprintf(stderr, "Error: GPU requested but not available\n"); + return false; + } + devs[0] = gpu; + devs[1] = nullptr; // null terminator + mparams.devices = devs; + mparams.n_gpu_layers = 999; + } else if (std::string_view(args.device) == "cpu") { + ggml_backend_dev_t cpu = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); + devs[0] = cpu; + devs[1] = nullptr; // null terminator + mparams.devices = devs; + } + + fprintf(stderr, "Using device: %s\n", ggml_backend_dev_name(devs[0])); + + model.reset(llama_model_load_from_file(args.model, mparams)); + + if (!model) { + fprintf(stderr, "Warning: failed to load model '%s', skipping test\n", args.model); + return false; + } + n_vocab = llama_vocab_n_tokens(get_vocab()); + fprintf(stderr, "Vocabulary size: %d\n", n_vocab); + + return true; + } + + bool setup(const backend_cli_args & args, std::vector & configs, int32_t n_seq_max = -1) { + if (!model) { + load_model(args); + } + + if (ctx) { + return true; + } + + llama_context_params cparams = llama_context_default_params(); + cparams.n_ctx = 512; + cparams.n_batch = 512; + cparams.samplers = configs.data(); + cparams.n_samplers = configs.size(); + + // If n_seq_max is not specified, calculate it from configs + if (n_seq_max < 0) { + int32_t max_seq_id = 0; + for (const auto & config : configs) { + max_seq_id = std::max(config.seq_id, max_seq_id); + } + cparams.n_seq_max = max_seq_id + 1; + } else { + cparams.n_seq_max = n_seq_max; + } + + ctx.reset(llama_init_from_model(model.get(), cparams)); + if (!ctx) { + fprintf(stderr, "Warning: failed to create context, skipping test\n"); + return false; + } + llama_set_warmup(ctx.get(), false); + + return true; + } + + bool decode(const std::map & prompts) { + if (!ctx) { + fprintf(stderr, "Error: context not initialized, call setup() first\n"); + return false; + } + + last_batch_info.clear(); + llama_batch batch = llama_batch_init(512, 0, prompts.size()); + + auto vocab = get_vocab(); + for (const auto & [seq_id, prompt] : prompts) { + std::vector tokens; + tokens.push_back(llama_vocab_bos(vocab)); + + std::vector prompt_tokens(32); + int n_tokens = llama_tokenize(vocab, prompt.c_str(), prompt.length(), + prompt_tokens.data(), prompt_tokens.size(), + false, false); + if (n_tokens < 0) { + fprintf(stderr, "Warning: tokenization failed for seq_id %d\n", seq_id); + llama_batch_free(batch); + return false; + } + + for (int i = 0; i < n_tokens; i++) { + tokens.push_back(prompt_tokens[i]); + } + + if (seq_positions.find(seq_id) == seq_positions.end()) { + seq_positions[seq_id] = 0; + } + + int32_t start_pos = seq_positions[seq_id]; + for (size_t i = 0; i < tokens.size(); i++) { + common_batch_add(batch, tokens[i], start_pos + i, { seq_id }, i == tokens.size() - 1); + } + + seq_positions[seq_id] = start_pos + tokens.size(); + } + + + printf("Batch contents:\n"); + printf("n_tokens: %d\n", batch.n_tokens); + for (int i = 0; i < batch.n_tokens; i++) { + printf("token[%d]: tok=%-5d, pos=%d, n_seq_id=%d, seq_ids=[", i, batch.token[i], batch.pos[i], batch.n_seq_id[i]); + + for (int j = 0; j < batch.n_seq_id[i]; j++) { + printf("%d%s", batch.seq_id[i][j], j < batch.n_seq_id[i]-1 ? ", " : ""); + } + printf("], logits=%d\n", batch.logits[i]); + } + + if (llama_decode(ctx.get(), batch) != 0) { + fprintf(stderr, "Warning: llama_decode failed\n"); + llama_batch_free(batch); + return false; + } + + // Build mapping from seq id to batch token idx + for (int i = 0; i < batch.n_tokens; i++) { + if (batch.logits[i]) { + llama_seq_id seq_id = batch.seq_id[i][0]; + last_batch_info[seq_id] = i; + } + } + + llama_batch_free(batch); + return true; + } + + int32_t idx_for_seq(llama_seq_id seq_id) { + auto it = last_batch_info.find(seq_id); + if (it == last_batch_info.end()) { + fprintf(stderr, "Error: no batch index found for seq_id %d\n", seq_id); + return -1; + } + return it->second; + } + + void update_batch_info(const llama_batch & batch) { + last_batch_info.clear(); + for (int i = 0; i < batch.n_tokens; i++) { + if (batch.logits[i]) { + llama_seq_id cur_seq = batch.seq_id[i][0]; + last_batch_info[cur_seq] = i; + } + } + } + + bool decode_token(llama_token token, llama_seq_id seq_id = 0) { + if (ctx == nullptr) { + fprintf(stderr, "Error: context not initialized, call setup() first\n"); + return false; + } + + llama_batch batch = llama_batch_init(1, 0, 1); + int32_t pos = seq_positions[seq_id]; + common_batch_add(batch, token, pos, { seq_id }, true); + + if (llama_decode(ctx.get(), batch) != 0) { + fprintf(stderr, "Warning: llama_decode failed for token %d in seq %d\n", token, seq_id); + llama_batch_free(batch); + return false; + } + + update_batch_info(batch); + + seq_positions[seq_id]++; + llama_batch_free(batch); + return true; + } + + bool decode_tokens(const std::map & seq_tokens) { + if (ctx == nullptr) { + fprintf(stderr, "Error: context not initialized, call setup() first\n"); + return false; + } + + llama_batch batch = llama_batch_init(seq_tokens.size(), 0, seq_tokens.size()); + + for (const auto & [seq_id, token] : seq_tokens) { + int32_t pos = seq_positions[seq_id]; + common_batch_add(batch, token, pos, { seq_id }, true); + } + + if (llama_decode(ctx.get(), batch) != 0) { + fprintf(stderr, "Warning: llama_decode failed for batch tokens\n"); + llama_batch_free(batch); + return false; + } + + for (const auto & [seq_id, _] : seq_tokens) { + seq_positions[seq_id]++; + } + + update_batch_info(batch); + + llama_batch_free(batch); + return true; + } + + std::string token_to_piece(llama_token token, bool special) { + std::string piece; + piece.resize(piece.capacity()); // using string internal cache, 15 bytes + '\n' + const int n_chars = llama_token_to_piece(get_vocab(), token, &piece[0], piece.size(), 0, special); + if (n_chars < 0) { + piece.resize(-n_chars); + int check = llama_token_to_piece(get_vocab(), token, &piece[0], piece.size(), 0, special); + GGML_ASSERT(check == -n_chars); + } + else { + piece.resize(n_chars); + } + + return piece; + } + + void reset() { + ctx.reset(); + seq_positions.clear(); + last_batch_info.clear(); + } + + const llama_vocab * get_vocab() const { + return model ? llama_model_get_vocab(model.get()) : nullptr; + } + +}; + +static void test_backend_greedy_sampling(const backend_cli_args & args) { + test_model_context test_ctx; + + const int seq_id = 0; + + struct llama_sampler_chain_params backend_sampler_params = llama_sampler_chain_default_params(); + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_sampler_params)); + + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_greedy()); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }}; + + if (!test_ctx.setup(args, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Some"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx); + printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + + token = llama_get_sampled_token_ith(test_ctx.ctx.get(), -1); + printf("greedy sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + + for (int i = 0; i < 10; i++) { + int32_t loop_idx = test_ctx.idx_for_seq(seq_id); + llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), loop_idx); + printf("Generation step %d: token id:%d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str()); + if (!test_ctx.decode_token(token, 0)) { + GGML_ASSERT(false && "Failed to decode token"); + } + } +} + +static void test_backend_top_k_sampling(const backend_cli_args & args) { + test_model_context test_ctx; + + const int seq_id = 0; + const int32_t k = 8; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_top_k(k)); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }}; + + if (!test_ctx.setup(args, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Hello"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx); + uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx); + for (size_t i = 0; i < n_logits; ++i) { + printf("top_k logit[%zu] = %.6f\n", i, logits[i]); + } + + llama_token * candidates = llama_get_sampled_candidates_ith(test_ctx.ctx.get(), batch_idx); + uint32_t n_candidates = llama_get_sampled_candidates_count_ith(test_ctx.ctx.get(), batch_idx); + for (size_t i = 0; i < n_candidates; ++i) { + printf("top_k candidate[%zu] = %d : %s\n", i, candidates[i], + test_ctx.token_to_piece(candidates[i], false).c_str()); + } + + // Sample using CPU sampler for verification that it is possible to do hybrid + // sampling, first top_k on the backend and then dist on the CPU. + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); + GGML_ASSERT(chain->iface->backend_apply != nullptr); + + llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18)); + llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + + printf("backend top-k hybrid sampling test PASSED\n"); +} + +static void test_backend_temp_sampling(const backend_cli_args & args) { + test_model_context test_ctx; + + + { + const float temp_0 = 0.8f; + struct llama_sampler_chain_params backend_chain_params_0 = llama_sampler_chain_default_params(); + llama_sampler_ptr backend_sampler_chain_0(llama_sampler_chain_init(backend_chain_params_0)); + llama_sampler_chain_add(backend_sampler_chain_0.get(), llama_sampler_init_temp(temp_0)); + + const float temp_1 = 0.1f; + struct llama_sampler_chain_params backend_chain_params_1 = llama_sampler_chain_default_params(); + llama_sampler_ptr backend_sampler_chain_1(llama_sampler_chain_init(backend_chain_params_1)); + llama_sampler_chain_add(backend_sampler_chain_1.get(), llama_sampler_init_temp(temp_1)); + + std::vector backend_sampler_configs = { + { 0, backend_sampler_chain_0.get() }, + { 1, backend_sampler_chain_1.get() } + }; + + if (!test_ctx.setup(args, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{0, "Some where over the"}, {1, "Once upon a"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + // Verfify sequence 0 + { + int32_t batch_idx = test_ctx.idx_for_seq(0); + int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx); + GGML_ASSERT(n_logits == test_ctx.n_vocab); + + // Sample from sequence 0 using CPU sampler + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); + llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18)); + + llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("Sequence 0 sampled token id:%d, string: '%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + } + + + // Verfify sequence 1 + { + int32_t batch_idx = test_ctx.idx_for_seq(1); + + // Sample from sequence 1 using CPU sampler + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); + llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18)); + + llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("Sequence 1 sampled token id:%d, string: '%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + } + } + + // lambda to testing non-positive temperature values. + auto test_argmax_temp = [&](float temp) { + printf("\nTesting temperature = %.1f\n", temp); + + test_ctx.reset(); + + int seq_id = 0; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_temp(temp)); + + std::vector backend_sampler_configs = { + { seq_id, backend_sampler_chain.get() }, + }; + + if (!test_ctx.setup(args, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Once"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx); + GGML_ASSERT(n_logits == 1); + }; + + test_argmax_temp(0.0f); + test_argmax_temp(-1.0f); + + printf("backend temp sampling test PASSED\n"); + +} + +static void test_backend_temp_ext_sampling(const backend_cli_args & args) { + test_model_context test_ctx; + + { + int seq_id = 0; + const float temp = 0.8f; + const float delta = 0.5f; + const float exponent = 1.5f; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_temp_ext(temp, delta, exponent)); + + std::vector backend_sampler_configs = { + { seq_id, backend_sampler_chain.get() }, + }; + + if (!test_ctx.setup(args, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Once upon a"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + // Verify sequence 0 + { + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + int n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx); + GGML_ASSERT(n_logits == test_ctx.n_vocab); + } + } + + test_ctx.reset(); + + // lambda to testing non-positive temp/delta/exponent values. + auto test_argmax_temp = [&](float temp, float delta, float exponent) { + printf("\nTesting temperature = %.1f, delta = %1.f, exponent = %1.f\n", temp, delta, exponent); + + test_ctx.reset(); + + int seq_id = 0; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_temp_ext(temp, delta, exponent)); + + std::vector backend_sampler_configs = { + { seq_id, backend_sampler_chain.get() }, + }; + + if (!test_ctx.setup(args, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Once"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx); + + if (temp <= 0.0f && delta >= 0.0f) { + GGML_ASSERT(n_logits == 1); + } else { + GGML_ASSERT(n_logits == (uint32_t) test_ctx.n_vocab); + } + }; + + test_argmax_temp(0.0f, 0.3f, 1.0f); // Greedy (temp=0) + test_argmax_temp(-1.0f, 0.3f, 2.0f); // Greedy (temp<0) + test_argmax_temp(0.8f, 0.0f, 2.0f); // Temperature scaling + + printf("backend temp_ext sampling test PASSED\n"); + +} + +static void test_backend_min_p_sampling(const backend_cli_args & args) { + test_model_context test_ctx; + + const int seq_id = 0; + const float p = 0.1; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_min_p(p, 0)); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }}; + + if (!test_ctx.setup(args, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Hello"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx); + uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx); + + // Print the logits that are above the min-p threshold + std::vector filtered_logits; + for (size_t i = 0; i < n_logits; ++i) { + if (logits[i] > -1e9f) { + filtered_logits.push_back(logits[i]); + //printf("min_p logit[%zu] = %.6f\n", i, logits[i]); + } + } + GGML_ASSERT(filtered_logits.size() < (size_t) test_ctx.n_vocab); + + // Sample using CPU sampler for verification to inspect they are reasonable + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); + llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(88)); + + llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("min-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + + // Decode and sampler 10 more tokens + for (int i = 0; i < 10; i++) { + int32_t loop_idx = test_ctx.idx_for_seq(seq_id); + llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), loop_idx); + printf("min-p gen step %d: token id :%5.d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str()); + if (!test_ctx.decode_token(token, 0)) { + GGML_ASSERT(false && "Failed to decode token"); + } + } + + printf("min-p sampling test PASSED\n"); +} + +static void test_backend_top_p_sampling(const backend_cli_args & args) { + test_model_context test_ctx; + + const int seq_id = 0; + const float p = 0.9; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_top_p(p, 0)); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }}; + + if (!test_ctx.setup(args, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Hello"}})) { + return; + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx); + uint32_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx); + + // Print the logits that are above the min-p threshold + std::vector filtered_logits; + for (size_t i = 0; i < n_logits; ++i) { + if (logits[i] > -1e9f) { + filtered_logits.push_back(logits[i]); + } + } + GGML_ASSERT(filtered_logits.size() < (size_t) test_ctx.n_vocab); + GGML_ASSERT(filtered_logits.size() > 0); + + // Sample using CPU sampler for verification to inspect they are reasonable + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); + llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(88)); + + llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("top-p cpu sampled token id:%d, string: '%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + + // Decode and sampler 10 more tokens + for (int i = 0; i < 10; i++) { + int32_t loop_idx = test_ctx.idx_for_seq(seq_id); + llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), loop_idx); + printf("top-p gen step %d: token id :%5.d, string: %s\n", i, token, test_ctx.token_to_piece(token, false).c_str()); + test_ctx.decode_token(token, 0); + } + + printf("top-p sampling test PASSED\n"); +} + +static void test_backend_multi_sequence_sampling(const backend_cli_args & args) { + test_model_context test_ctx; + + struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params(); + llama_sampler_ptr sampler_chain_0(llama_sampler_chain_init(chain_params_0)); + llama_sampler_chain_add(sampler_chain_0.get(), llama_sampler_init_greedy()); + + struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params(); + llama_sampler_ptr sampler_chain_1(llama_sampler_chain_init(chain_params_1)); + llama_sampler_chain_add(sampler_chain_1.get(), llama_sampler_init_temp(0.8f)); + llama_sampler_chain_add(sampler_chain_1.get(), llama_sampler_init_greedy()); + + std::vector backend_sampler_configs = { + { 0, sampler_chain_0.get() }, + { 1, sampler_chain_1.get() } + }; + + if (!test_ctx.setup(args, backend_sampler_configs)) { + return; + } + + std::map prompts = { + {0, "Hello"}, + {1, "Some"} + }; + + if (!test_ctx.decode(prompts)) { + GGML_ASSERT(false && "Failed to decode token"); + } + + // Verfiy sequence 0 + { + int32_t batch_idx = test_ctx.idx_for_seq(0); + llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("Seq 0 sampled token id=%d, string='%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + } + + // Verify sequence 1 + { + int32_t batch_idx= test_ctx.idx_for_seq(1); + llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("Seq 1 sampled token id=%d, string='%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + } + + // Generate tokens for each sequence + printf("\nMulti-sequence generation:\n"); + for (int step = 0; step < 4; step++) { + std::map tokens; + + for (llama_seq_id seq_id : {0, 1}) { + int32_t idx = test_ctx.idx_for_seq(seq_id); + llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf(" Seq %d, step %d: token id=%d, string='%s'\n", seq_id, step, token, token_str.c_str()); + tokens[seq_id] = token; + } + + // Decode all tokens in a single batch + if (!test_ctx.decode_tokens(tokens)) { + GGML_ASSERT(false && "Failed to decode token"); + } + } + + printf("backend multi-sequence sampling test PASSED\n"); +} + +static void test_backend_dist_sampling(const backend_cli_args & args) { + test_model_context test_ctx; + + const int seq_id = 189; + const int32_t seed = 88; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed)); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }}; + + if (!test_ctx.setup(args, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Some"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx); + printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + //GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx) == nullptr); + + token = llama_get_sampled_token_ith(test_ctx.ctx.get(), -1); + printf("dist sampled id:%d, string:'%s'\n", token, test_ctx.token_to_piece(token, false).c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + + printf("backend dist sampling test PASSED\n"); +} + +static void test_backend_dist_sampling_and_cpu(const backend_cli_args & args) { + test_model_context test_ctx; + + const int seq_id = 0; + const int32_t seed = 88; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed)); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }}; + + if (!test_ctx.setup(args, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Some"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + // Sample using CPU sampler + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); + llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18)); + + llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx); + llama_token cpu_token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx); + printf("dist & cpu sampled id:%d, string:'%s'\n", cpu_token, test_ctx.token_to_piece(cpu_token, false).c_str()); + GGML_ASSERT(backend_token == cpu_token); + + printf("backend dist & cpu sampling test PASSED\n"); +} + +static void test_backend_logit_bias_sampling(const backend_cli_args & args) { + test_model_context test_ctx; + + // Calling load_model to ensure vocab is loaded and can be accessed + if (!test_ctx.load_model(args)) { + return; + } + + const int seq_id = 0; + + // Create the logit biases vector. + std::vector logit_bias; + + // Get the token for the piece "World". + const std::string piece = "World"; + std::vector tokens(16); + llama_tokenize(test_ctx.get_vocab(), piece.c_str(), piece.size(), tokens.data(), tokens.size(), false, false); + llama_token bias_token = tokens[0]; + logit_bias.push_back({ bias_token, +100.0f }); + printf("biasing token piece '%s' -> token id %d\n", piece.c_str(), bias_token); + + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_logit_bias( + llama_vocab_n_tokens(test_ctx.get_vocab()), + logit_bias.size(), + logit_bias.data())); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(88)); + + std::vector backend_sampler_configs = { + { seq_id, backend_sampler_chain.get() }, + }; + + if (!test_ctx.setup(args, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Hello"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), test_ctx.idx_for_seq(seq_id)); + const std::string backend_token_str = test_ctx.token_to_piece(backend_token, false); + printf("logit bias sampled token = %d, string='%s'\n", backend_token, backend_token_str.c_str()); + GGML_ASSERT(backend_token == bias_token); + + printf("backend logit bias sampling test PASSED\n"); +} + +// This test verifies that it is possible to have two different backend sampler, +// one that uses the backend dist sampler, and another that uses CPU dist sampler. +static void test_backend_mixed_sampling(const backend_cli_args & args) { + test_model_context test_ctx; + + struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params(); + llama_sampler_ptr sampler_chain_0(llama_sampler_chain_init(chain_params_0)); + llama_sampler_chain_add(sampler_chain_0.get(), llama_sampler_init_dist(88)); + + int k = 40; + struct llama_sampler_chain_params chain_params_1 = llama_sampler_chain_default_params(); + llama_sampler_ptr sampler_chain_1(llama_sampler_chain_init(chain_params_1)); + llama_sampler_chain_add(sampler_chain_1.get(), llama_sampler_init_top_k(k)); + + std::vector backend_sampler_configs = { + { 0, sampler_chain_0.get() }, + { 1, sampler_chain_1.get() } + }; + + if (!test_ctx.setup(args, backend_sampler_configs)) { + return; + } + + std::map prompts = { + {0, "Hello"}, + {1, "Some"} + }; + + if (!test_ctx.decode(prompts)) { + GGML_ASSERT(false && "Failed to decode token"); + } + + // Verfiy sequence 0 that used the dist backend sampler. + { + int32_t batch_idx = test_ctx.idx_for_seq(0); + llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("sampled token id=%d, string='%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + //GGML_ASSERT(llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx) == nullptr); + //GGML_ASSERT(llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx) == 0); + } + + // Verfiy sequence 1 that used the top-k backend sampler. + { + int32_t batch_idx = test_ctx.idx_for_seq(1); + float * logits = llama_get_sampled_logits_ith(test_ctx.ctx.get(), batch_idx); + GGML_ASSERT(logits != nullptr); + size_t n_logits = llama_get_sampled_logits_count_ith(test_ctx.ctx.get(), batch_idx); + GGML_ASSERT(n_logits == (size_t) k); + GGML_ASSERT(llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx) == LLAMA_TOKEN_NULL); + } + + printf("backend mixed sampling test PASSED\n"); +} + +static void test_backend_set_sampler(const backend_cli_args & args) { + test_model_context test_ctx; + + const int32_t seed = 88; + const int seq_id = 0; + struct llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed)); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }}; + + if (!test_ctx.setup(args, backend_sampler_configs)) { + return; + } + + if (!test_ctx.decode({{seq_id, "Hello"}})) { + GGML_ASSERT(false && "Failed to decode token"); + } + + int32_t batch_idx = test_ctx.idx_for_seq(seq_id); + + // Sample using backend sampler configured above + llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx); + const std::string backend_token_str = test_ctx.token_to_piece(backend_token, false); + printf("dist sampled token = %d, string='%s'\n", backend_token, backend_token_str.c_str()); + + // Now clear the backend sampler for this sequence. + llama_set_sampler(test_ctx.ctx.get(), seq_id, nullptr); + printf("Cleared backend sampler for seq_id %d\n", seq_id); + + // Sample using CPU sampler + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); + llama_sampler_chain_add(chain.get(), llama_sampler_init_dist(18)); + + std::map tokens = { { seq_id, backend_token}, }; + if (!test_ctx.decode_tokens(tokens)) { + GGML_ASSERT(false && "Failed to decode token"); + } + + // Should not have any sampled token or probs after clearing the backend sampler. + const int32_t idx = test_ctx.idx_for_seq(seq_id); + GGML_ASSERT(llama_get_sampled_token_ith(test_ctx.ctx.get(), idx) == LLAMA_TOKEN_NULL); + GGML_ASSERT(llama_get_sampled_probs_ith(test_ctx.ctx.get(), idx) == nullptr); + + // Sample the token using the CPU sampler chain. + llama_token token2 = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), seq_id); + const std::string token2_str = test_ctx.token_to_piece(token2, false); + printf("CPU sampled token after clearing backend sampler: id=%d, string='%s'\n", token2, token2_str.c_str()); + std::map tokens2 = { { seq_id, token2}, }; + + // Set a new backend sampler for the sequence. + struct llama_sampler_chain_params new_backend_chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr new_backend_sampler_chain(llama_sampler_chain_init(new_backend_chain_params)); + llama_sampler_chain_add(new_backend_sampler_chain.get(), llama_sampler_init_top_k(20)); + llama_sampler_chain_add(new_backend_sampler_chain.get(), llama_sampler_init_dist(seed)); + llama_set_sampler(test_ctx.ctx.get(), seq_id, new_backend_sampler_chain.get()); + + if (!test_ctx.decode_tokens(tokens2)) { + GGML_ASSERT(false && "Failed to decode token"); + } + + llama_token new_backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), test_ctx.idx_for_seq(seq_id)); + const std::string new_backend_token_str = test_ctx.token_to_piece(new_backend_token, false); + printf("dist sampled token = %d, string='%s'\n", new_backend_token, new_backend_token_str.c_str()); + + printf("backend set sampler test PASSED\n"); +} + +static void test_backend_cpu_mixed_batch(const backend_cli_args & args) { + test_model_context test_ctx; + + // Sequence 0 uses backend sampling + struct llama_sampler_chain_params chain_params_0 = llama_sampler_chain_default_params(); + llama_sampler_ptr sampler_chain_0(llama_sampler_chain_init(chain_params_0)); + llama_sampler_chain_add(sampler_chain_0.get(), llama_sampler_init_dist(88)); + + std::vector backend_sampler_configs = { + { 0, sampler_chain_0.get() }, + }; + + // We need 2 sequences: seq 0 with backend sampling, seq 1 with CPU sampling + if (!test_ctx.setup(args, backend_sampler_configs, 2)) { + return; + } + + std::map prompts = { + {0, "Hello"}, // Will use backend sampling + {1, "Some"} // Will use CPU sampling + }; + + if (!test_ctx.decode(prompts)) { + GGML_ASSERT(false && "Failed to decode token"); + } + + // Verify sequence 0 (backend sampled) + { + int32_t batch_idx = test_ctx.idx_for_seq(0); + llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("Seq 0 (backend) sampled token id=%d, string='%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + } + + // Verify sequence 1 (CPU sampled) + { + int32_t batch_idx = test_ctx.idx_for_seq(1); + + llama_token backend_token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx); + GGML_ASSERT(backend_token == LLAMA_TOKEN_NULL); + + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); + llama_sampler_chain_add(chain.get(), llama_sampler_init_greedy()); + + llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("Seq 1 (CPU) sampled token id=%d, string='%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + } + + // Clear/remove the backend sampler, and sample again + { + // clear the backend sampler for seq 0 so that there are no backend + // samplers. + llama_set_sampler(test_ctx.ctx.get(), 0, nullptr); + + // Create a CPU sampler and verify we can sampler from it. + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr chain(llama_sampler_chain_init(chain_params)); + llama_sampler_chain_add(chain.get(), llama_sampler_init_greedy()); + + int32_t batch_idx = test_ctx.idx_for_seq(1); + llama_token token = llama_sampler_sample(chain.get(), test_ctx.ctx.get(), batch_idx); + if (!test_ctx.decode_token(token, 1)) { + GGML_ASSERT(false && "Failed to decode token"); + } + } + + // Set a backend sampler so that we can verify that it can be reset + { + struct llama_sampler_chain_params chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr sampler_chain(llama_sampler_chain_init(chain_params)); + llama_sampler_chain_add(sampler_chain.get(), llama_sampler_init_dist(88)); + + llama_set_sampler(test_ctx.ctx.get(), 0, sampler_chain.get()); + + if (!test_ctx.decode_token(3834, 0)) { + GGML_ASSERT(false && "Failed to decode token"); + } + + int32_t batch_idx = test_ctx.idx_for_seq(0); + llama_token token = llama_get_sampled_token_ith(test_ctx.ctx.get(), batch_idx); + const std::string token_str = test_ctx.token_to_piece(token, false); + printf("re-added backend sampled token id=%d, string='%s'\n", token, token_str.c_str()); + GGML_ASSERT(token >= 0 && token < test_ctx.n_vocab); + } + + printf("backend-cpu mixed batch test PASSED\n"); +} + +static void test_backend_max_outputs(const backend_cli_args & args) { + test_model_context test_ctx; + + const int seq_id = 0; + const int32_t seed = 88; + llama_sampler_chain_params backend_chain_params = llama_sampler_chain_default_params(); + llama_sampler_ptr backend_sampler_chain(llama_sampler_chain_init(backend_chain_params)); + llama_sampler_chain_add(backend_sampler_chain.get(), llama_sampler_init_dist(seed)); + std::vector backend_sampler_configs = {{ seq_id, backend_sampler_chain.get() }}; + + if (!test_ctx.setup(args, backend_sampler_configs)) { + return; + } + + llama_batch batch = llama_batch_init(512, 0, 1); + std::string prompt = "Hello"; + + std::vector tokens; + tokens.push_back(llama_vocab_bos(test_ctx.get_vocab())); + + std::vector prompt_tokens(32); + int n_tokens = llama_tokenize(test_ctx.get_vocab(), prompt.c_str(), prompt.length(), + prompt_tokens.data(), prompt_tokens.size(), + false, false); + for (int i = 0; i < n_tokens; i++) { + tokens.push_back(prompt_tokens[i]); + } + + for (size_t i = 0; i < tokens.size(); i++) { + // set all tokens as output to trigger error + common_batch_add(batch, tokens[i], i, { seq_id }, true); + } + + printf(">>> test_max_outputs expected error start:\n"); + const int ret = llama_decode(test_ctx.ctx.get(), batch); + GGML_ASSERT(ret != 0 && "llama_decode should not succeed multiple outputs per sequence"); + printf("<<< test_max_outputs expected error end.\n"); + llama_batch_free(batch); + + printf("backend max outputs test PASSED\n"); +} + +struct backend_test_case { + const char * name; + void (*fn)(const backend_cli_args &); + bool enabled_by_default; +}; + +static const backend_test_case BACKEND_TESTS[] = { + { "greedy", test_backend_greedy_sampling, true }, + { "logit_bias", test_backend_logit_bias_sampling, true }, + { "temp", test_backend_temp_sampling, true }, + { "temp_ext", test_backend_temp_ext_sampling, true }, + { "top_k", test_backend_top_k_sampling, true }, + { "multi_sequence", test_backend_multi_sequence_sampling, true }, + { "dist", test_backend_dist_sampling, true }, + { "dist_and_cpu", test_backend_dist_sampling_and_cpu, true }, + { "set_sampler", test_backend_set_sampler, true }, + { "max_outputs", test_backend_max_outputs, true }, + { "mixed", test_backend_mixed_sampling, true }, + { "min_p", test_backend_min_p_sampling, true }, + { "cpu_mixed", test_backend_cpu_mixed_batch, true }, + { "top_p", test_backend_top_p_sampling, true }, +}; + +static backend_cli_args parse_backend_cli(int argc, char ** argv) { + backend_cli_args out; + + for (int i = 1; i < argc; ++i) { + const char * arg = argv[i]; + + if (std::strcmp(arg, "--test") == 0) { + if (i + 1 >= argc) { + fprintf(stderr, "--test expects a value\n"); + exit(EXIT_FAILURE); + } + out.test = argv[++i]; + continue; + } + if (std::strncmp(arg, "--test=", 7) == 0) { + out.test = arg + 7; + continue; + } + if (std::strcmp(arg, "--model") == 0) { + if (i + 1 >= argc) { + fprintf(stderr, "--model expects a value\n"); + exit(EXIT_FAILURE); + } + out.model = argv[++i]; + continue; + } + if (std::strncmp(arg, "--model=", 8) == 0) { + out.model = arg + 8; + continue; + } + if (std::strcmp(arg, "--device") == 0) { + if (i + 1 >= argc) { + fprintf(stderr, "--device expects a value (cpu or gpu)\n"); + exit(EXIT_FAILURE); + } + out.device = argv[++i]; + continue; + } + if (std::strncmp(arg, "--device=", 9) == 0) { + out.device = arg + 9; + continue; + } + if (!out.model) { + out.model = arg; + continue; + } + + fprintf(stderr, "Unexpected argument: %s\n", arg); + exit(EXIT_FAILURE); + } + + if (std::strcmp(out.device, "cpu") != 0 && std::strcmp(out.device, "gpu") != 0) { + fprintf(stderr, "Invalid device '%s'. Must be 'cpu' or 'gpu'\n", out.device); + exit(EXIT_FAILURE); + } + + return out; +} + +static std::vector collect_tests_to_run(const char * requested) { + std::vector selected; + + if (requested != nullptr) { + for (const auto & test : BACKEND_TESTS) { + if (std::strcmp(test.name, requested) == 0) { + selected.push_back(&test); + break; + } + } + if (selected.empty()) { + fprintf(stderr, "Unknown test '%s'. Available tests:\n", requested); + for (const auto & test : BACKEND_TESTS) { + fprintf(stderr, " %s\n", test.name); + } + exit(EXIT_FAILURE); + } + } else { + for (const auto & test : BACKEND_TESTS) { + if (test.enabled_by_default) { + selected.push_back(&test); + } + } + } + + if (selected.empty()) { + fprintf(stderr, "No backend sampling tests selected. Use --test= to pick one.\n"); + } + + return selected; +} + +static void run_tests(const std::vector & tests, const backend_cli_args & args) { + for (const auto * test : tests) { + fprintf(stderr, "\n=== %s ===\n", test->name); + test->fn(args); + } +} + + +int main(int argc, char ** argv) { + backend_cli_args args = parse_backend_cli(argc, argv); + + if (args.model == nullptr) { + args.model = get_model_or_exit(1, argv); + } + + std::ifstream file(args.model); + if (!file.is_open()) { + fprintf(stderr, "no model '%s' found\n", args.model); + return EXIT_FAILURE; + } + + fprintf(stderr, "using '%s'\n", args.model); + + ggml_time_init(); + + const std::vector tests = collect_tests_to_run(args.test); + if (!tests.empty()) { + run_tests(tests, args); + } + + return 0; +} diff --git a/tests/test-regex-partial.cpp b/tests/test-regex-partial.cpp index ffad189786..70af6d75a1 100644 --- a/tests/test-regex-partial.cpp +++ b/tests/test-regex-partial.cpp @@ -232,52 +232,52 @@ static void test_regex_to_reversed_partial_regex() { printf("[%s]\n", __func__); assert_equals( - "((?:(?:c)?b)?a)[\\s\\S]*", + "^((?:(?:c)?b)?a)", regex_to_reversed_partial_regex("abc")); assert_equals( - "(a+)[\\s\\S]*", + "^(a+)", regex_to_reversed_partial_regex("a+")); assert_equals( - "(a*)[\\s\\S]*", + "^(a*)", regex_to_reversed_partial_regex("a*")); assert_equals( - "(a?)[\\s\\S]*", + "^(a?)", regex_to_reversed_partial_regex("a?")); assert_equals( - "([a-z])[\\s\\S]*", + "^([a-z])", regex_to_reversed_partial_regex("[a-z]")); assert_equals( - "((?:\\w+)?[a-z])[\\s\\S]*", + "^((?:\\w+)?[a-z])", regex_to_reversed_partial_regex("[a-z]\\w+")); assert_equals( - "((?:a|b))[\\s\\S]*", + "^((?:a|b))", regex_to_reversed_partial_regex("(?:a|b)")); assert_equals( - "((?:(?:(?:d)?c)?b)?a)[\\s\\S]*", + "^((?:(?:(?:d)?c)?b)?a)", regex_to_reversed_partial_regex("abcd")); assert_equals( - "((?:b)?a*)[\\s\\S]*", // TODO: ((?:b)?a*+).* ?? + "^((?:b)?a*)", // TODO: ((?:b)?a*+).* ?? regex_to_reversed_partial_regex("a*b")); assert_equals( - "((?:(?:b)?a)?.*)[\\s\\S]*", + "^((?:(?:b)?a)?.*)", regex_to_reversed_partial_regex(".*?ab")); assert_equals( - "((?:(?:b)?.*)?a)[\\s\\S]*", + "^((?:(?:b)?.*)?a)", regex_to_reversed_partial_regex("a.*?b")); assert_equals( - "((?:(?:d)?(?:(?:c)?b))?a)[\\s\\S]*", + "^((?:(?:d)?(?:(?:c)?b))?a)", regex_to_reversed_partial_regex("a(bc)d")); assert_equals( - "((?:(?:(?:c)?b|(?:e)?d))?a)[\\s\\S]*", + "^((?:(?:(?:c)?b|(?:e)?d))?a)", regex_to_reversed_partial_regex("a(bc|de)")); assert_equals( - "((?:(?:(?:(?:(?:c)?b?)?b?)?b)?b)?a)[\\s\\S]*", + "^((?:(?:(?:(?:(?:c)?b?)?b?)?b)?b)?a)", regex_to_reversed_partial_regex("ab{2,4}c")); } diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index 8df3f41003..48959fefb5 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -25,7 +25,6 @@ else() if (LLAMA_BUILD_SERVER) add_subdirectory(server) endif() - add_subdirectory(run) add_subdirectory(tokenize) add_subdirectory(tts) add_subdirectory(mtmd) diff --git a/tools/mtmd/CMakeLists.txt b/tools/mtmd/CMakeLists.txt index 317d5f19fd..4b9022cb58 100644 --- a/tools/mtmd/CMakeLists.txt +++ b/tools/mtmd/CMakeLists.txt @@ -27,6 +27,7 @@ add_library(mtmd models/qwen3vl.cpp models/siglip.cpp models/whisper-enc.cpp + models/youtuvl.cpp ) set_target_properties(mtmd PROPERTIES diff --git a/tools/mtmd/clip-impl.h b/tools/mtmd/clip-impl.h index 1ed0741883..df7e479765 100644 --- a/tools/mtmd/clip-impl.h +++ b/tools/mtmd/clip-impl.h @@ -45,13 +45,14 @@ #define KEY_SPATIAL_MERGE_SIZE "clip.vision.spatial_merge_size" #define KEY_IS_DEEPSTACK_LAYERS "clip.vision.is_deepstack_layers" -#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" -#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" -#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution" -#define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern" -#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size" -#define KEY_MINICPMV_VERSION "clip.minicpmv_version" -#define KEY_MINICPMV_QUERY_NUM "clip.minicpmv_query_num" +#define KEY_MM_PATCH_MERGE_TYPE "clip.vision.mm_patch_merge_type" +#define KEY_IMAGE_GRID_PINPOINTS "clip.vision.image_grid_pinpoints" +#define KEY_IMAGE_CROP_RESOLUTION "clip.vision.image_crop_resolution" +#define KEY_WIN_ATTN_PATTERN "clip.vision.n_wa_pattern" +#define KEY_WIN_ATTN_LAYER_INDEXES "clip.vision.wa_layer_indexes" +#define KEY_ATTN_WINDOW_SIZE "clip.vision.window_size" +#define KEY_MINICPMV_VERSION "clip.minicpmv_version" +#define KEY_MINICPMV_QUERY_NUM "clip.minicpmv_query_num" // audio-specific #define KEY_AUDIO_PROJ_TYPE "clip.audio.projector_type" // for models with mixed modalities @@ -188,6 +189,7 @@ enum projector_type { PROJECTOR_TYPE_JANUS_PRO, PROJECTOR_TYPE_LFM2A, PROJECTOR_TYPE_GLM4V, + PROJECTOR_TYPE_YOUTUVL, PROJECTOR_TYPE_UNKNOWN, }; @@ -218,6 +220,7 @@ static std::map PROJECTOR_TYPE_NAMES = { { PROJECTOR_TYPE_JANUS_PRO, "janus_pro"}, { PROJECTOR_TYPE_LFM2A, "lfm2a"}, { PROJECTOR_TYPE_GLM4V, "glm4v"}, + { PROJECTOR_TYPE_YOUTUVL, "youtuvl"}, }; static projector_type clip_projector_type_from_string(const std::string & str) { diff --git a/tools/mtmd/clip-model.h b/tools/mtmd/clip-model.h index 1e5aa87b98..702e10151a 100644 --- a/tools/mtmd/clip-model.h +++ b/tools/mtmd/clip-model.h @@ -61,6 +61,7 @@ struct clip_hparams { std::unordered_set vision_feature_layer; int32_t attn_window_size = 0; int32_t n_wa_pattern = 0; + std::unordered_set wa_layer_indexes; // explicit layer indexes that use full attention (for irregular patterns like YoutuVL) // audio int32_t n_mel_bins = 0; // whisper preprocessor diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp index fb08dd258c..9c9abd8d2e 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -846,6 +846,10 @@ static ggml_cgraph * clip_image_build_graph(clip_ctx * ctx, const clip_image_f32 { builder = std::make_unique(ctx, img); } break; + case PROJECTOR_TYPE_YOUTUVL: + { + builder = std::make_unique(ctx, img); + } break; default: GGML_ABORT("missing cgraph builder"); } @@ -1159,6 +1163,20 @@ struct clip_model_loader { LOG_WRN("%s: more info: https://github.com/ggml-org/llama.cpp/issues/16842\n\n", __func__); } } break; + case PROJECTOR_TYPE_YOUTUVL: + { + hparams.n_merge = 2; + get_u32(KEY_SPATIAL_MERGE_SIZE, hparams.n_merge, false); + get_u32(KEY_ATTN_WINDOW_SIZE, hparams.attn_window_size, true); + std::vector wa_layer_indexes_vec; + get_arr_int(KEY_WIN_ATTN_LAYER_INDEXES, wa_layer_indexes_vec, true); + for (auto & layer : wa_layer_indexes_vec) { + hparams.wa_layer_indexes.insert(layer); + } + // support max_height * max_width = 8000 * 8000. 8000/16/2 = 250 image tokens + hparams.set_limit_image_tokens(1, 62500); + hparams.set_warmup_n_tokens(16*16); // avoid OOM on warmup + } break; case PROJECTOR_TYPE_GLM4V: { hparams.rope_theta = 10000.0f; @@ -1227,7 +1245,14 @@ struct clip_model_loader { LOG_INF("%s: has_llava_proj: %d\n", __func__, hparams.has_llava_projector); LOG_INF("%s: minicpmv_version: %d\n", __func__, hparams.minicpmv_version); LOG_INF("%s: n_merge: %d\n", __func__, hparams.n_merge); - LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern); + LOG_INF("%s: n_wa_pattern: %d\n", __func__, hparams.n_wa_pattern); + if (!hparams.wa_layer_indexes.empty()) { + LOG_INF("%s: wa_layer_indexes: ", __func__); + for (auto & layer : hparams.wa_layer_indexes) { + LOG_INF("%d ", layer); + } + LOG_INF("\n"); + } if (hparams.image_min_pixels > 0) { LOG_INF("%s: image_min_pixels: %d%s\n", __func__, hparams.image_min_pixels, hparams.custom_image_min_tokens > 0 ? " (custom value)" : ""); } @@ -1495,6 +1520,14 @@ struct clip_model_loader { model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); } break; + case PROJECTOR_TYPE_YOUTUVL: + { + model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM); // merger.ln_q (RMS norm) + model.mm_0_w = get_tensor(string_format(TN_LLAVA_PROJ, 0, "weight")); // merger.mlp.0 + model.mm_0_b = get_tensor(string_format(TN_LLAVA_PROJ, 0, "bias")); + model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); // merger.mlp.2 + model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); + } break; case PROJECTOR_TYPE_GLM4V: { model.projection = get_tensor(TN_MM_PROJECTOR); @@ -1519,6 +1552,14 @@ struct clip_model_loader { model.projection = get_tensor(TN_MM_PROJECTOR); } break; case PROJECTOR_TYPE_LFM2: + { + model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM, false); + model.mm_input_norm_b = get_tensor(TN_MM_INP_NORM_B, false); + model.mm_1_w = get_tensor(string_format(TN_LLAVA_PROJ, 1, "weight")); + model.mm_1_b = get_tensor(string_format(TN_LLAVA_PROJ, 1, "bias")); + model.mm_2_w = get_tensor(string_format(TN_LLAVA_PROJ, 2, "weight")); + model.mm_2_b = get_tensor(string_format(TN_LLAVA_PROJ, 2, "bias")); + } break; case PROJECTOR_TYPE_KIMIVL: { model.mm_input_norm_w = get_tensor(TN_MM_INP_NORM); @@ -2697,6 +2738,57 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str // res_imgs->data[0] = *res; res_imgs->entries.push_back(std::move(img_f32)); } break; + case PROJECTOR_TYPE_YOUTUVL: + { + const int patch_size = params.patch_size; // typically 16 + const int merge_size = params.n_merge; // typically 2 + const int align_size = patch_size * merge_size; // 32 + + const int max_num_patches = params.image_max_pixels > 0 ? + params.image_max_pixels / (patch_size * patch_size) : 256; + + // Linear search for optimal scale to fit within max_num_patches + float scale = 1.0f; + int target_height = original_size.height; + int target_width = original_size.width; + + auto get_scaled_image_size = [align_size](float scale, int size) -> int { + float scaled_size = size * scale; + // Round up to nearest multiple of align_size + int aligned = static_cast(std::ceil(scaled_size / align_size)) * align_size; + // Ensure at least one patch + return std::max(align_size, aligned); + }; + + // Linear search with 0.02 step size + while (scale > 0.0f) { + target_height = get_scaled_image_size(scale, original_size.height); + target_width = get_scaled_image_size(scale, original_size.width); + + int num_patches_h = target_height / patch_size; + int num_patches_w = target_width / patch_size; + int num_patches = num_patches_h * num_patches_w; + + if (num_patches > max_num_patches) { + scale -= 0.02f; + } else { + break; + } + } + + clip_image_size new_size = {target_width, target_height}; + + // Resize the image + clip_image_u8 resized; + img_tool::resize(*img, resized, new_size, img_tool::RESIZE_ALGO_BILINEAR, false); + + // Normalize to float32 + clip_image_f32_ptr img_f32(clip_image_f32_init()); + normalize_image_u8_to_f32(resized, *img_f32, params.image_mean, params.image_std); + + // Add to results + res_imgs->entries.push_back(std::move(img_f32)); + } break; case PROJECTOR_TYPE_IDEFICS3: { @@ -2929,6 +3021,7 @@ int clip_n_output_tokens_x(const struct clip_ctx * ctx, struct clip_image_f32 * case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: case PROJECTOR_TYPE_GLM4V: + case PROJECTOR_TYPE_YOUTUVL: return (img->nx / params.patch_size) / 2; default: break; @@ -2944,6 +3037,7 @@ int clip_n_output_tokens_y(const struct clip_ctx * ctx, struct clip_image_f32 * case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: case PROJECTOR_TYPE_GLM4V: + case PROJECTOR_TYPE_YOUTUVL: return (img->ny / params.patch_size) / 2; default: break; @@ -3004,6 +3098,7 @@ int clip_n_output_tokens(const struct clip_ctx * ctx, struct clip_image_f32 * im case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_QWEN3VL: case PROJECTOR_TYPE_GLM4V: + case PROJECTOR_TYPE_YOUTUVL: { // dynamic size (2 conv, so double patch size) int x_patch = img->nx / (params.patch_size * 2); @@ -3131,7 +3226,6 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima const int pos_w = image_size_width / patch_size; const int pos_h = image_size_height / patch_size; - const bool use_window_attn = hparams.n_wa_pattern > 0; // for qwen2.5vl auto get_inp_tensor = [&gf](const char * name) { ggml_tensor * inp = ggml_graph_get_tensor(gf, name); @@ -3280,9 +3374,11 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima set_input_i32("positions", positions); } break; case PROJECTOR_TYPE_QWEN25VL: + case PROJECTOR_TYPE_YOUTUVL: { // pw * ph = number of tokens output by ViT after apply patch merger // ipw * ipw = number of vision token been processed inside ViT + const bool use_window_attn = ctx->model.proj_type == PROJECTOR_TYPE_QWEN25VL ? hparams.n_wa_pattern > 0 : !hparams.wa_layer_indexes.empty(); const int merge_ratio = 2; const int pw = image_size_width / patch_size / merge_ratio; const int ph = image_size_height / patch_size / merge_ratio; @@ -3293,7 +3389,7 @@ bool clip_image_batch_encode(clip_ctx * ctx, const int n_threads, const clip_ima std::vector inv_idx(ph * pw); if (use_window_attn) { - const int attn_window_size = 112; + const int attn_window_size = hparams.attn_window_size > 0 ? hparams.attn_window_size : 112; const int grid_window = attn_window_size / patch_size / merge_ratio; int dst = 0; // [num_vision_tokens, num_vision_tokens] attention mask tensor @@ -3531,6 +3627,7 @@ int clip_n_mmproj_embd(const struct clip_ctx * ctx) { case PROJECTOR_TYPE_QWEN2VL: case PROJECTOR_TYPE_QWEN25VL: case PROJECTOR_TYPE_JANUS_PRO: + case PROJECTOR_TYPE_YOUTUVL: return ctx->model.mm_1_b->ne[0]; case PROJECTOR_TYPE_QWEN3VL: // main path + deepstack paths diff --git a/tools/mtmd/models/models.h b/tools/mtmd/models/models.h index e08c33f353..74e94f60ec 100644 --- a/tools/mtmd/models/models.h +++ b/tools/mtmd/models/models.h @@ -27,6 +27,11 @@ struct clip_graph_qwen3vl : clip_graph { ggml_cgraph * build() override; }; +struct clip_graph_youtuvl : clip_graph { + clip_graph_youtuvl(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} + ggml_cgraph * build() override; +}; + struct clip_graph_minicpmv : clip_graph { clip_graph_minicpmv(clip_ctx * ctx, const clip_image_f32 & img) : clip_graph(ctx, img) {} ggml_cgraph * build() override; diff --git a/tools/mtmd/models/siglip.cpp b/tools/mtmd/models/siglip.cpp index ef094cfd0e..b866a11c5a 100644 --- a/tools/mtmd/models/siglip.cpp +++ b/tools/mtmd/models/siglip.cpp @@ -50,10 +50,15 @@ ggml_cgraph * clip_graph_siglip::build() { const int scale_factor = model.hparams.n_merge; cur = build_patch_merge_permute(cur, scale_factor); - // projection - cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm - cur = ggml_mul(ctx0, cur, model.mm_input_norm_w); - cur = ggml_add(ctx0, cur, model.mm_input_norm_b); + // projection, in LFM2-VL input norm is optional + if (model.mm_input_norm_w) { + cur = ggml_norm(ctx0, cur, 1e-5); // default nn.LayerNorm + cur = ggml_mul(ctx0, cur, model.mm_input_norm_w); + } + + if (model.mm_input_norm_b) { + cur = ggml_add(ctx0, cur, model.mm_input_norm_b); + } cur = build_ffn(cur, model.mm_1_w, model.mm_1_b, diff --git a/tools/mtmd/models/youtuvl.cpp b/tools/mtmd/models/youtuvl.cpp new file mode 100644 index 0000000000..ffbf2be554 --- /dev/null +++ b/tools/mtmd/models/youtuvl.cpp @@ -0,0 +1,179 @@ +#include "models.h" + +ggml_cgraph * clip_graph_youtuvl::build() { + GGML_ASSERT(model.class_embedding == nullptr); + const int batch_size = 1; + const bool use_window_attn = !hparams.wa_layer_indexes.empty(); + const int n_pos = n_patches; + const int num_position_ids = n_pos * 4; + const int m = 2; + const int Wp = n_patches_x; + const int Hp = n_patches_y; + const int Hm = Hp / m; + const int Wm = Wp / m; + norm_type norm_t = NORM_TYPE_NORMAL; + + int mrope_sections[4] = {d_head/4, d_head/4, d_head/4, d_head/4}; + + ggml_tensor * inp = build_inp_raw(); + + // change conv3d to linear + // reshape and permute to get patches, permute from (patch_size, m, Wm, patch_size, m, Hm, C) to (C, patch_size, patch_size, m, m, Wm, Hm) + { + inp = ggml_reshape_4d( + ctx0, inp, + Wm * m * patch_size, m * patch_size, Hm, 3); + inp = ggml_permute(ctx0, inp, 1, 2, 3, 0); + inp = ggml_cont_4d( + ctx0, inp, + m * patch_size * 3, Wm, m * patch_size, Hm); + + inp = ggml_permute(ctx0, inp, 0, 2, 1, 3); + inp = ggml_cont_4d( + ctx0, inp, + m * patch_size * 3, patch_size, m, Hm * Wm); + + inp = ggml_permute(ctx0, inp, 1, 0, 2, 3); + inp = ggml_cont_4d( + ctx0, inp, + patch_size, 3, patch_size, Hm * Wm * m * m); + + inp = ggml_permute(ctx0, inp, 2, 0, 1, 3); + inp = ggml_cont_3d( + ctx0, inp, + 3*patch_size* patch_size, Hm * Wm * m * m, 1); + } + inp = ggml_mul_mat(ctx0, model.patch_embeddings_0, inp); + + if (model.patch_bias) { + inp = ggml_add(ctx0, inp, model.patch_bias); + } + + inp = ggml_reshape_2d(ctx0, inp, n_embd, n_patches); + + ggml_tensor * inpL = inp; + ggml_tensor * window_mask = nullptr; + ggml_tensor * window_idx = nullptr; + ggml_tensor * inv_window_idx = nullptr; + + ggml_tensor * positions = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, num_position_ids); + ggml_set_name(positions, "positions"); + ggml_set_input(positions); + + // pre-layernorm + if (model.pre_ln_w) { + inpL = build_norm(inpL, model.pre_ln_w, model.pre_ln_b, norm_t, eps, -1); + } + if (use_window_attn) { + inv_window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / 4); + ggml_set_name(inv_window_idx, "inv_window_idx"); + ggml_set_input(inv_window_idx); + // mask for window attention + window_mask = ggml_new_tensor_2d(ctx0, GGML_TYPE_F32, n_pos, n_pos); + ggml_set_name(window_mask, "window_mask"); + ggml_set_input(window_mask); + + // if flash attn is used, we need to pad the mask and cast to f16 + if (flash_attn_type == CLIP_FLASH_ATTN_TYPE_ENABLED) { + window_mask = ggml_cast(ctx0, window_mask, GGML_TYPE_F16); + } + + // inpL shape: [n_embd, n_patches_x * n_patches_y, batch_size] + GGML_ASSERT(batch_size == 1); + inpL = ggml_reshape_2d(ctx0, inpL, n_embd * 4, n_patches_x * n_patches_y * batch_size / 4); + inpL = ggml_get_rows(ctx0, inpL, inv_window_idx); + inpL = ggml_reshape_3d(ctx0, inpL, n_embd, n_patches_x * n_patches_y, batch_size); + } + + // loop over layers + for (int il = 0; il < n_layer; il++) { + const auto & layer = model.layers[il]; + const bool full_attn = use_window_attn ? hparams.wa_layer_indexes.count(il) > 0 : true; + + ggml_tensor * cur = inpL; // inpL = residual, cur = hidden_states + + // layernorm1 + cur = build_norm(cur, layer.ln_1_w, layer.ln_1_b, norm_t, eps, il); + // self-attention + { + ggml_tensor * Qcur = ggml_add(ctx0, + ggml_mul_mat(ctx0, layer.q_w, cur), layer.q_b); + ggml_tensor * Kcur = ggml_add(ctx0, + ggml_mul_mat(ctx0, layer.k_w, cur), layer.k_b); + ggml_tensor * Vcur = ggml_add(ctx0, + ggml_mul_mat(ctx0, layer.v_w, cur), layer.v_b); + + Qcur = ggml_reshape_3d(ctx0, Qcur, d_head, n_head, n_patches); + Kcur = ggml_reshape_3d(ctx0, Kcur, d_head, n_head, n_patches); + Vcur = ggml_reshape_3d(ctx0, Vcur, d_head, n_head, n_patches); + + Qcur = ggml_rope_multi( + ctx0, Qcur, positions, nullptr, + d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); + Kcur = ggml_rope_multi( + ctx0, Kcur, positions, nullptr, + d_head/2, mrope_sections, GGML_ROPE_TYPE_VISION, 32768, 10000, 1, 0, 1, 32, 1); + + ggml_tensor * attn_mask = full_attn ? nullptr : window_mask; + + cur = build_attn(layer.o_w, layer.o_b, + Qcur, Kcur, Vcur, attn_mask, kq_scale, il); + } + // re-add the layer input, e.g., residual + cur = ggml_add(ctx0, cur, inpL); + + inpL = cur; // inpL = residual, cur = hidden_states + + // layernorm2 + cur = build_norm(cur, layer.ln_2_w, layer.ln_2_b, norm_t, eps, il); + + // ffn + cur = build_ffn(cur, + layer.ff_up_w, layer.ff_up_b, + nullptr, nullptr, + layer.ff_down_w, layer.ff_down_b, + hparams.ffn_op, il); + + // residual 2 + cur = ggml_add(ctx0, inpL, cur); + + inpL = cur; + } + + ggml_tensor * embeddings = inpL; + if (use_window_attn) { + const int spatial_merge_unit = 4; + window_idx = ggml_new_tensor_1d(ctx0, GGML_TYPE_I32, n_pos / spatial_merge_unit); + ggml_set_name(window_idx, "window_idx"); + ggml_set_input(window_idx); + GGML_ASSERT(batch_size == 1); + embeddings = ggml_reshape_2d(ctx0, embeddings, n_embd * spatial_merge_unit, n_patches / spatial_merge_unit); + embeddings = ggml_get_rows(ctx0, embeddings, window_idx); + embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd, n_patches, batch_size); + cb(embeddings, "window_order_restored", -1); + } + + // post-layernorm (part of Siglip2VisionTransformer, applied after encoder) + if (model.post_ln_w) { + embeddings = build_norm(embeddings, model.post_ln_w, model.post_ln_b, norm_t, eps, n_layer); + } + + // Now apply merger (VLPatchMerger): + // 1. Apply RMS norm (ln_q in VLPatchMerger) + embeddings = build_norm(embeddings, model.mm_input_norm_w, nullptr, NORM_TYPE_RMS, 1e-6, -1); + cb(embeddings, "merger_normed", -1); + + // 2. First reshape for spatial merge (merge 2x2 patches) + embeddings = ggml_reshape_3d(ctx0, embeddings, n_embd * 4, n_pos / 4, batch_size); + cb(embeddings, "merger_reshaped", -1); + + embeddings = build_ffn(embeddings, + model.mm_0_w, model.mm_0_b, + nullptr, nullptr, + model.mm_1_w, model.mm_1_b, + FFN_GELU, + -1); + ggml_build_forward_expand(gf, embeddings); + + return gf; +} diff --git a/tools/mtmd/mtmd-audio.cpp b/tools/mtmd/mtmd-audio.cpp index e99101184b..e8eef035ff 100644 --- a/tools/mtmd/mtmd-audio.cpp +++ b/tools/mtmd/mtmd-audio.cpp @@ -9,207 +9,250 @@ #include #include -// most of the code here is copied from whisper.cpp +// some of the code here is copied from whisper.cpp constexpr bool DEBUG = false; -struct mtmd_audio_mel_filters { - int32_t n_mel; - int32_t n_fft; +void mtmd_audio_cache::fill_sin_cos_table(int n) { + sin_vals.resize(n); + cos_vals.resize(n); + for (int i = 0; i < n; i++) { + double theta = (2 * M_PI * i) / n; + sin_vals[i] = sinf(theta); + cos_vals[i] = cosf(theta); + } +} - std::vector data; -}; +void mtmd_audio_cache::fill_hann_window(int length, bool periodic) { + hann_window.resize(length); + int offset = -1; + if (periodic) { + offset = 0; + } + for (int i = 0; i < length; i++) { + hann_window[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); + } +} -// note: this global cache is shared among all preprocessors -// if we want to use multiple preprocessors at the same time, -// we will need to enclose it in the preprocessor class in the future -static struct mtmd_audio_global_cache { - // precomputed sin/cos table for FFT - std::vector sin_vals; - std::vector cos_vals; - - // hann window - std::vector hann_window; - - // mel filter bank - mtmd_audio_mel_filters filters; - - void fill_sin_cos_table(int n) { - sin_vals.resize(n); - cos_vals.resize(n); - for (int i = 0; i < n; i++) { - double theta = (2 * M_PI * i) / n; - sin_vals[i] = sinf(theta); - cos_vals[i] = cosf(theta); - } +void mtmd_audio_cache::fill_mel_filterbank_matrix(int n_mel, + int n_fft, + int sample_rate, + float fmin, + float fmax, + bool slaney_area_norm, + float scale) { + GGML_ASSERT(n_mel > 0 && n_fft > 1); + if (fmax <= 0.0f) { + fmax = 0.5f * sample_rate; } - void fill_hann_window(int length, bool periodic) { - hann_window.resize(length); - int offset = -1; - if (periodic) { - offset = 0; - } - for (int i = 0; i < length; i++) { - hann_window[i] = 0.5 * (1.0 - cosf((2.0 * M_PI * i) / (length + offset))); - } + // Slaney scale (matches librosa default) + const double min_log_hz = 1000.0; + const double lin_slope = 3 / 200.; + const double min_log_mel = min_log_hz * lin_slope; + const double log_step = log(6.4) / 27.0; + auto hz_to_mel = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double { + return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step; + }; + auto mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double { + return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step); + }; + + // infer N_fft from n_fft_bins + const double bin_hz_step = double(sample_rate) / double(n_fft); + + // mel grid: n_mel + 2 edges + const double m_lo = hz_to_mel(fmin); + const double m_hi = hz_to_mel(fmax); + std::vector mel_pts(n_mel + 2); + for (int i = 0; i < n_mel + 2; ++i) { + mel_pts[i] = m_lo + (m_hi - m_lo) * (double(i) / (n_mel + 1)); } - // Build mel filterbank matrix [n_mel × n_fft_bins] at runtime. - // n_fft_bins must be (N_fft / 2 + 1). Example: if N_fft=512 -> n_fft_bins=257. - void fill_mel_filterbank_matrix( - int n_mel, - int n_fft, - int sample_rate, // e.g. 16000 - float fmin = 0.0f, // e.g. 0.0 - float fmax = -1.0f, // e.g. sr/2; pass -1 for auto - bool slaney_area_norm = true, - float scale = 1.0f // optional extra scaling; use 1.0f/1000.0f to mimic your code - ) { - GGML_ASSERT(n_mel > 0 && n_fft > 1); - if (fmax <= 0.0f) { - fmax = 0.5f * sample_rate; - } + // convert to Hz + std::vector hz_pts(n_mel + 2); + for (int i = 0; i < n_mel + 2; ++i) { + hz_pts[i] = mel_to_hz(mel_pts[i]); + } - // Slaney scale (matches librosa default) - const double min_log_hz = 1000.0; - const double lin_slope = 3 / 200.; - const double min_log_mel = min_log_hz * lin_slope; - const double log_step = log(6.4) / 27.0; - auto hz_to_mel = [min_log_hz, lin_slope, log_step, min_log_mel](const double f_hz) -> double { - return (f_hz < min_log_hz) ? f_hz * lin_slope : min_log_mel + log(f_hz / min_log_hz) / log_step; - }; - auto mel_to_hz = [min_log_hz, lin_slope, log_step, min_log_mel](const double m) -> double { - return (m < min_log_mel) ? m / lin_slope : min_log_hz * exp((m - min_log_mel) * log_step); - }; + const int n_fft_bins = n_fft / 2 + 1; - // infer N_fft from n_fft_bins - const double bin_hz_step = double(sample_rate) / double(n_fft); + // filterbank + std::vector out(n_mel * n_fft_bins, 0); + for (int m = 0; m < n_mel; ++m) { + const double f_left = hz_pts[m]; + const double f_center = hz_pts[m + 1]; + const double f_right = hz_pts[m + 2]; - // mel grid: n_mel + 2 edges - const double m_lo = hz_to_mel(fmin); - const double m_hi = hz_to_mel(fmax); - std::vector mel_pts(n_mel + 2); - for (int i = 0; i < n_mel + 2; ++i) { - mel_pts[i] = m_lo + (m_hi - m_lo) * (double(i) / (n_mel + 1)); - } + const double denom_l = std::max(1e-30, f_center - f_left); + const double denom_r = std::max(1e-30, f_right - f_center); + const double enorm = slaney_area_norm ? (2.0 / std::max(1e-30, f_right - f_left)) : 1.0; - // convert to Hz - std::vector hz_pts(n_mel + 2); - for (int i = 0; i < n_mel + 2; ++i) { - hz_pts[i] = mel_to_hz(mel_pts[i]); - } - - const int n_fft_bins = n_fft / 2 + 1; - - // filterbank - std::vector out(n_mel * n_fft_bins, 0); - for (int m = 0; m < n_mel; ++m) { - const double f_left = hz_pts[m]; - const double f_center = hz_pts[m + 1]; - const double f_right = hz_pts[m + 2]; - - const double denom_l = std::max(1e-30, f_center - f_left); - const double denom_r = std::max(1e-30, f_right - f_center); - const double enorm = slaney_area_norm ? (2.0 / std::max(1e-30, f_right - f_left)) : 1.0; - - for (int k = 0; k < n_fft_bins; ++k) { - const double f = k * bin_hz_step; - double w = 0.0; - if (f >= f_left && f <= f_center) { - w = (f - f_left) / denom_l; - } else if (f > f_center && f <= f_right) { - w = (f_right - f) / denom_r; - } - out[size_t(m) * size_t(n_fft_bins) + size_t(k)] = float(w * enorm * scale); + for (int k = 0; k < n_fft_bins; ++k) { + const double f = k * bin_hz_step; + double w = 0.0; + if (f >= f_left && f <= f_center) { + w = (f - f_left) / denom_l; + } else if (f > f_center && f <= f_right) { + w = (f_right - f) / denom_r; } + out[size_t(m) * size_t(n_fft_bins) + size_t(k)] = float(w * enorm * scale); } + } - filters.n_mel = n_mel; - filters.n_fft = n_fft; - filters.data = std::move(out); + filters.n_mel = n_mel; + filters.n_fft = n_fft; + filters.data = std::move(out); - if (DEBUG) { // debug - for (size_t i = 0; i < filters.data.size(); ++i) { - if (filters.data[i] != 0.0f) { - printf("filters[%zu] = %f\n", i, filters.data[i] * 1000.0f); - } + if (DEBUG) { // debug + for (size_t i = 0; i < filters.data.size(); ++i) { + if (filters.data[i] != 0.0f) { + printf("filters[%zu] = %f\n", i, filters.data[i] * 1000.0f); } } } -} g_cache; +} -// naive Discrete Fourier Transform -// input is real-valued -// output is complex-valued -static void dft(const float * in, int N, float * out) { - const int n_sin_cos_vals = g_cache.sin_vals.size(); - const int sin_cos_step = n_sin_cos_vals / N; +// Unified DFT implementation for both forward and inverse transforms +// Template parameters: +// Inverse: false = DFT with exp(-2πi·k·n/N), no scaling +// true = IDFT with exp(+2πi·k·n/N), scales by 1/N +// RealInput: true = input is real-valued (stride 1), avoids imaginary computations +// false = input is complex-valued (interleaved real/imag, stride 2) +template +static void dft_impl(const mtmd_audio_cache & cache, const float * in, int N, float * out) { + const int n_sin_cos_vals = cache.sin_vals.size(); + const int sin_cos_step = n_sin_cos_vals / N; + + constexpr float sign = Inverse ? 1.0f : -1.0f; + const float scale = Inverse ? (1.0f / N) : 1.0f; for (int k = 0; k < N; k++) { float re = 0; float im = 0; for (int n = 0; n < N; n++) { - int idx = (k * n * sin_cos_step) % (n_sin_cos_vals); // t = 2*M_PI*k*n/N - re += in[n] * g_cache.cos_vals[idx]; // cos(t) - im -= in[n] * g_cache.sin_vals[idx]; // sin(t) + int idx = (k * n * sin_cos_step) % n_sin_cos_vals; + float cos_val = cache.cos_vals[idx]; + float sin_val = cache.sin_vals[idx]; + + if constexpr (RealInput) { + // Real input: in_im = 0, simplifies to: + // re += in_re * cos_val + // im += sign * in_re * sin_val + float in_re = in[n]; + re += in_re * cos_val; + im += sign * in_re * sin_val; + } else { + float in_re = in[n * 2 + 0]; + float in_im = in[n * 2 + 1]; + // (a + bi) * (cos + sign*i*sin) = (a*cos - sign*b*sin) + (sign*a*sin + b*cos)i + re += in_re * cos_val - sign * in_im * sin_val; + im += sign * in_re * sin_val + in_im * cos_val; + } } - out[k*2 + 0] = re; - out[k*2 + 1] = im; + out[k * 2 + 0] = re * scale; + out[k * 2 + 1] = im * scale; } } -// Cooley-Tukey FFT -// poor man's implementation - use something better -// input is real-valued -// output is complex-valued -static void fft(float * in, int N, float * out) { - const int n_sin_cos_vals = g_cache.sin_vals.size(); +// Cooley-Tukey FFT/IFFT unified implementation +// Template parameters: +// Inverse: false = FFT with exp(-2πi·k/N), no scaling +// true = IFFT with exp(+2πi·k/N), scales by 0.5 at each level +// RealInput: true = input is real-valued (stride 1) +// false = input is complex-valued (interleaved real/imag, stride 2) +template +static void fft_impl(const mtmd_audio_cache & cache, float * in, int N, float * out) { + const int n_sin_cos_vals = cache.sin_vals.size(); + if (N == 1) { out[0] = in[0]; - out[1] = 0; + if constexpr (RealInput) { + out[1] = 0.0f; + } else { + out[1] = in[1]; + } return; } const int half_N = N / 2; - if (N - half_N*2 == 1) { - dft(in, N, out); + if (N - half_N * 2 == 1) { + // Odd N: fall back to DFT + dft_impl(cache, in, N, out); return; } - float* even = in + N; - for (int i = 0; i < half_N; ++i) { - even[i]= in[2*i]; - } - float* even_fft = out + 2 * N; - fft(even, half_N, even_fft); + // Split into even and odd + if constexpr (RealInput) { + // Real input: stride is 1, copy only real values + float * even = in + N; + for (int i = 0; i < half_N; ++i) { + even[i] = in[2 * i]; + } + float * even_fft = out + 2 * N; + fft_impl(cache, even, half_N, even_fft); - float* odd = even; - for (int i = 0; i < half_N; ++i) { - odd[i] = in[2*i + 1]; + float * odd = even; + for (int i = 0; i < half_N; ++i) { + odd[i] = in[2 * i + 1]; + } + float * odd_fft = even_fft + N; + fft_impl(cache, odd, half_N, odd_fft); + } else { + // Complex input: stride is 2, copy complex pairs + float * even = in + N * 2; + for (int i = 0; i < half_N; ++i) { + even[i * 2 + 0] = in[2 * i * 2 + 0]; + even[i * 2 + 1] = in[2 * i * 2 + 1]; + } + float * even_fft = out + 2 * N; + fft_impl(cache, even, half_N, even_fft); + + float * odd = even; + for (int i = 0; i < half_N; ++i) { + odd[i * 2 + 0] = in[(2 * i + 1) * 2 + 0]; + odd[i * 2 + 1] = in[(2 * i + 1) * 2 + 1]; + } + float * odd_fft = even_fft + N; + fft_impl(cache, odd, half_N, odd_fft); } - float* odd_fft = even_fft + N; - fft(odd, half_N, odd_fft); + + float * even_fft = out + 2 * N; + float * odd_fft = even_fft + N; const int sin_cos_step = n_sin_cos_vals / N; + + constexpr float sign = Inverse ? 1.0f : -1.0f; + constexpr float scale = Inverse ? 0.5f : 1.0f; + for (int k = 0; k < half_N; k++) { - int idx = k * sin_cos_step; // t = 2*M_PI*k/N - float re = g_cache.cos_vals[idx]; // cos(t) - float im = -g_cache.sin_vals[idx]; // sin(t) + int idx = k * sin_cos_step; // t = 2*M_PI*k/N + float re = cache.cos_vals[idx]; + float im = sign * cache.sin_vals[idx]; - float re_odd = odd_fft[2*k + 0]; - float im_odd = odd_fft[2*k + 1]; + float re_odd = odd_fft[2 * k + 0]; + float im_odd = odd_fft[2 * k + 1]; - out[2*k + 0] = even_fft[2*k + 0] + re*re_odd - im*im_odd; - out[2*k + 1] = even_fft[2*k + 1] + re*im_odd + im*re_odd; + out[2 * k + 0] = scale * (even_fft[2 * k + 0] + re * re_odd - im * im_odd); + out[2 * k + 1] = scale * (even_fft[2 * k + 1] + re * im_odd + im * re_odd); - out[2*(k + half_N) + 0] = even_fft[2*k + 0] - re*re_odd + im*im_odd; - out[2*(k + half_N) + 1] = even_fft[2*k + 1] - re*im_odd - im*re_odd; + out[2 * (k + half_N) + 0] = scale * (even_fft[2 * k + 0] - re * re_odd + im * im_odd); + out[2 * (k + half_N) + 1] = scale * (even_fft[2 * k + 1] - re * im_odd - im * re_odd); } } +// Forward FFT for real input (used by mel spectrogram) +static void fft(const mtmd_audio_cache & cache, float * in, int N, float * out) { + fft_impl(cache, in, N, out); +} + +// Inverse FFT for complex input +static void ifft(const mtmd_audio_cache & cache, float * in, int N, float * out) { + fft_impl(cache, in, N, out); +} + struct filter_params { int32_t n_mel; int32_t n_fft_bins; @@ -222,20 +265,27 @@ struct filter_params { bool norm_per_feature = false; }; -static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const std::vector & samples, - int n_samples, int frame_size, int frame_step, int n_threads, - const filter_params & params, mtmd_audio_mel & out) { +static void log_mel_spectrogram_worker_thread(int ith, + const float * hann, + const std::vector & samples, + int n_samples, + int frame_size, + int frame_step, + int n_threads, + const filter_params & params, + const mtmd_audio_cache & cache, + mtmd_audio_mel & out) { std::vector fft_in(frame_size * 2, 0.0); std::vector fft_out(frame_size * 2 * 2 * 2); int n_fft_bins = params.n_fft_bins; int i = ith; - const auto & filters = g_cache.filters; + const auto & filters = cache.filters; // make sure n_fft == 1 + (WHISPER_N_FFT / 2), bin_0 to bin_nyquist GGML_ASSERT(n_fft_bins == 1 + (frame_size / 2)); - GGML_ASSERT(g_cache.sin_vals.size() == g_cache.cos_vals.size()); + GGML_ASSERT(cache.sin_vals.size() == cache.cos_vals.size()); // calculate FFT only when fft_in are not all zero for (; i < std::min(n_samples / frame_step + 1, out.n_len); i += n_threads) { const int offset = i * frame_step; @@ -251,7 +301,7 @@ static void log_mel_spectrogram_worker_thread(int ith, const float * hann, const } // FFT - fft(fft_in.data(), frame_size, fft_out.data()); + fft(cache, fft_in.data(), frame_size, fft_out.data()); // Calculate modulus^2 of complex numbers // Use pow(fft_out[2 * j + 0], 2) + pow(fft_out[2 * j + 1], 2) causes inference quality problem? Interesting. @@ -298,6 +348,7 @@ static bool log_mel_spectrogram( const int n_samples_in, const int n_threads, const filter_params & params, + const mtmd_audio_cache & cache, mtmd_audio_mel & out) { //const int64_t t_start_us = ggml_time_us(); @@ -305,9 +356,9 @@ static bool log_mel_spectrogram( int n_samples = n_samples_in; // Hann window - const float * hann = g_cache.hann_window.data(); - const int frame_size = (params.n_fft_bins - 1) * 2; - const int frame_step = params.hop_length; + const float * hann = cache.hann_window.data(); + const int frame_size = (params.n_fft_bins - 1) * 2; + const int frame_step = params.hop_length; // Padding std::vector samples_padded; @@ -335,9 +386,9 @@ static bool log_mel_spectrogram( // preemphasis if (params.preemph) { - const int pad_amount = frame_size / 2; + const int pad_amount = frame_size / 2; const float preemph = 0.97f; - float prev = samples_padded[pad_amount]; + float prev = samples_padded[pad_amount]; for (int i = pad_amount + 1; i + pad_amount < n_samples; ++i) { float cur = samples_padded[i]; samples_padded[i] = cur - preemph * prev; @@ -372,14 +423,14 @@ static bool log_mel_spectrogram( { std::vector workers(n_threads - 1); for (int iw = 0; iw < n_threads - 1; ++iw) { - workers[iw] = std::thread( - log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded), - n_samples, frame_size, frame_step, n_threads, - std::cref(params), std::ref(out)); + workers[iw] = + std::thread(log_mel_spectrogram_worker_thread, iw + 1, hann, std::cref(samples_padded), n_samples, + frame_size, frame_step, n_threads, std::cref(params), std::cref(cache), std::ref(out)); } // main thread - log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples, frame_size, frame_step, n_threads, params, out); + log_mel_spectrogram_worker_thread(0, hann, samples_padded, n_samples, frame_size, frame_step, n_threads, params, + cache, out); for (int iw = 0; iw < n_threads - 1; ++iw) { workers[iw].join(); } @@ -404,7 +455,7 @@ static bool log_mel_spectrogram( for (int j = 0; j < effective_n_len; ++j) { auto &value = out.data[i * out.n_len + j]; - value = (value - mean) / mstd; + value = (value - mean) / mstd; } // pad the rest with zeros @@ -450,18 +501,14 @@ static bool log_mel_spectrogram( // void mtmd_audio_preprocessor_whisper::initialize() { - g_cache.fill_sin_cos_table(hparams.audio_n_fft); - g_cache.fill_hann_window(hparams.audio_window_len, true); - g_cache.fill_mel_filterbank_matrix( - hparams.n_mel_bins, - hparams.audio_n_fft, - hparams.audio_sample_rate); + cache.fill_sin_cos_table(hparams.audio_n_fft); + cache.fill_hann_window(hparams.audio_window_len, true); + cache.fill_mel_filterbank_matrix(hparams.n_mel_bins, hparams.audio_n_fft, hparams.audio_sample_rate); } -bool mtmd_audio_preprocessor_whisper::preprocess( - const float * samples, - size_t n_samples, - std::vector & output) { +bool mtmd_audio_preprocessor_whisper::preprocess(const float * samples, + size_t n_samples, + std::vector & output) { if (n_samples == 0) { // empty audio return false; @@ -471,7 +518,7 @@ bool mtmd_audio_preprocessor_whisper::preprocess( // if input is too short, pad with zeros // this is to avoid potential issues with stage1/2 padding in log_mel_spectrogram // TODO: maybe handle this better - size_t min_samples = (size_t)hparams.audio_sample_rate * (hparams.audio_chunk_len + 1); // +1 second margin + size_t min_samples = (size_t) hparams.audio_sample_rate * (hparams.audio_chunk_len + 1); // +1 second margin if (n_samples < min_samples) { smpl.resize(min_samples, 0.0f); std::memcpy(smpl.data(), samples, n_samples * sizeof(float)); @@ -486,22 +533,19 @@ bool mtmd_audio_preprocessor_whisper::preprocess( params.hop_length = hparams.audio_hop_len; params.sample_rate = hparams.audio_sample_rate; params.center_padding = false; - params.preemph = 0.0f; // disabled + params.preemph = 0.0f; // disabled params.use_natural_log = false; params.norm_per_feature = false; - // make sure the global cache is initialized - GGML_ASSERT(!g_cache.sin_vals.empty()); - GGML_ASSERT(!g_cache.cos_vals.empty()); - GGML_ASSERT(!g_cache.filters.data.empty()); + // make sure the cache is initialized + GGML_ASSERT(!cache.sin_vals.empty()); + GGML_ASSERT(!cache.cos_vals.empty()); + GGML_ASSERT(!cache.filters.data.empty()); mtmd_audio_mel out_full; - bool ok = log_mel_spectrogram( - samples, - n_samples, - 4, // n_threads - params, - out_full); + bool ok = log_mel_spectrogram(samples, n_samples, + 4, // n_threads + params, cache, out_full); if (!ok) { return false; } @@ -512,21 +556,21 @@ bool mtmd_audio_preprocessor_whisper::preprocess( printf("output: n_mel = %d, n_len = %d\n", out_full.n_mel, out_full.n_len); } const size_t frames_per_chunk = 3000; - GGML_ASSERT((size_t)out_full.n_len > frames_per_chunk); - for (size_t off = 0; off < (size_t)out_full.n_len; off += frames_per_chunk) { - int n_len = std::min(frames_per_chunk, (size_t)out_full.n_len - off); - if ((size_t)n_len < frames_per_chunk) { - break; // last uncomplete chunk will always be a padded chunk, safe to ignore + GGML_ASSERT((size_t) out_full.n_len > frames_per_chunk); + for (size_t off = 0; off < (size_t) out_full.n_len; off += frames_per_chunk) { + int n_len = std::min(frames_per_chunk, (size_t) out_full.n_len - off); + if ((size_t) n_len < frames_per_chunk) { + break; // last uncomplete chunk will always be a padded chunk, safe to ignore } mtmd_audio_mel out_chunk; out_chunk.n_len = n_len; out_chunk.n_mel = out_full.n_mel; - out_chunk.n_len_org = out_full.n_mel; // unused + out_chunk.n_len_org = out_full.n_mel; // unused out_chunk.data.reserve(out_chunk.n_mel * out_chunk.n_len); for (int i = 0; i < out_full.n_mel; i++) { - auto src = out_full.data.begin() + i*out_full.n_len + off; + auto src = out_full.data.begin() + i * out_full.n_len + off; out_chunk.data.insert(out_chunk.data.end(), src, src + frames_per_chunk); } @@ -541,18 +585,14 @@ bool mtmd_audio_preprocessor_whisper::preprocess( // void mtmd_audio_preprocessor_conformer::initialize() { - g_cache.fill_sin_cos_table(hparams.audio_n_fft); - g_cache.fill_hann_window(hparams.audio_window_len, true); - g_cache.fill_mel_filterbank_matrix( - hparams.n_mel_bins, - hparams.audio_n_fft, - hparams.audio_sample_rate); + cache.fill_sin_cos_table(hparams.audio_n_fft); + cache.fill_hann_window(hparams.audio_window_len, true); + cache.fill_mel_filterbank_matrix(hparams.n_mel_bins, hparams.audio_n_fft, hparams.audio_sample_rate); } -bool mtmd_audio_preprocessor_conformer::preprocess( - const float * samples, - size_t n_samples, - std::vector & output) { +bool mtmd_audio_preprocessor_conformer::preprocess(const float * samples, + size_t n_samples, + std::vector & output) { // empty audio if (n_samples == 0) { return false; @@ -569,18 +609,15 @@ bool mtmd_audio_preprocessor_conformer::preprocess( params.use_natural_log = true; params.norm_per_feature = true; - // make sure the global cache is initialized - GGML_ASSERT(!g_cache.sin_vals.empty()); - GGML_ASSERT(!g_cache.cos_vals.empty()); - GGML_ASSERT(!g_cache.filters.data.empty()); + // make sure the cache is initialized + GGML_ASSERT(!cache.sin_vals.empty()); + GGML_ASSERT(!cache.cos_vals.empty()); + GGML_ASSERT(!cache.filters.data.empty()); mtmd_audio_mel out_full; - bool ok = log_mel_spectrogram( - samples, - n_samples, - 4, // n_threads - params, - out_full); + bool ok = log_mel_spectrogram(samples, n_samples, + 4, // n_threads + params, cache, out_full); if (!ok) { return false; } @@ -588,3 +625,106 @@ bool mtmd_audio_preprocessor_conformer::preprocess( output.push_back(std::move(out_full)); return true; } + +// +// mtmd_audio_streaming_istft implementation +// + +mtmd_audio_streaming_istft::mtmd_audio_streaming_istft(int n_fft, int hop_length) : + n_fft(n_fft), + hop_length(hop_length), + n_fft_bins(n_fft / 2 + 1), + overlap_buffer(n_fft, 0.0f), + window_sum_buffer(n_fft, 0.0f), + padding_to_remove((n_fft - hop_length) / 2), + ifft_in(n_fft * 2 * 4, 0.0f), // extra space for recursive IFFT + ifft_out(n_fft * 2 * 4, 0.0f) { + cache.fill_sin_cos_table(n_fft); + cache.fill_hann_window(n_fft, true); +} + +void mtmd_audio_streaming_istft::reset() { + std::fill(overlap_buffer.begin(), overlap_buffer.end(), 0.0f); + std::fill(window_sum_buffer.begin(), window_sum_buffer.end(), 0.0f); + padding_to_remove = (n_fft - hop_length) / 2; +} + +std::vector mtmd_audio_streaming_istft::process_frame(const float * frame_spectrum) { + std::vector output(hop_length); + + // copy frequencies + for (int j = 0; j < n_fft_bins; j++) { + ifft_in[j * 2 + 0] = frame_spectrum[j * 2 + 0]; + ifft_in[j * 2 + 1] = frame_spectrum[j * 2 + 1]; + } + + // mirror negative frequencies + for (int j = 1; j < n_fft_bins - 1; j++) { + int mirror_idx = n_fft - j; + ifft_in[mirror_idx * 2 + 0] = ifft_in[j * 2 + 0]; + ifft_in[mirror_idx * 2 + 1] = -ifft_in[j * 2 + 1]; // conjugate + } + + ifft(cache, ifft_in.data(), n_fft, ifft_out.data()); + + // update window sum and overlap buffer + for (int j = 0; j < n_fft; j++) { + window_sum_buffer[j] += cache.hann_window[j] * cache.hann_window[j]; + overlap_buffer[j] += ifft_out[j * 2] * cache.hann_window[j]; + } + + // extract hop_length samples with normalization + for (int i = 0; i < hop_length; i++) { + if (window_sum_buffer[i] > 1e-8f) { + output[i] = overlap_buffer[i] / window_sum_buffer[i]; + } else { + output[i] = overlap_buffer[i]; + } + } + + // shift buffers left by hop_length + std::copy(overlap_buffer.begin() + hop_length, overlap_buffer.end(), overlap_buffer.begin()); + std::fill(overlap_buffer.end() - hop_length, overlap_buffer.end(), 0.0f); + + std::copy(window_sum_buffer.begin() + hop_length, window_sum_buffer.end(), window_sum_buffer.begin()); + std::fill(window_sum_buffer.end() - hop_length, window_sum_buffer.end(), 0.0f); + + // Remove padding if needed + int to_remove = std::min(padding_to_remove, (int) output.size()); + padding_to_remove -= to_remove; + output.erase(output.begin(), output.begin() + to_remove); + + return output; +} + +std::vector mtmd_audio_streaming_istft::flush() { + std::vector output; + + // Extract remaining samples from overlap buffer + // Continue until we've extracted all meaningful samples + int remaining = n_fft - hop_length; + while (remaining > 0) { + int chunk_size = std::min(remaining, hop_length); + + for (int i = 0; i < chunk_size; i++) { + float sample; + if (window_sum_buffer[i] > 1e-8f) { + sample = overlap_buffer[i] / window_sum_buffer[i]; + } else { + sample = overlap_buffer[i]; + } + output.push_back(sample); + } + + // Shift buffers + std::copy(overlap_buffer.begin() + chunk_size, overlap_buffer.end(), overlap_buffer.begin()); + std::fill(overlap_buffer.end() - chunk_size, overlap_buffer.end(), 0.0f); + + std::copy(window_sum_buffer.begin() + chunk_size, window_sum_buffer.end(), window_sum_buffer.begin()); + std::fill(window_sum_buffer.end() - chunk_size, window_sum_buffer.end(), 0.0f); + + remaining -= chunk_size; + } + + return output; +} diff --git a/tools/mtmd/mtmd-audio.h b/tools/mtmd/mtmd-audio.h index d484c9d030..016c7392e4 100644 --- a/tools/mtmd/mtmd-audio.h +++ b/tools/mtmd/mtmd-audio.h @@ -17,6 +17,38 @@ struct mtmd_audio_mel { std::vector data; }; +struct mtmd_audio_mel_filters { + int32_t n_mel; + int32_t n_fft; + + std::vector data; +}; + +// cache for audio processing, each processor instance owns its own cache +struct mtmd_audio_cache { + std::vector sin_vals; + std::vector cos_vals; + + std::vector hann_window; + + mtmd_audio_mel_filters filters; + + void fill_sin_cos_table(int n); + + void fill_hann_window(int length, bool periodic); + + // Build mel filterbank matrix [n_mel × n_fft_bins] at runtime. + // n_fft_bins must be (N_fft / 2 + 1). Example: if N_fft=512 -> n_fft_bins=257. + void fill_mel_filterbank_matrix(int n_mel, + int n_fft, + int sample_rate, // e.g. 16000 + float fmin = 0.0f, // e.g. 0.0 + float fmax = -1.0f, // e.g. sr/2; pass -1 for auto + bool slaney_area_norm = true, + float scale = 1.0f // optional extra scaling + ); +}; + struct mtmd_audio_preprocessor { const clip_hparams & hparams; @@ -31,10 +63,51 @@ struct mtmd_audio_preprocessor_whisper : mtmd_audio_preprocessor { mtmd_audio_preprocessor_whisper(const clip_ctx * ctx) : mtmd_audio_preprocessor(ctx) {} void initialize() override; bool preprocess(const float * samples, size_t n_samples, std::vector & output) override; + + private: + mtmd_audio_cache cache; }; struct mtmd_audio_preprocessor_conformer : mtmd_audio_preprocessor { mtmd_audio_preprocessor_conformer(const clip_ctx * ctx) : mtmd_audio_preprocessor(ctx) {} void initialize() override; bool preprocess(const float * samples, size_t n_samples, std::vector & output) override; + + private: + mtmd_audio_cache cache; +}; + +// +// streaming ISTFT - converts spectrogram frames back to audio one frame at a time +// +struct mtmd_audio_streaming_istft { + mtmd_audio_streaming_istft(int n_fft, int hop_length); + + // reset streaming state + void reset(); + + // process a single STFT frame (streaming) + // frame_spectrum: [n_fft_bins x 2] interleaved real/imag + // returns: up to hop_length samples + std::vector process_frame(const float * frame_spectrum); + + // flush remaining samples at end of stream + std::vector flush(); + + private: + int n_fft; + int hop_length; + int n_fft_bins; + + // Own cache for output processing + mtmd_audio_cache cache; + + // Streaming state + std::vector overlap_buffer; + std::vector window_sum_buffer; + int padding_to_remove; + + // Working buffers for IFFT + std::vector ifft_in; + std::vector ifft_out; }; diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp index b0b5ab42ab..fca55b76f8 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -283,7 +283,7 @@ struct mtmd_context { // https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md img_end = "[IMG_END]"; - } else if (proj == PROJECTOR_TYPE_QWEN2VL || proj == PROJECTOR_TYPE_QWEN25VL || proj == PROJECTOR_TYPE_QWEN3VL) { + } else if (proj == PROJECTOR_TYPE_QWEN2VL || proj == PROJECTOR_TYPE_QWEN25VL || proj == PROJECTOR_TYPE_QWEN3VL || proj == PROJECTOR_TYPE_YOUTUVL) { // <|vision_start|> ... (image embeddings) ... <|vision_end|> img_beg = "<|vision_start|>"; img_end = "<|vision_end|>"; diff --git a/tools/run/CMakeLists.txt b/tools/run/CMakeLists.txt deleted file mode 100644 index 6ad7534e29..0000000000 --- a/tools/run/CMakeLists.txt +++ /dev/null @@ -1,23 +0,0 @@ -set(TARGET llama-run) -add_executable(${TARGET} run.cpp linenoise.cpp/linenoise.cpp) - -# TODO: avoid copying this code block from common/CMakeLists.txt -set(LLAMA_RUN_EXTRA_LIBS "") -if (LLAMA_CURL) - find_package(CURL REQUIRED) - target_compile_definitions(${TARGET} PUBLIC LLAMA_USE_CURL) - include_directories(${CURL_INCLUDE_DIRS}) - set(LLAMA_RUN_EXTRA_LIBS ${LLAMA_RUN_EXTRA_LIBS} ${CURL_LIBRARIES}) -endif () - -if(LLAMA_TOOLS_INSTALL) - install(TARGETS ${TARGET} RUNTIME) -endif() - -if (CMAKE_SYSTEM_NAME MATCHES "AIX") - # AIX's flock() function comes from libbsd.a - target_link_libraries(${TARGET} PRIVATE -lbsd) -endif() - -target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT} ${LLAMA_RUN_EXTRA_LIBS}) -target_compile_features(${TARGET} PRIVATE cxx_std_17) diff --git a/tools/run/README.md b/tools/run/README.md deleted file mode 100644 index 5fd769b44c..0000000000 --- a/tools/run/README.md +++ /dev/null @@ -1,52 +0,0 @@ -# llama.cpp/example/run - -The purpose of this example is to demonstrate a minimal usage of llama.cpp for running models. - -```bash -llama-run granite3-moe -``` - -```bash -Description: - Runs a llm - -Usage: - llama-run [options] model [prompt] - -Options: - -c, --context-size - Context size (default: 2048) - -n, -ngl, --ngl - Number of GPU layers (default: 0) - --temp - Temperature (default: 0.8) - -v, --verbose, --log-verbose - Set verbosity level to infinity (i.e. log all messages, useful for debugging) - -h, --help - Show help message - -Commands: - model - Model is a string with an optional prefix of - huggingface:// (hf://), ollama://, https:// or file://. - If no protocol is specified and a file exists in the specified - path, file:// is assumed, otherwise if a file does not exist in - the specified path, ollama:// is assumed. Models that are being - pulled are downloaded with .partial extension while being - downloaded and then renamed as the file without the .partial - extension when complete. - -Examples: - llama-run llama3 - llama-run ollama://granite-code - llama-run ollama://smollm:135m - llama-run hf://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf - llama-run huggingface://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf - llama-run ms://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf - llama-run modelscope://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf - llama-run https://example.com/some-file1.gguf - llama-run some-file2.gguf - llama-run file://some-file3.gguf - llama-run --ngl 999 some-file4.gguf - llama-run --ngl 999 some-file5.gguf Hello World -``` diff --git a/tools/run/linenoise.cpp/linenoise.cpp b/tools/run/linenoise.cpp/linenoise.cpp deleted file mode 100644 index 9cb9399003..0000000000 --- a/tools/run/linenoise.cpp/linenoise.cpp +++ /dev/null @@ -1,1995 +0,0 @@ -#ifndef _WIN32 -/* - * You can find the latest source code at: - * - * http://github.com/ericcurtin/linenoise.cpp - * - * Does a number of crazy assumptions that happen to be true in 99.9999% of - * the 2010 UNIX computers around. - * - * ------------------------------------------------------------------------ - * - * Copyright (c) 2010-2023, Salvatore Sanfilippo - * Copyright (c) 2010-2013, Pieter Noordhuis - * Copyright (c) 2025, Eric Curtin - * - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are - * met: - * - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - * - * ------------------------------------------------------------------------ - * - * References: - * - http://invisible-island.net/xterm/ctlseqs/ctlseqs.html - * - http://www.3waylabs.com/nw/WWW/products/wizcon/vt220.html - * - * Todo list: - * - Filter bogus Ctrl+ combinations. - * - Win32 support - * - * Bloat: - * - History search like Ctrl+r in readline? - * - * List of escape sequences used by this program, we do everything just - * with three sequences. In order to be so cheap we may have some - * flickering effect with some slow terminal, but the lesser sequences - * the more compatible. - * - * EL (Erase Line) - * Sequence: ESC [ n K - * Effect: if n is 0 or missing, clear from cursor to end of line - * Effect: if n is 1, clear from beginning of line to cursor - * Effect: if n is 2, clear entire line - * - * CUF (CUrsor Forward) - * Sequence: ESC [ n C - * Effect: moves cursor forward n chars - * - * CUB (CUrsor Backward) - * Sequence: ESC [ n D - * Effect: moves cursor backward n chars - * - * The following is used to get the terminal width if getting - * the width with the TIOCGWINSZ ioctl fails - * - * DSR (Device Status Report) - * Sequence: ESC [ 6 n - * Effect: reports the current cursor position as ESC [ n ; m R - * where n is the row and m is the column - * - * When multi line mode is enabled, we also use an additional escape - * sequence. However multi line editing is disabled by default. - * - * CUU (Cursor Up) - * Sequence: ESC [ n A - * Effect: moves cursor up of n chars. - * - * CUD (Cursor Down) - * Sequence: ESC [ n B - * Effect: moves cursor down of n chars. - * - * When linenoiseClearScreen() is called, two additional escape sequences - * are used in order to clear the screen and position the cursor at home - * position. - * - * CUP (Cursor position) - * Sequence: ESC [ H - * Effect: moves the cursor to upper left corner - * - * ED (Erase display) - * Sequence: ESC [ 2 J - * Effect: clear the whole screen - * - */ - -# include "linenoise.h" - -# include -# include -# include -# include -# include -# include -# include -# include -# include -# include -# include - -# include -# include -# include - -# define LINENOISE_DEFAULT_HISTORY_MAX_LEN 100 -# define LINENOISE_MAX_LINE 4096 -static std::vector unsupported_term = { "dumb", "cons25", "emacs" }; -static linenoiseCompletionCallback *completionCallback = NULL; -static linenoiseHintsCallback *hintsCallback = NULL; -static linenoiseFreeHintsCallback *freeHintsCallback = NULL; -static char *linenoiseNoTTY(void); -static void refreshLineWithCompletion(struct linenoiseState *ls, linenoiseCompletions *lc, int flags); -static void refreshLineWithFlags(struct linenoiseState *l, int flags); - -static struct termios orig_termios; /* In order to restore at exit.*/ -static int maskmode = 0; /* Show "***" instead of input. For passwords. */ -static int rawmode = 0; /* For atexit() function to check if restore is needed*/ -static int mlmode = 0; /* Multi line mode. Default is single line. */ -static int atexit_registered = 0; /* Register atexit just 1 time. */ -static int history_max_len = LINENOISE_DEFAULT_HISTORY_MAX_LEN; -static int history_len = 0; -static char **history = NULL; - -enum KEY_ACTION{ - KEY_NULL = 0, /* NULL */ - CTRL_A = 1, /* Ctrl+a */ - CTRL_B = 2, /* Ctrl-b */ - CTRL_C = 3, /* Ctrl-c */ - CTRL_D = 4, /* Ctrl-d */ - CTRL_E = 5, /* Ctrl-e */ - CTRL_F = 6, /* Ctrl-f */ - CTRL_H = 8, /* Ctrl-h */ - TAB = 9, /* Tab */ - CTRL_K = 11, /* Ctrl+k */ - CTRL_L = 12, /* Ctrl+l */ - ENTER = 13, /* Enter */ - CTRL_N = 14, /* Ctrl-n */ - CTRL_P = 16, /* Ctrl-p */ - CTRL_T = 20, /* Ctrl-t */ - CTRL_U = 21, /* Ctrl+u */ - CTRL_W = 23, /* Ctrl+w */ - ESC = 27, /* Escape */ - BACKSPACE = 127 /* Backspace */ -}; - -static void linenoiseAtExit(void); -int linenoiseHistoryAdd(const char *line); -#define REFRESH_CLEAN (1<<0) // Clean the old prompt from the screen -#define REFRESH_WRITE (1<<1) // Rewrite the prompt on the screen. -#define REFRESH_ALL (REFRESH_CLEAN|REFRESH_WRITE) // Do both. -static void refreshLine(struct linenoiseState *l); - -class File { - public: - FILE * file = nullptr; - - FILE * open(const std::string & filename, const char * mode) { - file = fopen(filename.c_str(), mode); - - return file; - } - - int lock() { - if (file) { - fd = fileno(file); - if (flock(fd, LOCK_EX | LOCK_NB) != 0) { - fd = -1; - - return 1; - } - } - - return 0; - } - - ~File() { - if (fd >= 0) { - flock(fd, LOCK_UN); - } - - if (file) { - fclose(file); - } - } - - private: - int fd = -1; -}; - -#if 0 -/* Debugging function. */ -__attribute__((format(printf, 1, 2))) -static void lndebug(const char *fmt, ...) { - static File file; - if (file.file == nullptr) { - file.open("/tmp/lndebug.txt", "a"); - } - - if (file.file != nullptr) { - va_list args; - va_start(args, fmt); - vfprintf(file.file, fmt, args); - va_end(args); - fflush(file.file); - } -} -#endif - -/* ========================== Encoding functions ============================= */ - -/* Get length of previous UTF8 codepoint */ -static size_t prevUtf8CodePointLen(const char * buf, int pos) { - int end = pos--; - while (pos >= 0 && ((unsigned char) buf[pos] & 0xC0) == 0x80) { - pos--; - } - return end - pos; -} - -/* Convert UTF8 to Unicode code point */ -static size_t utf8BytesToCodePoint(const char * buf, size_t len, int * cp) { - if (len) { - unsigned char byte = buf[0]; - if ((byte & 0x80) == 0) { - *cp = byte; - return 1; - } else if ((byte & 0xE0) == 0xC0) { - if (len >= 2) { - *cp = (((unsigned long) (buf[0] & 0x1F)) << 6) | ((unsigned long) (buf[1] & 0x3F)); - return 2; - } - } else if ((byte & 0xF0) == 0xE0) { - if (len >= 3) { - *cp = (((unsigned long) (buf[0] & 0x0F)) << 12) | (((unsigned long) (buf[1] & 0x3F)) << 6) | - ((unsigned long) (buf[2] & 0x3F)); - return 3; - } - } else if ((byte & 0xF8) == 0xF0) { - if (len >= 4) { - *cp = (((unsigned long) (buf[0] & 0x07)) << 18) | (((unsigned long) (buf[1] & 0x3F)) << 12) | - (((unsigned long) (buf[2] & 0x3F)) << 6) | ((unsigned long) (buf[3] & 0x3F)); - return 4; - } - } - } - return 0; -} - -/* Check if the code is a wide character */ -static const unsigned long wideCharTable[][2] = { - /* BEGIN: WIDE CHAR TABLE */ - { 0x1100, 0x115F }, - { 0x231A, 0x231B }, - { 0x2329, 0x232A }, - { 0x23E9, 0x23EC }, - { 0x23F0, 0x23F0 }, - { 0x23F3, 0x23F3 }, - { 0x25FD, 0x25FE }, - { 0x2614, 0x2615 }, - { 0x2630, 0x2637 }, - { 0x2648, 0x2653 }, - { 0x267F, 0x267F }, - { 0x268A, 0x268F }, - { 0x2693, 0x2693 }, - { 0x26A1, 0x26A1 }, - { 0x26AA, 0x26AB }, - { 0x26BD, 0x26BE }, - { 0x26C4, 0x26C5 }, - { 0x26CE, 0x26CE }, - { 0x26D4, 0x26D4 }, - { 0x26EA, 0x26EA }, - { 0x26F2, 0x26F3 }, - { 0x26F5, 0x26F5 }, - { 0x26FA, 0x26FA }, - { 0x26FD, 0x26FD }, - { 0x2705, 0x2705 }, - { 0x270A, 0x270B }, - { 0x2728, 0x2728 }, - { 0x274C, 0x274C }, - { 0x274E, 0x274E }, - { 0x2753, 0x2755 }, - { 0x2757, 0x2757 }, - { 0x2795, 0x2797 }, - { 0x27B0, 0x27B0 }, - { 0x27BF, 0x27BF }, - { 0x2B1B, 0x2B1C }, - { 0x2B50, 0x2B50 }, - { 0x2B55, 0x2B55 }, - { 0x2E80, 0x2E99 }, - { 0x2E9B, 0x2EF3 }, - { 0x2F00, 0x2FD5 }, - { 0x2FF0, 0x303E }, - { 0x3041, 0x3096 }, - { 0x3099, 0x30FF }, - { 0x3105, 0x312F }, - { 0x3131, 0x318E }, - { 0x3190, 0x31E5 }, - { 0x31EF, 0x321E }, - { 0x3220, 0x3247 }, - { 0x3250, 0xA48C }, - { 0xA490, 0xA4C6 }, - { 0xA960, 0xA97C }, - { 0xAC00, 0xD7A3 }, - { 0xF900, 0xFAFF }, - { 0xFE10, 0xFE19 }, - { 0xFE30, 0xFE52 }, - { 0xFE54, 0xFE66 }, - { 0xFE68, 0xFE6B }, - { 0xFF01, 0xFF60 }, - { 0xFFE0, 0xFFE6 }, - { 0x16FE0, 0x16FE4 }, - { 0x16FF0, 0x16FF1 }, - { 0x17000, 0x187F7 }, - { 0x18800, 0x18CD5 }, - { 0x18CFF, 0x18D08 }, - { 0x1AFF0, 0x1AFF3 }, - { 0x1AFF5, 0x1AFFB }, - { 0x1AFFD, 0x1AFFE }, - { 0x1B000, 0x1B122 }, - { 0x1B132, 0x1B132 }, - { 0x1B150, 0x1B152 }, - { 0x1B155, 0x1B155 }, - { 0x1B164, 0x1B167 }, - { 0x1B170, 0x1B2FB }, - { 0x1D300, 0x1D356 }, - { 0x1D360, 0x1D376 }, - { 0x1F004, 0x1F004 }, - { 0x1F0CF, 0x1F0CF }, - { 0x1F18E, 0x1F18E }, - { 0x1F191, 0x1F19A }, - { 0x1F200, 0x1F202 }, - { 0x1F210, 0x1F23B }, - { 0x1F240, 0x1F248 }, - { 0x1F250, 0x1F251 }, - { 0x1F260, 0x1F265 }, - { 0x1F300, 0x1F320 }, - { 0x1F32D, 0x1F335 }, - { 0x1F337, 0x1F37C }, - { 0x1F37E, 0x1F393 }, - { 0x1F3A0, 0x1F3CA }, - { 0x1F3CF, 0x1F3D3 }, - { 0x1F3E0, 0x1F3F0 }, - { 0x1F3F4, 0x1F3F4 }, - { 0x1F3F8, 0x1F43E }, - { 0x1F440, 0x1F440 }, - { 0x1F442, 0x1F4FC }, - { 0x1F4FF, 0x1F53D }, - { 0x1F54B, 0x1F54E }, - { 0x1F550, 0x1F567 }, - { 0x1F57A, 0x1F57A }, - { 0x1F595, 0x1F596 }, - { 0x1F5A4, 0x1F5A4 }, - { 0x1F5FB, 0x1F64F }, - { 0x1F680, 0x1F6C5 }, - { 0x1F6CC, 0x1F6CC }, - { 0x1F6D0, 0x1F6D2 }, - { 0x1F6D5, 0x1F6D7 }, - { 0x1F6DC, 0x1F6DF }, - { 0x1F6EB, 0x1F6EC }, - { 0x1F6F4, 0x1F6FC }, - { 0x1F7E0, 0x1F7EB }, - { 0x1F7F0, 0x1F7F0 }, - { 0x1F90C, 0x1F93A }, - { 0x1F93C, 0x1F945 }, - { 0x1F947, 0x1F9FF }, - { 0x1FA70, 0x1FA7C }, - { 0x1FA80, 0x1FA89 }, - { 0x1FA8F, 0x1FAC6 }, - { 0x1FACE, 0x1FADC }, - { 0x1FADF, 0x1FAE9 }, - { 0x1FAF0, 0x1FAF8 }, - { 0x20000, 0x2FFFD }, - { 0x30000, 0x3FFFD } - /* END: WIDE CHAR TABLE */ -}; - -static const size_t wideCharTableSize = sizeof(wideCharTable) / sizeof(wideCharTable[0]); - -static bool isWideChar(unsigned long cp) { - for (size_t i = 0; i < wideCharTableSize; i++) { - auto first_code = wideCharTable[i][0]; - auto last_code = wideCharTable[i][1]; - if (first_code > cp) { - return false; - } - if (first_code <= cp && cp <= last_code) { - return true; - } - } - return false; -} - -/* Check if the code is a combining character */ -static const unsigned long combiningCharTable[] = { - /* BEGIN: COMBINING CHAR TABLE */ - 0x0300, 0x0301, 0x0302, 0x0303, 0x0304, 0x0305, 0x0306, 0x0307, 0x0308, 0x0309, 0x030A, 0x030B, 0x030C, - 0x030D, 0x030E, 0x030F, 0x0310, 0x0311, 0x0312, 0x0313, 0x0314, 0x0315, 0x0316, 0x0317, 0x0318, 0x0319, - 0x031A, 0x031B, 0x031C, 0x031D, 0x031E, 0x031F, 0x0320, 0x0321, 0x0322, 0x0323, 0x0324, 0x0325, 0x0326, - 0x0327, 0x0328, 0x0329, 0x032A, 0x032B, 0x032C, 0x032D, 0x032E, 0x032F, 0x0330, 0x0331, 0x0332, 0x0333, - 0x0334, 0x0335, 0x0336, 0x0337, 0x0338, 0x0339, 0x033A, 0x033B, 0x033C, 0x033D, 0x033E, 0x033F, 0x0340, - 0x0341, 0x0342, 0x0343, 0x0344, 0x0345, 0x0346, 0x0347, 0x0348, 0x0349, 0x034A, 0x034B, 0x034C, 0x034D, - 0x034E, 0x034F, 0x0350, 0x0351, 0x0352, 0x0353, 0x0354, 0x0355, 0x0356, 0x0357, 0x0358, 0x0359, 0x035A, - 0x035B, 0x035C, 0x035D, 0x035E, 0x035F, 0x0360, 0x0361, 0x0362, 0x0363, 0x0364, 0x0365, 0x0366, 0x0367, - 0x0368, 0x0369, 0x036A, 0x036B, 0x036C, 0x036D, 0x036E, 0x036F, 0x0483, 0x0484, 0x0485, 0x0486, 0x0487, - 0x0591, 0x0592, 0x0593, 0x0594, 0x0595, 0x0596, 0x0597, 0x0598, 0x0599, 0x059A, 0x059B, 0x059C, 0x059D, - 0x059E, 0x059F, 0x05A0, 0x05A1, 0x05A2, 0x05A3, 0x05A4, 0x05A5, 0x05A6, 0x05A7, 0x05A8, 0x05A9, 0x05AA, - 0x05AB, 0x05AC, 0x05AD, 0x05AE, 0x05AF, 0x05B0, 0x05B1, 0x05B2, 0x05B3, 0x05B4, 0x05B5, 0x05B6, 0x05B7, - 0x05B8, 0x05B9, 0x05BA, 0x05BB, 0x05BC, 0x05BD, 0x05BF, 0x05C1, 0x05C2, 0x05C4, 0x05C5, 0x05C7, 0x0610, - 0x0611, 0x0612, 0x0613, 0x0614, 0x0615, 0x0616, 0x0617, 0x0618, 0x0619, 0x061A, 0x064B, 0x064C, 0x064D, - 0x064E, 0x064F, 0x0650, 0x0651, 0x0652, 0x0653, 0x0654, 0x0655, 0x0656, 0x0657, 0x0658, 0x0659, 0x065A, - 0x065B, 0x065C, 0x065D, 0x065E, 0x065F, 0x0670, 0x06D6, 0x06D7, 0x06D8, 0x06D9, 0x06DA, 0x06DB, 0x06DC, - 0x06DF, 0x06E0, 0x06E1, 0x06E2, 0x06E3, 0x06E4, 0x06E7, 0x06E8, 0x06EA, 0x06EB, 0x06EC, 0x06ED, 0x0711, - 0x0730, 0x0731, 0x0732, 0x0733, 0x0734, 0x0735, 0x0736, 0x0737, 0x0738, 0x0739, 0x073A, 0x073B, 0x073C, - 0x073D, 0x073E, 0x073F, 0x0740, 0x0741, 0x0742, 0x0743, 0x0744, 0x0745, 0x0746, 0x0747, 0x0748, 0x0749, - 0x074A, 0x07A6, 0x07A7, 0x07A8, 0x07A9, 0x07AA, 0x07AB, 0x07AC, 0x07AD, 0x07AE, 0x07AF, 0x07B0, 0x07EB, - 0x07EC, 0x07ED, 0x07EE, 0x07EF, 0x07F0, 0x07F1, 0x07F2, 0x07F3, 0x07FD, 0x0816, 0x0817, 0x0818, 0x0819, - 0x081B, 0x081C, 0x081D, 0x081E, 0x081F, 0x0820, 0x0821, 0x0822, 0x0823, 0x0825, 0x0826, 0x0827, 0x0829, - 0x082A, 0x082B, 0x082C, 0x082D, 0x0859, 0x085A, 0x085B, 0x0897, 0x0898, 0x0899, 0x089A, 0x089B, 0x089C, - 0x089D, 0x089E, 0x089F, 0x08CA, 0x08CB, 0x08CC, 0x08CD, 0x08CE, 0x08CF, 0x08D0, 0x08D1, 0x08D2, 0x08D3, - 0x08D4, 0x08D5, 0x08D6, 0x08D7, 0x08D8, 0x08D9, 0x08DA, 0x08DB, 0x08DC, 0x08DD, 0x08DE, 0x08DF, 0x08E0, - 0x08E1, 0x08E3, 0x08E4, 0x08E5, 0x08E6, 0x08E7, 0x08E8, 0x08E9, 0x08EA, 0x08EB, 0x08EC, 0x08ED, 0x08EE, - 0x08EF, 0x08F0, 0x08F1, 0x08F2, 0x08F3, 0x08F4, 0x08F5, 0x08F6, 0x08F7, 0x08F8, 0x08F9, 0x08FA, 0x08FB, - 0x08FC, 0x08FD, 0x08FE, 0x08FF, 0x0900, 0x0901, 0x0902, 0x093A, 0x093C, 0x0941, 0x0942, 0x0943, 0x0944, - 0x0945, 0x0946, 0x0947, 0x0948, 0x094D, 0x0951, 0x0952, 0x0953, 0x0954, 0x0955, 0x0956, 0x0957, 0x0962, - 0x0963, 0x0981, 0x09BC, 0x09C1, 0x09C2, 0x09C3, 0x09C4, 0x09CD, 0x09E2, 0x09E3, 0x09FE, 0x0A01, 0x0A02, - 0x0A3C, 0x0A41, 0x0A42, 0x0A47, 0x0A48, 0x0A4B, 0x0A4C, 0x0A4D, 0x0A51, 0x0A70, 0x0A71, 0x0A75, 0x0A81, - 0x0A82, 0x0ABC, 0x0AC1, 0x0AC2, 0x0AC3, 0x0AC4, 0x0AC5, 0x0AC7, 0x0AC8, 0x0ACD, 0x0AE2, 0x0AE3, 0x0AFA, - 0x0AFB, 0x0AFC, 0x0AFD, 0x0AFE, 0x0AFF, 0x0B01, 0x0B3C, 0x0B3F, 0x0B41, 0x0B42, 0x0B43, 0x0B44, 0x0B4D, - 0x0B55, 0x0B56, 0x0B62, 0x0B63, 0x0B82, 0x0BC0, 0x0BCD, 0x0C00, 0x0C04, 0x0C3C, 0x0C3E, 0x0C3F, 0x0C40, - 0x0C46, 0x0C47, 0x0C48, 0x0C4A, 0x0C4B, 0x0C4C, 0x0C4D, 0x0C55, 0x0C56, 0x0C62, 0x0C63, 0x0C81, 0x0CBC, - 0x0CBF, 0x0CC6, 0x0CCC, 0x0CCD, 0x0CE2, 0x0CE3, 0x0D00, 0x0D01, 0x0D3B, 0x0D3C, 0x0D41, 0x0D42, 0x0D43, - 0x0D44, 0x0D4D, 0x0D62, 0x0D63, 0x0D81, 0x0DCA, 0x0DD2, 0x0DD3, 0x0DD4, 0x0DD6, 0x0E31, 0x0E34, 0x0E35, - 0x0E36, 0x0E37, 0x0E38, 0x0E39, 0x0E3A, 0x0E47, 0x0E48, 0x0E49, 0x0E4A, 0x0E4B, 0x0E4C, 0x0E4D, 0x0E4E, - 0x0EB1, 0x0EB4, 0x0EB5, 0x0EB6, 0x0EB7, 0x0EB8, 0x0EB9, 0x0EBA, 0x0EBB, 0x0EBC, 0x0EC8, 0x0EC9, 0x0ECA, - 0x0ECB, 0x0ECC, 0x0ECD, 0x0ECE, 0x0F18, 0x0F19, 0x0F35, 0x0F37, 0x0F39, 0x0F71, 0x0F72, 0x0F73, 0x0F74, - 0x0F75, 0x0F76, 0x0F77, 0x0F78, 0x0F79, 0x0F7A, 0x0F7B, 0x0F7C, 0x0F7D, 0x0F7E, 0x0F80, 0x0F81, 0x0F82, - 0x0F83, 0x0F84, 0x0F86, 0x0F87, 0x0F8D, 0x0F8E, 0x0F8F, 0x0F90, 0x0F91, 0x0F92, 0x0F93, 0x0F94, 0x0F95, - 0x0F96, 0x0F97, 0x0F99, 0x0F9A, 0x0F9B, 0x0F9C, 0x0F9D, 0x0F9E, 0x0F9F, 0x0FA0, 0x0FA1, 0x0FA2, 0x0FA3, - 0x0FA4, 0x0FA5, 0x0FA6, 0x0FA7, 0x0FA8, 0x0FA9, 0x0FAA, 0x0FAB, 0x0FAC, 0x0FAD, 0x0FAE, 0x0FAF, 0x0FB0, - 0x0FB1, 0x0FB2, 0x0FB3, 0x0FB4, 0x0FB5, 0x0FB6, 0x0FB7, 0x0FB8, 0x0FB9, 0x0FBA, 0x0FBB, 0x0FBC, 0x0FC6, - 0x102D, 0x102E, 0x102F, 0x1030, 0x1032, 0x1033, 0x1034, 0x1035, 0x1036, 0x1037, 0x1039, 0x103A, 0x103D, - 0x103E, 0x1058, 0x1059, 0x105E, 0x105F, 0x1060, 0x1071, 0x1072, 0x1073, 0x1074, 0x1082, 0x1085, 0x1086, - 0x108D, 0x109D, 0x135D, 0x135E, 0x135F, 0x1712, 0x1713, 0x1714, 0x1732, 0x1733, 0x1752, 0x1753, 0x1772, - 0x1773, 0x17B4, 0x17B5, 0x17B7, 0x17B8, 0x17B9, 0x17BA, 0x17BB, 0x17BC, 0x17BD, 0x17C6, 0x17C9, 0x17CA, - 0x17CB, 0x17CC, 0x17CD, 0x17CE, 0x17CF, 0x17D0, 0x17D1, 0x17D2, 0x17D3, 0x17DD, 0x180B, 0x180C, 0x180D, - 0x180F, 0x1885, 0x1886, 0x18A9, 0x1920, 0x1921, 0x1922, 0x1927, 0x1928, 0x1932, 0x1939, 0x193A, 0x193B, - 0x1A17, 0x1A18, 0x1A1B, 0x1A56, 0x1A58, 0x1A59, 0x1A5A, 0x1A5B, 0x1A5C, 0x1A5D, 0x1A5E, 0x1A60, 0x1A62, - 0x1A65, 0x1A66, 0x1A67, 0x1A68, 0x1A69, 0x1A6A, 0x1A6B, 0x1A6C, 0x1A73, 0x1A74, 0x1A75, 0x1A76, 0x1A77, - 0x1A78, 0x1A79, 0x1A7A, 0x1A7B, 0x1A7C, 0x1A7F, 0x1AB0, 0x1AB1, 0x1AB2, 0x1AB3, 0x1AB4, 0x1AB5, 0x1AB6, - 0x1AB7, 0x1AB8, 0x1AB9, 0x1ABA, 0x1ABB, 0x1ABC, 0x1ABD, 0x1ABF, 0x1AC0, 0x1AC1, 0x1AC2, 0x1AC3, 0x1AC4, - 0x1AC5, 0x1AC6, 0x1AC7, 0x1AC8, 0x1AC9, 0x1ACA, 0x1ACB, 0x1ACC, 0x1ACD, 0x1ACE, 0x1B00, 0x1B01, 0x1B02, - 0x1B03, 0x1B34, 0x1B36, 0x1B37, 0x1B38, 0x1B39, 0x1B3A, 0x1B3C, 0x1B42, 0x1B6B, 0x1B6C, 0x1B6D, 0x1B6E, - 0x1B6F, 0x1B70, 0x1B71, 0x1B72, 0x1B73, 0x1B80, 0x1B81, 0x1BA2, 0x1BA3, 0x1BA4, 0x1BA5, 0x1BA8, 0x1BA9, - 0x1BAB, 0x1BAC, 0x1BAD, 0x1BE6, 0x1BE8, 0x1BE9, 0x1BED, 0x1BEF, 0x1BF0, 0x1BF1, 0x1C2C, 0x1C2D, 0x1C2E, - 0x1C2F, 0x1C30, 0x1C31, 0x1C32, 0x1C33, 0x1C36, 0x1C37, 0x1CD0, 0x1CD1, 0x1CD2, 0x1CD4, 0x1CD5, 0x1CD6, - 0x1CD7, 0x1CD8, 0x1CD9, 0x1CDA, 0x1CDB, 0x1CDC, 0x1CDD, 0x1CDE, 0x1CDF, 0x1CE0, 0x1CE2, 0x1CE3, 0x1CE4, - 0x1CE5, 0x1CE6, 0x1CE7, 0x1CE8, 0x1CED, 0x1CF4, 0x1CF8, 0x1CF9, 0x1DC0, 0x1DC1, 0x1DC2, 0x1DC3, 0x1DC4, - 0x1DC5, 0x1DC6, 0x1DC7, 0x1DC8, 0x1DC9, 0x1DCA, 0x1DCB, 0x1DCC, 0x1DCD, 0x1DCE, 0x1DCF, 0x1DD0, 0x1DD1, - 0x1DD2, 0x1DD3, 0x1DD4, 0x1DD5, 0x1DD6, 0x1DD7, 0x1DD8, 0x1DD9, 0x1DDA, 0x1DDB, 0x1DDC, 0x1DDD, 0x1DDE, - 0x1DDF, 0x1DE0, 0x1DE1, 0x1DE2, 0x1DE3, 0x1DE4, 0x1DE5, 0x1DE6, 0x1DE7, 0x1DE8, 0x1DE9, 0x1DEA, 0x1DEB, - 0x1DEC, 0x1DED, 0x1DEE, 0x1DEF, 0x1DF0, 0x1DF1, 0x1DF2, 0x1DF3, 0x1DF4, 0x1DF5, 0x1DF6, 0x1DF7, 0x1DF8, - 0x1DF9, 0x1DFA, 0x1DFB, 0x1DFC, 0x1DFD, 0x1DFE, 0x1DFF, 0x20D0, 0x20D1, 0x20D2, 0x20D3, 0x20D4, 0x20D5, - 0x20D6, 0x20D7, 0x20D8, 0x20D9, 0x20DA, 0x20DB, 0x20DC, 0x20E1, 0x20E5, 0x20E6, 0x20E7, 0x20E8, 0x20E9, - 0x20EA, 0x20EB, 0x20EC, 0x20ED, 0x20EE, 0x20EF, 0x20F0, 0x2CEF, 0x2CF0, 0x2CF1, 0x2D7F, 0x2DE0, 0x2DE1, - 0x2DE2, 0x2DE3, 0x2DE4, 0x2DE5, 0x2DE6, 0x2DE7, 0x2DE8, 0x2DE9, 0x2DEA, 0x2DEB, 0x2DEC, 0x2DED, 0x2DEE, - 0x2DEF, 0x2DF0, 0x2DF1, 0x2DF2, 0x2DF3, 0x2DF4, 0x2DF5, 0x2DF6, 0x2DF7, 0x2DF8, 0x2DF9, 0x2DFA, 0x2DFB, - 0x2DFC, 0x2DFD, 0x2DFE, 0x2DFF, 0x302A, 0x302B, 0x302C, 0x302D, 0x3099, 0x309A, 0xA66F, 0xA674, 0xA675, - 0xA676, 0xA677, 0xA678, 0xA679, 0xA67A, 0xA67B, 0xA67C, 0xA67D, 0xA69E, 0xA69F, 0xA6F0, 0xA6F1, 0xA802, - 0xA806, 0xA80B, 0xA825, 0xA826, 0xA82C, 0xA8C4, 0xA8C5, 0xA8E0, 0xA8E1, 0xA8E2, 0xA8E3, 0xA8E4, 0xA8E5, - 0xA8E6, 0xA8E7, 0xA8E8, 0xA8E9, 0xA8EA, 0xA8EB, 0xA8EC, 0xA8ED, 0xA8EE, 0xA8EF, 0xA8F0, 0xA8F1, 0xA8FF, - 0xA926, 0xA927, 0xA928, 0xA929, 0xA92A, 0xA92B, 0xA92C, 0xA92D, 0xA947, 0xA948, 0xA949, 0xA94A, 0xA94B, - 0xA94C, 0xA94D, 0xA94E, 0xA94F, 0xA950, 0xA951, 0xA980, 0xA981, 0xA982, 0xA9B3, 0xA9B6, 0xA9B7, 0xA9B8, - 0xA9B9, 0xA9BC, 0xA9BD, 0xA9E5, 0xAA29, 0xAA2A, 0xAA2B, 0xAA2C, 0xAA2D, 0xAA2E, 0xAA31, 0xAA32, 0xAA35, - 0xAA36, 0xAA43, 0xAA4C, 0xAA7C, 0xAAB0, 0xAAB2, 0xAAB3, 0xAAB4, 0xAAB7, 0xAAB8, 0xAABE, 0xAABF, 0xAAC1, - 0xAAEC, 0xAAED, 0xAAF6, 0xABE5, 0xABE8, 0xABED, 0xFB1E, 0xFE00, 0xFE01, 0xFE02, 0xFE03, 0xFE04, 0xFE05, - 0xFE06, 0xFE07, 0xFE08, 0xFE09, 0xFE0A, 0xFE0B, 0xFE0C, 0xFE0D, 0xFE0E, 0xFE0F, 0xFE20, 0xFE21, 0xFE22, - 0xFE23, 0xFE24, 0xFE25, 0xFE26, 0xFE27, 0xFE28, 0xFE29, 0xFE2A, 0xFE2B, 0xFE2C, 0xFE2D, 0xFE2E, 0xFE2F, - 0x101FD, 0x102E0, 0x10376, 0x10377, 0x10378, 0x10379, 0x1037A, 0x10A01, 0x10A02, 0x10A03, 0x10A05, 0x10A06, 0x10A0C, - 0x10A0D, 0x10A0E, 0x10A0F, 0x10A38, 0x10A39, 0x10A3A, 0x10A3F, 0x10AE5, 0x10AE6, 0x10D24, 0x10D25, 0x10D26, 0x10D27, - 0x10D69, 0x10D6A, 0x10D6B, 0x10D6C, 0x10D6D, 0x10EAB, 0x10EAC, 0x10EFC, 0x10EFD, 0x10EFE, 0x10EFF, 0x10F46, 0x10F47, - 0x10F48, 0x10F49, 0x10F4A, 0x10F4B, 0x10F4C, 0x10F4D, 0x10F4E, 0x10F4F, 0x10F50, 0x10F82, 0x10F83, 0x10F84, 0x10F85, - 0x11001, 0x11038, 0x11039, 0x1103A, 0x1103B, 0x1103C, 0x1103D, 0x1103E, 0x1103F, 0x11040, 0x11041, 0x11042, 0x11043, - 0x11044, 0x11045, 0x11046, 0x11070, 0x11073, 0x11074, 0x1107F, 0x11080, 0x11081, 0x110B3, 0x110B4, 0x110B5, 0x110B6, - 0x110B9, 0x110BA, 0x110C2, 0x11100, 0x11101, 0x11102, 0x11127, 0x11128, 0x11129, 0x1112A, 0x1112B, 0x1112D, 0x1112E, - 0x1112F, 0x11130, 0x11131, 0x11132, 0x11133, 0x11134, 0x11173, 0x11180, 0x11181, 0x111B6, 0x111B7, 0x111B8, 0x111B9, - 0x111BA, 0x111BB, 0x111BC, 0x111BD, 0x111BE, 0x111C9, 0x111CA, 0x111CB, 0x111CC, 0x111CF, 0x1122F, 0x11230, 0x11231, - 0x11234, 0x11236, 0x11237, 0x1123E, 0x11241, 0x112DF, 0x112E3, 0x112E4, 0x112E5, 0x112E6, 0x112E7, 0x112E8, 0x112E9, - 0x112EA, 0x11300, 0x11301, 0x1133B, 0x1133C, 0x11340, 0x11366, 0x11367, 0x11368, 0x11369, 0x1136A, 0x1136B, 0x1136C, - 0x11370, 0x11371, 0x11372, 0x11373, 0x11374, 0x113BB, 0x113BC, 0x113BD, 0x113BE, 0x113BF, 0x113C0, 0x113CE, 0x113D0, - 0x113D2, 0x113E1, 0x113E2, 0x11438, 0x11439, 0x1143A, 0x1143B, 0x1143C, 0x1143D, 0x1143E, 0x1143F, 0x11442, 0x11443, - 0x11444, 0x11446, 0x1145E, 0x114B3, 0x114B4, 0x114B5, 0x114B6, 0x114B7, 0x114B8, 0x114BA, 0x114BF, 0x114C0, 0x114C2, - 0x114C3, 0x115B2, 0x115B3, 0x115B4, 0x115B5, 0x115BC, 0x115BD, 0x115BF, 0x115C0, 0x115DC, 0x115DD, 0x11633, 0x11634, - 0x11635, 0x11636, 0x11637, 0x11638, 0x11639, 0x1163A, 0x1163D, 0x1163F, 0x11640, 0x116AB, 0x116AD, 0x116B0, 0x116B1, - 0x116B2, 0x116B3, 0x116B4, 0x116B5, 0x116B7, 0x1171D, 0x1171F, 0x11722, 0x11723, 0x11724, 0x11725, 0x11727, 0x11728, - 0x11729, 0x1172A, 0x1172B, 0x1182F, 0x11830, 0x11831, 0x11832, 0x11833, 0x11834, 0x11835, 0x11836, 0x11837, 0x11839, - 0x1183A, 0x1193B, 0x1193C, 0x1193E, 0x11943, 0x119D4, 0x119D5, 0x119D6, 0x119D7, 0x119DA, 0x119DB, 0x119E0, 0x11A01, - 0x11A02, 0x11A03, 0x11A04, 0x11A05, 0x11A06, 0x11A07, 0x11A08, 0x11A09, 0x11A0A, 0x11A33, 0x11A34, 0x11A35, 0x11A36, - 0x11A37, 0x11A38, 0x11A3B, 0x11A3C, 0x11A3D, 0x11A3E, 0x11A47, 0x11A51, 0x11A52, 0x11A53, 0x11A54, 0x11A55, 0x11A56, - 0x11A59, 0x11A5A, 0x11A5B, 0x11A8A, 0x11A8B, 0x11A8C, 0x11A8D, 0x11A8E, 0x11A8F, 0x11A90, 0x11A91, 0x11A92, 0x11A93, - 0x11A94, 0x11A95, 0x11A96, 0x11A98, 0x11A99, 0x11C30, 0x11C31, 0x11C32, 0x11C33, 0x11C34, 0x11C35, 0x11C36, 0x11C38, - 0x11C39, 0x11C3A, 0x11C3B, 0x11C3C, 0x11C3D, 0x11C3F, 0x11C92, 0x11C93, 0x11C94, 0x11C95, 0x11C96, 0x11C97, 0x11C98, - 0x11C99, 0x11C9A, 0x11C9B, 0x11C9C, 0x11C9D, 0x11C9E, 0x11C9F, 0x11CA0, 0x11CA1, 0x11CA2, 0x11CA3, 0x11CA4, 0x11CA5, - 0x11CA6, 0x11CA7, 0x11CAA, 0x11CAB, 0x11CAC, 0x11CAD, 0x11CAE, 0x11CAF, 0x11CB0, 0x11CB2, 0x11CB3, 0x11CB5, 0x11CB6, - 0x11D31, 0x11D32, 0x11D33, 0x11D34, 0x11D35, 0x11D36, 0x11D3A, 0x11D3C, 0x11D3D, 0x11D3F, 0x11D40, 0x11D41, 0x11D42, - 0x11D43, 0x11D44, 0x11D45, 0x11D47, 0x11D90, 0x11D91, 0x11D95, 0x11D97, 0x11EF3, 0x11EF4, 0x11F00, 0x11F01, 0x11F36, - 0x11F37, 0x11F38, 0x11F39, 0x11F3A, 0x11F40, 0x11F42, 0x11F5A, 0x13440, 0x13447, 0x13448, 0x13449, 0x1344A, 0x1344B, - 0x1344C, 0x1344D, 0x1344E, 0x1344F, 0x13450, 0x13451, 0x13452, 0x13453, 0x13454, 0x13455, 0x1611E, 0x1611F, 0x16120, - 0x16121, 0x16122, 0x16123, 0x16124, 0x16125, 0x16126, 0x16127, 0x16128, 0x16129, 0x1612D, 0x1612E, 0x1612F, 0x16AF0, - 0x16AF1, 0x16AF2, 0x16AF3, 0x16AF4, 0x16B30, 0x16B31, 0x16B32, 0x16B33, 0x16B34, 0x16B35, 0x16B36, 0x16F4F, 0x16F8F, - 0x16F90, 0x16F91, 0x16F92, 0x16FE4, 0x1BC9D, 0x1BC9E, 0x1CF00, 0x1CF01, 0x1CF02, 0x1CF03, 0x1CF04, 0x1CF05, 0x1CF06, - 0x1CF07, 0x1CF08, 0x1CF09, 0x1CF0A, 0x1CF0B, 0x1CF0C, 0x1CF0D, 0x1CF0E, 0x1CF0F, 0x1CF10, 0x1CF11, 0x1CF12, 0x1CF13, - 0x1CF14, 0x1CF15, 0x1CF16, 0x1CF17, 0x1CF18, 0x1CF19, 0x1CF1A, 0x1CF1B, 0x1CF1C, 0x1CF1D, 0x1CF1E, 0x1CF1F, 0x1CF20, - 0x1CF21, 0x1CF22, 0x1CF23, 0x1CF24, 0x1CF25, 0x1CF26, 0x1CF27, 0x1CF28, 0x1CF29, 0x1CF2A, 0x1CF2B, 0x1CF2C, 0x1CF2D, - 0x1CF30, 0x1CF31, 0x1CF32, 0x1CF33, 0x1CF34, 0x1CF35, 0x1CF36, 0x1CF37, 0x1CF38, 0x1CF39, 0x1CF3A, 0x1CF3B, 0x1CF3C, - 0x1CF3D, 0x1CF3E, 0x1CF3F, 0x1CF40, 0x1CF41, 0x1CF42, 0x1CF43, 0x1CF44, 0x1CF45, 0x1CF46, 0x1D167, 0x1D168, 0x1D169, - 0x1D17B, 0x1D17C, 0x1D17D, 0x1D17E, 0x1D17F, 0x1D180, 0x1D181, 0x1D182, 0x1D185, 0x1D186, 0x1D187, 0x1D188, 0x1D189, - 0x1D18A, 0x1D18B, 0x1D1AA, 0x1D1AB, 0x1D1AC, 0x1D1AD, 0x1D242, 0x1D243, 0x1D244, 0x1DA00, 0x1DA01, 0x1DA02, 0x1DA03, - 0x1DA04, 0x1DA05, 0x1DA06, 0x1DA07, 0x1DA08, 0x1DA09, 0x1DA0A, 0x1DA0B, 0x1DA0C, 0x1DA0D, 0x1DA0E, 0x1DA0F, 0x1DA10, - 0x1DA11, 0x1DA12, 0x1DA13, 0x1DA14, 0x1DA15, 0x1DA16, 0x1DA17, 0x1DA18, 0x1DA19, 0x1DA1A, 0x1DA1B, 0x1DA1C, 0x1DA1D, - 0x1DA1E, 0x1DA1F, 0x1DA20, 0x1DA21, 0x1DA22, 0x1DA23, 0x1DA24, 0x1DA25, 0x1DA26, 0x1DA27, 0x1DA28, 0x1DA29, 0x1DA2A, - 0x1DA2B, 0x1DA2C, 0x1DA2D, 0x1DA2E, 0x1DA2F, 0x1DA30, 0x1DA31, 0x1DA32, 0x1DA33, 0x1DA34, 0x1DA35, 0x1DA36, 0x1DA3B, - 0x1DA3C, 0x1DA3D, 0x1DA3E, 0x1DA3F, 0x1DA40, 0x1DA41, 0x1DA42, 0x1DA43, 0x1DA44, 0x1DA45, 0x1DA46, 0x1DA47, 0x1DA48, - 0x1DA49, 0x1DA4A, 0x1DA4B, 0x1DA4C, 0x1DA4D, 0x1DA4E, 0x1DA4F, 0x1DA50, 0x1DA51, 0x1DA52, 0x1DA53, 0x1DA54, 0x1DA55, - 0x1DA56, 0x1DA57, 0x1DA58, 0x1DA59, 0x1DA5A, 0x1DA5B, 0x1DA5C, 0x1DA5D, 0x1DA5E, 0x1DA5F, 0x1DA60, 0x1DA61, 0x1DA62, - 0x1DA63, 0x1DA64, 0x1DA65, 0x1DA66, 0x1DA67, 0x1DA68, 0x1DA69, 0x1DA6A, 0x1DA6B, 0x1DA6C, 0x1DA75, 0x1DA84, 0x1DA9B, - 0x1DA9C, 0x1DA9D, 0x1DA9E, 0x1DA9F, 0x1DAA1, 0x1DAA2, 0x1DAA3, 0x1DAA4, 0x1DAA5, 0x1DAA6, 0x1DAA7, 0x1DAA8, 0x1DAA9, - 0x1DAAA, 0x1DAAB, 0x1DAAC, 0x1DAAD, 0x1DAAE, 0x1DAAF, 0x1E000, 0x1E001, 0x1E002, 0x1E003, 0x1E004, 0x1E005, 0x1E006, - 0x1E008, 0x1E009, 0x1E00A, 0x1E00B, 0x1E00C, 0x1E00D, 0x1E00E, 0x1E00F, 0x1E010, 0x1E011, 0x1E012, 0x1E013, 0x1E014, - 0x1E015, 0x1E016, 0x1E017, 0x1E018, 0x1E01B, 0x1E01C, 0x1E01D, 0x1E01E, 0x1E01F, 0x1E020, 0x1E021, 0x1E023, 0x1E024, - 0x1E026, 0x1E027, 0x1E028, 0x1E029, 0x1E02A, 0x1E08F, 0x1E130, 0x1E131, 0x1E132, 0x1E133, 0x1E134, 0x1E135, 0x1E136, - 0x1E2AE, 0x1E2EC, 0x1E2ED, 0x1E2EE, 0x1E2EF, 0x1E4EC, 0x1E4ED, 0x1E4EE, 0x1E4EF, 0x1E5EE, 0x1E5EF, 0x1E8D0, 0x1E8D1, - 0x1E8D2, 0x1E8D3, 0x1E8D4, 0x1E8D5, 0x1E8D6, 0x1E944, 0x1E945, 0x1E946, 0x1E947, 0x1E948, 0x1E949, 0x1E94A, 0xE0100, - 0xE0101, 0xE0102, 0xE0103, 0xE0104, 0xE0105, 0xE0106, 0xE0107, 0xE0108, 0xE0109, 0xE010A, 0xE010B, 0xE010C, 0xE010D, - 0xE010E, 0xE010F, 0xE0110, 0xE0111, 0xE0112, 0xE0113, 0xE0114, 0xE0115, 0xE0116, 0xE0117, 0xE0118, 0xE0119, 0xE011A, - 0xE011B, 0xE011C, 0xE011D, 0xE011E, 0xE011F, 0xE0120, 0xE0121, 0xE0122, 0xE0123, 0xE0124, 0xE0125, 0xE0126, 0xE0127, - 0xE0128, 0xE0129, 0xE012A, 0xE012B, 0xE012C, 0xE012D, 0xE012E, 0xE012F, 0xE0130, 0xE0131, 0xE0132, 0xE0133, 0xE0134, - 0xE0135, 0xE0136, 0xE0137, 0xE0138, 0xE0139, 0xE013A, 0xE013B, 0xE013C, 0xE013D, 0xE013E, 0xE013F, 0xE0140, 0xE0141, - 0xE0142, 0xE0143, 0xE0144, 0xE0145, 0xE0146, 0xE0147, 0xE0148, 0xE0149, 0xE014A, 0xE014B, 0xE014C, 0xE014D, 0xE014E, - 0xE014F, 0xE0150, 0xE0151, 0xE0152, 0xE0153, 0xE0154, 0xE0155, 0xE0156, 0xE0157, 0xE0158, 0xE0159, 0xE015A, 0xE015B, - 0xE015C, 0xE015D, 0xE015E, 0xE015F, 0xE0160, 0xE0161, 0xE0162, 0xE0163, 0xE0164, 0xE0165, 0xE0166, 0xE0167, 0xE0168, - 0xE0169, 0xE016A, 0xE016B, 0xE016C, 0xE016D, 0xE016E, 0xE016F, 0xE0170, 0xE0171, 0xE0172, 0xE0173, 0xE0174, 0xE0175, - 0xE0176, 0xE0177, 0xE0178, 0xE0179, 0xE017A, 0xE017B, 0xE017C, 0xE017D, 0xE017E, 0xE017F, 0xE0180, 0xE0181, 0xE0182, - 0xE0183, 0xE0184, 0xE0185, 0xE0186, 0xE0187, 0xE0188, 0xE0189, 0xE018A, 0xE018B, 0xE018C, 0xE018D, 0xE018E, 0xE018F, - 0xE0190, 0xE0191, 0xE0192, 0xE0193, 0xE0194, 0xE0195, 0xE0196, 0xE0197, 0xE0198, 0xE0199, 0xE019A, 0xE019B, 0xE019C, - 0xE019D, 0xE019E, 0xE019F, 0xE01A0, 0xE01A1, 0xE01A2, 0xE01A3, 0xE01A4, 0xE01A5, 0xE01A6, 0xE01A7, 0xE01A8, 0xE01A9, - 0xE01AA, 0xE01AB, 0xE01AC, 0xE01AD, 0xE01AE, 0xE01AF, 0xE01B0, 0xE01B1, 0xE01B2, 0xE01B3, 0xE01B4, 0xE01B5, 0xE01B6, - 0xE01B7, 0xE01B8, 0xE01B9, 0xE01BA, 0xE01BB, 0xE01BC, 0xE01BD, 0xE01BE, 0xE01BF, 0xE01C0, 0xE01C1, 0xE01C2, 0xE01C3, - 0xE01C4, 0xE01C5, 0xE01C6, 0xE01C7, 0xE01C8, 0xE01C9, 0xE01CA, 0xE01CB, 0xE01CC, 0xE01CD, 0xE01CE, 0xE01CF, 0xE01D0, - 0xE01D1, 0xE01D2, 0xE01D3, 0xE01D4, 0xE01D5, 0xE01D6, 0xE01D7, 0xE01D8, 0xE01D9, 0xE01DA, 0xE01DB, 0xE01DC, 0xE01DD, - 0xE01DE, 0xE01DF, 0xE01E0, 0xE01E1, 0xE01E2, 0xE01E3, 0xE01E4, 0xE01E5, 0xE01E6, 0xE01E7, 0xE01E8, 0xE01E9, 0xE01EA, - 0xE01EB, 0xE01EC, 0xE01ED, 0xE01EE, 0xE01EF - /* END: COMBINING CHAR TABLE */ -}; - -static const unsigned long combiningCharTableSize = sizeof(combiningCharTable) / sizeof(combiningCharTable[0]); - -static bool isCombiningChar(unsigned long cp) { - for (size_t i = 0; i < combiningCharTableSize; i++) { - auto code = combiningCharTable[i]; - if (code > cp) { - return false; - } - if (code == cp) { - return true; - } - } - return false; -} - -/* Get length of previous grapheme */ -static size_t defaultPrevCharLen(const char * buf, size_t /*buf_len*/, size_t pos, size_t * col_len) { - size_t end = pos; - while (pos > 0) { - size_t len = prevUtf8CodePointLen(buf, pos); - pos -= len; - int cp; - utf8BytesToCodePoint(buf + pos, len, &cp); - if (!isCombiningChar(cp)) { - if (col_len != NULL) { - *col_len = isWideChar(cp) ? 2 : 1; - } - return end - pos; - } - } - /* NOTREACHED */ - return 0; -} - -/* Get length of next grapheme */ -static size_t defaultNextCharLen(const char * buf, size_t buf_len, size_t pos, size_t * col_len) { - size_t beg = pos; - int cp; - size_t len = utf8BytesToCodePoint(buf + pos, buf_len - pos, &cp); - if (isCombiningChar(cp)) { - /* NOTREACHED */ - return 0; - } - if (col_len != NULL) { - *col_len = isWideChar(cp) ? 2 : 1; - } - pos += len; - while (pos < buf_len) { - int cp; - len = utf8BytesToCodePoint(buf + pos, buf_len - pos, &cp); - if (!isCombiningChar(cp)) { - return pos - beg; - } - pos += len; - } - return pos - beg; -} - -/* Read a Unicode from file. */ -static size_t defaultReadCode(int fd, char * buf, size_t buf_len, int * cp) { - if (buf_len < 1) { - return -1; - } - size_t nread = read(fd, &buf[0], 1); - if (nread <= 0) { - return nread; - } - - unsigned char byte = buf[0]; - if ((byte & 0x80) == 0) { - ; - } else if ((byte & 0xE0) == 0xC0) { - if (buf_len < 2) { - return -1; - } - nread = read(fd, &buf[1], 1); - if (nread <= 0) { - return nread; - } - } else if ((byte & 0xF0) == 0xE0) { - if (buf_len < 3) { - return -1; - } - nread = read(fd, &buf[1], 2); - if (nread <= 0) { - return nread; - } - } else if ((byte & 0xF8) == 0xF0) { - if (buf_len < 3) { - return -1; - } - nread = read(fd, &buf[1], 3); - if (nread <= 0) { - return nread; - } - } else { - return -1; - } - - return utf8BytesToCodePoint(buf, buf_len, cp); -} - -/* Set default encoding functions */ -static linenoisePrevCharLen * prevCharLen = defaultPrevCharLen; -static linenoiseNextCharLen * nextCharLen = defaultNextCharLen; -static linenoiseReadCode * readCode = defaultReadCode; - -/* Set used defined encoding functions */ -void linenoiseSetEncodingFunctions(linenoisePrevCharLen * prevCharLenFunc, linenoiseNextCharLen * nextCharLenFunc, - linenoiseReadCode * readCodeFunc) { - prevCharLen = prevCharLenFunc; - nextCharLen = nextCharLenFunc; - readCode = readCodeFunc; -} - -/* ======================= Low level terminal handling ====================== */ - -/* Enable "mask mode". When it is enabled, instead of the input that - * the user is typing, the terminal will just display a corresponding - * number of asterisks, like "****". This is useful for passwords and other - * secrets that should not be displayed. */ -void linenoiseMaskModeEnable(void) { - maskmode = 1; -} - -/* Disable mask mode. */ -void linenoiseMaskModeDisable(void) { - maskmode = 0; -} - -/* Set if to use or not the multi line mode. */ -void linenoiseSetMultiLine(int ml) { - mlmode = ml; -} - -/* Return true if the terminal name is in the list of terminals we know are - * not able to understand basic escape sequences. */ -static int isUnsupportedTerm(void) { - char *term = getenv("TERM"); - if (term == NULL) return 0; - for (size_t j = 0; j < unsupported_term.size(); ++j) { - if (!strcasecmp(term, unsupported_term[j])) { - return 1; - } - } - return 0; -} - -/* Raw mode: 1960 magic shit. */ -static int enableRawMode(int fd) { - struct termios raw; - - if (!isatty(STDIN_FILENO)) goto fatal; - if (!atexit_registered) { - atexit(linenoiseAtExit); - atexit_registered = 1; - } - if (tcgetattr(fd,&orig_termios) == -1) goto fatal; - - raw = orig_termios; /* modify the original mode */ - /* input modes: no break, no CR to NL, no parity check, no strip char, - * no start/stop output control. */ - raw.c_iflag &= ~(BRKINT | ICRNL | INPCK | ISTRIP | IXON); - /* output modes - disable post processing */ - raw.c_oflag &= ~(OPOST); - /* control modes - set 8 bit chars */ - raw.c_cflag |= (CS8); - /* local modes - choing off, canonical off, no extended functions, - * no signal chars (^Z,^C) */ - raw.c_lflag &= ~(ECHO | ICANON | IEXTEN | ISIG); - /* control chars - set return condition: min number of bytes and timer. - * We want read to return every single byte, without timeout. */ - raw.c_cc[VMIN] = 1; raw.c_cc[VTIME] = 0; /* 1 byte, no timer */ - - /* put terminal in raw mode after flushing */ - if (tcsetattr(fd,TCSAFLUSH,&raw) < 0) goto fatal; - rawmode = 1; - return 0; - -fatal: - errno = ENOTTY; - return -1; -} - -static void disableRawMode(int fd) { - /* Don't even check the return value as it's too late. */ - if (rawmode && tcsetattr(fd,TCSAFLUSH,&orig_termios) != -1) - rawmode = 0; -} - -/* Use the ESC [6n escape sequence to query the horizontal cursor position - * and return it. On error -1 is returned, on success the position of the - * cursor. */ -static int getCursorPosition(int ifd, int ofd) { - char buf[32]; - int cols, rows; - unsigned int i = 0; - - /* Report cursor location */ - if (write(ofd, "\x1b[6n", 4) != 4) return -1; - - /* Read the response: ESC [ rows ; cols R */ - while (i < sizeof(buf)-1) { - if (read(ifd,buf+i,1) != 1) break; - if (buf[i] == 'R') break; - i++; - } - buf[i] = '\0'; - - /* Parse it. */ - if (buf[0] != ESC || buf[1] != '[') return -1; - if (sscanf(buf+2,"%d;%d",&rows,&cols) != 2) return -1; - return cols; -} - -/* Try to get the number of columns in the current terminal, or assume 80 - * if it fails. */ -static int getColumns(int ifd, int ofd) { - struct winsize ws; - - if (ioctl(1, TIOCGWINSZ, &ws) == -1 || ws.ws_col == 0) { - /* ioctl() failed. Try to query the terminal itself. */ - int start, cols; - - /* Get the initial position so we can restore it later. */ - start = getCursorPosition(ifd,ofd); - if (start == -1) goto failed; - - /* Go to right margin and get position. */ - if (write(ofd,"\x1b[999C",6) != 6) goto failed; - cols = getCursorPosition(ifd,ofd); - if (cols == -1) goto failed; - - /* Restore position. */ - if (cols > start) { - char seq[32]; - snprintf(seq,32,"\x1b[%dD",cols-start); - if (write(ofd,seq,strlen(seq)) == -1) { - /* Can't recover... */ - } - } - return cols; - } else { - return ws.ws_col; - } - -failed: - return 80; -} - -/* Clear the screen. Used to handle ctrl+l */ -void linenoiseClearScreen(void) { - if (write(STDOUT_FILENO,"\x1b[H\x1b[2J",7) <= 0) { - /* nothing to do, just to avoid warning. */ - } -} - -/* Beep, used for completion when there is nothing to complete or when all - * the choices were already shown. */ -static void linenoiseBeep(void) { - fprintf(stderr, "\x7"); - fflush(stderr); -} - -/* Called by completeLine() and linenoiseShow() to render the current - * edited line with the proposed completion. If the current completion table - * is already available, it is passed as second argument, otherwise the - * function will use the callback to obtain it. - * - * Flags are the same as refreshLine*(), that is REFRESH_* macros. */ -static void refreshLineWithCompletion(struct linenoiseState *ls, linenoiseCompletions *lc, int flags) { - /* Obtain the table of completions if the caller didn't provide one. */ - linenoiseCompletions ctable; - if (lc == NULL) { - completionCallback(ls->buf, &ctable); - lc = &ctable; - } - - /* Show the edited line with completion if possible, or just refresh. */ - if (ls->completion_idx < lc->len) { - struct linenoiseState saved = *ls; - ls->len = ls->pos = strlen(lc->cvec[ls->completion_idx]); - ls->buf = lc->cvec[ls->completion_idx]; - refreshLineWithFlags(ls, flags); - ls->len = saved.len; - ls->pos = saved.pos; - ls->buf = saved.buf; - } else { - refreshLineWithFlags(ls, flags); - } - - if (lc == &ctable) { - ctable.to_free = false; - } -} - -enum ESC_TYPE { ESC_NULL = 0, ESC_DELETE, ESC_UP, ESC_DOWN, ESC_RIGHT, ESC_LEFT, ESC_HOME, ESC_END }; - -static ESC_TYPE readEscapeSequence(struct linenoiseState * l) { - /* Check if the file input has additional data. */ - struct pollfd pfd; - pfd.fd = l->ifd; - pfd.events = POLLIN; - - auto ret = poll(&pfd, 1, 1); // 1 millisecond timeout - if (ret <= 0) { // -1: error, 0: timeout - return ESC_NULL; - } - - /* Read the next two bytes representing the escape sequence. - * Use two calls to handle slow terminals returning the two - * chars at different times. */ - char seq[3]; - if (read(l->ifd, seq, 1) == -1) { - return ESC_NULL; - } - if (read(l->ifd, seq + 1, 1) == -1) { - return ESC_NULL; - } - - /* ESC [ sequences. */ - if (seq[0] == '[') { - if (seq[1] >= '0' && seq[1] <= '9') { - /* Extended escape, read additional byte. */ - if (read(l->ifd, seq + 2, 1) == -1) { - return ESC_NULL; - } - if (seq[2] == '~') { - switch (seq[1]) { - case '3': - return ESC_DELETE; - } - } - } else { - switch (seq[1]) { - case 'A': - return ESC_UP; - case 'B': - return ESC_DOWN; - case 'C': - return ESC_RIGHT; - case 'D': - return ESC_LEFT; - case 'H': - return ESC_HOME; - case 'F': - return ESC_END; - } - } - } - - /* ESC O sequences. */ - else if (seq[0] == 'O') { - switch (seq[1]) { - case 'H': - return ESC_HOME; - case 'F': - return ESC_END; - } - } - return ESC_NULL; -} - -/* This is an helper function for linenoiseEdit*() and is called when the - * user types the key in order to complete the string currently in the - * input. - * - * The state of the editing is encapsulated into the pointed linenoiseState - * structure as described in the structure definition. - * - * If the function returns non-zero, the caller should handle the - * returned value as a byte read from the standard input, and process - * it as usually: this basically means that the function may return a byte - * read from the terminal but not processed. Otherwise, if zero is returned, - * the input was consumed by the completeLine() function to navigate the - * possible completions, and the caller should read for the next characters - * from stdin. */ -static int completeLine(struct linenoiseState * ls, int keypressed, ESC_TYPE esc_type) { - linenoiseCompletions lc; - int nwritten; - char c = keypressed; - - completionCallback(ls->buf, &lc); - if (lc.len == 0) { - linenoiseBeep(); - ls->in_completion = 0; - } else { - if (c == TAB) { - if (ls->in_completion == 0) { - ls->in_completion = 1; - ls->completion_idx = 0; - } else { - ls->completion_idx = (ls->completion_idx + 1) % (lc.len + 1); - if (ls->completion_idx == lc.len) { - linenoiseBeep(); - } - } - c = 0; - } else if (c == ESC && esc_type == ESC_NULL) { - /* Re-show original buffer */ - if (ls->completion_idx < lc.len) { - refreshLine(ls); - } - ls->in_completion = 0; - c = 0; - } else { - /* Update buffer and return */ - if (ls->completion_idx < lc.len) { - nwritten = snprintf(ls->buf, ls->buflen, "%s", lc.cvec[ls->completion_idx]); - ls->len = ls->pos = nwritten; - } - ls->in_completion = 0; - } - - /* Show completion or original buffer */ - if (ls->in_completion && ls->completion_idx < lc.len) { - refreshLineWithCompletion(ls, &lc, REFRESH_ALL); - } else { - refreshLine(ls); - } - } - - return c; /* Return last read character */ -} - -/* Register a callback function to be called for tab-completion. */ -void linenoiseSetCompletionCallback(linenoiseCompletionCallback *fn) { - completionCallback = fn; -} - -/* Register a hits function to be called to show hits to the user at the - * right of the prompt. */ -void linenoiseSetHintsCallback(linenoiseHintsCallback *fn) { - hintsCallback = fn; -} - -/* Register a function to free the hints returned by the hints callback - * registered with linenoiseSetHintsCallback(). */ -void linenoiseSetFreeHintsCallback(linenoiseFreeHintsCallback *fn) { - freeHintsCallback = fn; -} - -/* This function is used by the callback function registered by the user - * in order to add completion options given the input string when the - * user typed . See the example.c source code for a very easy to - * understand example. */ -void linenoiseAddCompletion(linenoiseCompletions *lc, const char *str) { - const size_t len = strlen(str); - auto copy = std::make_unique(len + 1); - if (!copy) { - return; - } - - memcpy(copy.get(), str, len + 1); - char ** cvec = static_cast(std::realloc(lc->cvec, sizeof(char *) * (lc->len + 1))); - if (cvec == nullptr) { - return; - } - - lc->cvec = cvec; - lc->cvec[lc->len++] = copy.release(); -} - -/* Get column length from begining of buffer to current byte position */ -static size_t columnPos(const char * buf, size_t buf_len, size_t pos) { - size_t ret = 0; - size_t off = 0; - while (off < pos) { - size_t col_len; - size_t len = nextCharLen(buf, buf_len, off, &col_len); - off += len; - ret += col_len; - } - return ret; -} - -/* Helper of refreshSingleLine() and refreshMultiLine() to show hints - * to the right of the prompt. */ -static void refreshShowHints(std::string & ab, struct linenoiseState * l, int pcollen) { - char seq[64]; - size_t collen = pcollen + columnPos(l->buf, l->len, l->len); - if (hintsCallback && collen < l->cols) { - int color = -1, bold = 0; - const char *hint = hintsCallback(l->buf,&color,&bold); - if (hint) { - int hintlen = strlen(hint); - int hintmaxlen = l->cols - collen; - if (hintlen > hintmaxlen) hintlen = hintmaxlen; - if (bold == 1 && color == -1) color = 37; - if (color != -1 || bold != 0) - snprintf(seq,64,"\033[%d;%d;49m",bold,color); - else - seq[0] = '\0'; - ab.append(seq); - ab.append(hint, hintlen); - if (color != -1 || bold != 0) - ab.append("\033[0m"); - - /* Call the function to free the hint returned. */ - if (freeHintsCallback) freeHintsCallback(hint); - } - } -} - -/* Check if text is an ANSI escape sequence */ -static int isAnsiEscape(const char * buf, size_t buf_len, size_t * len) { - if (buf_len > 2 && !memcmp("\033[", buf, 2)) { - size_t off = 2; - while (off < buf_len) { - switch (buf[off++]) { - case 'A': - case 'B': - case 'C': - case 'D': - case 'E': - case 'F': - case 'G': - case 'H': - case 'J': - case 'K': - case 'S': - case 'T': - case 'f': - case 'm': - *len = off; - return 1; - } - } - } - return 0; -} - -/* Get column length of prompt text */ -static size_t promptTextColumnLen(const char * prompt, size_t plen) { - char buf[LINENOISE_MAX_LINE]; - size_t buf_len = 0; - size_t off = 0; - while (off < plen) { - size_t len; - if (isAnsiEscape(prompt + off, plen - off, &len)) { - off += len; - continue; - } - buf[buf_len++] = prompt[off++]; - } - return columnPos(buf, buf_len, buf_len); -} - -/* Single line low level line refresh. - * - * Rewrite the currently edited line accordingly to the buffer content, - * cursor position, and number of columns of the terminal. - * - * Flags is REFRESH_* macros. The function can just remove the old - * prompt, just write it, or both. */ -static void refreshSingleLine(struct linenoiseState *l, int flags) { - char seq[64]; - size_t pcollen = promptTextColumnLen(l->prompt, strlen(l->prompt)); - int fd = l->ofd; - char *buf = l->buf; - size_t len = l->len; - size_t pos = l->pos; - std::string ab; - - while ((pcollen + columnPos(buf, len, pos)) >= l->cols) { - int chlen = nextCharLen(buf, len, 0, NULL); - buf += chlen; - len -= chlen; - pos -= chlen; - } - while (pcollen + columnPos(buf, len, len) > l->cols) { - len -= prevCharLen(buf, len, len, NULL); - } - - /* Cursor to left edge */ - snprintf(seq,sizeof(seq),"\r"); - ab.append(seq); - - if (flags & REFRESH_WRITE) { - /* Write the prompt and the current buffer content */ - ab.append(l->prompt); - if (maskmode == 1) { - while (len--) { - ab.append("*"); - } - } else { - ab.append(buf, len); - } - /* Show hits if any. */ - refreshShowHints(ab, l, pcollen); - } - - /* Erase to right */ - snprintf(seq,sizeof(seq),"\x1b[0K"); - ab.append(seq); - if (flags & REFRESH_WRITE) { - /* Move cursor to original position. */ - snprintf(seq, sizeof(seq), "\r\x1b[%dC", (int) (columnPos(buf, len, pos) + pcollen)); - ab.append(seq); - } - - (void) !write(fd, ab.c_str(), ab.size()); /* Can't recover from write error. */ -} - -/* Get column length from begining of buffer to current byte position for multiline mode*/ -static size_t columnPosForMultiLine(const char * buf, size_t buf_len, size_t pos, size_t cols, size_t ini_pos) { - size_t ret = 0; - size_t colwid = ini_pos; - - size_t off = 0; - while (off < buf_len) { - size_t col_len; - size_t len = nextCharLen(buf, buf_len, off, &col_len); - - int dif = (int) (colwid + col_len) - (int) cols; - if (dif > 0) { - ret += dif; - colwid = col_len; - } else if (dif == 0) { - colwid = 0; - } else { - colwid += col_len; - } - - if (off >= pos) { - break; - } - off += len; - ret += col_len; - } - - return ret; -} - -/* Multi line low level line refresh. - * - * Rewrite the currently edited line accordingly to the buffer content, - * cursor position, and number of columns of the terminal. - * - * Flags is REFRESH_* macros. The function can just remove the old - * prompt, just write it, or both. */ -static void refreshMultiLine(struct linenoiseState *l, int flags) { - char seq[64]; - size_t pcollen = promptTextColumnLen(l->prompt, strlen(l->prompt)); - int colpos = columnPosForMultiLine(l->buf, l->len, l->len, l->cols, pcollen); - int colpos2; /* cursor column position. */ - int rows = (pcollen + colpos + l->cols - 1) / l->cols; /* rows used by current buf. */ - int rpos = (pcollen + l->oldcolpos + l->cols) / l->cols; /* cursor relative row. */ - int rpos2; /* rpos after refresh. */ - int col; /* column position, zero-based. */ - int old_rows = l->oldrows; - int fd = l->ofd, j; - std::string ab; - l->oldrows = rows; - - /* First step: clear all the lines used before. To do so start by - * going to the last row. */ - if (flags & REFRESH_CLEAN) { - if (old_rows - rpos > 0) { - snprintf(seq,64,"\x1b[%dB", old_rows-rpos); - ab.append(seq); - } - - /* Now for every row clear it, go up. */ - for (j = 0; j < old_rows - 1; j++) { - snprintf(seq,64,"\r\x1b[0K\x1b[1A"); - ab.append(seq); - } - } - - if (flags & REFRESH_ALL) { - /* Clean the top line. */ - snprintf(seq,64,"\r\x1b[0K"); - ab.append(seq); - } - - /* Get column length to cursor position */ - colpos2 = columnPosForMultiLine(l->buf, l->len, l->pos, l->cols, pcollen); - - if (flags & REFRESH_WRITE) { - /* Write the prompt and the current buffer content */ - ab.append(l->prompt); - if (maskmode == 1) { - for (unsigned int i = 0; i < l->len; ++i) { - ab.append("*"); - } - } else { - ab.append(l->buf, l->len); - } - - /* Show hits if any. */ - refreshShowHints(ab, l, pcollen); - - /* If we are at the very end of the screen with our prompt, we need to - * emit a newline and move the prompt to the first column. */ - if (l->pos && l->pos == l->len && (colpos2 + pcollen) % l->cols == 0) { - ab.append("\n"); - snprintf(seq,64,"\r"); - ab.append(seq); - rows++; - if (rows > (int)l->oldrows) l->oldrows = rows; - } - - /* Move cursor to right position. */ - rpos2 = (pcollen + colpos2 + l->cols) / l->cols; /* Current cursor relative row */ - - /* Go up till we reach the expected position. */ - if (rows - rpos2 > 0) { - snprintf(seq,64,"\x1b[%dA", rows-rpos2); - ab.append(seq); - } - - /* Set column. */ - col = (pcollen + colpos2) % l->cols; - if (col) - snprintf(seq,64,"\r\x1b[%dC", col); - else - snprintf(seq,64,"\r"); - ab.append(seq); - } - - l->oldcolpos = colpos2; - - (void) !write(fd, ab.c_str(), ab.size()); /* Can't recover from write error. */ -} - -/* Calls the two low level functions refreshSingleLine() or - * refreshMultiLine() according to the selected mode. */ -static void refreshLineWithFlags(struct linenoiseState *l, int flags) { - if (mlmode) - refreshMultiLine(l,flags); - else - refreshSingleLine(l,flags); -} - -/* Utility function to avoid specifying REFRESH_ALL all the times. */ -static void refreshLine(struct linenoiseState *l) { - refreshLineWithFlags(l,REFRESH_ALL); -} - -/* Hide the current line, when using the multiplexing API. */ -void linenoiseHide(struct linenoiseState *l) { - if (mlmode) - refreshMultiLine(l,REFRESH_CLEAN); - else - refreshSingleLine(l,REFRESH_CLEAN); -} - -/* Show the current line, when using the multiplexing API. */ -void linenoiseShow(struct linenoiseState *l) { - if (l->in_completion) { - refreshLineWithCompletion(l,NULL,REFRESH_WRITE); - } else { - refreshLineWithFlags(l,REFRESH_WRITE); - } -} - -/* Insert the character 'c' at cursor current position. - * - * On error writing to the terminal -1 is returned, otherwise 0. */ -static int linenoiseEditInsert(struct linenoiseState * l, const char * cbuf, int clen) { - if (l->len + clen <= l->buflen) { - if (l->len == l->pos) { - memcpy(&l->buf[l->pos], cbuf, clen); - l->pos += clen; - l->len += clen; - ; - l->buf[l->len] = '\0'; - if ((!mlmode && promptTextColumnLen(l->prompt, l->plen) + columnPos(l->buf, l->len, l->len) < l->cols && - !hintsCallback)) { - /* Avoid a full update of the line in the - * trivial case. */ - if (maskmode == 1) { - static const char d = '*'; - if (write(l->ofd, &d, 1) == -1) { - return -1; - } - } else { - if (write(l->ofd, cbuf, clen) == -1) { - return -1; - } - } - } else { - refreshLine(l); - } - } else { - memmove(l->buf + l->pos + clen, l->buf + l->pos, l->len - l->pos); - memcpy(&l->buf[l->pos], cbuf, clen); - l->pos += clen; - l->len += clen; - l->buf[l->len] = '\0'; - refreshLine(l); - } - } - return 0; -} - -/* Move cursor on the left. */ -static void linenoiseEditMoveLeft(struct linenoiseState * l) { - if (l->pos > 0) { - l->pos -= prevCharLen(l->buf, l->len, l->pos, NULL); - refreshLine(l); - } -} - -/* Move cursor on the right. */ -static void linenoiseEditMoveRight(struct linenoiseState * l) { - if (l->pos != l->len) { - l->pos += nextCharLen(l->buf, l->len, l->pos, NULL); - refreshLine(l); - } -} - -/* Move cursor to the start of the line. */ -static void linenoiseEditMoveHome(struct linenoiseState * l) { - if (l->pos != 0) { - l->pos = 0; - refreshLine(l); - } -} - -/* Move cursor to the end of the line. */ -static void linenoiseEditMoveEnd(struct linenoiseState * l) { - if (l->pos != l->len) { - l->pos = l->len; - refreshLine(l); - } -} - -/* Substitute the currently edited line with the next or previous history - * entry as specified by 'dir'. */ -#define LINENOISE_HISTORY_NEXT 0 -#define LINENOISE_HISTORY_PREV 1 - -static void linenoiseEditHistoryNext(struct linenoiseState * l, int dir) { - if (history_len > 1) { - /* Update the current history entry before to - * overwrite it with the next one. */ - free(history[history_len - 1 - l->history_index]); - history[history_len - 1 - l->history_index] = strdup(l->buf); - /* Show the new entry */ - l->history_index += (dir == LINENOISE_HISTORY_PREV) ? 1 : -1; - if (l->history_index < 0) { - l->history_index = 0; - return; - } else if (l->history_index >= history_len) { - l->history_index = history_len-1; - return; - } - strncpy(l->buf,history[history_len - 1 - l->history_index],l->buflen); - l->buf[l->buflen-1] = '\0'; - l->len = l->pos = strlen(l->buf); - refreshLine(l); - } -} - -/* Delete the character at the right of the cursor without altering the cursor - * position. Basically this is what happens with the "Delete" keyboard key. */ -static void linenoiseEditDelete(struct linenoiseState * l) { - if (l->len > 0 && l->pos < l->len) { - int chlen = nextCharLen(l->buf, l->len, l->pos, NULL); - memmove(l->buf + l->pos, l->buf + l->pos + chlen, l->len - l->pos - chlen); - l->len -= chlen; - l->buf[l->len] = '\0'; - refreshLine(l); - } -} - -/* Backspace implementation. */ -static void linenoiseEditBackspace(struct linenoiseState * l) { - if (l->pos > 0 && l->len > 0) { - int chlen = prevCharLen(l->buf, l->len, l->pos, NULL); - memmove(l->buf + l->pos - chlen, l->buf + l->pos, l->len - l->pos); - l->pos -= chlen; - l->len -= chlen; - l->buf[l->len] = '\0'; - refreshLine(l); - } -} - -/* Delete the previous word, maintaining the cursor at the start of the - * current word. */ -static void linenoiseEditDeletePrevWord(struct linenoiseState * l) { - size_t old_pos = l->pos; - size_t diff; - - while (l->pos > 0 && l->buf[l->pos-1] == ' ') - l->pos--; - while (l->pos > 0 && l->buf[l->pos-1] != ' ') - l->pos--; - diff = old_pos - l->pos; - memmove(l->buf+l->pos,l->buf+old_pos,l->len-old_pos+1); - l->len -= diff; - refreshLine(l); -} - -/* This function is part of the multiplexed API of Linenoise, that is used - * in order to implement the blocking variant of the API but can also be - * called by the user directly in an event driven program. It will: - * - * 1. Initialize the linenoise state passed by the user. - * 2. Put the terminal in RAW mode. - * 3. Show the prompt. - * 4. Return control to the user, that will have to call linenoiseEditFeed() - * each time there is some data arriving in the standard input. - * - * The user can also call linenoiseEditHide() and linenoiseEditShow() if it - * is required to show some input arriving asynchronously, without mixing - * it with the currently edited line. - * - * When linenoiseEditFeed() returns non-NULL, the user finished with the - * line editing session (pressed enter CTRL-D/C): in this case the caller - * needs to call linenoiseEditStop() to put back the terminal in normal - * mode. This will not destroy the buffer, as long as the linenoiseState - * is still valid in the context of the caller. - * - * The function returns 0 on success, or -1 if writing to standard output - * fails. If stdin_fd or stdout_fd are set to -1, the default is to use - * STDIN_FILENO and STDOUT_FILENO. - */ -int linenoiseEditStart(struct linenoiseState *l, int stdin_fd, int stdout_fd, char *buf, size_t buflen, const char *prompt) { - /* Populate the linenoise state that we pass to functions implementing - * specific editing functionalities. */ - l->in_completion = 0; - l->ifd = stdin_fd != -1 ? stdin_fd : STDIN_FILENO; - l->ofd = stdout_fd != -1 ? stdout_fd : STDOUT_FILENO; - l->buf = buf; - l->buflen = buflen; - l->prompt = prompt; - l->plen = strlen(prompt); - l->oldcolpos = l->pos = 0; - l->len = 0; - - /* Enter raw mode. */ - if (enableRawMode(l->ifd) == -1) return -1; - - l->cols = getColumns(stdin_fd, stdout_fd); - l->oldrows = 0; - l->history_index = 0; - - /* Buffer starts empty. */ - l->buf[0] = '\0'; - l->buflen--; /* Make sure there is always space for the nullterm */ - - /* If stdin is not a tty, stop here with the initialization. We - * will actually just read a line from standard input in blocking - * mode later, in linenoiseEditFeed(). */ - if (!isatty(l->ifd)) return 0; - - /* The latest history entry is always our current buffer, that - * initially is just an empty string. */ - linenoiseHistoryAdd(""); - - if (write(l->ofd,prompt,l->plen) == -1) return -1; - return 0; -} - -const char* linenoiseEditMore = "If you see this, you are misusing the API: when linenoiseEditFeed() is called, if it returns linenoiseEditMore the user is yet editing the line. See the README file for more information."; - -static const char * handleEnterKey(struct linenoiseState * l) { - --history_len; - free(history[history_len]); - if (mlmode) { - linenoiseEditMoveEnd(l); - } - if (hintsCallback) { - /* Force a refresh without hints to leave the previous - * line as the user typed it after a newline. */ - linenoiseHintsCallback * hc = hintsCallback; - hintsCallback = NULL; - refreshLine(l); - hintsCallback = hc; - } - - return strdup(l->buf); -} - -static const char * handleCtrlCKey() { - errno = EAGAIN; - return NULL; -} - -static const char * handleCtrlDKey(struct linenoiseState * l) { - if (l->len > 0) { - linenoiseEditDelete(l); - return linenoiseEditMore; - } - - --history_len; - free(history[history_len]); - errno = ENOENT; - return NULL; -} - -static void handleCtrlTKey(struct linenoiseState * l) { - if (l->pos > 0 && l->pos < l->len) { - auto prev_chlen = prevCharLen(l->buf, l->len, l->pos, NULL); - auto curr_chlen = nextCharLen(l->buf, l->len, l->pos, NULL); - - std::string prev_char(prev_chlen, 0); - memcpy(prev_char.data(), l->buf + l->pos - prev_chlen, prev_chlen); - memmove(l->buf + l->pos - prev_chlen, l->buf + l->pos, curr_chlen); - memmove(l->buf + l->pos - prev_chlen + curr_chlen, prev_char.data(), prev_chlen); - - l->pos = l->pos - prev_chlen + curr_chlen; - if (l->pos + prev_chlen != l->len) { - l->pos += prev_chlen; - } - - refreshLine(l); - } -} - -static void handleEscapeSequence(struct linenoiseState * l, int esc_type) { - switch (esc_type) { - case ESC_NULL: - break; - case ESC_DELETE: - linenoiseEditDelete(l); - break; - case ESC_UP: - linenoiseEditHistoryNext(l, LINENOISE_HISTORY_PREV); - break; - case ESC_DOWN: - linenoiseEditHistoryNext(l, LINENOISE_HISTORY_NEXT); - break; - case ESC_RIGHT: - linenoiseEditMoveRight(l); - break; - case ESC_LEFT: - linenoiseEditMoveLeft(l); - break; - case ESC_HOME: - linenoiseEditMoveHome(l); - break; - case ESC_END: - linenoiseEditMoveEnd(l); - break; - } -} - -static void handleCtrlUKey(struct linenoiseState * l) { - l->buf[0] = '\0'; - l->pos = l->len = 0; - refreshLine(l); -} - -static void handleCtrlKKey(struct linenoiseState * l) { - l->buf[l->pos] = '\0'; - l->len = l->pos; - refreshLine(l); -} - -static const char * processInputCharacter(struct linenoiseState * l, int c, char * cbuf, int nread, int esc_type) { - switch (c) { - case ENTER: - return handleEnterKey(l); - case CTRL_C: - return handleCtrlCKey(); - case BACKSPACE: - case CTRL_H: - linenoiseEditBackspace(l); - break; - case CTRL_D: /* ctrl-d, remove char at right of cursor, or if the - line is empty, act as end-of-file. */ - return handleCtrlDKey(l); - case CTRL_T: - handleCtrlTKey(l); - break; - case CTRL_B: - linenoiseEditMoveLeft(l); - break; - case CTRL_F: - linenoiseEditMoveRight(l); - break; - case CTRL_P: - linenoiseEditHistoryNext(l, LINENOISE_HISTORY_PREV); - break; - case CTRL_N: - linenoiseEditHistoryNext(l, LINENOISE_HISTORY_NEXT); - break; - case ESC: - handleEscapeSequence(l, esc_type); - break; - default: - if (linenoiseEditInsert(l, cbuf, nread)) { - return NULL; - } - break; - case CTRL_U: /* Ctrl+u, delete the whole line. */ - handleCtrlUKey(l); - break; - case CTRL_K: /* Ctrl+k, delete from current to end of line. */ - handleCtrlKKey(l); - break; - case CTRL_A: /* Ctrl+a, go to the start of the line */ - linenoiseEditMoveHome(l); - break; - case CTRL_E: /* ctrl+e, go to the end of the line */ - linenoiseEditMoveEnd(l); - break; - case CTRL_L: /* ctrl+l, clear screen */ - linenoiseClearScreen(); - refreshLine(l); - break; - case CTRL_W: /* ctrl+w, delete previous word */ - linenoiseEditDeletePrevWord(l); - break; - } - return linenoiseEditMore; -} - -/* This function is part of the multiplexed API of linenoise, see the top - * comment on linenoiseEditStart() for more information. Call this function - * each time there is some data to read from the standard input file - * descriptor. In the case of blocking operations, this function can just be - * called in a loop, and block. - * - * The function returns linenoiseEditMore to signal that line editing is still - * in progress, that is, the user didn't yet pressed enter / CTRL-D. Otherwise - * the function returns the pointer to the heap-allocated buffer with the - * edited line, that the user should free with linenoiseFree(). - * - * On special conditions, NULL is returned and errno is populated: - * - * EAGAIN if the user pressed Ctrl-C - * ENOENT if the user pressed Ctrl-D - * - * Some other errno: I/O error. - */ -const char * linenoiseEditFeed(struct linenoiseState * l) { - /* Not a TTY, pass control to line reading without character count - * limits. */ - if (!isatty(l->ifd)) return linenoiseNoTTY(); - - int c; - int nread; - char cbuf[32]; - - nread = readCode(l->ifd, cbuf, sizeof(cbuf), &c); - if (nread <= 0) return NULL; - - auto esc_type = ESC_NULL; - if (c == ESC) { - esc_type = readEscapeSequence(l); - } - - /* Only autocomplete when the callback is set. It returns < 0 when - * there was an error reading from fd. Otherwise it will return the - * character that should be handled next. */ - if ((l->in_completion || c == 9) && completionCallback != NULL) { - c = completeLine(l, c, esc_type); - /* Read next character when 0 */ - if (c == 0) return linenoiseEditMore; - } - - return processInputCharacter(l, c, cbuf, nread, esc_type); -} - -/* This is part of the multiplexed linenoise API. See linenoiseEditStart() - * for more information. This function is called when linenoiseEditFeed() - * returns something different than NULL. At this point the user input - * is in the buffer, and we can restore the terminal in normal mode. */ -void linenoiseEditStop(struct linenoiseState *l) { - if (!isatty(l->ifd)) return; - disableRawMode(l->ifd); - printf("\n"); -} - -/* This just implements a blocking loop for the multiplexed API. - * In many applications that are not event-driven, we can just call - * the blocking linenoise API, wait for the user to complete the editing - * and return the buffer. */ -static const char *linenoiseBlockingEdit(int stdin_fd, int stdout_fd, char *buf, size_t buflen, const char *prompt) -{ - struct linenoiseState l; - - /* Editing without a buffer is invalid. */ - if (buflen == 0) { - errno = EINVAL; - return NULL; - } - - linenoiseEditStart(&l,stdin_fd,stdout_fd,buf,buflen,prompt); - const char *res; - while((res = linenoiseEditFeed(&l)) == linenoiseEditMore); - linenoiseEditStop(&l); - return res; -} - -/* This special mode is used by linenoise in order to print scan codes - * on screen for debugging / development purposes. It is implemented - * by the linenoise_example program using the --keycodes option. */ -void linenoisePrintKeyCodes(void) { - char quit[4]; - - printf("Linenoise key codes debugging mode.\n" - "Press keys to see scan codes. Type 'quit' at any time to exit.\n"); - if (enableRawMode(STDIN_FILENO) == -1) return; - memset(quit,' ',4); - while(1) { - char c; - int nread; - - nread = read(STDIN_FILENO,&c,1); - if (nread <= 0) continue; - memmove(quit,quit+1,sizeof(quit)-1); /* shift string to left. */ - quit[sizeof(quit)-1] = c; /* Insert current char on the right. */ - if (memcmp(quit,"quit",sizeof(quit)) == 0) break; - - printf("'%c' %02x (%d) (type quit to exit)\n", isprint((int) c) ? c : '?', (int) c, (int) c); - printf("\r"); /* Go left edge manually, we are in raw mode. */ - fflush(stdout); - } - disableRawMode(STDIN_FILENO); -} - -/* This function is called when linenoise() is called with the standard - * input file descriptor not attached to a TTY. So for example when the - * program using linenoise is called in pipe or with a file redirected - * to its standard input. In this case, we want to be able to return the - * line regardless of its length (by default we are limited to 4k). */ -static char *linenoiseNoTTY(void) { - char *line = NULL; - size_t len = 0, maxlen = 0; - - while(1) { - if (len == maxlen) { - if (maxlen == 0) maxlen = 16; - maxlen *= 2; - char *oldval = line; - line = (char*) realloc(line,maxlen); - if (line == NULL) { - if (oldval) free(oldval); - return NULL; - } - } - int c = fgetc(stdin); - if (c == EOF || c == '\n') { - if (c == EOF && len == 0) { - free(line); - return NULL; - } else { - line[len] = '\0'; - return line; - } - } else { - line[len] = c; - len++; - } - } -} - -/* The high level function that is the main API of the linenoise library. - * This function checks if the terminal has basic capabilities, just checking - * for a blacklist of stupid terminals, and later either calls the line - * editing function or uses dummy fgets() so that you will be able to type - * something even in the most desperate of the conditions. */ -const char *linenoise(const char *prompt) { - char buf[LINENOISE_MAX_LINE]; - - if (!isatty(STDIN_FILENO)) { - /* Not a tty: read from file / pipe. In this mode we don't want any - * limit to the line size, so we call a function to handle that. */ - return linenoiseNoTTY(); - } else if (isUnsupportedTerm()) { - size_t len; - - printf("%s",prompt); - fflush(stdout); - if (fgets(buf,LINENOISE_MAX_LINE,stdin) == NULL) return NULL; - len = strlen(buf); - while(len && (buf[len-1] == '\n' || buf[len-1] == '\r')) { - len--; - buf[len] = '\0'; - } - return strdup(buf); - } else { - const char *retval = linenoiseBlockingEdit(STDIN_FILENO,STDOUT_FILENO,buf,LINENOISE_MAX_LINE,prompt); - return retval; - } -} - -/* This is just a wrapper the user may want to call in order to make sure - * the linenoise returned buffer is freed with the same allocator it was - * created with. Useful when the main program is using an alternative - * allocator. */ -void linenoiseFree(void *ptr) { - if (ptr == linenoiseEditMore) return; // Protect from API misuse. - free(ptr); -} - -/* ================================ History ================================= */ - -/* Free the history, but does not reset it. Only used when we have to - * exit() to avoid memory leaks are reported by valgrind & co. */ -static void freeHistory(void) { - if (history) { - int j; - - for (j = 0; j < history_len; j++) - free(history[j]); - free(history); - } -} - -/* At exit we'll try to fix the terminal to the initial conditions. */ -static void linenoiseAtExit(void) { - disableRawMode(STDIN_FILENO); - freeHistory(); -} - -/* This is the API call to add a new entry in the linenoise history. - * It uses a fixed array of char pointers that are shifted (memmoved) - * when the history max length is reached in order to remove the older - * entry and make room for the new one, so it is not exactly suitable for huge - * histories, but will work well for a few hundred of entries. - * - * Using a circular buffer is smarter, but a bit more complex to handle. */ -int linenoiseHistoryAdd(const char *line) { - char *linecopy; - - if (history_max_len == 0) return 0; - - /* Initialization on first call. */ - if (history == NULL) { - history = (char**) malloc(sizeof(char*)*history_max_len); - if (history == NULL) return 0; - memset(history,0,(sizeof(char*)*history_max_len)); - } - - /* Don't add duplicated lines. */ - if (history_len && !strcmp(history[history_len-1], line)) return 0; - - /* Add an heap allocated copy of the line in the history. - * If we reached the max length, remove the older line. */ - linecopy = strdup(line); - if (!linecopy) return 0; - if (history_len == history_max_len) { - free(history[0]); - memmove(history,history+1,sizeof(char*)*(history_max_len-1)); - history_len--; - } - history[history_len] = linecopy; - history_len++; - return 1; -} - -/* Set the maximum length for the history. This function can be called even - * if there is already some history, the function will make sure to retain - * just the latest 'len' elements if the new history length value is smaller - * than the amount of items already inside the history. */ -int linenoiseHistorySetMaxLen(int len) { - char **new_ptr; - - if (len < 1) return 0; - if (history) { - int tocopy = history_len; - - new_ptr = (char**) malloc(sizeof(char*)*len); - if (new_ptr == NULL) return 0; - - /* If we can't copy everything, free the elements we'll not use. */ - if (len < tocopy) { - int j; - - for (j = 0; j < tocopy-len; j++) free(history[j]); - tocopy = len; - } - memset(new_ptr,0,sizeof(char*)*len); - memcpy(new_ptr,history+(history_len-tocopy), sizeof(char*)*tocopy); - free(history); - history = new_ptr; - } - history_max_len = len; - if (history_len > history_max_len) - history_len = history_max_len; - return 1; -} - -/* Save the history in the specified file. On success 0 is returned - * otherwise -1 is returned. */ -int linenoiseHistorySave(const char *filename) { - mode_t old_umask = umask(S_IXUSR|S_IRWXG|S_IRWXO); - File file; - file.open(filename, "w"); - umask(old_umask); - if (file.file == NULL) { - return -1; - } - chmod(filename,S_IRUSR|S_IWUSR); - for (int j = 0; j < history_len; ++j) { - fprintf(file.file, "%s\n", history[j]); - } - - return 0; -} - -/* Load the history from the specified file. If the file does not exist - * zero is returned and no operation is performed. - * - * If the file exists and the operation succeeded 0 is returned, otherwise - * on error -1 is returned. */ -int linenoiseHistoryLoad(const char *filename) { - File file; - file.open(filename, "r"); - char buf[LINENOISE_MAX_LINE]; - if (file.file == NULL) { - return -1; - } - - while (fgets(buf, LINENOISE_MAX_LINE, file.file) != NULL) { - char *p; - - p = strchr(buf,'\r'); - if (!p) p = strchr(buf,'\n'); - if (p) *p = '\0'; - linenoiseHistoryAdd(buf); - } - return 0; -} -#endif diff --git a/tools/run/linenoise.cpp/linenoise.h b/tools/run/linenoise.cpp/linenoise.h deleted file mode 100644 index 9823ca36d0..0000000000 --- a/tools/run/linenoise.cpp/linenoise.h +++ /dev/null @@ -1,137 +0,0 @@ -/* linenoise.h -- VERSION 1.0 - * - * Guerrilla line editing library against the idea that a line editing lib - * needs to be 20,000 lines of C++ code. - * - * See linenoise.cpp for more information. - * - * ------------------------------------------------------------------------ - * - * Copyright (c) 2010-2023, Salvatore Sanfilippo - * Copyright (c) 2010-2013, Pieter Noordhuis - * Copyright (c) 2025, Eric Curtin - * - * All rights reserved. - * - * Redistribution and use in source and binary forms, with or without - * modification, are permitted provided that the following conditions are - * met: - * - * * Redistributions of source code must retain the above copyright - * notice, this list of conditions and the following disclaimer. - * - * * Redistributions in binary form must reproduce the above copyright - * notice, this list of conditions and the following disclaimer in the - * documentation and/or other materials provided with the distribution. - * - * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS - * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT - * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR - * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT - * HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, - * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT - * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, - * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY - * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT - * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE - * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - */ - -#ifndef __LINENOISE_H -#define __LINENOISE_H - -#ifdef __cplusplus -extern "C" { -#endif - -#include /* For size_t. */ -#include - -extern const char * linenoiseEditMore; - -/* The linenoiseState structure represents the state during line editing. - * We pass this state to functions implementing specific editing - * functionalities. */ -struct linenoiseState { - int in_completion; /* The user pressed TAB and we are now in completion - * mode, so input is handled by completeLine(). */ - size_t completion_idx; /* Index of next completion to propose. */ - int ifd; /* Terminal stdin file descriptor. */ - int ofd; /* Terminal stdout file descriptor. */ - char * buf; /* Edited line buffer. */ - size_t buflen; /* Edited line buffer size. */ - const char * prompt; /* Prompt to display. */ - size_t plen; /* Prompt length. */ - size_t pos; /* Current cursor position. */ - size_t oldcolpos; /* Previous refresh cursor column position. */ - size_t len; /* Current edited line length. */ - size_t cols; /* Number of columns in terminal. */ - size_t oldrows; /* Rows used by last refreshed line (multiline mode) */ - int history_index; /* The history index we are currently editing. */ -}; - -struct linenoiseCompletions { - size_t len = 0; - char ** cvec = nullptr; - bool to_free = true; - - ~linenoiseCompletions() { - if (!to_free) { - return; - } - - for (size_t i = 0; i < len; ++i) { - free(cvec[i]); - } - - free(cvec); - } -}; - -/* Non blocking API. */ -int linenoiseEditStart(struct linenoiseState * l, int stdin_fd, int stdout_fd, char * buf, size_t buflen, - const char * prompt); -const char * linenoiseEditFeed(struct linenoiseState * l); -void linenoiseEditStop(struct linenoiseState * l); -void linenoiseHide(struct linenoiseState * l); -void linenoiseShow(struct linenoiseState * l); - -/* Blocking API. */ -const char * linenoise(const char * prompt); -void linenoiseFree(void * ptr); - -/* Completion API. */ -typedef void(linenoiseCompletionCallback)(const char *, linenoiseCompletions *); -typedef const char *(linenoiseHintsCallback) (const char *, int * color, int * bold); -typedef void(linenoiseFreeHintsCallback)(const char *); -void linenoiseSetCompletionCallback(linenoiseCompletionCallback *); -void linenoiseSetHintsCallback(linenoiseHintsCallback *); -void linenoiseSetFreeHintsCallback(linenoiseFreeHintsCallback *); -void linenoiseAddCompletion(linenoiseCompletions *, const char *); - -/* History API. */ -int linenoiseHistoryAdd(const char * line); -int linenoiseHistorySetMaxLen(int len); -int linenoiseHistorySave(const char * filename); -int linenoiseHistoryLoad(const char * filename); - -/* Other utilities. */ -void linenoiseClearScreen(void); -void linenoiseSetMultiLine(int ml); -void linenoisePrintKeyCodes(void); -void linenoiseMaskModeEnable(void); -void linenoiseMaskModeDisable(void); - -/* Encoding functions. */ -typedef size_t(linenoisePrevCharLen)(const char * buf, size_t buf_len, size_t pos, size_t * col_len); -typedef size_t(linenoiseNextCharLen)(const char * buf, size_t buf_len, size_t pos, size_t * col_len); -typedef size_t(linenoiseReadCode)(int fd, char * buf, size_t buf_len, int * c); - -void linenoiseSetEncodingFunctions(linenoisePrevCharLen * prevCharLenFunc, linenoiseNextCharLen * nextCharLenFunc, - linenoiseReadCode * readCodeFunc); - -#ifdef __cplusplus -} -#endif - -#endif /* __LINENOISE_H */ diff --git a/tools/run/run.cpp b/tools/run/run.cpp deleted file mode 100644 index b90a7253c4..0000000000 --- a/tools/run/run.cpp +++ /dev/null @@ -1,1408 +0,0 @@ -#include "chat.h" -#include "common.h" -#include "llama-cpp.h" -#include "log.h" - -#include "linenoise.cpp/linenoise.h" - -#define JSON_ASSERT GGML_ASSERT -#include - -#if defined(_WIN32) -# define WIN32_LEAN_AND_MEAN -# ifndef NOMINMAX -# define NOMINMAX -# endif -# include -# include -#else -# include -# include -# include -#endif - -#if defined(LLAMA_USE_CURL) -# include -#else -# include "http.h" -#endif - -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) || defined(_WIN32) -[[noreturn]] static void sigint_handler(int) { - printf("\n" LOG_COL_DEFAULT); - exit(0); // not ideal, but it's the only way to guarantee exit in all cases -} -#endif - -GGML_ATTRIBUTE_FORMAT(1, 2) -static int printe(const char * fmt, ...) { - va_list args; - va_start(args, fmt); - const int ret = vfprintf(stderr, fmt, args); - va_end(args); - - return ret; -} - -static std::string strftime_fmt(const char * fmt, const std::tm & tm) { - std::ostringstream oss; - oss << std::put_time(&tm, fmt); - - return oss.str(); -} - -class Opt { - public: - int init(int argc, const char ** argv) { - ctx_params = llama_context_default_params(); - model_params = llama_model_default_params(); - context_size_default = ctx_params.n_batch; - n_threads_default = ctx_params.n_threads; - ngl_default = model_params.n_gpu_layers; - common_params_sampling sampling; - temperature_default = sampling.temp; - - if (argc < 2) { - printe("Error: No arguments provided.\n"); - print_help(); - return 1; - } - - // Parse arguments - if (parse(argc, argv)) { - printe("Error: Failed to parse arguments.\n"); - print_help(); - return 1; - } - - // If help is requested, show help and exit - if (help) { - print_help(); - return 2; - } - - ctx_params.n_batch = context_size >= 0 ? context_size : context_size_default; - ctx_params.n_ctx = ctx_params.n_batch; - ctx_params.n_threads = ctx_params.n_threads_batch = n_threads >= 0 ? n_threads : n_threads_default; - model_params.n_gpu_layers = ngl >= 0 ? ngl : ngl_default; - temperature = temperature >= 0 ? temperature : temperature_default; - - return 0; // Success - } - - llama_context_params ctx_params; - llama_model_params model_params; - std::string model_; - std::string chat_template_file; - std::string user; - bool use_jinja = false; - int context_size = -1, ngl = -1, n_threads = -1; - float temperature = -1; - bool verbose = false; - - private: - int context_size_default = -1, ngl_default = -1, n_threads_default = -1; - float temperature_default = -1; - bool help = false; - - bool parse_flag(const char ** argv, int i, const char * short_opt, const char * long_opt) { - return strcmp(argv[i], short_opt) == 0 || strcmp(argv[i], long_opt) == 0; - } - - int handle_option_with_value(int argc, const char ** argv, int & i, int & option_value) { - if (i + 1 >= argc) { - return 1; - } - - option_value = std::atoi(argv[++i]); - - return 0; - } - - int handle_option_with_value(int argc, const char ** argv, int & i, float & option_value) { - if (i + 1 >= argc) { - return 1; - } - - option_value = std::atof(argv[++i]); - - return 0; - } - - int handle_option_with_value(int argc, const char ** argv, int & i, std::string & option_value) { - if (i + 1 >= argc) { - return 1; - } - - option_value = argv[++i]; - - return 0; - } - - int parse_options_with_value(int argc, const char ** argv, int & i, bool & options_parsing) { - if (options_parsing && (strcmp(argv[i], "-c") == 0 || strcmp(argv[i], "--context-size") == 0)) { - if (handle_option_with_value(argc, argv, i, context_size) == 1) { - return 1; - } - } else if (options_parsing && - (strcmp(argv[i], "-n") == 0 || strcmp(argv[i], "-ngl") == 0 || strcmp(argv[i], "--ngl") == 0)) { - if (handle_option_with_value(argc, argv, i, ngl) == 1) { - return 1; - } - } else if (options_parsing && (strcmp(argv[i], "-t") == 0 || strcmp(argv[i], "--threads") == 0)) { - if (handle_option_with_value(argc, argv, i, n_threads) == 1) { - return 1; - } - } else if (options_parsing && strcmp(argv[i], "--temp") == 0) { - if (handle_option_with_value(argc, argv, i, temperature) == 1) { - return 1; - } - } else if (options_parsing && strcmp(argv[i], "--chat-template-file") == 0) { - if (handle_option_with_value(argc, argv, i, chat_template_file) == 1) { - return 1; - } - use_jinja = true; - } else { - return 2; - } - - return 0; - } - - int parse_options(const char ** argv, int & i, bool & options_parsing) { - if (options_parsing && (parse_flag(argv, i, "-v", "--verbose") || parse_flag(argv, i, "-v", "--log-verbose"))) { - verbose = true; - } else if (options_parsing && strcmp(argv[i], "--jinja") == 0) { - use_jinja = true; - } else if (options_parsing && parse_flag(argv, i, "-h", "--help")) { - help = true; - return 0; - } else if (options_parsing && strcmp(argv[i], "--") == 0) { - options_parsing = false; - } else { - return 2; - } - - return 0; - } - - int parse_positional_args(const char ** argv, int & i, int & positional_args_i) { - if (positional_args_i == 0) { - if (!argv[i][0] || argv[i][0] == '-') { - return 1; - } - - ++positional_args_i; - model_ = argv[i]; - } else if (positional_args_i == 1) { - ++positional_args_i; - user = argv[i]; - } else { - user += " " + std::string(argv[i]); - } - - return 0; - } - - int parse(int argc, const char ** argv) { - bool options_parsing = true; - for (int i = 1, positional_args_i = 0; i < argc; ++i) { - int ret = parse_options_with_value(argc, argv, i, options_parsing); - if (ret == 0) { - continue; - } else if (ret == 1) { - return ret; - } - - ret = parse_options(argv, i, options_parsing); - if (ret == 0) { - continue; - } else if (ret == 1) { - return ret; - } - - if (parse_positional_args(argv, i, positional_args_i)) { - return 1; - } - } - - if (model_.empty()) { - return 1; - } - - return 0; - } - - void print_help() const { - printf( - "Description:\n" - " Runs a llm\n" - "\n" - "Usage:\n" - " llama-run [options] model [prompt]\n" - "\n" - "Options:\n" - " -c, --context-size \n" - " Context size (default: %d)\n" - " --chat-template-file \n" - " Path to the file containing the chat template to use with the model.\n" - " Only supports jinja templates and implicitly sets the --jinja flag.\n" - " --jinja\n" - " Use jinja templating for the chat template of the model\n" - " -n, -ngl, --ngl \n" - " Number of GPU layers (default: %d)\n" - " --temp \n" - " Temperature (default: %.1f)\n" - " -t, --threads \n" - " Number of threads to use during generation (default: %d)\n" - " -v, --verbose, --log-verbose\n" - " Set verbosity level to infinity (i.e. log all messages, useful for debugging)\n" - " -h, --help\n" - " Show help message\n" - "\n" - "Commands:\n" - " model\n" - " Model is a string with an optional prefix of \n" - " huggingface:// (hf://), modelscope:// (ms://), ollama://, https:// or file://.\n" - " If no protocol is specified and a file exists in the specified\n" - " path, file:// is assumed, otherwise if a file does not exist in\n" - " the specified path, ollama:// is assumed. Models that are being\n" - " pulled are downloaded with .partial extension while being\n" - " downloaded and then renamed as the file without the .partial\n" - " extension when complete.\n" - "\n" - "Examples:\n" - " llama-run llama3\n" - " llama-run ollama://granite-code\n" - " llama-run ollama://smollm:135m\n" - " llama-run hf://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf\n" - " llama-run " - "huggingface://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf\n" - " llama-run ms://QuantFactory/SmolLM-135M-GGUF/SmolLM-135M.Q2_K.gguf\n" - " llama-run " - "modelscope://bartowski/SmolLM-1.7B-Instruct-v0.2-GGUF/SmolLM-1.7B-Instruct-v0.2-IQ3_M.gguf\n" - " llama-run https://example.com/some-file1.gguf\n" - " llama-run some-file2.gguf\n" - " llama-run file://some-file3.gguf\n" - " llama-run --ngl 999 some-file4.gguf\n" - " llama-run --ngl 999 some-file5.gguf Hello World\n", - context_size_default, ngl_default, temperature_default, n_threads_default); - } -}; - -struct progress_data { - size_t file_size = 0; - std::chrono::steady_clock::time_point start_time = std::chrono::steady_clock::now(); - bool printed = false; -}; - -static int get_terminal_width() { -#if defined(_WIN32) - CONSOLE_SCREEN_BUFFER_INFO csbi; - GetConsoleScreenBufferInfo(GetStdHandle(STD_OUTPUT_HANDLE), &csbi); - return csbi.srWindow.Right - csbi.srWindow.Left + 1; -#else - struct winsize w; - ioctl(STDOUT_FILENO, TIOCGWINSZ, &w); - return w.ws_col; -#endif -} - -class File { - public: - FILE * file = nullptr; - - FILE * open(const std::string & filename, const char * mode) { - file = ggml_fopen(filename.c_str(), mode); - - return file; - } - - int lock() { - if (file) { -# ifdef _WIN32 - fd = _fileno(file); - hFile = (HANDLE) _get_osfhandle(fd); - if (hFile == INVALID_HANDLE_VALUE) { - fd = -1; - - return 1; - } - - OVERLAPPED overlapped = {}; - if (!LockFileEx(hFile, LOCKFILE_EXCLUSIVE_LOCK | LOCKFILE_FAIL_IMMEDIATELY, 0, MAXDWORD, MAXDWORD, - &overlapped)) { - fd = -1; - - return 1; - } -# else - fd = fileno(file); - if (flock(fd, LOCK_EX | LOCK_NB) != 0) { - fd = -1; - - return 1; - } -# endif - } - - return 0; - } - - std::string to_string() { - fseek(file, 0, SEEK_END); - const size_t size = ftell(file); - fseek(file, 0, SEEK_SET); - std::string out; - out.resize(size); - const size_t read_size = fread(&out[0], 1, size, file); - if (read_size != size) { - printe("Error reading file: %s", strerror(errno)); - } - - return out; - } - - ~File() { - if (fd >= 0) { -# ifdef _WIN32 - if (hFile != INVALID_HANDLE_VALUE) { - OVERLAPPED overlapped = {}; - UnlockFileEx(hFile, 0, MAXDWORD, MAXDWORD, &overlapped); - } -# else - flock(fd, LOCK_UN); -# endif - } - - if (file) { - fclose(file); - } - } - - private: - int fd = -1; -# ifdef _WIN32 - HANDLE hFile = nullptr; -# endif -}; - -class HttpClient { - public: - int init(const std::string & url, const std::vector & headers, const std::string & output_file, - const bool progress, std::string * response_str = nullptr) { - if (std::filesystem::exists(output_file)) { - return 0; - } - - std::string output_file_partial; - - if (!output_file.empty()) { - output_file_partial = output_file + ".partial"; - } - - if (download(url, headers, output_file_partial, progress, response_str)) { - return 1; - } - - if (!output_file.empty()) { - try { - std::filesystem::rename(output_file_partial, output_file); - } catch (const std::filesystem::filesystem_error & e) { - printe("Failed to rename '%s' to '%s': %s\n", output_file_partial.c_str(), output_file.c_str(), e.what()); - return 1; - } - } - - return 0; - } - -#ifdef LLAMA_USE_CURL - - ~HttpClient() { - if (chunk) { - curl_slist_free_all(chunk); - } - - if (curl) { - curl_easy_cleanup(curl); - } - } - - private: - CURL * curl = nullptr; - struct curl_slist * chunk = nullptr; - - int download(const std::string & url, const std::vector & headers, const std::string & output_file, - const bool progress, std::string * response_str = nullptr) { - curl = curl_easy_init(); - if (!curl) { - return 1; - } - - progress_data data; - File out; - if (!output_file.empty()) { - if (!out.open(output_file, "ab")) { - printe("Failed to open file for writing\n"); - - return 1; - } - - if (out.lock()) { - printe("Failed to exclusively lock file\n"); - - return 1; - } - } - - set_write_options(response_str, out); - data.file_size = set_resume_point(output_file); - set_progress_options(progress, data); - set_headers(headers); - CURLcode res = perform(url); - if (res != CURLE_OK){ - printe("Fetching resource '%s' failed: %s\n", url.c_str(), curl_easy_strerror(res)); - return 1; - } - - return 0; - } - - void set_write_options(std::string * response_str, const File & out) { - if (response_str) { - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, capture_data); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, response_str); - } else { - curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, write_data); - curl_easy_setopt(curl, CURLOPT_WRITEDATA, out.file); - } - } - - size_t set_resume_point(const std::string & output_file) { - size_t file_size = 0; - if (std::filesystem::exists(output_file)) { - file_size = std::filesystem::file_size(output_file); - curl_easy_setopt(curl, CURLOPT_RESUME_FROM_LARGE, static_cast(file_size)); - } - - return file_size; - } - - void set_progress_options(bool progress, progress_data & data) { - if (progress) { - curl_easy_setopt(curl, CURLOPT_NOPROGRESS, 0L); - curl_easy_setopt(curl, CURLOPT_XFERINFODATA, &data); - curl_easy_setopt(curl, CURLOPT_XFERINFOFUNCTION, update_progress); - } - } - - void set_headers(const std::vector & headers) { - if (!headers.empty()) { - if (chunk) { - curl_slist_free_all(chunk); - chunk = 0; - } - - for (const auto & header : headers) { - chunk = curl_slist_append(chunk, header.c_str()); - } - - curl_easy_setopt(curl, CURLOPT_HTTPHEADER, chunk); - } - } - - CURLcode perform(const std::string & url) { - curl_easy_setopt(curl, CURLOPT_URL, url.c_str()); - curl_easy_setopt(curl, CURLOPT_FOLLOWLOCATION, 1L); - curl_easy_setopt(curl, CURLOPT_DEFAULT_PROTOCOL, "https"); - curl_easy_setopt(curl, CURLOPT_FAILONERROR, 1L); -#ifdef _WIN32 - curl_easy_setopt(curl, CURLOPT_SSL_OPTIONS, CURLSSLOPT_NATIVE_CA); -#endif - return curl_easy_perform(curl); - } - -#else // LLAMA_USE_CURL is not defined - -#define curl_off_t long long // temporary hack - - private: - // this is a direct translation of the cURL download() above - int download(const std::string & url, const std::vector & headers_vec, const std::string & output_file, - const bool progress, std::string * response_str = nullptr) { - try { - auto [cli, url_parts] = common_http_client(url); - - httplib::Headers headers; - for (const auto & h : headers_vec) { - size_t pos = h.find(':'); - if (pos != std::string::npos) { - headers.emplace(h.substr(0, pos), h.substr(pos + 2)); - } - } - - File out; - if (!output_file.empty()) { - if (!out.open(output_file, "ab")) { - printe("Failed to open file for writing\n"); - return 1; - } - if (out.lock()) { - printe("Failed to exclusively lock file\n"); - return 1; - } - } - - size_t resume_offset = 0; - if (!output_file.empty() && std::filesystem::exists(output_file)) { - resume_offset = std::filesystem::file_size(output_file); - if (resume_offset > 0) { - headers.emplace("Range", "bytes=" + std::to_string(resume_offset) + "-"); - } - } - - progress_data data; - data.file_size = resume_offset; - - long long total_size = 0; - long long received_this_session = 0; - - auto response_handler = - [&](const httplib::Response & response) { - if (resume_offset > 0 && response.status != 206) { - printe("\nServer does not support resuming. Restarting download.\n"); - out.file = freopen(output_file.c_str(), "wb", out.file); - if (!out.file) { - return false; - } - data.file_size = 0; - } - if (progress) { - if (response.has_header("Content-Length")) { - total_size = std::stoll(response.get_header_value("Content-Length")); - } else if (response.has_header("Content-Range")) { - auto range = response.get_header_value("Content-Range"); - auto slash = range.find('/'); - if (slash != std::string::npos) { - total_size = std::stoll(range.substr(slash + 1)); - } - } - } - return true; - }; - - auto content_receiver = - [&](const char * chunk, size_t length) { - if (out.file && fwrite(chunk, 1, length, out.file) != length) { - return false; - } - if (response_str) { - response_str->append(chunk, length); - } - received_this_session += length; - - if (progress && total_size > 0) { - update_progress(&data, total_size, received_this_session, 0, 0); - } - return true; - }; - - auto res = cli.Get(url_parts.path, headers, response_handler, content_receiver); - - if (data.printed) { - printe("\n"); - } - - if (!res) { - auto err = res.error(); - printe("Fetching resource '%s' failed: %s\n", url.c_str(), httplib::to_string(err).c_str()); - return 1; - } - - if (res->status >= 400) { - printe("Fetching resource '%s' failed with status code: %d\n", url.c_str(), res->status); - return 1; - } - - } catch (const std::exception & e) { - printe("HTTP request failed: %s\n", e.what()); - return 1; - } - return 0; - } - -#endif // LLAMA_USE_CURL - - static std::string human_readable_time(double seconds) { - int hrs = static_cast(seconds) / 3600; - int mins = (static_cast(seconds) % 3600) / 60; - int secs = static_cast(seconds) % 60; - - if (hrs > 0) { - return string_format("%dh %02dm %02ds", hrs, mins, secs); - } else if (mins > 0) { - return string_format("%dm %02ds", mins, secs); - } else { - return string_format("%ds", secs); - } - } - - static std::string human_readable_size(curl_off_t size) { - static const char * suffix[] = { "B", "KB", "MB", "GB", "TB" }; - char length = sizeof(suffix) / sizeof(suffix[0]); - int i = 0; - double dbl_size = size; - if (size > 1024) { - for (i = 0; (size / 1024) > 0 && i < length - 1; i++, size /= 1024) { - dbl_size = size / 1024.0; - } - } - - return string_format("%.2f %s", dbl_size, suffix[i]); - } - - static int update_progress(void * ptr, curl_off_t total_to_download, curl_off_t now_downloaded, curl_off_t, - curl_off_t) { - progress_data * data = static_cast(ptr); - if (total_to_download <= 0) { - return 0; - } - - total_to_download += data->file_size; - const curl_off_t now_downloaded_plus_file_size = now_downloaded + data->file_size; - const curl_off_t percentage = calculate_percentage(now_downloaded_plus_file_size, total_to_download); - std::string progress_prefix = generate_progress_prefix(percentage); - - const double speed = calculate_speed(now_downloaded, data->start_time); - const double tim = (total_to_download - now_downloaded) / speed; - std::string progress_suffix = - generate_progress_suffix(now_downloaded_plus_file_size, total_to_download, speed, tim); - - int progress_bar_width = calculate_progress_bar_width(progress_prefix, progress_suffix); - std::string progress_bar; - generate_progress_bar(progress_bar_width, percentage, progress_bar); - - print_progress(progress_prefix, progress_bar, progress_suffix); - data->printed = true; - - return 0; - } - - static curl_off_t calculate_percentage(curl_off_t now_downloaded_plus_file_size, curl_off_t total_to_download) { - return (now_downloaded_plus_file_size * 100) / total_to_download; - } - - static std::string generate_progress_prefix(curl_off_t percentage) { - return string_format("%3ld%% |", static_cast(percentage)); - } - - static double calculate_speed(curl_off_t now_downloaded, const std::chrono::steady_clock::time_point & start_time) { - const auto now = std::chrono::steady_clock::now(); - const std::chrono::duration elapsed_seconds = now - start_time; - return now_downloaded / elapsed_seconds.count(); - } - - static std::string generate_progress_suffix(curl_off_t now_downloaded_plus_file_size, curl_off_t total_to_download, - double speed, double estimated_time) { - const int width = 10; - return string_format("%*s/%*s%*s/s%*s", width, human_readable_size(now_downloaded_plus_file_size).c_str(), - width, human_readable_size(total_to_download).c_str(), width, - human_readable_size(speed).c_str(), width, human_readable_time(estimated_time).c_str()); - } - - static int calculate_progress_bar_width(const std::string & progress_prefix, const std::string & progress_suffix) { - int progress_bar_width = get_terminal_width() - progress_prefix.size() - progress_suffix.size() - 3; - if (progress_bar_width < 1) { - progress_bar_width = 1; - } - - return progress_bar_width; - } - - static std::string generate_progress_bar(int progress_bar_width, curl_off_t percentage, - std::string & progress_bar) { - const curl_off_t pos = (percentage * progress_bar_width) / 100; - for (int i = 0; i < progress_bar_width; ++i) { - progress_bar.append((i < pos) ? "█" : " "); - } - - return progress_bar; - } - - static void print_progress(const std::string & progress_prefix, const std::string & progress_bar, - const std::string & progress_suffix) { - printe("\r" LOG_CLR_TO_EOL "%s%s| %s", progress_prefix.c_str(), progress_bar.c_str(), progress_suffix.c_str()); - } - // Function to write data to a file - static size_t write_data(void * ptr, size_t size, size_t nmemb, void * stream) { - FILE * out = static_cast(stream); - return fwrite(ptr, size, nmemb, out); - } - - // Function to capture data into a string - static size_t capture_data(void * ptr, size_t size, size_t nmemb, void * stream) { - std::string * str = static_cast(stream); - str->append(static_cast(ptr), size * nmemb); - return size * nmemb; - } - -}; - -class LlamaData { - public: - llama_model_ptr model; - llama_sampler_ptr sampler; - llama_context_ptr context; - std::vector messages; // TODO: switch to common_chat_msg - std::list msg_strs; - std::vector fmtted; - - int init(Opt & opt) { - model = initialize_model(opt); - if (!model) { - return 1; - } - - context = initialize_context(model, opt); - if (!context) { - return 1; - } - - sampler = initialize_sampler(opt); - - return 0; - } - - private: - int download(const std::string & url, const std::string & output_file, const bool progress, - const std::vector & headers = {}, std::string * response_str = nullptr) { - HttpClient http; - if (http.init(url, headers, output_file, progress, response_str)) { - return 1; - } - - return 0; - } - - // Helper function to handle model tag extraction and URL construction - std::pair extract_model_and_tag(std::string & model, const std::string & base_url) { - std::string model_tag = "latest"; - const size_t colon_pos = model.find(':'); - if (colon_pos != std::string::npos) { - model_tag = model.substr(colon_pos + 1); - model = model.substr(0, colon_pos); - } - - std::string url = base_url + model + "/manifests/" + model_tag; - - return { model, url }; - } - - // Helper function to download and parse the manifest - int download_and_parse_manifest(const std::string & url, const std::vector & headers, - nlohmann::json & manifest) { - std::string manifest_str; - int ret = download(url, "", false, headers, &manifest_str); - if (ret) { - return ret; - } - - manifest = nlohmann::json::parse(manifest_str); - - return 0; - } - - int dl_from_endpoint(std::string & model_endpoint, std::string & model, const std::string & bn) { - // Find the second occurrence of '/' after protocol string - size_t pos = model.find('/'); - pos = model.find('/', pos + 1); - std::string hfr, hff; - std::vector headers = { "User-Agent: llama-cpp", "Accept: application/json" }; - std::string url; - - if (pos == std::string::npos) { - auto [model_name, manifest_url] = extract_model_and_tag(model, model_endpoint + "v2/"); - hfr = model_name; - - nlohmann::json manifest; - int ret = download_and_parse_manifest(manifest_url, headers, manifest); - if (ret) { - return ret; - } - - hff = manifest["ggufFile"]["rfilename"]; - } else { - hfr = model.substr(0, pos); - hff = model.substr(pos + 1); - } - - url = model_endpoint + hfr + "/resolve/main/" + hff; - - return download(url, bn, true, headers); - } - - int modelscope_dl(std::string & model, const std::string & bn) { - std::string model_endpoint = "https://modelscope.cn/models/"; - return dl_from_endpoint(model_endpoint, model, bn); - } - - int huggingface_dl(std::string & model, const std::string & bn) { - std::string model_endpoint = get_model_endpoint(); - return dl_from_endpoint(model_endpoint, model, bn); - } - - int ollama_dl(std::string & model, const std::string & bn) { - const std::vector headers = { "Accept: application/vnd.docker.distribution.manifest.v2+json" }; - if (model.find('/') == std::string::npos) { - model = "library/" + model; - } - - auto [model_name, manifest_url] = extract_model_and_tag(model, "https://registry.ollama.ai/v2/"); - nlohmann::json manifest; - int ret = download_and_parse_manifest(manifest_url, {}, manifest); - if (ret) { - return ret; - } - - std::string layer; - for (const auto & l : manifest["layers"]) { - if (l["mediaType"] == "application/vnd.ollama.image.model") { - layer = l["digest"]; - break; - } - } - - std::string blob_url = "https://registry.ollama.ai/v2/" + model_name + "/blobs/" + layer; - - return download(blob_url, bn, true, headers); - } - - int github_dl(const std::string & model, const std::string & bn) { - std::string repository = model; - std::string branch = "main"; - const size_t at_pos = model.find('@'); - if (at_pos != std::string::npos) { - repository = model.substr(0, at_pos); - branch = model.substr(at_pos + 1); - } - - const std::vector repo_parts = string_split(repository, "/"); - if (repo_parts.size() < 3) { - printe("Invalid GitHub repository format\n"); - return 1; - } - - const std::string & org = repo_parts[0]; - const std::string & project = repo_parts[1]; - std::string url = "https://raw.githubusercontent.com/" + org + "/" + project + "/" + branch; - for (size_t i = 2; i < repo_parts.size(); ++i) { - url += "/" + repo_parts[i]; - } - - return download(url, bn, true); - } - - int s3_dl(const std::string & model, const std::string & bn) { - const size_t slash_pos = model.find('/'); - if (slash_pos == std::string::npos) { - return 1; - } - - const std::string bucket = model.substr(0, slash_pos); - const std::string key = model.substr(slash_pos + 1); - const char * access_key = std::getenv("AWS_ACCESS_KEY_ID"); - const char * secret_key = std::getenv("AWS_SECRET_ACCESS_KEY"); - if (!access_key || !secret_key) { - printe("AWS credentials not found in environment\n"); - return 1; - } - - // Generate AWS Signature Version 4 headers - // (Implementation requires HMAC-SHA256 and date handling) - // Get current timestamp - const time_t now = time(nullptr); - const tm tm = *gmtime(&now); - const std::string date = strftime_fmt("%Y%m%d", tm); - const std::string datetime = strftime_fmt("%Y%m%dT%H%M%SZ", tm); - const std::vector headers = { - "Authorization: AWS4-HMAC-SHA256 Credential=" + std::string(access_key) + "/" + date + - "/us-east-1/s3/aws4_request", - "x-amz-content-sha256: UNSIGNED-PAYLOAD", "x-amz-date: " + datetime - }; - - const std::string url = "https://" + bucket + ".s3.amazonaws.com/" + key; - - return download(url, bn, true, headers); - } - - std::string basename(const std::string & path) { - const size_t pos = path.find_last_of("/\\"); - if (pos == std::string::npos) { - return path; - } - - return path.substr(pos + 1); - } - - int rm_until_substring(std::string & model_, const std::string & substring) { - const std::string::size_type pos = model_.find(substring); - if (pos == std::string::npos) { - return 1; - } - - model_ = model_.substr(pos + substring.size()); // Skip past the substring - return 0; - } - - int resolve_model(std::string & model_) { - int ret = 0; - if (string_starts_with(model_, "file://") || std::filesystem::exists(model_)) { - rm_until_substring(model_, "://"); - - return ret; - } - - const std::string bn = basename(model_); - if (string_starts_with(model_, "hf://") || string_starts_with(model_, "huggingface://") || - string_starts_with(model_, "hf.co/")) { - rm_until_substring(model_, "hf.co/"); - rm_until_substring(model_, "://"); - ret = huggingface_dl(model_, bn); - } else if (string_starts_with(model_, "ms://") || string_starts_with(model_, "modelscope://")) { - rm_until_substring(model_, "://"); - ret = modelscope_dl(model_, bn); - } else if ((string_starts_with(model_, "https://") || string_starts_with(model_, "http://")) && - !string_starts_with(model_, "https://ollama.com/library/")) { - ret = download(model_, bn, true); - } else if (string_starts_with(model_, "github:") || string_starts_with(model_, "github://")) { - rm_until_substring(model_, "github:"); - rm_until_substring(model_, "://"); - ret = github_dl(model_, bn); - } else if (string_starts_with(model_, "s3://")) { - rm_until_substring(model_, "://"); - ret = s3_dl(model_, bn); - } else { // ollama:// or nothing - rm_until_substring(model_, "ollama.com/library/"); - rm_until_substring(model_, "://"); - ret = ollama_dl(model_, bn); - } - - model_ = bn; - - return ret; - } - - // Initializes the model and returns a unique pointer to it - llama_model_ptr initialize_model(Opt & opt) { - ggml_backend_load_all(); - resolve_model(opt.model_); - printe("\r" LOG_CLR_TO_EOL "Loading model"); - llama_model_ptr model(llama_model_load_from_file(opt.model_.c_str(), opt.model_params)); - if (!model) { - printe("%s: error: unable to load model from file: %s\n", __func__, opt.model_.c_str()); - } - - printe("\r" LOG_CLR_TO_EOL); - return model; - } - - // Initializes the context with the specified parameters - llama_context_ptr initialize_context(const llama_model_ptr & model, const Opt & opt) { - llama_context_ptr context(llama_init_from_model(model.get(), opt.ctx_params)); - if (!context) { - printe("%s: error: failed to create the llama_context\n", __func__); - } - - return context; - } - - // Initializes and configures the sampler - llama_sampler_ptr initialize_sampler(const Opt & opt) { - llama_sampler_ptr sampler(llama_sampler_chain_init(llama_sampler_chain_default_params())); - llama_sampler_chain_add(sampler.get(), llama_sampler_init_min_p(0.05f, 1)); - llama_sampler_chain_add(sampler.get(), llama_sampler_init_temp(opt.temperature)); - llama_sampler_chain_add(sampler.get(), llama_sampler_init_dist(LLAMA_DEFAULT_SEED)); - - return sampler; - } -}; - -// Add a message to `messages` and store its content in `msg_strs` -static void add_message(const char * role, const std::string & text, LlamaData & llama_data) { - llama_data.msg_strs.push_back(std::move(text)); - llama_data.messages.push_back({ role, llama_data.msg_strs.back().c_str() }); -} - -// Function to apply the chat template and resize `formatted` if needed -static int apply_chat_template(const struct common_chat_templates * tmpls, LlamaData & llama_data, const bool append, bool use_jinja) { - common_chat_templates_inputs inputs; - for (const auto & msg : llama_data.messages) { - common_chat_msg cmsg; - cmsg.role = msg.role; - cmsg.content = msg.content; - inputs.messages.push_back(cmsg); - } - inputs.add_generation_prompt = append; - inputs.use_jinja = use_jinja; - - auto chat_params = common_chat_templates_apply(tmpls, inputs); - // TODO: use other params for tool calls. - auto result = chat_params.prompt; - llama_data.fmtted.resize(result.size() + 1); - memcpy(llama_data.fmtted.data(), result.c_str(), result.size() + 1); - return result.size(); -} - -// Function to tokenize the prompt -static int tokenize_prompt(const llama_vocab * vocab, const std::string & prompt, - std::vector & prompt_tokens, const LlamaData & llama_data) { - const bool is_first = llama_memory_seq_pos_max(llama_get_memory(llama_data.context.get()), 0) == -1; - int n_tokens = prompt.size() + 2 * is_first; - prompt_tokens.resize(n_tokens); - n_tokens = llama_tokenize(vocab, prompt.c_str(), prompt.size(), - prompt_tokens.data(), prompt_tokens.size(), - is_first, /*parse_special =*/true); - if (n_tokens == std::numeric_limits::min()) { - printe("tokenization failed: input too large\n"); - return -1; - } - if (n_tokens < 0) { - prompt_tokens.resize(-n_tokens); - int check = llama_tokenize(vocab, prompt.c_str(), prompt.size(), - prompt_tokens.data(), prompt_tokens.size(), - is_first, /*parse_special =*/true); - if (check != -n_tokens) { - printe("failed to tokenize the prompt (size mismatch)\n"); - return -1; - } - n_tokens = check; - } else { - prompt_tokens.resize(n_tokens); - } - return n_tokens; -} - -// Check if we have enough space in the context to evaluate this batch -static int check_context_size(const llama_context_ptr & ctx, const llama_batch & batch) { - const int n_ctx = llama_n_ctx(ctx.get()); - const int n_ctx_used = llama_memory_seq_pos_max(llama_get_memory(ctx.get()), 0); - if (n_ctx_used + batch.n_tokens > n_ctx) { - printf(LOG_COL_DEFAULT "\n"); - printe("context size exceeded\n"); - return 1; - } - - return 0; -} - -// convert the token to a string -static int convert_token_to_string(const llama_vocab * vocab, const llama_token token_id, std::string & piece) { - char buf[256]; - int n = llama_token_to_piece(vocab, token_id, buf, sizeof(buf), 0, true); - if (n < 0) { - printe("failed to convert token to piece\n"); - return 1; - } - - piece = std::string(buf, n); - return 0; -} - -static void print_word_and_concatenate_to_response(const std::string & piece, std::string & response) { - printf("%s", piece.c_str()); - fflush(stdout); - response += piece; -} - -// helper function to evaluate a prompt and generate a response -static int generate(LlamaData & llama_data, const std::string & prompt, std::string & response) { - const llama_vocab * vocab = llama_model_get_vocab(llama_data.model.get()); - - std::vector tokens; - if (tokenize_prompt(vocab, prompt, tokens, llama_data) < 0) { - return 1; - } - - // prepare a batch for the prompt - llama_batch batch = llama_batch_get_one(tokens.data(), tokens.size()); - llama_token new_token_id; - while (true) { - check_context_size(llama_data.context, batch); - if (llama_decode(llama_data.context.get(), batch)) { - printe("failed to decode\n"); - return 1; - } - - // sample the next token, check is it an end of generation? - new_token_id = llama_sampler_sample(llama_data.sampler.get(), llama_data.context.get(), -1); - if (llama_vocab_is_eog(vocab, new_token_id)) { - break; - } - - std::string piece; - if (convert_token_to_string(vocab, new_token_id, piece)) { - return 1; - } - - print_word_and_concatenate_to_response(piece, response); - - // prepare the next batch with the sampled token - batch = llama_batch_get_one(&new_token_id, 1); - } - - printf(LOG_COL_DEFAULT); - return 0; -} - -static int read_user_input(std::string & user_input) { - static const char * prompt_prefix_env = std::getenv("LLAMA_PROMPT_PREFIX"); - static const char * prompt_prefix = prompt_prefix_env ? prompt_prefix_env : "> "; -#ifdef WIN32 - printf("\r" LOG_CLR_TO_EOL LOG_COL_DEFAULT "%s", prompt_prefix); - - std::getline(std::cin, user_input); - if (std::cin.eof()) { - printf("\n"); - return 1; - } -#else - std::unique_ptr line(const_cast(linenoise(prompt_prefix)), free); - if (!line) { - return 1; - } - - user_input = line.get(); -#endif - - if (user_input == "/bye") { - return 1; - } - - if (user_input.empty()) { - return 2; - } - -#ifndef WIN32 - linenoiseHistoryAdd(line.get()); -#endif - - return 0; // Should have data in happy path -} - -// Function to generate a response based on the prompt -static int generate_response(LlamaData & llama_data, const std::string & prompt, std::string & response, - const bool stdout_a_terminal) { - // Set response color - if (stdout_a_terminal) { - printf(LOG_COL_YELLOW); - } - - if (generate(llama_data, prompt, response)) { - printe("failed to generate response\n"); - return 1; - } - - // End response with color reset and newline - printf("\n%s", stdout_a_terminal ? LOG_COL_DEFAULT : ""); - return 0; -} - -// Helper function to apply the chat template and handle errors -static int apply_chat_template_with_error_handling(const common_chat_templates * tmpls, LlamaData & llama_data, const bool append, int & output_length, bool use_jinja) { - const int new_len = apply_chat_template(tmpls, llama_data, append, use_jinja); - if (new_len < 0) { - printe("failed to apply the chat template\n"); - return -1; - } - - output_length = new_len; - return 0; -} - -// Helper function to handle user input -static int handle_user_input(std::string & user_input, const std::string & user) { - if (!user.empty()) { - user_input = user; - return 0; // No need for interactive input - } - - return read_user_input(user_input); // Returns true if input ends the loop -} - -static bool is_stdin_a_terminal() { -#if defined(_WIN32) - HANDLE hStdin = GetStdHandle(STD_INPUT_HANDLE); - DWORD mode; - return GetConsoleMode(hStdin, &mode); -#else - return isatty(STDIN_FILENO); -#endif -} - -static bool is_stdout_a_terminal() { -#if defined(_WIN32) - HANDLE hStdout = GetStdHandle(STD_OUTPUT_HANDLE); - DWORD mode; - return GetConsoleMode(hStdout, &mode); -#else - return isatty(STDOUT_FILENO); -#endif -} - -// Function to handle user input -static int get_user_input(std::string & user_input, const std::string & user) { - while (true) { - const int ret = handle_user_input(user_input, user); - if (ret == 1) { - return 1; - } - - if (ret == 2) { - continue; - } - - break; - } - - return 0; -} - -// Reads a chat template file to be used -static std::string read_chat_template_file(const std::string & chat_template_file) { - File file; - if (!file.open(chat_template_file, "r")) { - printe("Error opening chat template file '%s': %s", chat_template_file.c_str(), strerror(errno)); - return ""; - } - - return file.to_string(); -} - -static int process_user_message(const Opt & opt, const std::string & user_input, LlamaData & llama_data, - const common_chat_templates_ptr & chat_templates, int & prev_len, - const bool stdout_a_terminal) { - add_message("user", opt.user.empty() ? user_input : opt.user, llama_data); - int new_len; - if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, true, new_len, opt.use_jinja) < 0) { - return 1; - } - - std::string prompt(llama_data.fmtted.begin() + prev_len, llama_data.fmtted.begin() + new_len); - std::string response; - if (generate_response(llama_data, prompt, response, stdout_a_terminal)) { - return 1; - } - - if (!opt.user.empty()) { - return 2; - } - - add_message("assistant", response, llama_data); - if (apply_chat_template_with_error_handling(chat_templates.get(), llama_data, false, prev_len, opt.use_jinja) < 0) { - return 1; - } - - return 0; -} - -// Main chat loop function -static int chat_loop(LlamaData & llama_data, const Opt & opt) { - int prev_len = 0; - llama_data.fmtted.resize(llama_n_ctx(llama_data.context.get())); - std::string chat_template; - if (!opt.chat_template_file.empty()) { - chat_template = read_chat_template_file(opt.chat_template_file); - } - - common_chat_templates_ptr chat_templates = common_chat_templates_init(llama_data.model.get(), chat_template); - static const bool stdout_a_terminal = is_stdout_a_terminal(); - while (true) { - // Get user input - std::string user_input; - if (get_user_input(user_input, opt.user) == 1) { - return 0; - } - - const int ret = process_user_message(opt, user_input, llama_data, chat_templates, prev_len, stdout_a_terminal); - if (ret == 1) { - return 1; - } else if (ret == 2) { - break; - } - } - - return 0; -} - -static void log_callback(const enum ggml_log_level level, const char * text, void * p) { - const Opt * opt = static_cast(p); - if (opt->verbose || level == GGML_LOG_LEVEL_ERROR) { - printe("%s", text); - } -} - -static std::string read_pipe_data() { - std::ostringstream result; - result << std::cin.rdbuf(); // Read all data from std::cin - return result.str(); -} - -static void ctrl_c_handling() { -#if defined(__unix__) || (defined(__APPLE__) && defined(__MACH__)) - struct sigaction sigint_action; - sigint_action.sa_handler = sigint_handler; - sigemptyset(&sigint_action.sa_mask); - sigint_action.sa_flags = 0; - sigaction(SIGINT, &sigint_action, NULL); -#elif defined(_WIN32) - auto console_ctrl_handler = +[](DWORD ctrl_type) -> BOOL { - return (ctrl_type == CTRL_C_EVENT) ? (sigint_handler(SIGINT), true) : false; - }; - SetConsoleCtrlHandler(reinterpret_cast(console_ctrl_handler), true); -#endif -} - -int main(int argc, const char ** argv) { - ctrl_c_handling(); - Opt opt; - const int ret = opt.init(argc, argv); - if (ret == 2) { - return 0; - } else if (ret) { - return 1; - } - - if (!is_stdin_a_terminal()) { - if (!opt.user.empty()) { - opt.user += "\n\n"; - } - - opt.user += read_pipe_data(); - } - - llama_log_set(log_callback, &opt); - LlamaData llama_data; - if (llama_data.init(opt)) { - return 1; - } - - if (chat_loop(llama_data, opt)) { - return 1; - } - - return 0; -} diff --git a/tools/server/public/index.html.gz b/tools/server/public/index.html.gz index b3983b2b17..e572817dca 100644 Binary files a/tools/server/public/index.html.gz and b/tools/server/public/index.html.gz differ diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index b02afaefda..e4a0be44cc 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -1385,16 +1385,21 @@ json format_response_rerank( std::vector get_token_probabilities(llama_context * ctx, int idx) { std::vector cur; + const auto * logits = llama_get_logits_ith(ctx, idx); + const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx); - const llama_model * model = llama_get_model(ctx); - const llama_vocab * vocab = llama_model_get_vocab(model); + const int n_logits = llama_get_sampled_logits_count_ith(ctx, idx); - const int n_vocab = llama_vocab_n_tokens(vocab); - - cur.resize(n_vocab); - for (llama_token token_id = 0; token_id < n_vocab; token_id++) { - cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + cur.resize(n_logits); + if (sampled_ids) { + for (int i = 0; i < n_logits; i++) { + cur[i] = llama_token_data{sampled_ids[i], logits[i], 0.0f}; + } + } else { + for (llama_token token_id = 0; token_id < n_logits; token_id++) { + cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f}; + } } // sort tokens by logits diff --git a/tools/server/server-context.cpp b/tools/server/server-context.cpp index 9726e02522..33635a1586 100644 --- a/tools/server/server-context.cpp +++ b/tools/server/server-context.cpp @@ -1148,6 +1148,25 @@ private: return false; } + const bool need_logits = task.params.sampling.n_probs > 0; + + bool backend_sampling = true; + + backend_sampling &= task.params.sampling.backend_sampling; + + // TODO: speculative decoding requires multiple samples per batch - not supported yet + backend_sampling &= !(slot.ctx_dft && task.params.speculative.n_max > 0); + + // TODO: getting post/pre sampling logits is not yet supported with backend sampling + backend_sampling &= !need_logits; + + // TODO: tmp until backend sampling is fully implemented + if (backend_sampling) { + llama_set_sampler(ctx, slot.id, common_sampler_get(slot.smpl.get())); + } else { + llama_set_sampler(ctx, slot.id, nullptr); + } + SLT_INF(slot, "sampler chain: %s\n", common_sampler_print(slot.smpl.get()).c_str()); } @@ -1486,9 +1505,9 @@ private: res->n_tokens = slot.task->n_tokens(); res->res_type = slot.task->params.res_type; - const int n_embd = llama_model_n_embd(model); + const int n_embd_out = llama_model_n_embd_out(model); - std::vector embd_res(n_embd, 0.0f); + std::vector embd_res(n_embd_out, 0.0f); for (int i = 0; i < batch.n_tokens; ++i) { if (!batch.logits[i] || batch.seq_id[i][0] != slot.id) { @@ -1505,18 +1524,18 @@ private: if (embd == nullptr) { SLT_ERR(slot, "failed to get embeddings, token = %d, seq_id = %d\n", batch.token[i], batch.seq_id[i][0]); - res->embedding.push_back(std::vector(n_embd, 0.0f)); + res->embedding.push_back(std::vector(n_embd_out, 0.0f)); continue; } // normalize only when there is pooling if (llama_pooling_type(slot.ctx) != LLAMA_POOLING_TYPE_NONE) { - common_embd_normalize(embd, embd_res.data(), n_embd, slot.task->params.embd_normalize); + common_embd_normalize(embd, embd_res.data(), n_embd_out, slot.task->params.embd_normalize); res->embedding.push_back(embd_res); break; } - res->embedding.emplace_back(embd, embd + n_embd); + res->embedding.emplace_back(embd, embd + n_embd_out); } SLT_DBG(slot, "%s", "sending embeddings\n"); diff --git a/tools/server/server-models.cpp b/tools/server/server-models.cpp index 56e1dc46b8..803cb02e6e 100644 --- a/tools/server/server-models.cpp +++ b/tools/server/server-models.cpp @@ -21,11 +21,13 @@ #ifdef _WIN32 #include +#include #else #include #include #include #include +extern char **environ; #endif #if defined(__APPLE__) && defined(__MACH__) @@ -99,6 +101,49 @@ static void unset_reserved_args(common_preset & preset, bool unset_model_args) { } } +#ifdef _WIN32 +static std::string wide_to_utf8(const wchar_t * ws) { + if (!ws || !*ws) { + return {}; + } + + const int len = static_cast(std::wcslen(ws)); + const int bytes = WideCharToMultiByte(CP_UTF8, 0, ws, len, nullptr, 0, nullptr, nullptr); + if (bytes == 0) { + return {}; + } + + std::string utf8(bytes, '\0'); + WideCharToMultiByte(CP_UTF8, 0, ws, len, utf8.data(), bytes, nullptr, nullptr); + + return utf8; +} +#endif + +static std::vector get_environment() { + std::vector env; + +#ifdef _WIN32 + LPWCH env_block = GetEnvironmentStringsW(); + if (!env_block) { + return env; + } + for (LPWCH e = env_block; *e; e += wcslen(e) + 1) { + env.emplace_back(wide_to_utf8(e)); + } + FreeEnvironmentStringsW(env_block); +#else + if (environ == nullptr) { + return env; + } + for (char ** e = environ; *e != nullptr; e++) { + env.emplace_back(*e); + } +#endif + + return env; +} + void server_model_meta::update_args(common_preset_context & ctx_preset, std::string bin_path) { // update params unset_reserved_args(preset, false); @@ -117,14 +162,11 @@ void server_model_meta::update_args(common_preset_context & ctx_preset, std::str server_models::server_models( const common_params & params, int argc, - char ** argv, - char ** envp) + char ** argv) : ctx_preset(LLAMA_EXAMPLE_SERVER), base_params(params), + base_env(get_environment()), base_preset(ctx_preset.load_from_args(argc, argv)) { - for (char ** env = envp; *env != nullptr; env++) { - base_env.push_back(std::string(*env)); - } // clean up base preset unset_reserved_args(base_preset, true); // set binary path diff --git a/tools/server/server-models.h b/tools/server/server-models.h index 24ddc65662..a397abda4a 100644 --- a/tools/server/server-models.h +++ b/tools/server/server-models.h @@ -105,7 +105,7 @@ private: void add_model(server_model_meta && meta); public: - server_models(const common_params & params, int argc, char ** argv, char ** envp); + server_models(const common_params & params, int argc, char ** argv); void load_models(); @@ -147,8 +147,8 @@ struct server_models_routes { common_params params; json webui_settings = json::object(); server_models models; - server_models_routes(const common_params & params, int argc, char ** argv, char ** envp) - : params(params), models(params, argc, argv, envp) { + server_models_routes(const common_params & params, int argc, char ** argv) + : params(params), models(params, argc, argv) { if (!this->params.webui_config_json.empty()) { try { webui_settings = json::parse(this->params.webui_config_json); diff --git a/tools/server/server-task.cpp b/tools/server/server-task.cpp index 22f5b2059c..ed4f6546ea 100644 --- a/tools/server/server-task.cpp +++ b/tools/server/server-task.cpp @@ -78,6 +78,7 @@ json task_params::to_json(bool only_metrics) const { {"speculative.p_min", speculative.p_min}, {"timings_per_token", timings_per_token}, {"post_sampling_probs", post_sampling_probs}, + {"backend_sampling", sampling.backend_sampling}, {"lora", lora}, }; } @@ -136,6 +137,7 @@ json task_params::to_json(bool only_metrics) const { {"speculative.p_min", speculative.p_min}, {"timings_per_token", timings_per_token}, {"post_sampling_probs", post_sampling_probs}, + {"backend_sampling", sampling.backend_sampling}, {"lora", lora}, }; } @@ -204,6 +206,7 @@ task_params server_task::params_from_json_cmpl( params.sampling.seed = json_value(data, "seed", defaults.sampling.seed); params.sampling.n_probs = json_value(data, "n_probs", defaults.sampling.n_probs); params.sampling.min_keep = json_value(data, "min_keep", defaults.sampling.min_keep); + params.sampling.backend_sampling = json_value(data, "backend_sampling", defaults.sampling.backend_sampling); params.post_sampling_probs = json_value(data, "post_sampling_probs", defaults.post_sampling_probs); params.speculative.n_min = json_value(data, "speculative.n_min", defaults.speculative.n_min); @@ -811,6 +814,15 @@ json server_task_result_cmpl_final::to_json_anthropic() { msg.content = content; } + // thinking block comes first (Anthropic extended thinking format) + if (!msg.reasoning_content.empty()) { + content_blocks.push_back({ + {"type", "thinking"}, + {"thinking", msg.reasoning_content}, + {"signature", ""} // empty signature for local models (no cryptographic verification) + }); + } + if (!msg.content.empty()) { content_blocks.push_back({ {"type", "text"}, @@ -859,20 +871,57 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() { stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use"; } - bool has_text = !oaicompat_msg.content.empty(); + bool has_thinking = !oaicompat_msg.reasoning_content.empty(); + bool has_text = !oaicompat_msg.content.empty(); size_t num_tool_calls = oaicompat_msg.tool_calls.size(); - bool text_block_started = false; + // content block indices: thinking (0) -> text (0 or 1) -> tool_use (n+) + size_t thinking_block_index = 0; + size_t text_block_index = has_thinking ? 1 : 0; + + bool thinking_block_started = false; + bool text_block_started = false; std::unordered_set tool_calls_started; for (const auto & diff : oaicompat_msg_diffs) { + // handle thinking/reasoning content + if (!diff.reasoning_content_delta.empty()) { + if (!thinking_block_started) { + events.push_back({ + {"event", "content_block_start"}, + {"data", { + {"type", "content_block_start"}, + {"index", thinking_block_index}, + {"content_block", { + {"type", "thinking"}, + {"thinking", ""} + }} + }} + }); + thinking_block_started = true; + } + + events.push_back({ + {"event", "content_block_delta"}, + {"data", { + {"type", "content_block_delta"}, + {"index", thinking_block_index}, + {"delta", { + {"type", "thinking_delta"}, + {"thinking", diff.reasoning_content_delta} + }} + }} + }); + } + + // handle regular text content if (!diff.content_delta.empty()) { if (!text_block_started) { events.push_back({ {"event", "content_block_start"}, {"data", { {"type", "content_block_start"}, - {"index", 0}, + {"index", text_block_index}, {"content_block", { {"type", "text"}, {"text", ""} @@ -886,7 +935,7 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() { {"event", "content_block_delta"}, {"data", { {"type", "content_block_delta"}, - {"index", 0}, + {"index", text_block_index}, {"delta", { {"type", "text_delta"}, {"text", diff.content_delta} @@ -895,8 +944,9 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() { }); } + // handle tool calls if (diff.tool_call_index != std::string::npos) { - size_t content_block_index = (has_text ? 1 : 0) + diff.tool_call_index; + size_t content_block_index = (has_thinking ? 1 : 0) + (has_text ? 1 : 0) + diff.tool_call_index; if (tool_calls_started.find(diff.tool_call_index) == tool_calls_started.end()) { const auto & full_tool_call = oaicompat_msg.tool_calls[diff.tool_call_index]; @@ -932,18 +982,42 @@ json server_task_result_cmpl_final::to_json_anthropic_stream() { } } + // close content blocks in order + if (has_thinking) { + // Anthropic API requires a signature_delta before closing thinking blocks + // We use an empty signature since we can't generate a cryptographic signature for local models + events.push_back({ + {"event", "content_block_delta"}, + {"data", { + {"type", "content_block_delta"}, + {"index", thinking_block_index}, + {"delta", { + {"type", "signature_delta"}, + {"signature", ""} + }} + }} + }); + events.push_back({ + {"event", "content_block_stop"}, + {"data", { + {"type", "content_block_stop"}, + {"index", thinking_block_index} + }} + }); + } + if (has_text) { events.push_back({ {"event", "content_block_stop"}, {"data", { {"type", "content_block_stop"}, - {"index", 0} + {"index", text_block_index} }} }); } for (size_t i = 0; i < num_tool_calls; i++) { - size_t content_block_index = (has_text ? 1 : 0) + i; + size_t content_block_index = (has_thinking ? 1 : 0) + (has_text ? 1 : 0) + i; events.push_back({ {"event", "content_block_stop"}, {"data", { @@ -1151,11 +1225,10 @@ json server_task_result_rerank::to_json() { json server_task_result_cmpl_partial::to_json_anthropic() { json events = json::array(); bool first = (n_decoded == 1); - bool text_block_started = false; + // use member variables to track block state across streaming calls + // (anthropic_thinking_block_started, anthropic_text_block_started) if (first) { - text_block_started = false; - events.push_back({ {"event", "message_start"}, {"data", { @@ -1177,28 +1250,69 @@ json server_task_result_cmpl_partial::to_json_anthropic() { }); } + // content block indices: thinking (0) -> text (0 or 1) -> tool_use (n+) + size_t thinking_block_index = 0; + // use anthropic_has_reasoning (set in update()) to know if ANY reasoning was generated + size_t text_block_index = anthropic_has_reasoning ? 1 : 0; + + // use local copies of streaming state (copied from task_result_state in update()) + // these reflect the state BEFORE this chunk was processed + bool thinking_started = anthropic_thinking_block_started; + bool text_started = anthropic_text_block_started; + for (const auto & diff : oaicompat_msg_diffs) { - if (!diff.content_delta.empty()) { - if (!text_block_started) { + // handle thinking/reasoning content + if (!diff.reasoning_content_delta.empty()) { + if (!thinking_started) { events.push_back({ {"event", "content_block_start"}, {"data", { {"type", "content_block_start"}, - {"index", 0}, + {"index", thinking_block_index}, {"content_block", { - {"type", "text"}, - {"text", ""} + {"type", "thinking"}, + {"thinking", ""} }} }} }); - text_block_started = true; + thinking_started = true; } events.push_back({ {"event", "content_block_delta"}, {"data", { {"type", "content_block_delta"}, - {"index", 0}, + {"index", thinking_block_index}, + {"delta", { + {"type", "thinking_delta"}, + {"thinking", diff.reasoning_content_delta} + }} + }} + }); + } + + // handle regular text content + if (!diff.content_delta.empty()) { + if (!text_started) { + events.push_back({ + {"event", "content_block_start"}, + {"data", { + {"type", "content_block_start"}, + {"index", text_block_index}, + {"content_block", { + {"type", "text"}, + {"text", ""} + }} + }} + }); + text_started = true; + } + + events.push_back({ + {"event", "content_block_delta"}, + {"data", { + {"type", "content_block_delta"}, + {"index", text_block_index}, {"delta", { {"type", "text_delta"}, {"text", diff.content_delta} @@ -1207,8 +1321,10 @@ json server_task_result_cmpl_partial::to_json_anthropic() { }); } + // handle tool calls if (diff.tool_call_index != std::string::npos) { - size_t content_block_index = (text_block_started ? 1 : 0) + diff.tool_call_index; + // use anthropic_has_reasoning for thinking block count (persists across calls) + size_t content_block_index = (anthropic_has_reasoning ? 1 : 0) + (text_started ? 1 : 0) + diff.tool_call_index; if (!diff.tool_call_delta.name.empty()) { events.push_back({ diff --git a/tools/server/server-task.h b/tools/server/server-task.h index 687770de5e..ead1491182 100644 --- a/tools/server/server-task.h +++ b/tools/server/server-task.h @@ -96,6 +96,10 @@ struct task_result_state { std::string generated_text; // append new chunks of generated text here std::vector generated_tool_call_ids; + // for Anthropic API streaming: track content block state across chunks + bool anthropic_thinking_block_started = false; + bool anthropic_text_block_started = false; + task_result_state(const common_chat_syntax & oaicompat_chat_syntax) : oaicompat_chat_syntax(oaicompat_chat_syntax) {} @@ -337,6 +341,12 @@ struct server_task_result_cmpl_partial : server_task_result { std::vector oaicompat_msg_diffs; // to be populated by update() bool is_updated = false; + // for Anthropic API: track if any reasoning content has been generated + bool anthropic_has_reasoning = false; + // Streaming state copied from task_result_state for this chunk + bool anthropic_thinking_block_started = false; + bool anthropic_text_block_started = false; + virtual bool is_stop() override { return false; // in stream mode, partial responses are not considered stop } @@ -346,6 +356,22 @@ struct server_task_result_cmpl_partial : server_task_result { virtual void update(task_result_state & state) override { is_updated = true; state.update_chat_msg(content, true, oaicompat_msg_diffs); + // track if the accumulated message has any reasoning content + anthropic_has_reasoning = !state.chat_msg.reasoning_content.empty(); + + // Copy current state for use in to_json_anthropic() (reflects state BEFORE this chunk) + anthropic_thinking_block_started = state.anthropic_thinking_block_started; + anthropic_text_block_started = state.anthropic_text_block_started; + + // Pre-compute state updates based on diffs (for next chunk) + for (const auto & diff : oaicompat_msg_diffs) { + if (!diff.reasoning_content_delta.empty() && !state.anthropic_thinking_block_started) { + state.anthropic_thinking_block_started = true; + } + if (!diff.content_delta.empty() && !state.anthropic_text_block_started) { + state.anthropic_text_block_started = true; + } + } } json to_json_non_oaicompat(); diff --git a/tools/server/server.cpp b/tools/server/server.cpp index 0fbc7b6d35..1d9abf6055 100644 --- a/tools/server/server.cpp +++ b/tools/server/server.cpp @@ -66,7 +66,7 @@ static server_http_context::handler_t ex_wrapper(server_http_context::handler_t }; } -int main(int argc, char ** argv, char ** envp) { +int main(int argc, char ** argv) { // own arguments required by this example common_params params; @@ -126,7 +126,7 @@ int main(int argc, char ** argv, char ** envp) { if (is_router_server) { // setup server instances manager try { - models_routes.emplace(params, argc, argv, envp); + models_routes.emplace(params, argc, argv); } catch (const std::exception & e) { LOG_ERR("%s: failed to initialize router models: %s\n", __func__, e.what()); return 1; diff --git a/tools/server/tests/unit/test_compat_anthropic.py b/tools/server/tests/unit/test_compat_anthropic.py index e0a003557e..e16e0235c6 100644 --- a/tools/server/tests/unit/test_compat_anthropic.py +++ b/tools/server/tests/unit/test_compat_anthropic.py @@ -805,3 +805,92 @@ def test_anthropic_vs_openai_different_response_format(): assert "input_tokens" in anthropic_res.body["usage"] assert "completion_tokens" in openai_res.body["usage"] assert "output_tokens" in anthropic_res.body["usage"] + + +# Extended thinking tests with reasoning models + +@pytest.mark.slow +@pytest.mark.parametrize("stream", [False, True]) +def test_anthropic_thinking_with_reasoning_model(stream): + """Test that thinking content blocks are properly returned for reasoning models""" + global server + server = ServerProcess() + server.model_hf_repo = "bartowski/DeepSeek-R1-Distill-Qwen-7B-GGUF" + server.model_hf_file = "DeepSeek-R1-Distill-Qwen-7B-Q4_K_M.gguf" + server.reasoning_format = "deepseek" + server.jinja = True + server.n_ctx = 8192 + server.n_predict = 1024 + server.server_port = 8084 + server.start(timeout_seconds=600) # large model needs time to download + + if stream: + res = server.make_stream_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 1024, + "thinking": { + "type": "enabled", + "budget_tokens": 500 + }, + "messages": [ + {"role": "user", "content": "What is 2+2?"} + ], + "stream": True + }) + + events = list(res) + + # should have thinking content block events + thinking_starts = [e for e in events if + e.get("type") == "content_block_start" and + e.get("content_block", {}).get("type") == "thinking"] + assert len(thinking_starts) > 0, "Should have thinking content_block_start event" + assert thinking_starts[0]["index"] == 0, "Thinking block should be at index 0" + + # should have thinking_delta events + thinking_deltas = [e for e in events if + e.get("type") == "content_block_delta" and + e.get("delta", {}).get("type") == "thinking_delta"] + assert len(thinking_deltas) > 0, "Should have thinking_delta events" + + # should have signature_delta event before thinking block closes (Anthropic API requirement) + signature_deltas = [e for e in events if + e.get("type") == "content_block_delta" and + e.get("delta", {}).get("type") == "signature_delta"] + assert len(signature_deltas) > 0, "Should have signature_delta event for thinking block" + + # should have text block after thinking + text_starts = [e for e in events if + e.get("type") == "content_block_start" and + e.get("content_block", {}).get("type") == "text"] + assert len(text_starts) > 0, "Should have text content_block_start event" + assert text_starts[0]["index"] == 1, "Text block should be at index 1 (after thinking)" + else: + res = server.make_request("POST", "/v1/messages", data={ + "model": "test", + "max_tokens": 1024, + "thinking": { + "type": "enabled", + "budget_tokens": 500 + }, + "messages": [ + {"role": "user", "content": "What is 2+2?"} + ] + }) + + assert res.status_code == 200 + assert res.body["type"] == "message" + + content = res.body["content"] + assert len(content) >= 2, "Should have at least thinking and text blocks" + + # first block should be thinking + thinking_blocks = [b for b in content if b.get("type") == "thinking"] + assert len(thinking_blocks) > 0, "Should have thinking content block" + assert "thinking" in thinking_blocks[0], "Thinking block should have 'thinking' field" + assert len(thinking_blocks[0]["thinking"]) > 0, "Thinking content should not be empty" + assert "signature" in thinking_blocks[0], "Thinking block should have 'signature' field (Anthropic API requirement)" + + # should also have text block + text_blocks = [b for b in content if b.get("type") == "text"] + assert len(text_blocks) > 0, "Should have text content block" diff --git a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte index 4ec9b478fd..5a668aa300 100644 --- a/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte +++ b/tools/server/webui/src/lib/components/app/chat/ChatSettings/ChatSettings.svelte @@ -185,6 +185,11 @@ key: 'samplers', label: 'Samplers', type: 'input' + }, + { + key: 'backend_sampling', + label: 'Backend sampling', + type: 'checkbox' } ] }, diff --git a/tools/server/webui/src/lib/constants/settings-config.ts b/tools/server/webui/src/lib/constants/settings-config.ts index f9584d01d7..cac48a557c 100644 --- a/tools/server/webui/src/lib/constants/settings-config.ts +++ b/tools/server/webui/src/lib/constants/settings-config.ts @@ -21,6 +21,7 @@ export const SETTING_CONFIG_DEFAULT: Record = autoMicOnEmpty: false, // make sure these default values are in sync with `common.h` samplers: 'top_k;typ_p;top_p;min_p;temperature', + backend_sampling: false, temperature: 0.8, dynatemp_range: 0.0, dynatemp_exponent: 1.0, @@ -57,6 +58,8 @@ export const SETTING_CONFIG_INFO: Record = { 'When copying a message with text attachments, combine them into a single plain text string instead of a special format that can be pasted back as attachments.', samplers: 'The order at which samplers are applied, in simplified way. Default is "top_k;typ_p;top_p;min_p;temperature": top_k->typ_p->top_p->min_p->temperature', + backend_sampling: + 'Enable backend-based samplers. When enabled, supported samplers run on the accelerator backend for faster sampling.', temperature: 'Controls the randomness of the generated text by affecting the probability distribution of the output tokens. Higher = more random, lower = more focused.', dynatemp_range: diff --git a/tools/server/webui/src/lib/services/chat.ts b/tools/server/webui/src/lib/services/chat.ts index 86648f3cba..02fc6381c0 100644 --- a/tools/server/webui/src/lib/services/chat.ts +++ b/tools/server/webui/src/lib/services/chat.ts @@ -86,6 +86,7 @@ export class ChatService { dry_penalty_last_n, // Other parameters samplers, + backend_sampling, custom, timings_per_token, // Config options @@ -159,6 +160,8 @@ export class ChatService { : samplers; } + if (backend_sampling !== undefined) requestBody.backend_sampling = backend_sampling; + if (timings_per_token !== undefined) requestBody.timings_per_token = timings_per_token; if (custom) { diff --git a/tools/server/webui/src/lib/stores/chat.svelte.ts b/tools/server/webui/src/lib/stores/chat.svelte.ts index 67157e36ac..879b2f3245 100644 --- a/tools/server/webui/src/lib/stores/chat.svelte.ts +++ b/tools/server/webui/src/lib/stores/chat.svelte.ts @@ -1461,6 +1461,8 @@ class ChatStore { if (hasValue(currentConfig.dry_penalty_last_n)) apiOptions.dry_penalty_last_n = Number(currentConfig.dry_penalty_last_n); if (currentConfig.samplers) apiOptions.samplers = currentConfig.samplers; + if (currentConfig.backend_sampling) + apiOptions.backend_sampling = currentConfig.backend_sampling; if (currentConfig.custom) apiOptions.custom = currentConfig.custom; return apiOptions; diff --git a/tools/server/webui/src/lib/types/api.d.ts b/tools/server/webui/src/lib/types/api.d.ts index c2ecc02820..714509f024 100644 --- a/tools/server/webui/src/lib/types/api.d.ts +++ b/tools/server/webui/src/lib/types/api.d.ts @@ -149,6 +149,7 @@ export interface ApiLlamaCppServerProps { reasoning_in_content: boolean; thinking_forced_open: boolean; samplers: string[]; + backend_sampling: boolean; 'speculative.n_max': number; 'speculative.n_min': number; 'speculative.p_min': number; @@ -212,6 +213,7 @@ export interface ApiChatCompletionRequest { dry_penalty_last_n?: number; // Sampler configuration samplers?: string[]; + backend_sampling?: boolean; // Custom parameters (JSON string) custom?: Record; timings_per_token?: boolean; @@ -312,6 +314,7 @@ export interface ApiSlotData { reasoning_in_content: boolean; thinking_forced_open: boolean; samplers: string[]; + backend_sampling: boolean; 'speculative.n_max': number; 'speculative.n_min': number; 'speculative.p_min': number; diff --git a/tools/server/webui/src/lib/types/settings.d.ts b/tools/server/webui/src/lib/types/settings.d.ts index e09f0f332c..38b3047dd0 100644 --- a/tools/server/webui/src/lib/types/settings.d.ts +++ b/tools/server/webui/src/lib/types/settings.d.ts @@ -43,6 +43,7 @@ export interface SettingsChatServiceOptions { dry_penalty_last_n?: number; // Sampler configuration samplers?: string | string[]; + backend_sampling?: boolean; // Custom parameters custom?: string; timings_per_token?: boolean;