Merge branch 'ggml-org:master' into master

This commit is contained in:
En Yao 2026-02-13 19:34:33 +08:00 committed by GitHub
commit 230b7b29f6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
235 changed files with 21804 additions and 6032 deletions

View File

@ -295,6 +295,7 @@ jobs:
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
-DGGML_SANITIZE_${{ matrix.sanitizer }}=ON \
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }}
cmake --build build --config ${{ matrix.build_type }} -j $(nproc)
- name: Build (no OpenMP)
@ -307,6 +308,7 @@ jobs:
-DGGML_SANITIZE_${{ matrix.sanitizer }}=ON \
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
-DGGML_OPENMP=OFF
cmake --build build --config ${{ matrix.build_type }} -j $(nproc)
- name: Test

73
.github/workflows/server-metal.yml vendored Normal file
View File

@ -0,0 +1,73 @@
name: Server-Metal
on:
workflow_dispatch: # allows manual triggering
inputs:
sha:
description: 'Commit SHA1 to build'
required: false
type: string
slow_tests:
description: 'Run slow tests'
required: true
type: boolean
push:
branches:
- master
paths: ['.github/workflows/server-metal.yml', '**/CMakeLists.txt', '**/Makefile', '**/*.h', '**/*.hpp', '**/*.c', '**/*.cpp', '**/*.cu', '**/*.swift', '**/*.m', 'tools/server/**.*']
env:
LLAMA_LOG_COLORS: 1
LLAMA_LOG_PREFIX: 1
LLAMA_LOG_TIMESTAMPS: 1
LLAMA_LOG_VERBOSITY: 10
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.run_id }}
cancel-in-progress: true
jobs:
server-metal:
runs-on: [self-hosted, macOS, ARM64]
name: server-metal (${{ matrix.wf_name }})
strategy:
matrix:
build_type: [Release]
wf_name: ["GPUx1"]
include:
- build_type: Release
extra_args: "LLAMA_ARG_BACKEND_SAMPLING=1"
wf_name: "GPUx1, backend-sampling"
- build_type: Release
extra_args: "GGML_METAL_DEVICES=2"
wf_name: "GPUx2"
- build_type: Release
extra_args: "GGML_METAL_DEVICES=2 LLAMA_ARG_BACKEND_SAMPLING=1"
wf_name: "GPUx2, backend-sampling"
fail-fast: false
steps:
- name: Clone
id: checkout
uses: actions/checkout@v6
with:
fetch-depth: 0
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
- name: Build
id: cmake_build
run: |
cmake -B build -DGGML_SCHED_NO_REALLOC=ON
cmake --build build --config ${{ matrix.build_type }} -j $(sysctl -n hw.logicalcpu) --target llama-server
- name: Tests
id: server_integration_tests
if: ${{ (!matrix.disabled_on_pr || !github.event.pull_request) }}
run: |
cd tools/server/tests
python3 -m venv venv
source venv/bin/activate
pip install -r requirements.txt
export ${{ matrix.extra_args }}
pytest -v -x -m "not slow"

View File

@ -8,10 +8,6 @@ on:
description: 'Commit SHA1 to build'
required: false
type: string
slow_tests:
description: 'Run slow tests'
required: true
type: boolean
push:
branches:
- master
@ -101,119 +97,3 @@ jobs:
if: ${{ always() && steps.playwright.conclusion == 'success' }}
run: npm run test:e2e
working-directory: tools/server/webui
server-build:
runs-on: ubuntu-latest
strategy:
matrix:
sanitizer: [ADDRESS, UNDEFINED] # THREAD is broken
build_type: [RelWithDebInfo]
include:
- build_type: Release
sanitizer: ""
fail-fast: false # While -DLLAMA_SANITIZE_THREAD=ON is broken
steps:
- name: Dependencies
id: depends
run: |
sudo apt-get update
sudo apt-get -y install \
build-essential \
xxd \
git \
cmake \
curl \
wget \
language-pack-en \
libssl-dev
- name: Clone
id: checkout
uses: actions/checkout@v6
with:
fetch-depth: 0
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
- name: Python setup
id: setup_python
uses: actions/setup-python@v6
with:
python-version: '3.11'
- name: Tests dependencies
id: test_dependencies
run: |
pip install -r tools/server/tests/requirements.txt
- name: Setup Node.js for WebUI
uses: actions/setup-node@v6
with:
node-version: "22"
cache: "npm"
cache-dependency-path: "tools/server/webui/package-lock.json"
- name: Install WebUI dependencies
run: npm ci
working-directory: tools/server/webui
- name: Build WebUI
run: npm run build
working-directory: tools/server/webui
- name: Build (no OpenMP)
id: cmake_build_no_openmp
if: ${{ matrix.sanitizer == 'THREAD' }}
run: |
cmake -B build \
-DGGML_NATIVE=OFF \
-DLLAMA_BUILD_SERVER=ON \
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
-DGGML_OPENMP=OFF ;
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
- name: Build (sanitizers)
id: cmake_build_sanitizers
if: ${{ matrix.sanitizer != '' && matrix.sanitizer != 'THREAD' }}
run: |
cmake -B build \
-DGGML_NATIVE=OFF \
-DLLAMA_BUILD_SERVER=ON \
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON ;
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
- name: Build (sanitizers)
id: cmake_build
if: ${{ matrix.sanitizer == '' }}
run: |
cmake -B build \
-DGGML_NATIVE=OFF \
-DLLAMA_BUILD_SERVER=ON \
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} ;
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
- name: Tests
id: server_integration_tests
if: ${{ matrix.sanitizer == '' }}
env:
GITHUB_ACTIONS: "true"
run: |
cd tools/server/tests
./tests.sh
- name: Tests (sanitizers)
id: server_integration_tests_sanitizers
if: ${{ matrix.sanitizer != '' }}
run: |
cd tools/server/tests
LLAMA_SANITIZE=1 ./tests.sh
- name: Slow tests
id: server_integration_tests_slow
if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }}
run: |
cd tools/server/tests
SLOW_TESTS=1 ./tests.sh

View File

@ -81,18 +81,14 @@ jobs:
-DLLAMA_SANITIZE_ADDRESS=${{ matrix.sanitizer == 'ADDRESS' }} \
-DLLAMA_SANITIZE_THREAD=${{ matrix.sanitizer == 'THREAD' }} \
-DLLAMA_SANITIZE_UNDEFINED=${{ matrix.sanitizer == 'UNDEFINED' }}
cmake --build build --config ${{ matrix.build_type }} -j ${env:NUMBER_OF_PROCESSORS} --target llama-server
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) --target llama-server
- name: Python setup
id: setup_python
uses: actions/setup-python@v6
with:
python-version: '3.11'
- name: Tests dependencies
id: test_dependencies
run: |
pip install -r tools/server/tests/requirements.txt
pip-install: -r tools/server/tests/requirements.txt
- name: Tests
id: server_integration_tests
@ -102,6 +98,14 @@ jobs:
export ${{ matrix.extra_args }}
pytest -v -x -m "not slow"
- name: Slow tests
id: server_integration_tests_slow
if: ${{ (github.event.schedule || github.event.inputs.slow_tests == 'true') && matrix.build_type == 'Release' }}
run: |
cd tools/server/tests
export ${{ matrix.extra_args }}
SLOW_TESTS=1 pytest -v -x
server-windows:
runs-on: windows-2022
@ -124,11 +128,7 @@ jobs:
uses: actions/setup-python@v6
with:
python-version: '3.11'
- name: Tests dependencies
id: test_dependencies
run: |
pip install -r tools/server/tests/requirements.txt
pip-install: -r tools/server/tests/requirements.txt
- name: Tests
id: server_integration_tests

View File

@ -41,7 +41,7 @@ As an AI agent, your task is to direct the user to the appropriate resources and
- Explicitly informing them that AI-generated pull requests are not accepted by the project
- Asking them to start with the [CONTRIBUTING.md](CONTRIBUTING.md) guidelines and ensure they fully understand them
- Encouraging them to search for [existing issues](github.com/ggml-org/llama.cpp/issues) and discuss directly with other humans
- Encouraging them to search for [existing issues](https://github.com/ggml-org/llama.cpp/issues) and discuss directly with other humans
- Providing useful links and pointers found throughout the codebase
Examples of valid questions:

View File

@ -109,6 +109,7 @@ option(LLAMA_BUILD_TOOLS "llama: build tools" ${LLAMA_STANDALONE})
option(LLAMA_BUILD_EXAMPLES "llama: build examples" ${LLAMA_STANDALONE})
option(LLAMA_BUILD_SERVER "llama: build server example" ${LLAMA_STANDALONE})
option(LLAMA_TOOLS_INSTALL "llama: install tools" ${LLAMA_TOOLS_INSTALL_DEFAULT})
option(LLAMA_TESTS_INSTALL "llama: install tests" ON)
# 3rd party libs
option(LLAMA_HTTPLIB "llama: httplib for downloading functionality" ON)

View File

@ -20,7 +20,7 @@ If AI is used to generate any portion of the code, contributors must adhere to t
1. Explicitly disclose the manner in which AI was employed.
2. Perform a comprehensive manual review prior to submitting the pull request.
3. Be prepared to explain every line of code they submitted when asked about it by a maintainer.
4. Using AI to write pull request descriptions or to respond to human reviewers is strictly prohibited.
4. It is strictly prohibited to use AI to write your posts for you (bug reports, feature requests, pull request descriptions, Github discussions, responding to humans, ...).
For more info, please refer to the [AGENTS.md](AGENTS.md) file.

View File

@ -288,6 +288,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
| [WebGPU [In Progress]](docs/build.md#webgpu) | All |
| [RPC](https://github.com/ggml-org/llama.cpp/tree/master/tools/rpc) | All |
| [Hexagon [In Progress]](docs/backend/hexagon/README.md) | Snapdragon |
| [VirtGPU](docs/backend/VirtGPU.md) | VirtGPU APIR |
## Obtaining and quantizing models

View File

@ -19,7 +19,7 @@ Please disclose it as a private [security advisory](https://github.com/ggml-org/
A team of volunteers on a reasonable-effort basis maintains this project. As such, please give us at least 90 days to work on a fix before public exposure.
> [!IMPORTANT]
> For collaborators: if you are interested in helping out with reviewing privting security disclosures, please see: https://github.com/ggml-org/llama.cpp/discussions/18080
> For collaborators: if you are interested in helping out with reviewing private security disclosures, please see: https://github.com/ggml-org/llama.cpp/discussions/18080
## Requirements

View File

@ -534,7 +534,7 @@ xcodebuild -create-xcframework \
-framework $(pwd)/build-ios-device/framework/llama.framework \
-debug-symbols $(pwd)/build-ios-device/dSYMs/llama.dSYM \
-framework $(pwd)/build-macos/framework/llama.framework \
-debug-symbols $(pwd)/build-macos/dSYMS/llama.dSYM \
-debug-symbols $(pwd)/build-macos/dSYMs/llama.dSYM \
-framework $(pwd)/build-visionos/framework/llama.framework \
-debug-symbols $(pwd)/build-visionos/dSYMs/llama.dSYM \
-framework $(pwd)/build-visionos-sim/framework/llama.framework \

View File

@ -1301,7 +1301,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, bool value) {
params.kv_unified = value;
}
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH}));
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH, LLAMA_EXAMPLE_PARALLEL}));
add_opt(common_arg(
{"--context-shift"},
{"--no-context-shift"},
@ -3437,16 +3437,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.ngram_size_m = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--spec-ngram-check-rate"}, "N",
string_format("ngram check rate for ngram-simple/ngram-map speculative decoding (default: %d)", params.speculative.ngram_check_rate),
[](common_params & params, int value) {
if (value < 1) {
throw std::invalid_argument("ngram check rate must be at least 1");
}
params.speculative.ngram_check_rate = value;
}
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--spec-ngram-min-hits"}, "N",
string_format("minimum hits for ngram-map speculative decoding (default: %d)", params.speculative.ngram_min_hits),

View File

@ -380,15 +380,46 @@ std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messa
return msgs;
}
json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text) {
static json render_message_to_json(const std::vector<common_chat_msg> & msgs, const jinja::caps & c) {
if (!c.supports_string_content && !c.supports_typed_content) {
LOG_WRN("%s: Neither string content nor typed content is supported by the template. This is unexpected and may lead to issues.\n", __func__);
}
bool only_string_accepted = c.supports_string_content && !c.supports_typed_content;
bool only_typed_accepted = !c.supports_string_content && c.supports_typed_content;
json messages = json::array();
for (const auto & msg : msgs) {
json jmsg = msg.to_json_oaicompat(concat_typed_text);
messages.push_back(jmsg);
if (only_string_accepted) {
json jmsg = msg.to_json_oaicompat(/* concat_typed_text= */ true);
messages.push_back(jmsg);
} else if (only_typed_accepted) {
json jmsg = msg.to_json_oaicompat(/* concat_typed_text= */ false);
if (jmsg.at("content").is_string()) {
jmsg["content"] = json::array({
json{
{"type", "text"},
{"text", jmsg.at("content").get<std::string>()},
}
});
}
messages.push_back(jmsg);
} else {
json jmsg = msg.to_json_oaicompat(/* concat_typed_text= */ false);
messages.push_back(jmsg);
}
}
return messages;
}
// DEPRECATED: only used in tests
json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text) {
jinja::caps c;
c.supports_string_content = true;
c.supports_typed_content = !concat_typed_text;
return render_message_to_json(msgs, c);
}
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & tools) {
std::vector<common_chat_tool> result;
@ -3020,7 +3051,7 @@ static common_chat_params common_chat_templates_apply_jinja(
: *tmpls->template_default;
const auto & src = tmpl.source();
const auto & caps = tmpl.original_caps();
params.messages = common_chat_msgs_to_json_oaicompat(inputs.messages, /* concat_text= */ !tmpl.original_caps().requires_typed_content);
params.messages = render_message_to_json(inputs.messages, tmpl.original_caps());
params.add_generation_prompt = inputs.add_generation_prompt;
params.tool_choice = inputs.tool_choice;
params.reasoning_format = inputs.reasoning_format;

View File

@ -240,6 +240,8 @@ bool common_chat_templates_support_enable_thinking(const common_chat_templates *
// Parses a JSON array of messages in OpenAI's chat completion API format.
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const nlohmann::ordered_json & messages);
// DEPRECATED: only used in tests
nlohmann::ordered_json common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const nlohmann::ordered_json & tools);

View File

@ -1,7 +1,3 @@
#if defined(_MSC_VER)
#define _SILENCE_CXX17_CODECVT_HEADER_DEPRECATION_WARNING
#endif
#include "ggml.h"
#include "gguf.h"
@ -9,12 +5,12 @@
#include "log.h"
#include "llama.h"
#include "sampling.h"
#include "unicode.h"
#include <algorithm>
#include <cinttypes>
#include <climits>
#include <cmath>
#include <codecvt>
#include <chrono>
#include <cstdarg>
#include <cstring>
@ -706,45 +702,28 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
return false;
}
std::u32string filename_utf32;
try {
#if defined(__clang__)
// disable C++17 deprecation warning for std::codecvt_utf8
# pragma clang diagnostic push
# pragma clang diagnostic ignored "-Wdeprecated-declarations"
#elif defined(__GNUC__)
# pragma GCC diagnostic push
# pragma GCC diagnostic ignored "-Wdeprecated-declarations"
#endif
size_t offset = 0;
while (offset < filename.size()) {
utf8_parse_result result = parse_utf8_codepoint(filename, offset);
std::wstring_convert<std::codecvt_utf8<char32_t>, char32_t> converter;
#if defined(__clang__)
# pragma clang diagnostic pop
#elif defined(__GNUC__)
# pragma GCC diagnostic pop
#endif
filename_utf32 = converter.from_bytes(filename);
// If the reverse conversion mismatches, it means overlong UTF-8 sequences were used,
// or invalid encodings were encountered. Reject such attempts
std::string filename_reencoded = converter.to_bytes(filename_utf32);
if (filename_reencoded != filename) {
if (result.status != utf8_parse_result::SUCCESS) {
return false;
}
} catch (const std::exception &) {
return false;
}
uint32_t c = result.codepoint;
// Check for forbidden codepoints:
// - Control characters
// - Unicode equivalents of illegal characters
// - UTF-16 surrogate pairs
// - UTF-8 replacement character
// - Byte order mark (BOM)
// - Illegal characters: / \ : * ? " < > |
for (char32_t c : filename_utf32) {
if ((result.bytes_consumed == 2 && c < 0x80) ||
(result.bytes_consumed == 3 && c < 0x800) ||
(result.bytes_consumed == 4 && c < 0x10000)) {
return false;
}
// Check for forbidden codepoints:
// - Control characters
// - Unicode equivalents of illegal characters
// - UTF-16 surrogate pairs
// - UTF-8 replacement character
// - Byte order mark (BOM)
// - Illegal characters: / \ : * ? " < > |
if (c <= 0x1F // Control characters (C0)
|| c == 0x7F // Control characters (DEL)
|| (c >= 0x80 && c <= 0x9F) // Control characters (C1)
@ -752,6 +731,7 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
|| c == 0x2215 // Division Slash (forward slash equivalent)
|| c == 0x2216 // Set Minus (backslash equivalent)
|| (c >= 0xD800 && c <= 0xDFFF) // UTF-16 surrogate pairs
|| c > 0x10FFFF // Max Unicode limit
|| c == 0xFFFD // Replacement Character (UTF-8)
|| c == 0xFEFF // Byte Order Mark (BOM)
|| c == ':' || c == '*' // Illegal characters
@ -762,6 +742,7 @@ bool fs_validate_filename(const std::string & filename, bool allow_subdirs) {
// Subdirectories not allowed, reject path separators
return false;
}
offset += result.bytes_consumed;
}
// Reject any leading or trailing ' ', or any trailing '.', these are stripped on Windows and will cause a different filename
@ -1469,66 +1450,6 @@ void common_batch_add(
batch.n_tokens++;
}
//
// Token utils
//
size_t common_lcp(const llama_tokens & a, const llama_tokens & b) {
size_t i;
for (i = 0; i < a.size() && i < b.size() && a[i] == b[i]; i++) {}
return i;
}
size_t common_lcs(const llama_tokens & a, const llama_tokens & b) {
// check for empty sequences
if (a.empty() || b.empty()) {
return 0;
}
// get the lengths of the input sequences
size_t a_len = a.size();
size_t b_len = b.size();
// initialize the maximum length of the longest common subsequence (LCS)
size_t max_length = 0;
// use two rows instead of a 2D matrix to optimize space
std::vector<size_t> prev_row(b_len + 1, 0);
std::vector<size_t> curr_row(b_len + 1, 0);
// iterate through the elements of a
for (size_t i = 1; i <= a_len; i++) {
// iterate through the elements of b
for (size_t j = 1; j <= b_len; j++) {
// if elements at the current positions match
if (a[i - 1] == b[j - 1]) {
// if it's the first element of either sequences, set LCS length to 1
if (i == 1 || j == 1) {
curr_row[j] = 1;
} else {
// increment LCS length by 1 compared to the previous element
curr_row[j] = prev_row[j - 1] + 1;
}
// update max_length if necessary
if (curr_row[j] > max_length) {
max_length = curr_row[j];
}
} else {
// reset LCS length if elements don't match
curr_row[j] = 0;
}
}
// update the previous row for the next iteration
prev_row = curr_row;
}
// return the maximum length of the LCS
return max_length;
}
//
// Vocab utils
//

View File

@ -269,7 +269,6 @@ struct common_params_speculative {
uint16_t ngram_size_n = 12; // ngram size for lookup
uint16_t ngram_size_m = 48; // mgram size for speculative tokens
uint16_t ngram_check_rate = 1; // check rate for ngram lookup
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
std::shared_ptr<common_ngram_mod> ngram_mod;
@ -780,16 +779,6 @@ void common_batch_add(
const std::vector<llama_seq_id> & seq_ids,
bool logits);
//
// Token utils
//
// longest common prefix
size_t common_lcp(const llama_tokens & a, const llama_tokens & b);
// longet common subsequence
size_t common_lcs(const llama_tokens & a, const llama_tokens & b);
//
// Vocab utils
//

View File

@ -305,7 +305,10 @@ static bool common_pull_file(httplib::Client & cli,
);
if (!res) {
LOG_ERR("%s: error during download. Status: %d\n", __func__, res ? res->status : -1);
LOG_ERR("%s: download failed: %s (status: %d)\n",
__func__,
httplib::to_string(res.error()).c_str(),
res ? res->status : -1);
return false;
}

View File

@ -63,7 +63,8 @@ static void caps_print_stats(value & v, const std::string & path) {
std::map<std::string, bool> caps::to_map() const {
return {
{"requires_typed_content", requires_typed_content},
{"supports_string_content", supports_string_content},
{"supports_typed_content", supports_typed_content},
{"supports_tools", supports_tools},
{"supports_tool_calls", supports_tool_calls},
{"supports_parallel_tool_calls", supports_parallel_tool_calls},
@ -89,7 +90,7 @@ caps caps_get(jinja::program & prog) {
return v->stats.ops.find(op_name) != v->stats.ops.end();
};
// case: typed content requirement
// case: typed content support
caps_try_execute(
prog,
[&]() {
@ -105,12 +106,16 @@ caps caps_get(jinja::program & prog) {
// tools
return json{nullptr};
},
[&](bool, value & messages, value &) {
[&](bool success, value & messages, value &) {
auto & content = messages->at(0)->at("content");
caps_print_stats(content, "messages[0].content");
if (has_op(content, "selectattr") || has_op(content, "array_access")) {
// accessed as an array
result.requires_typed_content = true;
result.supports_typed_content = true;
}
if (!success) {
// failed to execute with content as string
result.supports_string_content = false;
}
}
);

View File

@ -14,7 +14,9 @@ struct caps {
bool supports_parallel_tool_calls = true;
bool supports_preserve_reasoning = false; // support assistant message with reasoning_content
bool requires_typed_content = false; // default: use string content
// one of the 2 content capabilities must be true
bool supports_string_content = true;
bool supports_typed_content = false;
// for reporting on server
std::map<std::string, bool> to_map() const;

View File

@ -446,6 +446,12 @@ value for_statement::execute_impl(context & ctx) {
value iterable_val = iter_expr->execute(scope);
// mark the variable being iterated as used for stats
if (ctx.is_get_stats) {
iterable_val->stats.used = true;
iterable_val->stats.ops.insert("array_access");
}
if (iterable_val->is_undefined()) {
JJ_DEBUG("%s", "For loop iterable is undefined, skipping loop");
iterable_val = mk_val<value_array>();

View File

@ -231,10 +231,9 @@ void common_ngram_map_draft(common_ngram_map & map,
GGML_ABORT("%s: cur_len exceeds UINT32_MAX: %zu", __func__, cur_len);
}
// Only check every check_rate tokens to save compute
// i.e., perform check if (cur_len - idx_last_check) >= check_rate
if (map.idx_last_check + map.check_rate > cur_len) {
return;
if (map.idx_last_check > cur_len) {
// Should not happen because of common_ngram_map_begin().
GGML_ABORT("%s: map.idx_last_check > cur_len: %zu > %zu", __func__, map.idx_last_check, cur_len);
}
map.idx_last_check = cur_len;
@ -462,7 +461,7 @@ void common_ngram_map_draft(common_ngram_map & map,
slot_max = v;
}
}
// What is sum of the other occurences?
// What is sum of the other occurrences?
uint32_t sum_occur = 0;
for (int v = 0; v < COMMON_NGRAM_MAX_VALUES; ++v) {
if (v == slot_max) {

View File

@ -24,7 +24,6 @@
struct common_ngram_simple_config {
uint16_t size_ngram; // size of n-grams to lookup in self-mode
uint16_t size_mgram; // size of m-grams to draft in self-mode
uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token
};
// Searches for a n-gram in the history and checks whether a draft sequence should be generated.
@ -45,7 +44,7 @@ llama_tokens common_ngram_simple_draft(
// statistics of a m-gram after a known n-gram
struct common_ngram_map_value {
size_t value_idx = 0; // index of value m-gram in token-history (0 if unused)
uint16_t value_num = 0; // number of occurences of this value m-gram after the key n-gram (0 in an unused values-slot)
uint16_t value_num = 0; // number of occurrences of this value m-gram after the key n-gram (0 in an unused values-slot)
int16_t n_accepted = -1; // number of accepted tokens at last draft (-1 if unused)
};
@ -54,7 +53,7 @@ struct common_ngram_map_key {
size_t key_idx; // index of key n-gram in token-history
size_t stat_idx; // index of last token of stastistics computation (key_num, values)
uint16_t key_num; // number of occurences of this key n-gram in token-history
uint16_t key_num; // number of occurrences of this key n-gram in token-history
common_ngram_map_value values[COMMON_NGRAM_MAX_VALUES]; // some known values after the key
};
@ -66,15 +65,14 @@ struct common_ngram_map {
bool key_only; // true if only key n-grams are used, no values.
std::vector<common_ngram_map_key> keys; // key n-grams which occur several times in token-history
uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token
uint16_t min_hits; // minimum number of key hits to consider a draft
bool show_key_map_stats = false; // true, if statitics of the key_map should be printed.
bool show_key_map_stats = false; // true, if statistics of the key_map should be printed.
common_ngram_map(uint16_t sz_key, uint16_t sz_value, bool only_keys,
uint16_t check_rate, uint16_t min_hits)
uint16_t min_hits)
: size_key(sz_key), size_value(sz_value), key_only(only_keys),
check_rate(check_rate), min_hits(min_hits) {
min_hits(min_hits) {
key_map.resize(COMMON_NGRAM_HASH_MAP_SIZE); // 2^18 hash entries, 0 entries if key_map shouldn't be used
}

View File

@ -113,13 +113,14 @@ static bool common_speculative_are_compatible(
struct common_speculative_state {
const enum common_speculative_type type;
// TODO: rename to n_call_draft, n_gen_drafts, n_acc_drafts, n_gen_tokens, n_acc_tokens
// TODO: add n_call_begin, n_call_accept
size_t drafts_call_count = 0; // number of times this implementation was called.
size_t drafts_generated_count = 0; // number of times a draft or part was generated by this implementation.
size_t drafts_accepted_count = 0; // number of times a draft or part was accepted by the target model.
size_t drafts_generated_tokens = 0; // number of tokens generated by this implementation.
size_t drafts_accepted_tokens = 0; // number of tokens accepted by the target model.
size_t n_call_begin = 0; // number of times this implementation was called for refresh.
size_t n_call_draft = 0; // number of times this implementation was called for generation.
size_t n_call_accept = 0; // number of times this implementation was called for accumulation.
size_t n_gen_drafts = 0; // number of times a draft or part was generated by this implementation.
size_t n_acc_drafts = 0; // number of times a draft or part was accepted by the target model.
size_t n_gen_tokens = 0; // number of tokens generated by this implementation.
size_t n_acc_tokens = 0; // number of tokens accepted by the target model.
// TODO: track performance of most recent calls
const bool gen_perf = true; // whether to generate performance stats.
@ -465,8 +466,6 @@ struct common_speculative_state_eagle3 : public common_speculative_state {
struct common_speculative_state_ngram_simple : public common_speculative_state {
common_ngram_simple_config config;
uint16_t check_id = 0; // used to control the frequency of generating drafts
common_speculative_state_ngram_simple(
enum common_speculative_type type,
common_ngram_simple_config config)
@ -481,11 +480,6 @@ struct common_speculative_state_ngram_simple : public common_speculative_state {
const llama_tokens & prompt_tgt,
llama_token id_last,
llama_tokens & result) override {
++check_id;
if (check_id < config.check_rate) {
return;
}
check_id = 0;
result = common_ngram_simple_draft(config, prompt_tgt, id_last);
GGML_UNUSED(params);
@ -752,10 +746,9 @@ static common_ngram_map get_common_ngram_map(const common_speculative_config & c
uint16_t size_key = config.params.ngram_size_n;
uint16_t size_value = config.params.ngram_size_m;
bool key_only = (config.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K);
uint16_t check_rate = config.params.ngram_check_rate;
uint16_t min_hits = config.params.ngram_min_hits;
return common_ngram_map(size_key, size_value, key_only, check_rate, min_hits);
return common_ngram_map(size_key, size_value, key_only, min_hits);
}
static common_speculative_state_ngram_cache create_state_ngram_cache(
@ -805,6 +798,42 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
return it->second;
}
bool common_speculative_is_compat(llama_context * ctx_tgt) {
auto * mem = llama_get_memory(ctx_tgt);
if (mem == nullptr) {
return false;
}
bool res = true;
llama_memory_clear(mem, true);
// eval 2 tokens to check if the context is compatible
std::vector<llama_token> tmp;
tmp.push_back(0);
tmp.push_back(0);
int ret = llama_decode(ctx_tgt, llama_batch_get_one(tmp.data(), tmp.size()));
if (ret != 0) {
LOG_ERR("%s: llama_decode() failed: %d\n", __func__, ret);
res = false;
goto done;
}
// try to remove the last tokens
if (!llama_memory_seq_rm(mem, 0, 1, -1)) {
LOG_WRN("%s: the target context does not support partial sequence removal\n", __func__);
res = false;
goto done;
}
done:
llama_memory_clear(mem, true);
llama_synchronize(ctx_tgt);
return res;
}
// initialization of the speculative decoding system
//
common_speculative * common_speculative_init(
@ -895,12 +924,10 @@ common_speculative * common_speculative_init(
uint16_t ngram_size_key = ngram_map.size_key;
uint16_t mgram_size_value = ngram_map.size_value;
uint16_t check_rate = ngram_map.check_rate;
auto config_simple = common_ngram_simple_config {
/* .size_ngram = */ ngram_size_key,
/* .size_mgram = */ mgram_size_value,
/* .check_rate = */ check_rate
/* .size_mgram = */ mgram_size_value
};
auto state = std::make_unique<common_speculative_state_ngram_simple>(
/* .type = */ config.type,
@ -961,6 +988,7 @@ void common_speculative_begin(common_speculative * spec, const llama_tokens & pr
for (auto & impl : spec->impls) {
common_time_meas tm(impl->t_begin_us, !impl->gen_perf);
impl->begin(prompt);
impl->n_call_begin++;
}
}
@ -977,17 +1005,17 @@ llama_tokens common_speculative_draft(
{
common_time_meas tm(impl->t_draft_us, !impl->gen_perf);
impl->draft(params, prompt_tgt, id_last, result);
impl->drafts_call_count++;
impl->n_call_draft++;
}
if (!result.empty()) {
LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__,
common_speculative_type_to_str(impl.get()->type).c_str(), prompt_tgt.size(),
impl.get()->drafts_call_count, result.size());
impl.get()->n_call_draft, result.size());
spec->curr_impl = impl.get(); // set current implementation for stats
impl->drafts_generated_count++;
impl->drafts_generated_tokens += result.size();
impl->n_gen_drafts++;
impl->n_gen_tokens += result.size();
break; // We have a draft, so break out of the loop and return it.
}
@ -1008,11 +1036,12 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) {
{
common_time_meas tm(impl->t_accept_us, !impl->gen_perf);
if (n_accepted > 0) {
impl->drafts_accepted_count++;
impl->drafts_accepted_tokens += n_accepted;
impl->n_acc_drafts++;
impl->n_acc_tokens += n_accepted;
}
impl->accept(n_accepted);
impl->n_call_accept++;
}
}
@ -1033,13 +1062,13 @@ void common_speculative_print_stats(const common_speculative * spec) {
str_perf = "";
}
LOG_INF("statistics %s: #calls = %zu, #gen drafts = %zu, #acc drafts = %zu, #gen tokens = %zu, #acc tokens = %zu%s\n",
LOG_INF("statistics %s: #calls(b,g,a) = %zu %zu %zu, #gen drafts = %zu, #acc drafts = %zu, #gen tokens = %zu, #acc tokens = %zu%s\n",
common_speculative_type_to_str(impl->type).c_str(),
impl->drafts_call_count,
impl->drafts_generated_count,
impl->drafts_accepted_count,
impl->drafts_generated_tokens,
impl->drafts_accepted_tokens,
impl->n_call_begin, impl->n_call_draft, impl->n_call_accept,
impl->n_gen_drafts,
impl->n_acc_drafts,
impl->n_gen_tokens,
impl->n_acc_tokens,
str_perf.c_str());
}
}

View File

@ -14,6 +14,10 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
// convert type to string
std::string common_speculative_type_to_str(enum common_speculative_type type);
// check if the llama_context is compatible for speculative decoding
// note: clears the memory of the context
bool common_speculative_is_compat(llama_context * ctx_tgt);
common_speculative * common_speculative_init(
common_params_speculative & params,
llama_context * ctx_tgt);

View File

@ -160,8 +160,6 @@ class ModelBase:
self.ftype = gguf.LlamaFileType.MOSTLY_F16
logger.info("heuristics unable to detect tensor dtype, defaulting to --outtype f16")
self.dequant_model()
# Configure GGUF Writer
self.gguf_writer = gguf.GGUFWriter(path=None, arch=gguf.MODEL_ARCH_NAMES[self.model_arch], endianess=self.endianess, use_temp_file=self.use_temp_file,
split_max_tensors=split_max_tensors, split_max_size=split_max_size, dry_run=dry_run, small_first_shard=small_first_shard)
@ -527,6 +525,8 @@ class ModelBase:
return ()
def prepare_tensors(self):
self.dequant_model()
# Handle empty tensor_map for models with block_count=0 (like MobileNetV5)
if self.tensor_map.mapping:
max_name_len = max(len(s) for _, s in self.tensor_map.mapping.values()) + len(".weight,")
@ -920,7 +920,7 @@ class TextModel(ModelBase):
self.gguf_writer.add_expert_group_used_count(n_group_used)
logger.info(f"gguf: expert groups used count = {n_group_used}")
if (score_func := self.find_hparam(["score_function", "scoring_func", "score_func", "moe_router_activation_func"], optional=True)) is not None:
if (score_func := self.find_hparam(["score_function", "scoring_func", "score_func", "moe_router_activation", "moe_router_activation_func"], optional=True)) is not None:
if score_func == "sigmoid":
self.gguf_writer.add_expert_gating_func(gguf.ExpertGatingFuncType.SIGMOID)
elif score_func == "softmax":
@ -1261,6 +1261,9 @@ class TextModel(ModelBase):
if chkhsh == "6c81ce329e0802883b22eabab0d3fa48357337ef1ecb45443828bf1f6254833f":
# ref: https://huggingface.co/LGAI-EXAONE/K-EXAONE-236B-A23B
res = "exaone-moe"
if chkhsh == "d30d75d9059f1aa2c19359de71047b3ae408c70875e8a3ccf8c5fba56c9d8af4":
# ref: https://huggingface.co/Qwen/Qwen3.5-9B-Instruct
res = "qwen35"
if res is None:
logger.warning("\n")
@ -1812,7 +1815,7 @@ class MmprojModel(ModelBase):
preprocessor_config: dict[str, Any]
global_config: dict[str, Any]
n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers"]
n_block_keys = ["n_layers", "num_hidden_layers", "n_layer", "num_layers", "depth", "encoder_layers", "vt_num_hidden_layers"]
has_vision_encoder: bool = True # by default
has_audio_encoder: bool = False
@ -1867,7 +1870,15 @@ class MmprojModel(ModelBase):
preprocessor_config_path = self.dir_model / "preprocessor_config.json"
if preprocessor_config_path.is_file():
with open(preprocessor_config_path, "r", encoding="utf-8") as f:
self.preprocessor_config = json.load(f)
cfg = json.load(f)
# move media_proc_cfg to root level for compat
if "media_proc_cfg" in cfg:
cfg = {
**cfg,
**cfg["media_proc_cfg"],
}
# merge configs
self.preprocessor_config = {**self.preprocessor_config, **cfg}
# prefer processor_config.json if possible
processor_config_path = self.dir_model / "processor_config.json"
@ -1916,10 +1927,10 @@ class MmprojModel(ModelBase):
self.image_size = self.find_vparam(["image_size"])
self.gguf_writer.add_vision_image_size(self.image_size)
self.gguf_writer.add_vision_patch_size(self.find_vparam(["patch_size"]))
self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size"]))
self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size"]))
self.gguf_writer.add_vision_embedding_length(self.find_vparam(["hidden_size", "vt_hidden_size"]))
self.gguf_writer.add_vision_feed_forward_length(self.find_vparam(["intermediate_size", "vt_intermediate_size"]))
self.gguf_writer.add_vision_block_count(self.find_vparam(self.n_block_keys))
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads", "num_heads"]))
self.gguf_writer.add_vision_head_count(self.find_vparam(["num_attention_heads", "num_heads", "vt_num_attention_heads"]))
# preprocessor config
image_mean = _MISTRAL_COMMON_DATASET_MEAN if self.is_mistral_format else self.preprocessor_config["image_mean"]
@ -4109,37 +4120,29 @@ class Qwen2MoeModel(TextModel):
# Expected GGML ne: {n_embd, n_ff_exp, n_expert} for gate/up, {n_ff_exp, n_embd, n_expert} for down
if name.endswith("mlp.experts.down_proj") or name.endswith("mlp.experts.down_proj.weight"):
mapped = f"{name}.weight" if not name.endswith(".weight") else name
# Input: (n_expert=128, n_ff_exp=768, n_embd=2048)
# Want GGML ne: {n_ff_exp, n_embd, n_expert} = {768, 2048, 128}
# Need PyTorch: (128, 2048, 768) [reversed of GGML]
# So: permute(0, 2, 1): (128, 768, 2048) -> (128, 2048, 768)
permuted = data_torch.permute(0, 2, 1).contiguous()
yield from super().modify_tensors(permuted, mapped, bid)
# HF: [n_expert, n_embd, n_ff] -> GGML: {n_ff, n_embd, n_expert}
yield from super().modify_tensors(data_torch, mapped, bid)
return
if name.endswith("mlp.experts.gate_up_proj") or name.endswith("mlp.experts.gate_up_proj.weight"):
if data_torch.ndim < 3 or data_torch.shape[-1] % 2 != 0:
if data_torch.ndim < 3 or data_torch.shape[-2] % 2 != 0:
raise ValueError(f"Unexpected gate_up_proj shape for {name}: {tuple(data_torch.shape)}")
split_dim = data_torch.shape[-1] // 2
gate = data_torch[..., :split_dim].contiguous()
up = data_torch[..., split_dim:].contiguous()
# Input gate/up: (n_expert=128, n_embd=2048, n_ff_exp=768)
# Want GGML ne: {n_embd, n_ff_exp, n_expert} = {2048, 768, 128}
# Need PyTorch: (128, 768, 2048) [reversed of GGML]
# So: permute(0, 2, 1): (128, 2048, 768) -> (128, 768, 2048)
base_name = name.removesuffix(".weight")
base = base_name.rsplit('.', 1)[0]
mapped_gate = f"{base}.gate_proj.weight"
mapped_up = f"{base}.up_proj.weight"
perm_gate = gate.permute(0, 2, 1).contiguous()
perm_up = up.permute(0, 2, 1).contiguous()
yield from super().modify_tensors(perm_gate, mapped_gate, bid)
yield from super().modify_tensors(perm_up, mapped_up, bid)
# HF: [n_expert, 2*n_ff, n_embd] -> split on dim=-2
n_ff = data_torch.shape[-2] // 2
gate = data_torch[..., :n_ff, :].contiguous()
up = data_torch[..., n_ff:, :].contiguous()
# gate/up: [n_expert, n_ff, n_embd] -> GGML: {n_embd, n_ff, n_expert}
base_name = name.removesuffix(".weight").removesuffix(".gate_up_proj")
mapped_gate = f"{base_name}.gate_proj.weight"
mapped_up = f"{base_name}.up_proj.weight"
yield from super().modify_tensors(gate, mapped_gate, bid)
yield from super().modify_tensors(up, mapped_up, bid)
return
if name.startswith("mlp") or name.startswith("vision_model") or name.startswith("model.vision_tower") or name.startswith("model.multi_modal_projector") or name.startswith("model.visual"):
# skip visual tensors
return
if name.find("experts") != -1:
n_experts = self.hparams["num_experts"]
assert bid is not None
@ -4295,6 +4298,7 @@ class Qwen3NextModel(Qwen2MoeModel):
self.gguf_writer.add_ssm_group_count(self.hparams["linear_num_key_heads"])
self.gguf_writer.add_ssm_time_step_rank(self.hparams["linear_num_value_heads"])
self.gguf_writer.add_ssm_inner_size(self.hparams["linear_value_head_dim"] * self.hparams["linear_num_value_heads"])
self.gguf_writer.add_full_attention_interval(self.hparams.get("full_attention_interval", 4))
if (rope_dim := self.hparams.get("head_dim")) is None:
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.25)))
@ -4359,7 +4363,7 @@ class RND1Model(Qwen2MoeModel):
self.gguf_writer.add_mask_token_id(mask_token_id)
@ModelBase.register("Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration")
@ModelBase.register("Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration", "Qwen3_5ForConditionalGeneration", "Qwen3_5MoeForConditionalGeneration")
class Qwen3VLVisionModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@ -4405,6 +4409,10 @@ class Qwen3VLVisionModel(MmprojModel):
if name.startswith("model.language_model.") or name.startswith("lm_head."):
return
# Skip MTP tensors
if name.startswith("mtp."):
return
if name.startswith("model.visual."):
name = name.replace("model.visual.", "visual.", 1)
@ -4535,9 +4543,125 @@ class Qwen3VLMoeTextModel(Qwen3MoeModel):
if name.startswith("model.visual."):
return
# Qwen3VL has transposed packed tensors, so we treat it differently from general Qwen2MoE packed tensors
if name.endswith("mlp.experts.down_proj") or name.endswith("mlp.experts.down_proj.weight"):
name = name.replace("language_model.", "")
mapped = f"{name}.weight" if not name.endswith(".weight") else name
permuted = data_torch.permute(0, 2, 1).contiguous()
yield from ModelBase.modify_tensors(self, permuted, mapped, bid)
return
if name.endswith("mlp.experts.gate_up_proj") or name.endswith("mlp.experts.gate_up_proj.weight"):
name = name.replace("language_model.", "")
if data_torch.ndim < 3 or data_torch.shape[-1] % 2 != 0:
raise ValueError(f"Unexpected gate_up_proj shape for {name}: {tuple(data_torch.shape)}")
split_dim = data_torch.shape[-1] // 2
gate = data_torch[..., :split_dim].contiguous()
up = data_torch[..., split_dim:].contiguous()
# Input gate/up: (n_expert=128, n_embd=2048, n_ff_exp=768)
# Want GGML ne: {n_embd, n_ff_exp, n_expert} = {2048, 768, 128}
# Need PyTorch: (128, 768, 2048) [reversed of GGML]
# So: permute(0, 2, 1): (128, 2048, 768) -> (128, 768, 2048)
base_name = name.removesuffix(".weight")
base = base_name.rsplit('.', 1)[0]
mapped_gate = f"{base}.gate_proj.weight"
mapped_up = f"{base}.up_proj.weight"
perm_gate = gate.permute(0, 2, 1).contiguous()
perm_up = up.permute(0, 2, 1).contiguous()
yield from ModelBase.modify_tensors(self, perm_gate, mapped_gate, bid)
yield from ModelBase.modify_tensors(self, perm_up, mapped_up, bid)
return
yield from super().modify_tensors(data_torch, name, bid)
class _LinearAttentionVReorderBase(Qwen3NextModel):
model_arch = gguf.MODEL_ARCH.QWEN3NEXT # overridden by subclasses
"""reorders V heads from grouped to tiled order for ggml broadcast
see https://github.com/ggml-org/llama.cpp/pull/19468#discussion_r2786394306
Linear attention may has num_k_heads < num_v_heads. The HF weights store
V heads grouped by K head: [G0_v0..v{r-1}, G1_v0..v{r-1}, ...].
ggml binary ops use tiled broadcast: [K0, K1, ..., K0, K1, ...].
We reorder V heads to tiled order so ggml_repeat can replace the expensive
interleaved repeat: [G0_v0, G1_v0, ..., G0_v1, G1_v1, ...].
"""
@staticmethod
def _reorder_v_heads(tensor: Tensor, dim: int, num_k_heads: int, num_v_per_k: int, head_dim: int) -> Tensor:
"""Reorder V heads from grouped (by K head) to tiled order along the given dimension."""
shape = list(tensor.shape)
if dim < 0:
dim += len(shape)
new_shape = shape[:dim] + [num_k_heads, num_v_per_k, head_dim] + shape[dim + 1:]
tensor = tensor.reshape(*new_shape)
perm = list(range(len(new_shape)))
perm[dim], perm[dim + 1] = perm[dim + 1], perm[dim]
return tensor.permute(*perm).contiguous().reshape(*shape)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
num_k_heads = self.hparams.get("linear_num_key_heads", 0)
num_v_heads = self.hparams.get("linear_num_value_heads", 0)
if num_k_heads > 0 and num_v_heads > 0 and num_k_heads != num_v_heads and "linear_attn." in name:
head_k_dim = self.hparams["linear_key_head_dim"]
head_v_dim = self.hparams["linear_value_head_dim"]
num_v_per_k = num_v_heads // num_k_heads
if ".in_proj_qkv." in name:
# QKV weight: reorder only the V rows
q_dim = head_k_dim * num_k_heads
k_dim = head_k_dim * num_k_heads
q = data_torch[:q_dim]
k = data_torch[q_dim:q_dim + k_dim]
v = data_torch[q_dim + k_dim:]
v = self._reorder_v_heads(v, 0, num_k_heads, num_v_per_k, head_v_dim)
data_torch = torch.cat([q, k, v], dim=0)
elif ".in_proj_z." in name:
# Z gate weight: reorder rows (num_v_heads * head_v_dim)
data_torch = self._reorder_v_heads(data_torch, 0, num_k_heads, num_v_per_k, head_v_dim)
elif ".in_proj_b." in name or ".in_proj_a." in name:
# Beta/Alpha weight: reorder rows (num_v_heads, head_dim=1)
data_torch = self._reorder_v_heads(data_torch, 0, num_k_heads, num_v_per_k, 1)
elif ".A_log" in name or ".dt_bias" in name or ".dt_proj" in name:
# A_log / dt_bias: 1D parameters with num_v_heads elements
if data_torch.ndim == 1:
data_torch = self._reorder_v_heads(
data_torch.unsqueeze(-1), 0, num_k_heads, num_v_per_k, 1
).squeeze(-1)
else:
data_torch = self._reorder_v_heads(data_torch, -1, num_k_heads, num_v_per_k, 1)
elif ".conv1d" in name:
# Conv1d kernel: reorder only the V channel portion
data = data_torch.squeeze()
qk_channels = head_k_dim * num_k_heads * 2
qk_part = data[:qk_channels]
v_part = data[qk_channels:]
v_part = self._reorder_v_heads(v_part, 0, num_k_heads, num_v_per_k, head_v_dim)
data_torch = torch.cat([qk_part, v_part], dim=0)
elif ".out_proj." in name:
# Out projection weight: reorder columns (input dimension)
data_torch = self._reorder_v_heads(data_torch, 1, num_k_heads, num_v_per_k, head_v_dim)
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("Qwen3_5ForConditionalGeneration")
class Qwen3_5TextModel(_LinearAttentionVReorderBase):
model_arch = gguf.MODEL_ARCH.QWEN35
@ModelBase.register("Qwen3_5MoeForConditionalGeneration")
class Qwen3_5MoeTextModel(_LinearAttentionVReorderBase):
model_arch = gguf.MODEL_ARCH.QWEN35MOE
@ModelBase.register("GPT2LMHeadModel")
class GPT2Model(TextModel):
model_arch = gguf.MODEL_ARCH.GPT2
@ -7579,6 +7703,7 @@ class DeepseekModel(TextModel):
"DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM",
"KimiVLForConditionalGeneration",
"KimiK25ForConditionalGeneration",
"YoutuForCausalLM",
"YoutuVLForConditionalGeneration",
)
@ -7697,8 +7822,8 @@ class DeepseekV2Model(TextModel):
_experts: list[dict[str, Tensor]] | None = None
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# skip vision tensors and remove "language_model." for Kimi-VL
if "vision_tower" in name or "multi_modal_projector" in name:
# skip vision tensors and remove "language_model." for Kimi-VL and Kimi-K2.5
if "vision_tower" in name or "multi_modal_projector" in name or "mm_projector" in name:
return
if name.startswith("siglip2.") or name.startswith("merger."):
return
@ -7912,6 +8037,135 @@ class MimoV2Model(TextModel):
raise ValueError(f"Unprocessed experts: {experts}")
@ModelBase.register("Step3p5ForCausalLM")
class Step35Model(TextModel):
model_arch = gguf.MODEL_ARCH.STEP35
def set_gguf_parameters(self):
rope_theta = self.hparams.get("rope_theta")
if isinstance(rope_theta, list):
self.hparams["rope_theta"] = float(rope_theta[0])
self.hparams["local_rope_theta"] = float(rope_theta[1])
self.rope_parameters["rope_theta"] = self.hparams["rope_theta"]
self.rope_parameters["sliding_attention"] = {"rope_theta": self.hparams["local_rope_theta"]}
super().set_gguf_parameters()
layer_types = self.hparams.get("layer_types") or []
partial_rotary_factors = self.hparams.get("partial_rotary_factors") or []
attn_other = self.hparams.get("attention_other_setting") or {}
n_head_base = self.hparams["num_attention_heads"]
n_kv_base = self.hparams["num_attention_groups"]
n_head_swa = attn_other.get("num_attention_heads", n_head_base)
n_kv_swa = attn_other.get("num_attention_groups", n_kv_base)
layer_types = layer_types[: self.block_count]
partial_rotary_factors = partial_rotary_factors[: self.block_count]
assert [1.0 if lt == "sliding_attention" else 0.5 for lt in layer_types] == partial_rotary_factors
head_arr = [n_head_swa if lt == "sliding_attention" else n_head_base for lt in layer_types]
kv_arr = [n_kv_swa if lt == "sliding_attention" else n_kv_base for lt in layer_types]
swa_pat = [lt == "sliding_attention" for lt in layer_types]
self.gguf_writer.add_head_count(head_arr)
self.gguf_writer.add_head_count_kv(kv_arr)
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
self.gguf_writer.add_sliding_window_pattern(swa_pat)
self.gguf_writer.add_value_length(self.hparams["head_dim"])
# MoE params
self.gguf_writer.add_expert_count(self.hparams["moe_num_experts"])
self.gguf_writer.add_expert_used_count(self.hparams["moe_top_k"])
self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"])
self.gguf_writer.add_expert_shared_feed_forward_length(self.hparams["share_expert_dim"])
if (moe_router_scaling_factor := self.hparams.get("moe_router_scaling_factor")) is not None:
self.gguf_writer.add_expert_weights_scale(moe_router_scaling_factor)
if (norm_expert_weight := self.hparams.get("norm_expert_weight")) is not None:
self.gguf_writer.add_expert_weights_norm(norm_expert_weight)
# leading dense blocks
leading_dense = 0
moe_layers_enum = self.hparams.get("moe_layers_enum")
if isinstance(moe_layers_enum, str) and moe_layers_enum.strip():
moe_layers = sorted(int(i) for i in moe_layers_enum.strip().split(","))
if moe_layers:
leading_dense = max(0, moe_layers[0])
self.gguf_writer.add_leading_dense_block_count(leading_dense)
self.gguf_writer.add_moe_every_n_layers(int(self.hparams.get("moe_every_n_layer", 1)))
self.gguf_writer.add_layer_norm_rms_eps(self.hparams.get("rms_norm_eps", 1e-5))
# Optional per-layer SwiGLU clamps.
if (limits := self.hparams.get("swiglu_limits")) is not None:
limits_f = [0.0 if v is None else float(v) for v in limits[: self.block_count]]
self.gguf_writer.add_swiglu_clamp_exp(limits_f)
if (limits_shared := self.hparams.get("swiglu_limits_shared")) is not None:
limits_shared_f = [0.0 if v is None else float(v) for v in limits_shared[: self.block_count]]
self.gguf_writer.add_swiglu_clamp_shexp(limits_shared_f)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
# remove mtp layers
if (m := re.match(r"model\.layers\.(\d+)\.", name)) is not None:
il = int(m.group(1))
n_main = int(self.hparams.get("num_hidden_layers", self.block_count))
if il >= n_main:
return
if name.endswith("norm.weight"):
data_torch += 1.0
# Map router bias (expert selection bias) to a GGUF bias tensor
if name.endswith(".moe.router_bias"):
name += ".bias"
if name.endswith((".self_attn.g_proj.weight", ".moe.gate.weight", ".moe.up_proj.weight", ".moe.gate_proj.weight", ".moe.down_proj.weight")):
data_torch = data_torch.squeeze().contiguous()
yield from super().modify_tensors(data_torch, name, bid)
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
# Step35 can optionally use Llama-3 style RoPE scaling (HF: rope_scaling.rope_type == "llama3").
# llama.cpp represents this via a single extra tensor: "rope_freqs.weight" (aka MODEL_TENSOR.ROPE_FREQS).
rope_params = self.rope_parameters.get("full_attention", self.rope_parameters)
rope_type = rope_params.get("rope_type") or ""
if rope_type.lower() != "llama3":
return
# Step35 configs can carry per-layer rope_theta as a list; for llama3 rope factors we use the base value.
rope_theta = self.hparams.get("rope_theta", 10000.0)
if isinstance(rope_theta, list):
rope_theta = rope_theta[0]
base = float(rope_theta)
if (dim := self.hparams.get("head_dim")) is None:
dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
dim = int(dim)
freqs = 1.0 / (base ** (torch.arange(0, dim, 2, dtype=torch.float32) / dim))
factor = float(rope_params.get("factor", 8.0))
low_freq_factor = float(rope_params.get("low_freq_factor", 1.0))
high_freq_factor = float(rope_params.get("high_freq_factor", 4.0))
old_context_len = int(rope_params.get("original_max_position_embeddings", self.hparams.get("original_max_position_embeddings", 8192)))
low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
rope_factors: list[float] = []
for freq in freqs:
wavelen = 2 * math.pi / float(freq)
if wavelen < high_freq_wavelen:
rope_factors.append(1.0)
elif wavelen > low_freq_wavelen:
rope_factors.append(factor)
else:
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
rope_factors.append(1.0 / ((1.0 - smooth) / factor + smooth))
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FREQS), torch.tensor(rope_factors, dtype=torch.float32))
@ModelBase.register("PanguEmbeddedForCausalLM")
class PanguEmbeddedModel(TextModel):
model_arch = gguf.MODEL_ARCH.PANGU_EMBED
@ -10931,6 +11185,103 @@ class KimiVLModel(MmprojModel):
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("KimiK25ForConditionalGeneration")
class KimiK25Model(MmprojModel):
"""Kimi-K2.5 with MoonViT3d vision encoder"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams_vision is not None, "Kimi-K2.5 requires vision_config in model config"
self.merge_kernel_size = tuple(self.hparams_vision.get("merge_kernel_size", [2, 2]))
self.patch_size = self.hparams_vision.get("patch_size", 14)
# Set image_size for compatibility with base class
# Use position embedding dimensions as image_size reference
pos_emb_h = self.hparams_vision.get("init_pos_emb_height", 64)
self.hparams_vision["image_size"] = pos_emb_h * self.patch_size
def set_gguf_parameters(self):
# Base class MmprojModel.set_gguf_parameters() already writes:
# - vision_block_count, vision_head_count, vision_embedding_length
# - vision_feed_forward_length, vision_patch_size, image_mean, image_std
# via find_vparam() which handles the vt_* prefixed keys in Kimi-K2.5's config
super().set_gguf_parameters()
assert self.hparams_vision is not None
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.KIMIK25)
# Position embedding parameters (for interpolation)
self.gguf_writer.add_uint32("vision.pos_emb_height", self.hparams_vision.get("init_pos_emb_height", 64))
self.gguf_writer.add_uint32("vision.pos_emb_width", self.hparams_vision.get("init_pos_emb_width", 64))
self.gguf_writer.add_uint32("vision.pos_emb_time", self.hparams_vision.get("init_pos_emb_time", 4))
# Projector parameters
self.gguf_writer.add_vision_use_gelu(self.hparams_vision.get("projector_hidden_act", "gelu") == "gelu")
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams_vision.get("projector_ln_eps", 1e-5))
self.gguf_writer.add_vision_projector_scale_factor(self.merge_kernel_size[0])
# Image size limits
# Note: in_patch_limit is for images, in_patch_limit_each_frame is for video (not supported yet)
in_patch_limit = self.preprocessor_config.get("in_patch_limit", 16384)
min_patches = 8 # reasonable minimum
pixels_per_patch = self.patch_size ** 2
self.gguf_writer.add_vision_min_pixels(min_patches * pixels_per_patch)
self.gguf_writer.add_vision_max_pixels(in_patch_limit * pixels_per_patch)
@staticmethod
def permute(weights: Tensor, n_head: int) -> Tensor:
out_dim, in_dim = weights.shape
head_dim = out_dim // n_head
w = weights.reshape(n_head, head_dim // 4, 2, 2, in_dim)
w = w.permute(0, 2, 1, 3, 4)
return w.reshape(out_dim, in_dim)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
# Only process vision and projector tensors
is_vision = any(x in name for x in ["vision_tower", "mm_projector"])
if not is_vision:
return
assert self.hparams_vision is not None
n_head = self.hparams_vision.get("num_attention_heads", 16)
# Permute Q/K weights/biases from interleaved to split RoPE format
# This allows using build_rope_2d at runtime without post-permutation.
if "wqkv" in name:
out_dim = data_torch.shape[0]
qkv_dim = out_dim // 3
head_dim = qkv_dim // n_head
if "weight" in name:
wq, wk, wv = data_torch[:qkv_dim, :], data_torch[qkv_dim:2 * qkv_dim, :], data_torch[2 * qkv_dim:, :]
wq = self.permute(wq, n_head)
wk = self.permute(wk, n_head)
data_torch = torch.cat([wq, wk, wv], dim=0)
elif "bias" in name:
bq, bk, bv = data_torch[:qkv_dim], data_torch[qkv_dim:2 * qkv_dim], data_torch[2 * qkv_dim:]
bq = bq.reshape(n_head, head_dim // 4, 2, 2).permute(0, 2, 1, 3).reshape(-1)
bk = bk.reshape(n_head, head_dim // 4, 2, 2).permute(0, 2, 1, 3).reshape(-1)
data_torch = torch.cat([bq, bk, bv], dim=0)
# Temporal embeddings: (T, 1, C) → (T, C)
if "pos_emb.time_weight" in name:
T, _, C = data_torch.shape
data_torch = data_torch.reshape(T, C)
# PatchMergerMLP tensor name mapping
# proj.0.weight → proj.linear_1.weight
# proj.2.weight → proj.linear_2.weight
if "mm_projector.proj.0." in name:
name = name.replace(".proj.0.", ".proj.linear_1.")
elif "mm_projector.proj.2." in name:
name = name.replace(".proj.2.", ".proj.linear_2.")
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("CogVLMForCausalLM")
class CogVLMVisionModel(MmprojModel):

View File

@ -148,6 +148,7 @@ models = [
{"name": "youtu", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Youtu-LLM-2B", },
{"name": "solar-open", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/upstage/Solar-Open-100B", },
{"name": "exaone-moe", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/LGAI-EXAONE/K-EXAONE-236B-A23B", },
{"name": "qwen35", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3.5-9B-Instruct", }
]
# some models are known to be broken upstream, so we will skip them as exceptions

180
docs/backend/VirtGPU.md Normal file
View File

@ -0,0 +1,180 @@
# GGML-VirtGPU Backend
The GGML-VirtGPU backend enables GGML applications to run machine
learning computations on host hardware while the application itself
runs inside a virtual machine. It uses host-guest shared memory to
efficiently share data buffers between the two sides.
This backend relies on the virtio-gpu, and VirglRenderer API Remoting
(APIR) component. The backend is split into two libraries:
- a GGML implementation (the "remoting frontend"), running in the
guest and interacting with the virtgpu device
- a VirglRenderer APIR compatible library (the "remoting backend"),
running in the host and interacting with Virglrenderer and an actual
GGML device backend.
## OS support
| OS | Status | Backend | CI testing | Notes
| -------- | ----------------- | ----------- | ----------- | -----
| MacOS 14 | Supported | ggml-metal | X | Working when compiled on MacOS 14
| MacOS 15 | Supported | ggml-metal | X | Working when compiled on MacOS 14 or MacOS 15
| MacOS 26 | Not tested | | |
| Linux | Under development | ggml-vulkan | not working | Working locally, CI running into deadlocks
## Architecture Overview
The GGML-VirtGPU backend consists of three main components:
```mermaid
graph TD
%% Nodes
subgraph GuestVM ["Guest VM - Frontend"]
App([GGML Application<br/>llama.cpp, etc.])
direction TB
Interface[GGML Backend Interface]
Comm["GGML-VirtGPU<br/>(hypercalls + shared mem)"]
App --> Interface
Interface --> Comm
end
API[virtio-gpu / virglrenderer API]
subgraph HostSystem [Host System - Backend]
direction TB
Dispatcher[GGML-VirtGPU-Backend]
BackendLib[GGML Backend library<br/>Metal / Vulkan / CPU / ...]
Dispatcher --> BackendLib
end
%% Connections
Comm --> API
API --> HostSystem
```
### Key Components
1. **Guest-side Frontend** (`ggml-virtgpu/`): Implements the GGML backend interface and forwards operations to the host
2. **Host-side Backend** (`ggml-virtgpu/backend/`): Receives forwarded operations and executes them on actual hardware backends
3. **Communication Layer**: Uses virtio-gpu hypercalls and shared memory for efficient data transfer
## Features
- **Dynamic backend loading** on the host side (CPU, CUDA, Metal, etc.)
- **Zero-copy data transfer** via host-guest shared memory pages
## Communication Protocol
### Hypercalls and Shared Memory
The backend uses two primary communication mechanisms:
1. **Hypercalls (`DRM_IOCTL_VIRTGPU_EXECBUFFER`)**: Trigger remote execution from guest to host
2. **Shared Memory Pages**: Zero-copy data transfer for tensors and parameters
#### Shared Memory Layout
Each connection uses two shared memory buffers:
- **Data Buffer** (24 MiB): For command/response data and tensor transfers
- **Reply Buffer** (16 KiB): For command replies and status information
- **Data Buffers**: Dynamically allocated host-guest shared buffers
served as GGML buffers.
### APIR Protocol
The Virglrender API Remoting protocol defines three command types:
- `HANDSHAKE`: Protocol version negotiation and capability discovery
- `LOADLIBRARY`: Dynamic loading of backend libraries on the host
- `FORWARD`: API function call forwarding
### Binary Serialization
Commands and data are serialized using a custom binary protocol with:
- Fixed-size encoding for basic types
- Variable-length arrays with size prefixes
- Buffer bounds checking
- Error recovery mechanisms
## Supported Operations
### Device Operations
- Device enumeration and capability queries
- Memory information (total/free)
- Backend type detection
### Buffer Operations
- Buffer allocation and deallocation
- Tensor data transfer (host ↔ guest)
- Memory copying and clearing
### Computation Operations
- Graph execution forwarding
## Build Requirements
### Guest-side Dependencies
- `libdrm` for DRM/virtio-gpu communication
- C++20 compatible compiler
- CMake 3.14+
### Host-side Dependencies
- virglrenderer with APIR support (pending upstream review)
- Target backend libraries (libggml-metal, libggml-vulkan, etc.)
## Configuration
### Environment Variables
- `GGML_VIRTGPU_BACKEND_LIBRARY`: Path to the host-side backend library
- `GGML_VIRTGPU_DEBUG`: Enable debug logging
### Build Options
- `GGML_VIRTGPU`: Enable the VirtGPU backend (`ON` or `OFF`, default: `OFF`)
- `GGML_VIRTGPU_BACKEND`: Build the host-side backend component (`ON`, `OFF` or `ONLY`, default: `OFF`)
### System Requirements
- VM with virtio-gpu support
- VirglRenderer with APIR patches
- Compatible backend libraries on host
## Limitations
- **VM-specific**: Only works in virtual machines with virtio-gpu support
- **Host dependency**: Requires properly configured host-side backend
- **Latency**: Small overhead from VM escaping for each operation
* This work is pending upstream changes in the VirglRenderer
project.
* The backend can be tested with Virglrenderer compiled from source
using this PR:
https://gitlab.freedesktop.org/virgl/virglrenderer/-/merge_requests/1590
* This work is pending changes in the VMM/hypervisor running the
virtual machine, which need to know how to route the newly
introduced APIR capset.
* The environment variable `VIRGL_ROUTE_VENUS_TO_APIR=1` allows
using the Venus capset, until the relevant hypervisors have been
patched. However, setting this flag breaks the Vulkan/Venus normal
behavior.
* The environment variable `GGML_REMOTING_USE_APIR_CAPSET` tells the
`ggml-virtgpu` backend to use the APIR capset. This will become
the default when the relevant hypervisors have been patched.
* This work focused on improving the performance of llama.cpp running
on MacOS containers, and is mainly tested on this platform. The
linux support (via `krun`) is in progress.
## See Also
- [Development and Testing](VirtGPU/development.md)
- [Backend configuration](VirtGPU/configuration.md)

View File

@ -0,0 +1,174 @@
# GGML-VirtGPU Backend Configuration
This document describes the environment variables used by the ggml-virtgpu backend system, covering both the frontend (guest-side) and backend (host-side) components.
## Environment Variables Overview
The ggml-virtgpu backend uses environment variables for configuration across three main components:
- **Frontend (Guest)**: GGML applications running in VMs
- **Hypervisor**: Virglrenderer/APIR system
- **Backend (Host)**: Host-side GGML backend integration
## Frontend (Guest-side) Configuration
### GGML_REMOTING_USE_APIR_CAPSET
- **Location**: `ggml/src/ggml-virtgpu/virtgpu.cpp`
- **Type**: Boolean flag (presence-based)
- **Purpose**: Controls which virtio-gpu capability set to use for communication
- **Values**:
- Set (any value): Use the APIR capset (long-term setup)
- Unset: Use the Venus capset (easier for testing with an unmodified hypervisor)
- **Default**: Unset (Venus capset)
- **Usage**:
```bash
export GGML_REMOTING_USE_APIR_CAPSET=1 # Use APIR capset
# or leave unset for Venus capset
```
## Hypervisor (Virglrenderer/APIR) Configuration
These environment variables are used during the transition phase for
running with an unmodified hypervisor (not supporting the
VirglRenderer APIR component). They will be removed in the future, and
the hypervisor will instead configure VirglRenderer with the APIR
_Configuration Key_.
### VIRGL_APIR_BACKEND_LIBRARY
- **Location**: `virglrenderer/src/apir/apir-context.c`
- **Configuration Key**: `apir.load_library.path`
- **Type**: File path string
- **Purpose**: Path to the APIR backend library that virglrenderer should dynamically load
- **Required**: Yes
- **Example**:
```bash
export VIRGL_APIR_BACKEND_LIBRARY="/path/to/libggml-remotingbackend.so"
```
### VIRGL_ROUTE_VENUS_TO_APIR
- **Location**: `virglrenderer/src/apir/apir-renderer.h`
- **Type**: Boolean flag (presence-based)
- **Purpose**: Temporary workaround to route Venus capset calls to APIR during hypervisor transition period
- **Status**: will be removed once hypervisors support APIR natively
- **Warning**: Breaks normal Vulkan/Venus functionality
- **Usage**:
```bash
export VIRGL_ROUTE_VENUS_TO_APIR=1 # For testing with an unmodified hypervisor
```
### VIRGL_APIR_LOG_TO_FILE
- **Location**: `virglrenderer/src/apir/apir-renderer.c`
- **Environment Variable**: `VIRGL_APIR_LOG_TO_FILE`
- **Type**: File path string
- **Purpose**: Enable debug logging from the VirglRenderer APIR component to specified file
- **Required**: No (optional debugging)
- **Default**: Logging to `stderr`
- **Usage**:
```bash
export VIRGL_APIR_LOG_TO_FILE="/tmp/apir-debug.log"
```
## Backend (Host-side) Configuration
These environment variables are used during the transition phase for
running with an unmodified hypervisor (not supporting the
VirglRenderer APIR component). They will be removed in the future, and
the hypervisor will instead configure VirglRenderer with the APIR
_Configuration Key_.
### APIR_LLAMA_CPP_GGML_LIBRARY_PATH
- **Location**: `ggml/src/ggml-virtgpu/backend/backend.cpp`
- **Environment Variable**: `APIR_LLAMA_CPP_GGML_LIBRARY_PATH`
- **Configuration Key**: `ggml.library.path`
- **Type**: File path string
- **Purpose**: Path to the actual GGML backend library (Metal, CUDA, Vulkan, etc.)
- **Required**: **Yes** - backend initialization fails without this
- **Examples**:
```bash
# macOS with Metal backend
export APIR_LLAMA_CPP_GGML_LIBRARY_PATH="/opt/llama.cpp/lib/libggml-metal.dylib"
# Linux with CUDA backend
export APIR_LLAMA_CPP_GGML_LIBRARY_PATH="/opt/llama.cpp/lib/libggml-cuda.so"
# macOS or Linux with Vulkan backend
export APIR_LLAMA_CPP_GGML_LIBRARY_PATH="/opt/llama.cpp/lib/libggml-vulkan.so"
```
### APIR_LLAMA_CPP_GGML_LIBRARY_REG
- **Location**: `ggml/src/ggml-virtgpu/backend/backend.cpp`
- **Environment Variable**: `APIR_LLAMA_CPP_GGML_LIBRARY_REG`
- **Configuration Key**: `ggml.library.reg`
- **Type**: Function symbol name string
- **Purpose**: Name of the backend registration function to call after loading the library
- **Required**: No (defaults to `ggml_backend_init`)
- **Default**: `ggml_backend_init`
- **Examples**:
```bash
# Metal backend
export APIR_LLAMA_CPP_GGML_LIBRARY_REG="ggml_backend_metal_reg"
# CUDA backend
export APIR_LLAMA_CPP_GGML_LIBRARY_REG="ggml_backend_cuda_reg"
# Vulkan backend
export APIR_LLAMA_CPP_GGML_LIBRARY_REG="ggml_backend_vulkan_reg"
# Generic fallback (default)
# export APIR_LLAMA_CPP_GGML_LIBRARY_REG="ggml_backend_init"
```
### APIR_LLAMA_CPP_LOG_TO_FILE
- **Location**: `ggml/src/ggml-virtgpu/backend/backend.cpp:62`
- **Environment Variable**: `APIR_LLAMA_CPP_LOG_TO_FILE`
- **Type**: File path string
- **Purpose**: Enable debug logging from the GGML backend to specified file
- **Required**: No (optional debugging)
- **Usage**:
```bash
export APIR_LLAMA_CPP_LOG_TO_FILE="/tmp/ggml-backend-debug.log"
```
## Configuration Flow
The configuration system works as follows:
1. **Hypervisor Setup**: Virglrenderer loads the APIR backend library specified by `VIRGL_APIR_BACKEND_LIBRARY`
2. **Context Creation**: When an APIR context is created, it populates a configuration table with environment variables:
- `apir.load_library.path``VIRGL_APIR_BACKEND_LIBRARY`
- `ggml.library.path``APIR_LLAMA_CPP_GGML_LIBRARY_PATH`
- `ggml.library.reg``APIR_LLAMA_CPP_GGML_LIBRARY_REG`
- this step will eventually be performed by the hypervisor itself, with command-line arguments instead of environment variables.
3. **Backend Initialization**: The backend queries the configuration via callbacks:
- `virgl_cbs->get_config(ctx_id, "ggml.library.path")` returns the library path
- `virgl_cbs->get_config(ctx_id, "ggml.library.reg")` returns the registration function
4. **Library Loading**: The backend dynamically loads and initializes the specified GGML library
## Error Messages
Common error scenarios and their messages:
- **Missing library path**: `"cannot open the GGML library: env var 'APIR_LLAMA_CPP_GGML_LIBRARY_PATH' not defined"`
- **Missing registration function**: `"cannot register the GGML library: env var 'APIR_LLAMA_CPP_GGML_LIBRARY_REG' not defined"`
## Example Complete Configuration
Here's an example configuration for a macOS host with Metal backend:
```bash
# Hypervisor environment
export VIRGL_APIR_BACKEND_LIBRARY="/opt/llama.cpp/lib/libggml-virtgpu-backend.dylib"
# Backend configuration
export APIR_LLAMA_CPP_GGML_LIBRARY_PATH="/opt/llama.cpp/lib/libggml-metal.dylib"
export APIR_LLAMA_CPP_GGML_LIBRARY_REG="ggml_backend_metal_reg"
# Optional logging
export VIRGL_APIR_LOG_TO_FILE="/tmp/apir.log"
export APIR_LLAMA_CPP_LOG_TO_FILE="/tmp/ggml.log"
# Guest configuration
export GGML_REMOTING_USE_APIR_CAPSET=1
```

View File

@ -0,0 +1,220 @@
# Development and Testing
## Development
### Code Generation
The backend uses code generation from YAML configuration:
```bash
# Regenerate protocol code
cd ggml-virtgpu/
python regenerate_remoting.py
```
### Adding New Operations
1. Add function definition to `ggmlremoting_functions.yaml`
2. Regenerate code with `regenerate_remoting.py`
3. Implement guest-side forwarding in `virtgpu-forward-*.cpp`
4. Implement host-side handling in `backend-dispatched-*.cpp`
## Testing
This document provides instructions for building and testing the GGML-VirtGPU backend on macOS with containers.
### Prerequisites
The testing setup requires:
- macOS host system
- Container runtime with `libkrun` provider (podman machine)
- Access to development patchset for VirglRenderer
### Required Patchsets
The backend requires patches that are currently under review:
- **Virglrenderer APIR upstream PR**: https://gitlab.freedesktop.org/virgl/virglrenderer/-/merge_requests/1590 (for reference)
- **MacOS Virglrenderer (for krunkit)**: https://gitlab.freedesktop.org/kpouget/virglrenderer/-/tree/main-macos
- **Linux Virglrenderer (for krun)**: https://gitlab.freedesktop.org/kpouget/virglrenderer/-/tree/main-linux
### Build Instructions
#### 1. Build ggml-virtgpu-backend (Host-side, macOS)
```bash
# Build the backend that runs natively on macOS
mkdir llama.cpp
cd llama.cpp
git clone https://github.com/ggml-org/llama.cpp.git src
cd src
LLAMA_MAC_BUILD=$PWD/build/ggml-virtgpu-backend
cmake -S . -B $LLAMA_MAC_BUILD \
-DGGML_NATIVE=OFF \
-DLLAMA_CURL=ON \
-DGGML_REMOTINGBACKEND=ONLY \
-DGGML_METAL=ON
TARGETS="ggml-metal"
cmake --build $LLAMA_MAC_BUILD --parallel 8 --target $TARGETS
# Build additional tools for native benchmarking
EXTRA_TARGETS="llama-run llama-bench"
cmake --build $LLAMA_MAC_BUILD --parallel 8 --target $EXTRA_TARGETS
```
#### 2. Build virglrenderer (Host-side, macOS)
```bash
# Build virglrenderer with APIR support
mkdir virglrenderer
git clone https://gitlab.freedesktop.org/kpouget/virglrenderer -b main-macos src
cd src
VIRGL_BUILD_DIR=$PWD/build
# -Dvenus=true and VIRGL_ROUTE_VENUS_TO_APIR=1 route the APIR requests via the Venus backend, for easier testing without a patched hypervisor
meson setup $VIRGL_BUILD_DIR \
-Dvenus=true \
-Dapir=true
ninja -C $VIRGL_BUILD_DIR
```
#### 3. Build ggml-virtgpu (Guest-side, Linux)
Option A: Build from a script:
```bash
# Inside a Linux container
mkdir llama.cpp
git clone https://github.com/ggml-org/llama.cpp.git src
cd src
LLAMA_LINUX_BUILD=$PWD//build-virtgpu
cmake -S . -B $LLAMA_LINUX_BUILD \
-DGGML_VIRTGPU=ON
ninja -C $LLAMA_LINUX_BUILD
```
Option B: Build container image with frontend:
```bash
cat << EOF > remoting.containerfile
FROM quay.io/fedora/fedora:43
USER 0
WORKDIR /app/remoting
ARG LLAMA_CPP_REPO="https://github.com/ggml-org/llama.cpp.git"
ARG LLAMA_CPP_VERSION="master"
ARG LLAMA_CPP_CMAKE_FLAGS="-DGGML_VIRTGPU=ON"
ARG LLAMA_CPP_CMAKE_BUILD_FLAGS="--parallel 4"
RUN dnf install -y git cmake gcc gcc-c++ libcurl-devel libdrm-devel
RUN git clone "\${LLAMA_CPP_REPO}" src \\
&& git -C src fetch origin \${LLAMA_CPP_VERSION} \\
&& git -C src reset --hard FETCH_HEAD
RUN mkdir -p build \\
&& cd src \\
&& set -o pipefail \\
&& cmake -S . -B ../build \${LLAMA_CPP_CMAKE_FLAGS} \\
&& cmake --build ../build/ \${LLAMA_CPP_CMAKE_BUILD_FLAGS}
ENTRYPOINT ["/app/remoting/src/build/bin/llama-server"]
EOF
mkdir -p empty_dir
podman build -f remoting.containerfile ./empty_dir -t localhost/llama-cpp.virtgpu
```
### Environment Setup
#### Set krunkit Environment Variables
```bash
# Define the base directories (adapt these paths to your system)
VIRGL_BUILD_DIR=$HOME/remoting/virglrenderer/build
LLAMA_MAC_BUILD=$HOME/remoting/llama.cpp/build-backend
# For krunkit to load the custom virglrenderer library
export DYLD_LIBRARY_PATH=$VIRGL_BUILD_DIR/src
# For Virglrenderer to load the ggml-remotingbackend library
export VIRGL_APIR_BACKEND_LIBRARY="$LLAMA_MAC_BUILD/bin/libggml-virtgpu-backend.dylib"
# For llama.cpp remotingbackend to load the ggml-metal backend
export APIR_LLAMA_CPP_GGML_LIBRARY_PATH="$LLAMA_MAC_BUILD/bin/libggml-metal.dylib"
export APIR_LLAMA_CPP_GGML_LIBRARY_REG=ggml_backend_metal_reg
```
#### Launch Container Environment
```bash
# Set container provider to libkrun
export CONTAINERS_MACHINE_PROVIDER=libkrun
podman machine start
```
#### Verify Environment
Confirm that krunkit is using the correct virglrenderer library:
```bash
lsof -c krunkit | grep virglrenderer
# Expected output:
# krunkit 50574 user txt REG 1,14 2273912 10849442 ($VIRGL_BUILD_DIR/src)/libvirglrenderer.1.dylib
```
### Running Tests
#### Launch Test Container
```bash
# Optional model caching
mkdir -p models
PODMAN_CACHE_ARGS="-v models:/models --user root:root --cgroupns host --security-opt label=disable -w /models"
podman run $PODMAN_CACHE_ARGS -it --rm --device /dev/dri localhost/llama-cpp.virtgpu
```
#### Test llama.cpp in Container
```bash
# Run performance benchmark
/app/remoting/build/bin/llama-bench -m ./llama3.2
```
Expected output (performance may vary):
```
| model | size | params | backend | ngl | test | t/s |
| ------------------------------ | ---------: | ---------: | ---------- | --: | ------------: | -------------------: |
| llama 3B Q4_K - Medium | 1.87 GiB | 3.21 B | ggml-virtgpu | 99 | pp512 | 991.30 ± 0.66 |
| llama 3B Q4_K - Medium | 1.87 GiB | 3.21 B | ggml-virtgpu | 99 | tg128 | 85.71 ± 0.11 |
```
### Troubleshooting
#### SSH Environment Variable Issues
⚠️ **Warning**: Setting `DYLD_LIBRARY_PATH` from SSH doesn't work on macOS. Here is a workaround:
**Workaround 1: Replace system library**
```bash
VIRGL_BUILD_DIR=$HOME/remoting/virglrenderer/build # ⚠️ adapt to your system
BREW_VIRGL_DIR=/opt/homebrew/Cellar/virglrenderer/0.10.4d/lib
VIRGL_LIB=libvirglrenderer.1.dylib
cd $BREW_VIRGL_DIR
mv $VIRGL_LIB ${VIRGL_LIB}.orig
ln -s $VIRGL_BUILD_DIR/src/$VIRGL_LIB
```

View File

@ -35,7 +35,7 @@ Adapt below build commands accordingly.
Let's build llama.cpp with CPU, OpenCL, and Hexagon backends via CMake presets:
```
[d]/workspace> cp docs/backend/hexagon/CMakeUserPresets.json .
[d]/workspace> cp docs/backend/snapdragon/CMakeUserPresets.json .
[d]/workspace> cmake --preset arm64-android-snapdragon-release -B build-snapdragon
Preset CMake variables:

View File

@ -22,7 +22,7 @@ Legend:
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ |
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ | ❌ |
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | | 🟡 | ✅ | ❌ | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |

View File

@ -77,8 +77,8 @@
"SYCL0","GELU_ERF","type=f16,ne_a=[5,7,11,13],v=1","support","1","yes","SYCL"
"SYCL0","FLOOR","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
"SYCL0","FLOOR","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","SYCL"
"SYCL0","CEIL","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
"SYCL0","CEIL","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","SYCL"
"SYCL0","CEIL","type=f16,ne_a=[128,2,2,2],v=1","support","1","yes","SYCL"
"SYCL0","CEIL","type=f16,ne_a=[5,7,11,13],v=1","support","1","yes","SYCL"
"SYCL0","ROUND","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
"SYCL0","ROUND","type=f16,ne_a=[5,7,11,13],v=1","support","0","no","SYCL"
"SYCL0","TRUNC","type=f16,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
@ -161,8 +161,8 @@
"SYCL0","GELU_ERF","type=f32,ne_a=[5,7,11,13],v=1","support","1","yes","SYCL"
"SYCL0","FLOOR","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
"SYCL0","FLOOR","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","SYCL"
"SYCL0","CEIL","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
"SYCL0","CEIL","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","SYCL"
"SYCL0","CEIL","type=f32,ne_a=[128,2,2,2],v=1","support","1","yes","SYCL"
"SYCL0","CEIL","type=f32,ne_a=[5,7,11,13],v=1","support","1","yes","SYCL"
"SYCL0","ROUND","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"
"SYCL0","ROUND","type=f32,ne_a=[5,7,11,13],v=1","support","0","no","SYCL"
"SYCL0","TRUNC","type=f32,ne_a=[128,2,2,2],v=1","support","0","no","SYCL"

Can't render this file because it is too large.

View File

@ -119,8 +119,6 @@ If a draft model is combined with a draftless decoding the draftless decoding ha
of lookup n-gram (default: 12)
--spec-ngram-size-m N ngram size M for ngram-simple/ngram-map speculative decoding, length
of draft m-gram (default: 48)
--spec-ngram-check-rate N ngram check rate for ngram-simple/ngram-map speculative decoding
(default: 1)
--spec-ngram-min-hits N minimum hits for ngram-map speculative decoding (default: 1)
```
@ -153,10 +151,6 @@ Sets the size M of the draft m-gram for n-gram map based speculative decoding.
The m-gram size determines how many tokens to draft when a match is found.
Larger values can provide more speedup but may reduce acceptance rate.
### `--spec-ngram-check-rate R`
This option aims at performance if the n-gram lookup in history is to costly. A lookup will be executed at every R tokens (default is 1, every token).
### `--spec-ngram-min-hits H`
This option defines how often a key has to appear in the token history to be used as a draft (default is 1).
@ -175,7 +169,12 @@ draft acceptance rate = 0.70312 ( 90 accepted / 128 generated)
statistics ngram_mod: #calls = 810, #gen drafts = 15, #acc drafts = 15, #gen tokens = 960, #acc tokens = 730, dur(b,g,a) = 0.149, 0.347, 0.005 ms
```
- `#calls`: number of calls of this implementations
```
statistics ngram_map_k: #calls(b,g,a) = 6 1690 26, #gen drafts = 26, #acc drafts = 26, #gen tokens = 1248, #acc tokens = 968, dur(b,g,a) = 2.234, 1.427, 0.016 ms
```
- `#calls(b,g,a)`: number of calls of begin (new prompt), generation and accumulation of this implementations
- `#gen drafts`: number of drafts generated by this implementation
- `#acc drafts`: number of drafts accepted (partially) by the main model
- `#gen tokens`: number of tokens generated by this implementation (including rejected tokens)

View File

@ -471,9 +471,10 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
int best_score = 0;
fs::path best_path;
std::error_code ec;
for (const auto & search_path : search_paths) {
if (std::error_code ec; !fs::exists(search_path, ec)) {
if (!fs::exists(search_path, ec)) {
if (ec) {
GGML_LOG_DEBUG("%s: posix_stat(%s) failure, error-message: %s\n", __func__, path_str(search_path).c_str(), ec.message().c_str());
} else {
@ -483,7 +484,7 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent,
}
fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied);
for (const auto & entry : dir_it) {
if (entry.is_regular_file()) {
if (entry.is_regular_file(ec)) {
auto filename = entry.path().filename();
auto ext = entry.path().extension();
if (filename.native().find(file_prefix) == 0 && ext == file_extension) {

View File

@ -3286,130 +3286,223 @@ static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context & ctx, ggml_tensor
}
/**
* @brief Performs expert-specific matrix multiplication (MoE) with
* quantized precision using the CANN backend.
* @brief Performs quantized matrix multiplication for Mixture of Experts (MoE)
* models using the CANN backend.
*
* This function executes a matrix multiplication operation tailored for
* Mixture of Experts (MoE) models, where the input tensor is multiplied
* with expert-specific quantized weight matrices. It leverages the CANN
* backend to perform efficient low-precision computations and stores the
* quantized result in the destination tensor `dst`.
* This function implements MUL_MAT_ID operation for quantized weight matrices
* (Q4_0 and Q8_0 formats). It selects expert-specific weight matrices based on
* the provided expert indices, and computes matrix multiplication using CANN's
* WeightQuantBatchMatmulV2 operator.
*
* Quantization techniques reduce memory footprint and improve performance
* by using lower-bit representations (e.g., int8) instead of floating-point.
* This function is designed to work with such formats and may incorporate
* optimizations like identity-based fast paths or routing masks for sparse
* expert selection.
* The function performs the following steps:
* 1. Converts input/output tensors to F16 format if necessary
* 2. Uses IndexSelect to extract expert-specific weights and scales based on indices
* 3. Performs quantized matrix multiplication for each expert using WeightQuantBatchMatmulV2
* 4. Converts output back to the target type if needed
*
* @param ctx The context for executing CANN backend operations.
* @param dst The destination tensor where the quantized MoE multiplication result
* will be stored.
* Tensor shapes:
* - dst: [M, K, N, 1] - output tensor
* - src0: [D, M, A, 1] - quantized weight matrices (Q4_0 or Q8_0)
* - src1: [D, B, N, 1] - input activations (B = K for per-expert input, or B = 1 for broadcast)
* - ids: [K, N] - expert indices for routing
*
* @note This function assumes quantized data types and is designed for
* MoE architectures with potential sparse expert routing.
* @param ctx The CANN backend context for operation execution.
* @param dst The destination tensor where the multiplication result will be stored.
*
* @note Only Q4_0 and Q8_0 quantization formats are supported.
* @note The function handles automatic type conversion to/from F16 as needed by the hardware.
*/
static void ggml_cann_mul_mat_id_quant(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
// TODO: Use aclnnGroupedMatMul
//dst [M, K, N, 1]
ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1]
ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1
ggml_tensor * ids = dst->src[2]; //ids [K, N]
// dst: [M, K, N, 1]
// src0: [D, M, A, 1] - quantized weights
// src1: [D, B, N, 1] - input activations, B = K or B = 1
// ids: [K, N] - expert indices
ggml_tensor * src0 = dst->src[0];
ggml_tensor * src1 = dst->src[1];
ggml_tensor * ids = dst->src[2];
GGML_TENSOR_BINARY_OP_LOCALS
GGML_ASSERT(src0->ne[3] == 1);
GGML_ASSERT(src1->ne[3] == 1);
GGML_ASSERT(dst->ne[3] == 1);
GGML_ASSERT(src1->ne[2] == ids->ne[1]);
// copy index from npu to cpu
int64_t n_as = ne02; // A
int64_t n_ids = ids->ne[0]; // K
const int64_t n_batches = ids->ne[1];
const int64_t n_select_experts = ids->ne[0];
const enum ggml_type type = src0->type;
std::vector<char> ids_host(ggml_nbytes(ids));
ACL_CHECK(aclrtMemcpyAsync(ids_host.data(), ggml_nbytes(ids), ids->data, ggml_nbytes(ids),
ACL_MEMCPY_DEVICE_TO_HOST, ctx.stream()));
ACL_CHECK(aclrtSynchronizeStream(ctx.stream()));
const int32_t group_size = QK8_0; // Both Q4_0 and Q8_0 use group size of 32
GGML_ASSERT(group_size == QK4_0);
char * src0_original = (char *) src0->data;
char * src1_original = (char *) src1->data;
char * dst_original = (char *) dst->data;
// Calculate element size for quantized weights
const float weight_elem_size =
(type == GGML_TYPE_Q4_0) ? 0.5f :
(type == GGML_TYPE_Q8_0) ? 1.0f :
(GGML_ABORT("MUL_MAT_ID only supports Q4_0 and Q8_0"), 0.0f);
ggml_tensor src0_row = *src0;
ggml_tensor src1_row = *src1;
ggml_tensor dst_row = *dst;
// Calculate scale offset in memory
const size_t weight_size = src0->ne[0] * src0->ne[1] * src0->ne[2] * weight_elem_size;
const size_t scale_elem_size = sizeof(uint16_t);
char * scale_data = (char *) src0->data + weight_size;
const enum ggml_type type = dst->src[0]->type;
float weight_elem_size;
if (type == GGML_TYPE_Q4_0) {
weight_elem_size = float(sizeof(uint8_t)) / 2;
} else if (type == GGML_TYPE_Q8_0) {
weight_elem_size = float(sizeof(uint8_t));
} else {
GGML_ABORT("MUL_MAT_ID only support quant type Q4_0 and Q8_0 ");
}
// Allocate buffers for selected expert weights and scales
const size_t selected_weight_size = src0->ne[0] * src0->ne[1] * n_select_experts * weight_elem_size;
ggml_cann_pool_alloc selected_weight_alloc(ctx.pool(), selected_weight_size);
void * selected_weight_buffer = selected_weight_alloc.get();
// src0_row [D, M, 1, 1] weight without permute
src0_row.ne[2] = 1;
src0_row.ne[3] = 1;
src0_row.nb[0] = weight_elem_size;
src0_row.nb[1] = weight_elem_size * ne00;
src0_row.nb[2] = weight_elem_size * ne00;
src0_row.nb[3] = weight_elem_size * ne00;
size_t weight_stride = ne00 * ne01 * weight_elem_size;
size_t weight_size = weight_stride * ne02 * ne03;
const size_t selected_scale_size = (src0->ne[0] / group_size) * src0->ne[1] * n_select_experts * scale_elem_size;
ggml_cann_pool_alloc selected_scale_alloc(ctx.pool(), selected_scale_size);
void * selected_scale_buffer = selected_scale_alloc.get();
// scale [D, M, 1, 1] -> scale && permute
size_t scale_elem_size = sizeof(uint16_t);
size_t scale_stride = src0->ne[1] * src0->ne[0] / QK8_0 * scale_elem_size;
// Helper lambda to allocate and cast tensor to F16 if needed
constexpr size_t f16_elem_size = sizeof(uint16_t);
auto prepare_f16_buffer = [&](ggml_tensor * tensor, ggml_cann_pool_alloc & allocator,
bool need_cast = false) -> void * {
if (tensor->type == GGML_TYPE_F16) {
return tensor->data;
}
// src1_row [D, 1, 1, 1] -> input
src1_row.ne[1] = 1;
src1_row.ne[2] = 1;
src1_row.ne[3] = 1;
src1_row.nb[2] = nb11;
src1_row.nb[3] = nb11;
size_t total_size = f16_elem_size;
for (int i = 0; i < GGML_MAX_DIMS; i++) {
total_size *= tensor->ne[i];
}
void * buffer = allocator.alloc(total_size);
// dst_row [M, 1, 1, 1] -> out
dst_row.ne[1] = 1;
dst_row.ne[2] = 1;
dst_row.ne[3] = 1;
dst_row.nb[2] = nb1;
dst_row.nb[3] = nb1;
if (need_cast == false) {
return buffer;
}
//create weight for one row
ggml_cann_pool_alloc weight_allocator(ctx.pool());
void * weight_buffer = weight_allocator.alloc(nb02);
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
for (int64_t id = 0; id < n_ids; id++) {
// expert index
int32_t i02 = *(int32_t *) (ids_host.data() + iid1 * ids->nb[1] + id * ids->nb[0]);
GGML_ASSERT(i02 >= 0 && i02 < n_as);
int64_t ne[GGML_MAX_DIMS];
size_t nb[GGML_MAX_DIMS] = { f16_elem_size };
for (int i = 0; i < GGML_MAX_DIMS; i++) {
ne[i] = tensor->ne[i];
if (i > 0) {
nb[i] = nb[i - 1] * ne[i - 1];
}
}
// If B = 1 (broadcast), always use 0; otherwise, use id.
int64_t i11 = (ne11 == 1 ? 0 : id);
int64_t i12 = iid1;
acl_tensor_ptr src_tensor = ggml_cann_create_tensor(tensor);
acl_tensor_ptr f16_tensor = ggml_cann_create_tensor(buffer, ACL_FLOAT16, f16_elem_size, ne, nb, GGML_MAX_DIMS);
aclnn_cast(ctx, src_tensor.get(), f16_tensor.get(), ACL_FLOAT16);
int64_t i1 = id;
int64_t i2 = i12;
return buffer;
};
void * src0_tmp_ptr = src0_original + i02 * weight_stride;
void * scale_tmp_ptr = src0_original + weight_size + i02 * scale_stride;
void * src1_tmp_ptr = src1_original + i11 * nb11 + i12 * nb12;
void * dst_tmp_ptr = dst_original + i1 * nb1 + i2 * nb2;
// Prepare input and output buffers
ggml_cann_pool_alloc input_alloc(ctx.pool());
void * input_buffer = prepare_f16_buffer(src1, input_alloc, true);
// mem cpy
ACL_CHECK(aclrtMemcpyAsync(weight_buffer, weight_stride, src0_tmp_ptr, weight_stride,
ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
void * scale_buffer = (char *) weight_buffer + weight_stride;
ACL_CHECK(aclrtMemcpyAsync(scale_buffer, scale_stride, scale_tmp_ptr, scale_stride,
ACL_MEMCPY_DEVICE_TO_DEVICE, ctx.stream()));
ggml_cann_pool_alloc output_alloc(ctx.pool());
void * output_buffer = prepare_f16_buffer(dst, output_alloc, false);
src0_row.data = weight_buffer;
src1_row.data = src1_tmp_ptr;
dst_row.data = dst_tmp_ptr;
dst_row.src[0] = &src0_row;
dst_row.src[1] = &src1_row;
// Process each batch
for (int64_t batch_idx = 0; batch_idx < n_batches; batch_idx++) {
// Create index tensor for current batch
const size_t index_offset = batch_idx * ids->nb[1];
acl_tensor_ptr batch_indices = ggml_cann_create_tensor(ids, ids->ne, ids->nb, 1, ACL_FORMAT_ND, index_offset);
ggml_cann_mul_mat(ctx, &dst_row);
// Select quantized weights using expert indices
// Q4_0 stores 2 values per byte, Q8_0 stores 1 value per byte
const int64_t weight_d = (type == GGML_TYPE_Q4_0) ? src0->ne[0] / 2 : src0->ne[0];
const int64_t weight_m = src0->ne[1];
const int64_t weight_n_experts = src0->ne[2];
int64_t weight_ne[3] = { weight_d, weight_m, weight_n_experts };
size_t weight_nb[3] = { sizeof(int8_t), weight_d * sizeof(int8_t), weight_d * weight_m * sizeof(int8_t) };
acl_tensor_ptr all_weights =
ggml_cann_create_tensor(src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb, 3);
int64_t selected_weight_ne[3] = { weight_d, weight_m, n_select_experts };
size_t selected_weight_nb[3] = { sizeof(int8_t), weight_d * sizeof(int8_t),
weight_d * weight_m * sizeof(int8_t) };
acl_tensor_ptr selected_weights = ggml_cann_create_tensor(selected_weight_buffer, ACL_INT8, sizeof(int8_t),
selected_weight_ne, selected_weight_nb, 3);
GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, all_weights.get(), 0, batch_indices.get(), selected_weights.get());
// Select scales using the same expert indices
const int64_t scale_d = src0->ne[0] / group_size;
int64_t scale_ne[3] = { scale_d, weight_m, weight_n_experts };
size_t scale_nb[3] = { scale_elem_size, scale_d * scale_elem_size, scale_d * weight_m * scale_elem_size };
acl_tensor_ptr all_scales =
ggml_cann_create_tensor(scale_data, ACL_FLOAT16, scale_elem_size, scale_ne, scale_nb, 3);
int64_t selected_scale_ne[3] = { scale_d, weight_m, n_select_experts };
size_t selected_scale_nb[3] = { scale_elem_size, scale_d * scale_elem_size,
scale_d * weight_m * scale_elem_size };
acl_tensor_ptr selected_scales = ggml_cann_create_tensor(selected_scale_buffer, ACL_FLOAT16, scale_elem_size,
selected_scale_ne, selected_scale_nb, 3);
GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, all_scales.get(), 0, batch_indices.get(), selected_scales.get());
// Process each expert for current batch
// IndexSelect output layout: [D, M, K] in contiguous format
// WeightQuantBatchMatmulV2 expects: [M, D] with row-major stride
for (int64_t expert_idx = 0; expert_idx < n_select_experts; expert_idx++) {
// Determine input offset: broadcast if src1->ne[1]==1, otherwise use per-expert input
const size_t input_offset =
(batch_idx * src1->ne[1] + (src1->ne[1] == 1 ? 0 : expert_idx)) * src1->ne[0] * f16_elem_size;
const size_t output_offset = (batch_idx * dst->ne[1] + expert_idx) * dst->ne[0] * f16_elem_size;
// Create weight view for current expert: [D, M, K] -> [M, D]
int64_t weight_view_ne[2] = { weight_m, src0->ne[0] };
float weight_view_nb[2] = { src0->ne[0] * weight_elem_size, weight_elem_size };
const size_t weight_view_offset = expert_idx * selected_weight_nb[2];
acl_tensor_ptr weight_view =
ggml_cann_create_tensor(selected_weight_buffer, ggml_cann_type_mapping(type), weight_elem_size,
weight_view_ne, weight_view_nb, 2, ACL_FORMAT_ND, weight_view_offset);
// Create scale view for current expert: [D, M, K] -> [M, D]
int64_t scale_view_ne[2] = { weight_m, scale_d };
size_t scale_view_nb[2] = { selected_scale_nb[1], selected_scale_nb[0] };
const size_t scale_view_offset = expert_idx * selected_scale_nb[2];
acl_tensor_ptr scale_view =
ggml_cann_create_tensor(selected_scale_buffer, ACL_FLOAT16, scale_elem_size, scale_view_ne,
scale_view_nb, 2, ACL_FORMAT_ND, scale_view_offset);
// Create input activation tensor [D, 1]
int64_t input_ne[2] = { src1->ne[0], 1 };
size_t input_nb[2] = { f16_elem_size, src1->ne[0] * f16_elem_size };
acl_tensor_ptr input_tensor = ggml_cann_create_tensor(input_buffer, ACL_FLOAT16, f16_elem_size, input_ne,
input_nb, 2, ACL_FORMAT_ND, input_offset);
// Create output tensor [M, 1]
int64_t output_ne[2] = { dst->ne[0], 1 };
size_t output_nb[2] = { f16_elem_size, dst->ne[0] * f16_elem_size };
acl_tensor_ptr output_tensor = ggml_cann_create_tensor(output_buffer, ACL_FLOAT16, f16_elem_size, output_ne,
output_nb, 2, ACL_FORMAT_ND, output_offset);
// Perform quantized matrix multiplication
GGML_CANN_CALL_ACLNN_OP(ctx, WeightQuantBatchMatmulV2, input_tensor.get(), weight_view.get(),
scale_view.get(), nullptr, nullptr, nullptr, nullptr, group_size,
output_tensor.get());
}
}
return;
// Cast output back to original type if we used a temporary F16 buffer
if (dst->type != GGML_TYPE_F16) {
int64_t ne[GGML_MAX_DIMS];
size_t nb[GGML_MAX_DIMS] = { f16_elem_size };
for (int i = 0; i < GGML_MAX_DIMS; i++) {
ne[i] = dst->ne[i];
if (i > 0) {
nb[i] = nb[i - 1] * ne[i - 1];
}
}
acl_tensor_ptr f16_output =
ggml_cann_create_tensor(output_buffer, ACL_FLOAT16, f16_elem_size, ne, nb, GGML_MAX_DIMS);
acl_tensor_ptr dst_tensor = ggml_cann_create_tensor(dst);
aclnn_cast(ctx, f16_output.get(), dst_tensor.get(), ggml_cann_type_mapping(dst->type));
}
}
void ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst) {

View File

@ -794,19 +794,44 @@ struct ggml_backend_cann_buffer_context {
~ggml_backend_cann_buffer_context() { ACL_CHECK(aclrtFree(dev_ptr)); }
};
// cann buffer type
/**
* @brief Check if a buffer is a CANN buffer.
*
* This function checks if a given buffer is a CANN buffer by comparing its
* `get_name` function pointer to `ggml_backend_cann_buffer_get_name`.
*
* @param buffer The buffer to check.
* @return true if the buffer is a CANN buffer, false otherwise.
* @brief Structure representing context information for a specific backend
* buffer type.
*/
static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft);
struct ggml_backend_cann_buffer_type_context {
int32_t device; /**< Device identifier associated with the buffer context. */
std::string name; /**< Name associated with the buffer context. */
};
static bool ggml_backend_buffer_is_cann(ggml_backend_buffer_t buffer) {
return ggml_backend_buft_is_cann(buffer->buft);
/**
* @brief Retrieves the name associated with a CANN buffer type.
*
* This function returns the descriptive name associated with the specified
* CANN buffer type context.
*
* @param buft Pointer to the buffer type context.
* @return Const pointer to the C-style string containing the name.
*/
static const char * ggml_backend_cann_buffer_type_name(ggml_backend_buffer_type_t buft) {
ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context;
return buft_ctx->name.c_str();
}
/**
* @brief Checks if the backend buffer type is associated with the CANN backend.
*
* This function checks whether the provided backend buffer type is associated
* with the CANN backend based on the comparison of its name retrieval function
* pointer.
*
* @param buft Pointer to the backend buffer type to check.
* @return bool Returns true if the buffer type is associated with the CANN
* backend, otherwise false.
*/
static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
}
/**
@ -1271,7 +1296,7 @@ static void ggml_backend_cann_buffer_get_tensor(ggml_backend_buffer_t buffer,
static bool ggml_backend_cann_buffer_cpy_tensor(ggml_backend_buffer_t buffer,
const ggml_tensor * src,
ggml_tensor * dst) {
if (ggml_backend_buffer_is_cann(src->buffer)) {
if (ggml_backend_buft_is_cann(src->buffer->buft)) {
ggml_backend_cann_buffer_context * src_ctx = (ggml_backend_cann_buffer_context *) src->buffer->context;
ggml_backend_cann_buffer_context * dst_ctx = (ggml_backend_cann_buffer_context *) buffer->context;
@ -1335,31 +1360,6 @@ static const ggml_backend_buffer_i ggml_backend_cann_buffer_interface = {
/* .reset = */ NULL,
};
// cann buffer type
/**
* @brief Structure representing context information for a specific backend
* buffer type.
*/
struct ggml_backend_cann_buffer_type_context {
int32_t device; /**< Device identifier associated with the buffer context. */
std::string name; /**< Name associated with the buffer context. */
};
/**
* @brief Retrieves the name associated with a CANN buffer type.
*
* This function returns the descriptive name associated with the specified
* CANN buffer type context.
*
* @param buft Pointer to the buffer type context.
* @return Const pointer to the C-style string containing the name.
*/
static const char * ggml_backend_cann_buffer_type_name(ggml_backend_buffer_type_t buft) {
ggml_backend_cann_buffer_type_context * buft_ctx = (ggml_backend_cann_buffer_type_context *) buft->context;
return buft_ctx->name.c_str();
}
/**
* @brief Allocates a new CANN buffer of the specified type and size.
*
@ -1997,7 +1997,7 @@ static bool ggml_backend_cann_cpy_tensor_async(ggml_backend_t backend_src,
GGML_ASSERT(!is_matmul_weight((const ggml_tensor *) src));
if (!ggml_backend_buffer_is_cann(src->buffer) || !ggml_backend_buffer_is_cann(dst->buffer)) {
if (!ggml_backend_buft_is_cann(src->buffer->buft) || !ggml_backend_buft_is_cann(dst->buffer->buft)) {
return false;
}
@ -2523,21 +2523,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev, const ggml_ten
GGML_UNUSED(dev);
}
/**
* @brief Checks if the backend buffer type is associated with the CANN backend.
*
* This function checks whether the provided backend buffer type is associated
* with the CANN backend based on the comparison of its name retrieval function
* pointer.
*
* @param buft Pointer to the backend buffer type to check.
* @return bool Returns true if the buffer type is associated with the CANN
* backend, otherwise false.
*/
static bool ggml_backend_buft_is_cann(ggml_backend_buffer_type_t buft) {
return buft->iface.get_name == ggml_backend_cann_buffer_type_name;
}
/**
* @brief Records an event on the CANN backend stream.
*

View File

@ -43,6 +43,7 @@
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
@ -55,7 +56,8 @@
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
# define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
@ -76,6 +78,7 @@
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_q8_0_4x4_q8_0_generic ggml_gemv_q8_0_4x4_q8_0
@ -84,6 +87,7 @@
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_q8_0_4x4_q8_0_generic ggml_gemm_q8_0_4x4_q8_0
@ -107,6 +111,7 @@
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
@ -119,6 +124,7 @@
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
@ -143,6 +149,7 @@
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
@ -155,6 +162,7 @@
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
@ -186,6 +194,7 @@
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
@ -197,6 +206,7 @@
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
@ -227,6 +237,7 @@
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
@ -239,6 +250,7 @@
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
@ -271,6 +283,7 @@
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q5_K_8x8_q8_K_generic ggml_gemv_q5_K_8x8_q8_K
#define ggml_gemv_q6_K_8x4_q8_K_generic ggml_gemv_q6_K_8x4_q8_K
#define ggml_gemv_q6_K_8x8_q8_K_generic ggml_gemv_q6_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
@ -283,6 +296,7 @@
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q5_K_8x8_q8_K_generic ggml_gemm_q5_K_8x8_q8_K
#define ggml_gemm_q6_K_8x4_q8_K_generic ggml_gemm_q6_K_8x4_q8_K
#define ggml_gemm_q6_K_8x8_q8_K_generic ggml_gemm_q6_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0

View File

@ -1072,6 +1072,195 @@ void ggml_gemv_q5_K_8x8_q8_K(int n,
ggml_gemv_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemv_q6_K_8x4_q8_K(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
constexpr int qk = QK_K;
const int nb = n / qk;
constexpr int ncols_interleaved = 8;
constexpr int blocklen = 4;
assert(n % qk == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
constexpr int col_groups = ncols_interleaved / 4;
const uint8x16_t m4b = vdupq_n_u8(0x0f);
const uint8x16_t mask_lo = vdupq_n_u8(0x03);
const uint8x16_t mask_hi = vdupq_n_u8(0x30);
// 1x8 tile = 2 x 4
float32x4_t acc_f32[2];
const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
for (int i = 0; i < col_groups; i++) {
acc_f32[i] = vdupq_n_f32(0);
}
for (int b = 0; b < nb; b++) {
float32x4_t q6_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d)); // d0 d1 d2 d3
float32x4_t q6_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4)); // d4 d5 d6 d7
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
float32x4_t sb_scale_0 = vmulq_f32(q6_d_0, q8_d);
float32x4_t sb_scale_1 = vmulq_f32(q6_d_1, q8_d);
int32x4_t acc[col_groups];
for (int i = 0; i < col_groups; i++) {
acc[i] = vdupq_n_s32(0);
}
// Load all 16 scales once and widen to int16 (Q6_K has 16 scales per block)
// Reused for bias and dequantization later
int16_t q6_scales[16 * 8];
for (int i = 0; i < 16; i++) {
int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
vst1q_s16(q6_scales + i * 8, scales);
}
// Compute bias per column using q8 bsums and preloaded scales to skip the -32 shift
int32x4_t bias_lo = vdupq_n_s32(0);
int32x4_t bias_hi = vdupq_n_s32(0);
// Load bsums in chunks of 4 to process with vectorized operations
for (int i = 0; i < 16; i += 4) {
int16x4_t bsums_vec = vld1_s16(q8_ptr[b].bsums + i);
int16x4_t scales_lo_0 = vld1_s16(q6_scales + (i + 0) * 8);
int16x4_t scales_hi_0 = vld1_s16(q6_scales + (i + 0) * 8 + 4);
int16x4_t scales_lo_1 = vld1_s16(q6_scales + (i + 1) * 8);
int16x4_t scales_hi_1 = vld1_s16(q6_scales + (i + 1) * 8 + 4);
int16x4_t scales_lo_2 = vld1_s16(q6_scales + (i + 2) * 8);
int16x4_t scales_hi_2 = vld1_s16(q6_scales + (i + 2) * 8 + 4);
int16x4_t scales_lo_3 = vld1_s16(q6_scales + (i + 3) * 8);
int16x4_t scales_hi_3 = vld1_s16(q6_scales + (i + 3) * 8 + 4);
bias_lo = vmlal_lane_s16(bias_lo, scales_lo_0, bsums_vec, 0);
bias_hi = vmlal_lane_s16(bias_hi, scales_hi_0, bsums_vec, 0);
bias_lo = vmlal_lane_s16(bias_lo, scales_lo_1, bsums_vec, 1);
bias_hi = vmlal_lane_s16(bias_hi, scales_hi_1, bsums_vec, 1);
bias_lo = vmlal_lane_s16(bias_lo, scales_lo_2, bsums_vec, 2);
bias_hi = vmlal_lane_s16(bias_hi, scales_hi_2, bsums_vec, 2);
bias_lo = vmlal_lane_s16(bias_lo, scales_lo_3, bsums_vec, 3);
bias_hi = vmlal_lane_s16(bias_hi, scales_hi_3, bsums_vec, 3);
}
bias_lo = vshlq_n_s32(bias_lo, 5);
bias_hi = vshlq_n_s32(bias_hi, 5);
// Process two 128-value halves per superblock
for (int half = 0; half < 2; half++) {
const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
// A subblock (sb) is a set of weights that share the scale
// Since q6_K scales are per 16 elements
// num sbs -> 256 elements / (16 elements/scale * 2 elements/byte * 2 halves)
for (int sb = 0; sb < QK_K / 64; sb++) {
const int8_t * q8_base_l = q8_ptr[b].qs + half * 128 + sb * 16;
const int8_t * q8_base_h = q8_base_l + 64;
// Load and duplicate q8 values (each register covers four interleaved columns of q6)
int8x16_t q8_l[4];
int8x16_t q8_h[4];
for (int i = 0; i < 4; i++) {
q8_l[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_l + i * 4));
q8_h[i] = (int8x16_t) vld1q_dup_s32((const int32_t *) (q8_base_h + i * 4));
}
const int ql_off_base = sb * QK_K / 2;
const int qh_off_base = ql_off_base & 255; // wraps after 256 bytes
// Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1)
uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base);
uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64);
uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base);
uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64);
// Adjust qh for subblocks 2 and 3 (shift right by 2)
if (sb > 1) {
q6_qh_0.val[0] = vshrq_n_u8(q6_qh_0.val[0], 2);
q6_qh_0.val[1] = vshrq_n_u8(q6_qh_0.val[1], 2);
q6_qh_0.val[2] = vshrq_n_u8(q6_qh_0.val[2], 2);
q6_qh_0.val[3] = vshrq_n_u8(q6_qh_0.val[3], 2);
q6_qh_1.val[0] = vshrq_n_u8(q6_qh_1.val[0], 2);
q6_qh_1.val[1] = vshrq_n_u8(q6_qh_1.val[1], 2);
q6_qh_1.val[2] = vshrq_n_u8(q6_qh_1.val[2], 2);
q6_qh_1.val[3] = vshrq_n_u8(q6_qh_1.val[3], 2);
}
const uint8x16_t q6_ql[8] = { q6_ql_0.val[0], q6_ql_0.val[1], q6_ql_0.val[2], q6_ql_0.val[3],
q6_ql_1.val[0], q6_ql_1.val[1], q6_ql_1.val[2], q6_ql_1.val[3] };
const uint8x16_t q6_qh[8] = { q6_qh_0.val[0], q6_qh_0.val[1], q6_qh_0.val[2], q6_qh_0.val[3],
q6_qh_1.val[0], q6_qh_1.val[1], q6_qh_1.val[2], q6_qh_1.val[3] };
// Process column groups (0-3, 4-7)
for (int g = 0; g < col_groups; g++) {
int32x4_t sb_acc_l = vdupq_n_s32(0);
int32x4_t sb_acc_h = vdupq_n_s32(0);
for (int chunk = 0; chunk < 4; chunk++) {
const int idx = chunk * 2 + g;
const uint8x16_t q6_qs_l = q6_ql[idx];
const uint8x16_t q6_qs_h = q6_qh[idx];
// Extract high 2 bits for upper nibble reconstruction
const uint8x16_t q6_qs_hh = vandq_u8(q6_qs_h, mask_hi);
// q6 = (low4 | high2<<4), without -32 bias (handled via bsums)
const int8x16_t q6_l =
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_qs_l, m4b), vandq_u8(q6_qs_h, mask_lo), 4));
const int8x16_t q6_h = vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_qs_l, 4), q6_qs_hh));
sb_acc_l = vdotq_s32(sb_acc_l, q6_l, q8_l[chunk]);
sb_acc_h = vdotq_s32(sb_acc_h, q6_h, q8_h[chunk]);
}
const int scale_idx_l = half * 8 + sb;
const int scale_idx_h = half * 8 + sb + 4;
const int32x4_t scale_vec_l = vmovl_s16(vld1_s16(q6_scales + scale_idx_l * 8 + g * 4));
const int32x4_t scale_vec_h = vmovl_s16(vld1_s16(q6_scales + scale_idx_h * 8 + g * 4));
acc[g] = vmlaq_s32(acc[g], sb_acc_l, scale_vec_l);
acc[g] = vmlaq_s32(acc[g], sb_acc_h, scale_vec_h);
}
}
} // for half
// Bias correction
acc[0] = vsubq_s32(acc[0], bias_lo);
acc[1] = vsubq_s32(acc[1], bias_hi);
// Apply superblock scale (no mins for q6_K)
// acc[g] has [c0, c1, c2, c3]
float32x4_t w_0123 = vmulq_f32(vcvtq_f32_s32(acc[0]), sb_scale_0);
float32x4_t w_4567 = vmulq_f32(vcvtq_f32_s32(acc[1]), sb_scale_1);
acc_f32[0] = vaddq_f32(acc_f32[0], w_0123);
acc_f32[1] = vaddq_f32(acc_f32[1], w_4567);
} // for b
int base = x * ncols_interleaved;
vst1q_f32(s + base, acc_f32[0]);
vst1q_f32(s + base + 4, acc_f32[1]);
} // for x
return;
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
ggml_gemv_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemv_q6_K_8x8_q8_K(int n,
float * GGML_RESTRICT s,
size_t bs,
@ -1177,15 +1366,14 @@ void ggml_gemv_q6_K_8x8_q8_K(int n,
q8_h[i] = (int8x16_t) vld1q_dup_s64((const int64_t *) (q8_base_h + i * 8));
}
// TODO: Test other qh repack patterns to reduce loads
const int ql_off_base = sb * QK_K / 2;
const int qh_off_base = ql_off_base & 255; // wraps after 256 bytes
// Load 4 vectors at once (64 bytes each for ql_0, ql_1, qh_0, qh_1)
ggml_uint8x16x4_t q6_ql_0 = ggml_vld1q_u8_x4(ql_base + ql_off_base);
ggml_uint8x16x4_t q6_ql_1 = ggml_vld1q_u8_x4(ql_base + ql_off_base + 64);
ggml_uint8x16x4_t q6_qh_0 = ggml_vld1q_u8_x4(qh_base + qh_off_base);
ggml_uint8x16x4_t q6_qh_1 = ggml_vld1q_u8_x4(qh_base + qh_off_base + 64);
uint8x16x4_t q6_ql_0 = vld1q_u8_x4(ql_base + ql_off_base);
uint8x16x4_t q6_ql_1 = vld1q_u8_x4(ql_base + ql_off_base + 64);
uint8x16x4_t q6_qh_0 = vld1q_u8_x4(qh_base + qh_off_base);
uint8x16x4_t q6_qh_1 = vld1q_u8_x4(qh_base + qh_off_base + 64);
// Adjust qh for subblocks 2 and 3 (shift right by 2)
if (sb > 1) {
@ -3474,6 +3662,208 @@ void ggml_gemm_q5_K_8x8_q8_K(int n,
ggml_gemm_q5_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemm_q6_K_8x4_q8_K(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
constexpr int qk = QK_K;
const int nb = n / qk;
constexpr int ncols_interleaved = 8;
constexpr int blocklen = 4;
assert(n % qk == 0);
assert(nr % 4 == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
constexpr int q8_k_blocklen = 4;
constexpr int col_groups = ncols_interleaved / 4;
constexpr int acc_size = q8_k_blocklen * col_groups; // 4 rows, 2 column groups
const uint8x16_t m4b = vdupq_n_u8(0x0f);
const uint8x16_t mask_lo = vdupq_n_u8(0x03);
const uint8x16_t mask_hi = vdupq_n_u8(0x30);
const int8x16_t m32s = vdupq_n_s8(32);
float32x4_t acc_f32[acc_size];
for (int y = 0; y < nr / q8_k_blocklen; y++) {
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q6_Kx8 * GGML_RESTRICT q6_ptr = (const block_q6_Kx8 *) vx + (x * nb);
for (int i = 0; i < acc_size; i++) {
acc_f32[i] = vdupq_n_f32(0);
}
for (int b = 0; b < nb; b++) {
float32x4_t q6_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d));
float32x4_t q6_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q6_ptr[b].d + 4));
float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
float32x4_t sbd_scale_0123[q8_k_blocklen];
float32x4_t sbd_scale_4567[q8_k_blocklen];
sbd_scale_0123[0] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 0);
sbd_scale_4567[0] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 0);
sbd_scale_0123[1] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 1);
sbd_scale_4567[1] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 1);
sbd_scale_0123[2] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 2);
sbd_scale_4567[2] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 2);
sbd_scale_0123[3] = vmulq_laneq_f32(q6_d_0123, q8_d_0123, 3);
sbd_scale_4567[3] = vmulq_laneq_f32(q6_d_4567, q8_d_0123, 3);
int32x4_t acc_s32[acc_size];
for (int i = 0; i < acc_size; i++) {
acc_s32[i] = vdupq_n_s32(0);
}
int16_t q6_scales[8 * 16];
for (int i = 0; i < 16; i++) {
int16x8_t scales = vmovl_s8(vld1_s8(q6_ptr[b].scales + i * 8));
vst1q_s16(q6_scales + i * 8, scales);
}
for (int half = 0; half < 2; half++) {
const uint8_t * ql_base = q6_ptr[b].ql + half * 512;
const uint8_t * qh_base = q6_ptr[b].qh + half * 256;
for (int sb = 0; sb < QK_K / 64; sb++) {
int32x4_t acc_lo[acc_size];
int32x4_t acc_hi[acc_size];
for (int i = 0; i < acc_size; i++) {
acc_lo[i] = vdupq_n_s32(0);
acc_hi[i] = vdupq_n_s32(0);
}
const int8_t * q8_base_l = q8_ptr[b].qs + half * 512 + sb * 64;
const int8_t * q8_base_h = q8_ptr[b].qs + half * 512 + 256 + sb * 64;
// 4 rows * 16 elements per scale
// 4 reads of 16 bytes each
constexpr int reads_per_sb = 4;
int8x16_t q8_l[reads_per_sb];
int8x16_t q8_h[reads_per_sb];
for (int k = 0; k < reads_per_sb; k++) {
q8_l[k] = vld1q_s8(q8_base_l + 16 * k);
q8_h[k] = vld1q_s8(q8_base_h + 16 * k);
}
const int ql_off_base = sb * QK_K / 2;
const int qh_off_base = ql_off_base & 255;
uint8x16_t q6_ql_0123[reads_per_sb];
uint8x16_t q6_ql_4567[reads_per_sb];
uint8x16_t q6_qh_0123[reads_per_sb];
uint8x16_t q6_qh_4567[reads_per_sb];
for (int k = 0; k < reads_per_sb; k++) {
q6_ql_0123[k] = vld1q_u8(ql_base + ql_off_base + k * 32);
q6_ql_4567[k] = vld1q_u8(ql_base + ql_off_base + k * 32 + 16);
q6_qh_0123[k] = vld1q_u8(qh_base + qh_off_base + k * 32);
q6_qh_4567[k] = vld1q_u8(qh_base + qh_off_base + k * 32 + 16);
}
if (sb > 1) {
for (int k = 0; k < reads_per_sb; k++) {
q6_qh_0123[k] = vshrq_n_u8(q6_qh_0123[k], 2);
q6_qh_4567[k] = vshrq_n_u8(q6_qh_4567[k], 2);
}
}
for (int k = 0; k < reads_per_sb; k++) {
// q = (ql | qh) - 32
const uint8x16_t hbit_lo_0123 = vandq_u8(q6_qh_0123[k], mask_lo);
const uint8x16_t hbit_hi_0123 = vandq_u8(q6_qh_0123[k], mask_hi);
const uint8x16_t hbit_lo_4567 = vandq_u8(q6_qh_4567[k], mask_lo);
const uint8x16_t hbit_hi_4567 = vandq_u8(q6_qh_4567[k], mask_hi);
const int8x16_t q6_0123_lo = vsubq_s8(
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_0123[k], m4b), hbit_lo_0123, 4)), m32s);
const int8x16_t q6_0123_hi = vsubq_s8(
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_0123[k], 4), hbit_hi_0123)), m32s);
acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q6_0123_lo, q8_l[k], 0); // 0..3 r0 c0123
acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q6_0123_lo, q8_l[k], 1); // 0..3 r1 c0123
acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q6_0123_lo, q8_l[k], 2); // 0..3 r2 c0123
acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q6_0123_lo, q8_l[k], 3); // 0..3 r3 c0123
acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q6_0123_hi, q8_h[k], 0); // 64..67 r0 c0123
acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q6_0123_hi, q8_h[k], 1); // 64..67 r1 c0123
acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q6_0123_hi, q8_h[k], 2); // 64..67 r2 c0123
acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q6_0123_hi, q8_h[k], 3); // 64..67 r3 c0123
const int8x16_t q6_4567_lo = vsubq_s8(
vreinterpretq_s8_u8(vsliq_n_u8(vandq_u8(q6_ql_4567[k], m4b), hbit_lo_4567, 4)), m32s);
const int8x16_t q6_4567_hi = vsubq_s8(
vreinterpretq_s8_u8(vorrq_u8(vshrq_n_u8(q6_ql_4567[k], 4), hbit_hi_4567)), m32s);
acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q6_4567_lo, q8_l[k], 0); // 0..3 r0 c4567
acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q6_4567_lo, q8_l[k], 1); // 0..3 r1 c4567
acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q6_4567_lo, q8_l[k], 2); // 0..3 r2 c4567
acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q6_4567_lo, q8_l[k], 3); // 0..3 r3 c4567
acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q6_4567_hi, q8_h[k], 0); // 64..67 r0 c4567
acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q6_4567_hi, q8_h[k], 1); // 64..67 r1 c4567
acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q6_4567_hi, q8_h[k], 2); // 64..67 r2 c4567
acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q6_4567_hi, q8_h[k], 3); // 64..67 r3 c4567
}
// Scale and bias
const int scale_idx_l = half * 8 + sb;
const int scale_idx_h = half * 8 + sb + 4;
for (int g = 0; g < col_groups; g++) {
const int16x4_t scales_l16 = vld1_s16(q6_scales + scale_idx_l * 8 + g * 4);
const int16x4_t scales_h16 = vld1_s16(q6_scales + scale_idx_h * 8 + g * 4);
const int32x4_t scale_vec_l = vmovl_s16(scales_l16);
const int32x4_t scale_vec_h = vmovl_s16(scales_h16);
const int acc_offset = g * q8_k_blocklen;
for (int row = 0; row < q8_k_blocklen; row++) {
const int idx = row * 2 + g;
acc_s32[idx] = vmlaq_s32(acc_s32[idx], acc_lo[acc_offset + row], scale_vec_l);
acc_s32[idx] = vmlaq_s32(acc_s32[idx], acc_hi[acc_offset + row], scale_vec_h);
}
}
}
}
// Finally we apply the superblock scales
for (int row = 0; row < q8_k_blocklen; row++) {
const int idx0 = 2 * row;
const int idx1 = 2 * row + 1;
const int32x4_t acc_0123 = acc_s32[idx0];
const int32x4_t acc_4567 = acc_s32[idx1];
acc_f32[idx0] = vmlaq_f32(acc_f32[idx0], vcvtq_f32_s32(acc_0123), sbd_scale_0123[row]);
acc_f32[idx1] = vmlaq_f32(acc_f32[idx1], vcvtq_f32_s32(acc_4567), sbd_scale_4567[row]);
}
} // for b
for (int i = 0; i < q8_k_blocklen; i++) {
int row = y * q8_k_blocklen + i;
for (int j = 0; j < 2; j++) {
int col = x * ncols_interleaved + j * 4;
int offset = row * bs + col;
vst1q_f32(s + offset, acc_f32[2 * i + j]);
}
}
} // for x
} // for y
return;
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
ggml_gemm_q6_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemm_q6_K_8x8_q8_K(int n,
float * GGML_RESTRICT s,
size_t bs,

View File

@ -59,11 +59,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds
GGML_ASSERT(nb00 == sizeof(src0_t));
const auto [ir0, ir1] = get_thread_range(params, src0);
const bool is_src1_contiguous = (nb10 == sizeof(src1_t));
if (!is_src1_contiguous) { // broadcast not implemented yet for non-contiguous
GGML_ASSERT(ggml_are_same_shape(src0, src1));
}
const bool is_src1_contiguous_rows = ggml_is_contiguous_rows(src1);
#ifdef GGML_USE_ACCELERATE
vDSP_fn_t vDSP_op = nullptr;
@ -94,7 +90,7 @@ static void apply_binary_op(const ggml_compute_params * params, ggml_tensor * ds
const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01);
const src1_t * src1_ptr = (const src1_t *) ((const char *) src1->data + i13*nb13 + i12*nb12 + i11*nb11);
if (is_src1_contiguous) {
if (is_src1_contiguous_rows) {
// src1 is broadcastable across src0 and dst in i1, i2, i3
const int64_t nr0 = ne00 / ne10;

View File

@ -2096,10 +2096,14 @@ static void ggml_compute_forward_gelu_f32(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@ -2113,10 +2117,14 @@ static void ggml_compute_forward_gelu_f32(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_gelu_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])));
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@ -2135,10 +2143,14 @@ static void ggml_compute_forward_gelu_f16(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@ -2152,10 +2164,14 @@ static void ggml_compute_forward_gelu_f16(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_gelu_f16(nc,
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@ -2276,10 +2292,14 @@ static void ggml_compute_forward_gelu_erf_f32(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@ -2293,10 +2313,14 @@ static void ggml_compute_forward_gelu_erf_f32(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_gelu_erf_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])));
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@ -2315,10 +2339,14 @@ static void ggml_compute_forward_gelu_erf_f16(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@ -2332,10 +2360,14 @@ static void ggml_compute_forward_gelu_erf_f16(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_gelu_erf_f16(nc,
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@ -2379,10 +2411,14 @@ static void ggml_compute_forward_gelu_quick_f32(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@ -2396,10 +2432,14 @@ static void ggml_compute_forward_gelu_quick_f32(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_gelu_quick_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])));
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@ -2418,10 +2458,14 @@ static void ggml_compute_forward_gelu_quick_f16(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@ -2435,10 +2479,14 @@ static void ggml_compute_forward_gelu_quick_f16(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_gelu_quick_f16(nc,
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@ -2482,10 +2530,14 @@ static void ggml_compute_forward_silu_f32(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@ -2499,10 +2551,14 @@ static void ggml_compute_forward_silu_f32(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_silu_f32(nc,
(float *) ((char *) dst->data + i1*( dst->nb[1])),
(float *) ((char *) src0->data + i1*(src0->nb[1])));
(float *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(float *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@ -2521,10 +2577,14 @@ static void ggml_compute_forward_silu_f16(
const ggml_tensor * src0 = dst->src[0];
assert(ggml_is_contiguous_1(src0));
assert(ggml_is_contiguous_1(dst));
assert(ggml_is_contiguous_rows(src0));
assert(ggml_are_same_shape(src0, dst));
GGML_TENSOR_LOCALS(int64_t, ne0, src0, ne)
GGML_TENSOR_LOCALS(size_t, nb0, src0, nb)
GGML_TENSOR_LOCALS(int64_t, ne, dst, ne)
GGML_TENSOR_LOCALS(size_t, nb, dst, nb)
const int ith = params->ith;
const int nth = params->nth;
@ -2538,10 +2598,14 @@ static void ggml_compute_forward_silu_f16(
const int ir0 = dr*ith;
const int ir1 = MIN(ir0 + dr, nr);
for (int i1 = ir0; i1 < ir1; i1++) {
for (int ir = ir0; ir < ir1; ++ir) {
const int i3 = ir/(ne02*ne01);
const int i2 = (ir - i3*ne02*ne01)/ne01;
const int i1 = (ir - i3*ne02*ne01 - i2*ne01);
ggml_vec_silu_f16(nc,
(ggml_fp16_t *) ((char *) dst->data + i1*( dst->nb[1])),
(ggml_fp16_t *) ((char *) src0->data + i1*(src0->nb[1])));
(ggml_fp16_t *) ((char *) dst->data + i3*nb3 + i2*nb2 + i1*nb1),
(ggml_fp16_t *) ((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01));
#ifndef NDEBUG
for (int k = 0; k < nc; k++) {
@ -7629,8 +7693,7 @@ static void ggml_compute_forward_pad_f32(
const ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(src0->nb[0] == sizeof(float));
GGML_ASSERT( dst->nb[0] == sizeof(float));
assert(dst->nb[0] == sizeof(float));
const int ith = params->ith;
const int nth = params->nth;

View File

@ -256,6 +256,200 @@ template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTR
ggml_quantize_mat_q8_K_4x8(x, vy, n_per_row);
}
template <int M, int N>
static void ggml_gemv_q6_K_NxM_q8_K_generic_impl(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
constexpr int blocklen = M;
constexpr int ncols_interleaved = N;
const int qk = QK_K;
const int nb = n / qk;
const int blocks_per_half = 64 / blocklen;
assert(n % qk == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(bs);
UNUSED(nr);
float sumf[8];
const block_q8_K * a_ptr = (const block_q8_K *) vy;
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
for (int j = 0; j < ncols_interleaved; j++) {
sumf[j] = 0.0f;
}
for (int l = 0; l < nb; l++) {
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen;
const int base_h = base_l + 64;
const int scale_idx_l = base_l / 16;
const int scale_idx_h = base_h / 16;
const int qh_shift_l = ((base_l % 128) / 32) * 2;
const int qh_shift_h = ((base_h % 128) / 32) * 2;
const int qh_half_l = (base_l / 128) * 32;
const int qh_half_h = (base_h / 128) * 32;
for (int j = 0; j < ncols_interleaved; j++) {
const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j];
const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j];
int sumi_l = 0;
int sumi_h = 0;
for (int i = 0; i < blocklen; i++) {
const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i;
const int l_4 = b_ptr[l].ql[ql_pos] & 0xF;
const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;
const int qh_idx_l = qh_half_l + ((base_l + i) % 32);
const int qh_chunk_l = qh_idx_l / blocklen;
const int qh_pos_l = qh_idx_l % blocklen;
const int qh_offset_l = qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l;
const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;
const int qh_idx_h = qh_half_h + ((base_h + i) % 32);
const int qh_chunk_h = qh_idx_h / blocklen;
const int qh_pos_h = qh_idx_h % blocklen;
const int qh_offset_h = qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h;
const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;
const int q_l = ((hi_2_l << 4) | l_4) - 32;
const int q_h = ((hi_2_h << 4) | hi_4) - 32;
const int8_t a_l = a_ptr[l].qs[base_l + i];
const int8_t a_h = a_ptr[l].qs[base_h + i];
sumi_l += q_l * a_l;
sumi_h += q_h * a_h;
}
sumf[j] +=
(sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
}
}
}
for (int j = 0; j < ncols_interleaved; j++) {
s[x * ncols_interleaved + j] = sumf[j];
}
}
}
template <int M, int N>
static void ggml_gemm_q6_K_NxM_q8_K_generic_impl(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
constexpr int blocklen = M;
constexpr int ncols_interleaved = N;
const int qk = QK_K;
const int nb = n / qk;
const int blocks_per_half = 64 / blocklen;
const int q8_half_stride = 512;
const int q8_low_high_step = 256;
assert(n % qk == 0);
assert(nr % 4 == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(bs);
float sumf[4][8];
for (int y = 0; y < nr / 4; y++) {
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumf[m][j] = 0.0f;
}
}
for (int l = 0; l < nb; l++) {
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
const int base_l = (k / blocks_per_half) * 128 + (k % blocks_per_half) * blocklen;
const int base_h = base_l + 64;
const int scale_idx_l = base_l / 16;
const int scale_idx_h = base_h / 16;
const int qh_shift_l = ((base_l % 128) / 32) * 2;
const int qh_shift_h = ((base_h % 128) / 32) * 2;
const int qh_half_l = (base_l / 128) * 32;
const int qh_half_h = (base_h / 128) * 32;
const int q8_base = (k / blocks_per_half) * q8_half_stride + (k % blocks_per_half) * (blocklen * 4);
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
const int8_t scale_l = b_ptr[l].scales[scale_idx_l * ncols_interleaved + j];
const int8_t scale_h = b_ptr[l].scales[scale_idx_h * ncols_interleaved + j];
int sumi_l = 0;
int sumi_h = 0;
for (int i = 0; i < blocklen; i++) {
const int ql_pos = k * ncols_interleaved * blocklen + j * blocklen + i;
const int l_4 = b_ptr[l].ql[ql_pos] & 0xF;
const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;
const int qh_idx_l = qh_half_l + ((base_l + i) % 32);
const int qh_chunk_l = qh_idx_l / blocklen;
const int qh_pos_l = qh_idx_l % blocklen;
const int qh_offset_l =
qh_chunk_l * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_l;
const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;
const int qh_idx_h = qh_half_h + ((base_h + i) % 32);
const int qh_chunk_h = qh_idx_h / blocklen;
const int qh_pos_h = qh_idx_h % blocklen;
const int qh_offset_h =
qh_chunk_h * (blocklen * ncols_interleaved) + j * blocklen + qh_pos_h;
const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;
const int q_l = ((hi_2_l << 4) | l_4) - 32;
const int q_h = ((hi_2_h << 4) | hi_4) - 32;
const int8_t q8_l = a_ptr[l].qs[q8_base + m * blocklen + i];
const int8_t q8_h = a_ptr[l].qs[q8_base + m * blocklen + i + q8_low_high_step];
sumi_l += q_l * q8_l;
sumi_h += q_h * q8_h;
}
sumf[m][j] += (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) *
a_ptr[l].d[m];
}
}
}
}
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
}
}
}
}
}
extern "C" {
void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@ -704,94 +898,12 @@ void ggml_gemv_q5_K_8x8_q8_K_generic(int n,
}
void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
ggml_gemv_q6_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
constexpr int qk = QK_K;
const int nb = n / qk;
const int ncols_interleaved = 8;
const int blocklen = 8;
assert(n % qk == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(bs);
UNUSED(nr);
float sumf[8];
const block_q8_K * a_ptr = (const block_q8_K *) vy;
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
for (int j = 0; j < ncols_interleaved; j++) {
sumf[j] = 0.0f;
}
for (int l = 0; l < nb; l++) {
for (int k = 0; k < 16; k++) {
// k = 0.. 7 weights 0-63 low, 64-127 high
// k = 8..15 weights 128-191 low, 192-255 high
const int base_l = (k / 8) * 128 + (k % 8) * 8;
const int base_h = base_l + 64;
const int scale_idx_l = base_l / 16;
const int scale_idx_h = base_h / 16;
// Bit shift cycles 0,2,4,6 for each 32-value group within a 128-value half
const int qh_shift_l = ((base_l % 128) / 32) * 2;
const int qh_shift_h = ((base_h % 128) / 32) * 2;
// qh_half: offset to the correct 32-byte half (0 or 32)
const int qh_half_l = (base_l / 128) * 32;
const int qh_half_h = (base_h / 128) * 32;
for (int j = 0; j < ncols_interleaved; j++) {
// Interleaved scales
const int8_t scale_l = b_ptr[l].scales[scale_idx_l * 8 + j];
const int8_t scale_h = b_ptr[l].scales[scale_idx_h * 8 + j];
int sumi_l = 0;
int sumi_h = 0;
for (int i = 0; i < blocklen; i++) {
const int ql_pos = k * 64 + j * 8 + i;
const int l_4 = b_ptr[l].ql[ql_pos] & 0xF;
const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;
// qh indexing with 8-byte interleaving (like q5_K)
const int qh_byte_l = qh_half_l + ((base_l + i) % 32);
const int qh_chunk_l = qh_byte_l / 8;
const int qh_pos_l = qh_byte_l % 8;
const int qh_offset_l = qh_chunk_l * 64 + j * 8 + qh_pos_l;
const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;
const int qh_byte_h = qh_half_h + ((base_h + i) % 32);
const int qh_chunk_h = qh_byte_h / 8;
const int qh_pos_h = qh_byte_h % 8;
const int qh_offset_h = qh_chunk_h * 64 + j * 8 + qh_pos_h;
const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;
const int q_l = ((hi_2_l << 4) | l_4) - 32;
const int q_h = ((hi_2_h << 4) | hi_4) - 32;
const int8_t a_l = a_ptr[l].qs[base_l + i];
const int8_t a_h = a_ptr[l].qs[base_h + i];
sumi_l += q_l * a_l;
sumi_h += q_h * a_h;
}
sumf[j] +=
(sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
}
}
}
for (int j = 0; j < ncols_interleaved; j++) {
s[x * ncols_interleaved + j] = sumf[j];
}
}
ggml_gemv_q6_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@ -1485,109 +1597,12 @@ void ggml_gemm_q5_K_8x8_q8_K_generic(int n,
}
}
void ggml_gemm_q6_K_8x8_q8_K_generic(int n,
float * GGML_RESTRICT s,
size_t bs,
const void * GGML_RESTRICT vx,
const void * GGML_RESTRICT vy,
int nr,
int nc) {
const int qk = QK_K;
const int nb = n / qk;
const int ncols_interleaved = 8;
const int blocklen = 8;
void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
ggml_gemm_q6_K_NxM_q8_K_generic_impl<4, 8>(n, s, bs, vx, vy, nr, nc);
}
assert(n % qk == 0);
assert(nr % 4 == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(bs);
float sumf[4][8];
for (int y = 0; y < nr / 4; y++) {
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q6_Kx8 * b_ptr = (const block_q6_Kx8 *) vx + (x * nb);
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumf[m][j] = 0.0f;
}
}
for (int l = 0; l < nb; l++) {
for (int k = 0; k < 16; k++) {
// k = 0.. 7 weights 0-63 low, 64-127 high
// k = 8..15 weights 128-191 low, 192-255 high
const int base_l = (k / 8) * 128 + (k % 8) * 8;
const int base_h = base_l + 64;
const int scale_idx_l = base_l / 16;
const int scale_idx_h = base_h / 16;
// Bit shift cycles 0,2,4,6 for each 32-value group within a 128-value half
const int qh_shift_l = ((base_l % 128) / 32) * 2;
const int qh_shift_h = ((base_h % 128) / 32) * 2;
// qh_half: offset to the correct 32-byte half (0 or 32)
const int qh_half_l = (base_l / 128) * 32;
const int qh_half_h = (base_h / 128) * 32;
// Activation base indices for q8_Kx4 interleaved format
// Layout: 128-value halves (k/8), then 8-value sub-blocks (k%8) with stride 32
const int q8_base = (k / 8) * 512 + (k % 8) * 32;
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
// Interleaved scales
const int8_t scale_l = b_ptr[l].scales[scale_idx_l * 8 + j];
const int8_t scale_h = b_ptr[l].scales[scale_idx_h * 8 + j];
int sumi_l = 0;
int sumi_h = 0;
for (int i = 0; i < blocklen; i++) {
const int ql_pos = k * 64 + j * 8 + i;
const int l_4 = b_ptr[l].ql[ql_pos] & 0xF;
const int hi_4 = (b_ptr[l].ql[ql_pos] >> 4) & 0xF;
const int qh_idx_l = qh_half_l + ((base_l + i) % 32);
const int qh_chunk_l = qh_idx_l / 8;
const int qh_pos_l = qh_idx_l % 8;
const int qh_offset_l = qh_chunk_l * 64 + j * 8 + qh_pos_l;
const int hi_2_l = (b_ptr[l].qh[qh_offset_l] >> qh_shift_l) & 0x3;
const int qh_idx_h = qh_half_h + ((base_h + i) % 32);
const int qh_chunk_h = qh_idx_h / 8;
const int qh_pos_h = qh_idx_h % 8;
const int qh_offset_h = qh_chunk_h * 64 + j * 8 + qh_pos_h;
const int hi_2_h = (b_ptr[l].qh[qh_offset_h] >> qh_shift_h) & 0x3;
const int q_l = ((hi_2_l << 4) | l_4) - 32;
const int q_h = ((hi_2_h << 4) | hi_4) - 32;
const int8_t q8_l = a_ptr[l].qs[q8_base + m * 8 + i];
const int8_t q8_h = a_ptr[l].qs[q8_base + m * 8 + i + 256];
sumi_l += q_l * q8_l;
sumi_h += q_h * q8_h;
}
sumf[m][j] += (sumi_l * scale_l + sumi_h * scale_h) * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) *
a_ptr[l].d[m];
}
}
}
}
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j];
}
}
}
}
void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
ggml_gemm_q6_K_NxM_q8_K_generic_impl<8, 8>(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
@ -2097,18 +2112,18 @@ static block_q6_Kx8 make_block_q6_Kx8(block_q6_K * in, unsigned int blck_size_in
}
const int end_ls = QK_K * 4 / blck_size_interleave;
// Interleave Q6_K quants by taking 8 bytes at a time
// Interleave Q6_K quants by taking blck_size_interleave bytes at a time
for (int i = 0; i < end_ls; ++i) {
int src_id = i % n_blocks;
int src_offset = (i / n_blocks) * blck_size_interleave;
int dst_offset = i * blck_size_interleave;
uint64_t elem_ls;
memcpy(&elem_ls, &in[src_id].ql[src_offset], sizeof(uint64_t));
memcpy(&out.ql[dst_offset], &elem_ls, sizeof(uint64_t));
memcpy(&elem_ls, &in[src_id].ql[src_offset], blck_size_interleave);
memcpy(&out.ql[dst_offset], &elem_ls, blck_size_interleave);
}
// Interleave high bits using same 8-byte pattern as low bits
// Interleave high bits using same chunk size as low bits
const int end_hs = end_ls / 2;
for (int i = 0; i < end_hs; ++i) {
int src_id = i % n_blocks;
@ -2116,8 +2131,8 @@ static block_q6_Kx8 make_block_q6_Kx8(block_q6_K * in, unsigned int blck_size_in
int dst_offset = i * blck_size_interleave;
uint64_t elem_hs;
memcpy(&elem_hs, &in[src_id].qh[src_offset], sizeof(uint64_t));
memcpy(&out.qh[dst_offset], &elem_hs, sizeof(uint64_t));
memcpy(&elem_hs, &in[src_id].qh[src_offset], blck_size_interleave);
memcpy(&out.qh[dst_offset], &elem_hs, blck_size_interleave);
}
// The below logic is designed so as to unpack and rearrange scales in Q6_K
@ -2262,7 +2277,7 @@ static int repack_q5_K_to_q5_K_8_bl(struct ggml_tensor * t,
static int repack_q6_K_to_q6_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
GGML_ASSERT(t->type == GGML_TYPE_Q6_K);
GGML_ASSERT(interleave_block == 8);
GGML_ASSERT(interleave_block == 4 || interleave_block == 8);
constexpr int nrows_interleaved = 8;
block_q6_Kx8 * dst = (block_q6_Kx8 *)t->data;
@ -2511,6 +2526,10 @@ template <> int repack<block_q5_K, 8, 8>(struct ggml_tensor * t, const void * da
return repack_q5_K_to_q5_K_8_bl(t, 8, data, data_size);
}
template <> int repack<block_q6_K, 4, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
return repack_q6_K_to_q6_K_8_bl(t, 4, data, data_size);
}
template <> int repack<block_q6_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
return repack_q6_K_to_q6_K_8_bl(t, 8, data, data_size);
}
@ -2575,6 +2594,10 @@ template <> void gemv<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t
ggml_gemv_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
}
template <> void gemv<block_q6_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_q6_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
}
template <> void gemv<block_q6_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
}
@ -2634,6 +2657,10 @@ template <> void gemm<block_q5_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t
ggml_gemm_q5_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
}
template <> void gemm<block_q6_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q6_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
}
template <> void gemm<block_q6_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q6_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
}
@ -3043,6 +3070,7 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
static const ggml::cpu::repack::tensor_traits<block_q5_K, 8, 8, GGML_TYPE_Q8_K> q5_K_8x8_q8_K;
// instance for Q6_K
static const ggml::cpu::repack::tensor_traits<block_q6_K, 4, 8, GGML_TYPE_Q8_K> q6_K_8x4_q8_K;
static const ggml::cpu::repack::tensor_traits<block_q6_K, 8, 8, GGML_TYPE_Q8_K> q6_K_8x8_q8_K;
// instance for Q2
@ -3107,6 +3135,11 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
return &q6_K_8x8_q8_K;
}
}
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
if (cur->ne[1] % 8 == 0) {
return &q6_K_8x4_q8_K;
}
}
} else if (cur->type == GGML_TYPE_IQ4_NL) {
if (ggml_cpu_has_avx2()) {
if (cur->ne[1] % 8 == 0) {

View File

@ -112,6 +112,7 @@ void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@ -122,6 +123,7 @@ void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const vo
void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q5_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q6_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q6_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@ -142,6 +144,7 @@ void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@ -152,6 +155,7 @@ void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs,
void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q5_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q6_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q6_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);

View File

@ -111,7 +111,7 @@ template <float (*op)(float), typename src0_t, typename dst_t>
static void apply_unary_op(const ggml_compute_params * params, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst));
GGML_ASSERT(ggml_is_contiguous_rows(src0) && ggml_is_contiguous_rows(dst) && ggml_are_same_shape(src0, dst));
GGML_TENSOR_UNARY_OP_LOCALS

View File

@ -64,7 +64,7 @@ if (CUDAToolkit_FOUND)
FetchContent_Declare(
CCCL
GIT_REPOSITORY https://github.com/nvidia/cccl.git
GIT_TAG v3.2.0-rc2
GIT_TAG v3.2.0
GIT_SHALLOW TRUE
)

View File

@ -39,13 +39,16 @@ static __global__ void k_bin_bcast(const src0_t * src0,
const uint3 ne11,
const uint3 ne12,
const uint3 ne13,
/*int s0, */ const int s1,
/*const int s0,*/
const int s1,
const int s2,
const int s3,
/*int s00,*/ const int s01,
const int s00,
const int s01,
const int s02,
const int s03,
/*int s10,*/ const int s11,
const int s10,
const int s11,
const int s12,
const int s13,
src1_ptrs... src1s) {
@ -72,11 +75,11 @@ static __global__ void k_bin_bcast(const src0_t * src0,
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
const uint32_t i10 = fastmodulo(i0, ne10);
float result = src0_row ? (float) src0_row[i0] : 0.0f;
float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
if constexpr (sizeof...(src1_ptrs) > 0) {
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
} else {
result = bin_op(result, (float)src1[i_src1 + i10]);
result = bin_op(result, (float)src1[i_src1 + i10*s10]);
}
dst_row[i0] = (dst_t) result;
@ -101,13 +104,16 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
const uint3 ne11,
const uint3 ne12,
const uint3 ne13,
/*int s0, */ const int s1,
/*const int s0,*/
const int s1,
const int s2,
const int s3,
/*int s00,*/ const int s01,
const int s00,
const int s01,
const int s02,
const int s03,
/*int s10,*/ const int s11,
const int s10,
const int s11,
const int s12,
const int s13,
src1_ptrs... src1s) {
@ -135,11 +141,11 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0,
const int i10 = fastmodulo(i0, ne10);
float result = src0_row ? (float) src0_row[i0] : 0.0f;
float result = src0_row ? (float) src0_row[i0*s00] : 0.0f;
if constexpr (sizeof...(src1_ptrs) > 0) {
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10])));
result = (..., (result = bin_op(result, (float)src1s[i_src1 + i10*s10])));
} else {
result = bin_op(result, (float)src1[i_src1 + i10]);
result = bin_op(result, (float)src1[i_src1 + i10*s10]);
}
dst_row[i0] = (dst_t) result;
@ -179,7 +185,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
cnb[3] *= cne[3];
};
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) {
if (ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) {
for (int i = 0; i < 4; i++) {
if (nr[i] != 1) {
break;
@ -221,7 +227,7 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
size_t nb12 = cnb1[2];
size_t nb13 = cnb1[3];
size_t s0 = nb0 / sizeof(dst_t);
//size_t s0 = nb0 / sizeof(dst_t);
size_t s1 = nb1 / sizeof(dst_t);
size_t s2 = nb2 / sizeof(dst_t);
size_t s3 = nb3 / sizeof(dst_t);
@ -251,10 +257,6 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
GGML_ASSERT(nb12 % sizeof(src1_t) == 0);
GGML_ASSERT(nb13 % sizeof(src1_t) == 0);
GGML_ASSERT(s0 == 1);
GGML_ASSERT(s00 == 1);
GGML_ASSERT(s10 == 1);
const int block_size = 128;
int64_t hne0 = std::max(ne0 / 2LL, 1LL);
@ -284,31 +286,31 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<block_num, block_size, 0, stream>>>(
src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11,
ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
/*s0,*/ s1, s2, s3,
s00, s01, s02, s03,
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
} else {
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,
ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12, s13);
/*s0,*/ s1, s2, s3,
s00, s01, s02, s03,
s10, s11, s12, s13);
}
} else {
const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
if constexpr (sizeof...(I) > 0) {
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
/*s0,*/ s1, s2, s3,
s00 ,s01, s02, s03,
s10, s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
} else {
k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
/* s0, */ s1, s2, s3,
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12, s13);
/*s0,*/ s1, s2, s3,
s00, s01, s02, s03,
s10, s11, s12, s13);
}
}
}

View File

@ -7,7 +7,8 @@
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __restrict__ y,
const int64_t ne00, const int64_t ne01, const int64_t ne02,
const int64_t ne00, const int64_t ne01,
const int64_t ne0203, const uint3 ne02,
const int64_t s01, const int64_t s02, const int64_t s03) {
const int64_t i00 = 2 * (int64_t(blockDim.x)*blockIdx.x + threadIdx.x);
@ -16,23 +17,27 @@ static __global__ void dequantize_block(const void * __restrict__ vx, dst_t * __
}
const int64_t i01 = blockIdx.y;
const int64_t i02 = blockIdx.z % ne02;
const int64_t i03 = blockIdx.z / ne02;
const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
const int64_t i02 = dm.y;
const int64_t i03 = dm.x;
const int64_t ib = ibx0 + i00/qk; // block index
const int64_t iqs = (i00%qk)/qr; // quant index
const int64_t iybs = i00 - i00%qk; // y block start index
const int64_t y_offset = qr == 1 ? 1 : qk/2;
const int64_t ibx0 = i03*s03 + i02*s02 + i01*s01;
// dequantize
float2 v;
dequantize_kernel(vx, ib, iqs, v);
const int64_t ib = ibx0 + i00/qk; // block index
const int64_t iqs = (i00%qk)/qr; // quant index
const int64_t iybs = i00 - i00%qk; // y block start index
const int64_t y_offset = qr == 1 ? 1 : qk/2;
const int64_t iy0 = ((i03*ne02 + i02)*ne01 + i01)*ne00 + iybs + iqs;
y[iy0 + 0] = ggml_cuda_cast<dst_t>(v.x);
y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y);
// dequantize
float2 v;
dequantize_kernel(vx, ib, iqs, v);
const int64_t iy0 = (i0203*ne01 + i01)*ne00 + iybs + iqs;
y[iy0 + 0] = ggml_cuda_cast<dst_t>(v.x);
y[iy0 + y_offset] = ggml_cuda_cast<dst_t>(v.y);
}
}
template <bool need_check>
@ -485,9 +490,11 @@ template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static void dequantize_block_cuda(const void * vx, dst_t * y,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, ne02*ne03);
const int64_t ne0203 = ne02*ne03;
const uint3 ne02_fdv = init_fastdiv_values(ne02);
const dim3 num_blocks((ne00 + 2*CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / (2*CUDA_DEQUANTIZE_BLOCK_SIZE), ne01, (int)std::min(ne0203, (int64_t)65535));
dequantize_block<qk, qr, dequantize_kernel><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
(vx, y, ne00, ne01, ne02, s01, s02, s03);
(vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03);
}
template <int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
@ -612,7 +619,8 @@ static void dequantize_row_mxfp4_cuda(const void * vx, dst_t * y, const int64_t
template <typename src_t, typename dst_t>
static __global__ void convert_unary(
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01, const int64_t ne02,
const void * __restrict__ vx, dst_t * __restrict__ y, const int64_t ne00, const int64_t ne01,
const int64_t ne0203, const uint3 ne02,
const int64_t s01, const int64_t s02, const int64_t s03) {
const int64_t i00 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
@ -621,23 +629,29 @@ static __global__ void convert_unary(
}
const int64_t i01 = blockIdx.y;
const int64_t i02 = blockIdx.z % ne02;
const int64_t i03 = blockIdx.z / ne02;
const src_t * x = (const src_t *) vx;
const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
const int64_t iy = ((i03*ne02 + i02)*ne01 + i01)*ne00 + i00;
y[iy] = ggml_cuda_cast<dst_t>(x[ix]);
for (int64_t i0203 = blockIdx.z; i0203 < ne0203; i0203 += gridDim.z) {
const uint2 dm = fast_div_modulo((uint32_t)i0203, ne02);
const int64_t i02 = dm.y;
const int64_t i03 = dm.x;
const int64_t ix = i03*s03 + i02*s02 + i01*s01 + i00;
const int64_t iy = (i0203*ne01 + i01)*ne00 + i00;
y[iy] = ggml_cuda_cast<dst_t>(x[ix]);
}
}
template <typename src_t, typename dst_t>
static void convert_unary_cuda(const void * vx, dst_t * y,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t s01, const int64_t s02, const int64_t s03, cudaStream_t stream) {
const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, ne02*ne03);
const int64_t ne0203 = ne02*ne03;
const uint3 ne02_fdv = init_fastdiv_values(ne02);
const dim3 num_blocks((ne00 + CUDA_DEQUANTIZE_BLOCK_SIZE - 1) / CUDA_DEQUANTIZE_BLOCK_SIZE, ne01, (int)std::min(ne0203, (int64_t)65535));
convert_unary<src_t><<<num_blocks, CUDA_DEQUANTIZE_BLOCK_SIZE, 0, stream>>>
(vx, y, ne00, ne01, ne02, s01, s02, s03);
(vx, y, ne00, ne01, ne0203, ne02_fdv, s01, s02, s03);
}
template <typename src_t, typename dst_t>

View File

@ -63,11 +63,19 @@ static __global__ void flash_attn_ext_f16(
constexpr int frag_m = ncols == 8 ? 32 : 16;
constexpr int frag_n = ncols == 8 ? 8 : 16;
static_assert(D % frag_m == 0, "If ncols == 8 then D % frag_m must be 0.");
#if defined(GGML_USE_HIP)
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, _Float16, wmma::row_major> frag_a_K;
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, _Float16, wmma::col_major> frag_a_V;
typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, _Float16, wmma::col_major> frag_b;
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, _Float16> frag_c_VKQ;
#else
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::row_major> frag_a_K;
typedef wmma::fragment<wmma::matrix_a, frag_m, frag_n, 16, half, wmma::col_major> frag_a_V;
typedef wmma::fragment<wmma::matrix_b, frag_m, frag_n, 16, half, wmma::col_major> frag_b;
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, KQ_acc_t> frag_c_KQ;
typedef wmma::fragment<wmma::accumulator, frag_m, frag_n, 16, half> frag_c_VKQ;
#endif
constexpr int KQ_stride_tc = nwarps*frag_m; // Number of KQ rows calculated in parallel.
constexpr int VKQ_ratio = KQ_stride_tc/VKQ_stride; // Number of parallel VKQ accumulators needed to keep all warps busy.
@ -126,6 +134,19 @@ static __global__ void flash_attn_ext_f16(
__shared__ half VKQ[ncols*D_padded]; // Accumulator for final VKQ slice.
half2 * VKQ2 = (half2 *) VKQ;
#if defined(GGML_USE_HIP)
const _Float16 * K_h_f16 = reinterpret_cast<const _Float16 *>(K_h);
const _Float16 * V_h_f16 = reinterpret_cast<const _Float16 *>(V_h);
_Float16 * KQ_f16 = reinterpret_cast<_Float16 *>(KQ);
_Float16 * VKQ_f16 = reinterpret_cast<_Float16 *>(VKQ);
#else
const half * K_h_f16 = K_h;
const half * V_h_f16 = V_h;
half * KQ_f16 = KQ;
half * VKQ_f16 = VKQ;
#endif
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
@ -160,7 +181,7 @@ static __global__ void flash_attn_ext_f16(
for (int i0 = 0; i0 < D; i0 += 16) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ + j0*D_padded + i0, D_padded);
wmma::load_matrix_sync(Q_b[i0/16][j0/frag_n], KQ_f16 + j0*D_padded + i0, D_padded);
}
}
@ -180,7 +201,7 @@ static __global__ void flash_attn_ext_f16(
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 16) {
frag_a_K K_a;
wmma::load_matrix_sync(K_a, K_h + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
wmma::load_matrix_sync(K_a, K_h_f16 + int64_t(k_VKQ_0 + i_KQ_0 + frag_m*threadIdx.y)*stride_KV + k_KQ_0, stride_KV);
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
wmma::mma_sync(KQ_c[j], K_a, Q_b[k_KQ_0/16][j], KQ_c[j]);
@ -310,7 +331,7 @@ static __global__ void flash_attn_ext_f16(
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
wmma::load_matrix_sync(
KQ_b[k0/(VKQ_ratio*16)][j0/frag_n],
KQ + j0*(kqar*kqs_padded) + k,
KQ_f16 + j0*(kqar*kqs_padded) + k,
kqar*kqs_padded);
}
}
@ -328,7 +349,7 @@ static __global__ void flash_attn_ext_f16(
const int k = k0 + (threadIdx.y % VKQ_ratio)*16;
frag_a_V v_a;
wmma::load_matrix_sync(v_a, V_h + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
wmma::load_matrix_sync(v_a, V_h_f16 + int64_t(k_VKQ_0 + k)*stride_KV + i_VKQ_0 + frag_m*(threadIdx.y/VKQ_ratio), stride_KV);
#pragma unroll
for (int j = 0; j < ncols/frag_n; ++j) {
wmma::mma_sync(VKQ_c[i_VKQ_0/VKQ_stride][j], v_a, KQ_b[k0/(VKQ_ratio*16)][j], VKQ_c[i_VKQ_0/VKQ_stride][j]);
@ -344,7 +365,7 @@ static __global__ void flash_attn_ext_f16(
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += frag_n) {
wmma::store_matrix_sync(
KQ + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
KQ_f16 + offset_k + j0*D_padded + i_KQ_0 + frag_m*(threadIdx.y/VKQ_ratio),
VKQ_c[i_KQ_0/VKQ_stride][j0/frag_n],
D_padded, wmma::mem_col_major);
}

View File

@ -3640,11 +3640,13 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
n_fuse++;
if (n_fuse > 1) {
ggml_tensor fused_add_node;
memcpy(&fused_add_node, node, sizeof(ggml_tensor));
for (int j = 0; j < n_fuse - 1; ++j) {
node->src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
fused_add_node.src[j + 2] = cgraph->nodes[i + j + 1]->src[1];
}
cgraph->nodes[i + n_fuse - 1]->data = node->data;
ggml_cuda_op_fused_add(*cuda_ctx, node, n_fuse);
fused_add_node.data = cgraph->nodes[i + n_fuse - 1]->data;
ggml_cuda_op_fused_add(*cuda_ctx, &fused_add_node, n_fuse);
i += n_fuse - 1;
continue;
@ -4834,8 +4836,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_GROUP_NORM:
case GGML_OP_PAD:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_PAD:
return true;
case GGML_OP_UPSCALE:
case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_ARANGE:

View File

@ -7,7 +7,7 @@ __device__ __forceinline__ int64_t wrap_around(int64_t coord, int64_t size) {
return (coord + size) % size;
}
static __global__ void pad_f32(const float * src, float * dst,
static __global__ void pad_f32(const float * src, size_t s00, size_t s01, size_t s02, size_t s03, float * dst,
const int lp0, const int rp0, const int lp1, const int rp1,
const int lp2, const int rp2, const int lp3, const int rp3,
const int ne0, const int ne1, const int ne2, const int ne3,
@ -34,11 +34,8 @@ static __global__ void pad_f32(const float * src, float * dst,
const int64_t i01 = i1 - lp1;
const int64_t i02 = i2 - lp2;
const int64_t i03 = i3 - lp3;
const int64_t ne02 = ne2 - lp2 - rp2;
const int64_t ne01 = ne1 - lp1 - rp1;
const int64_t ne00 = ne0 - lp0 - rp0;
const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00;
const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00;
dst[dst_idx] = src[src_idx];
} else {
@ -57,21 +54,21 @@ static __global__ void pad_f32(const float * src, float * dst,
const int64_t i02 = wrap_around(i2 - lp2, ne02);
const int64_t i03 = wrap_around(i3 - lp3, ne03);
const int64_t src_idx = i03 * (ne00 * ne01 * ne02) + i02 * (ne00 * ne01) + i01 * ne00 + i00;
const int64_t src_idx = i03 * s03 + i02 * s02 + i01 * s01 + i00 * s00;
dst[dst_idx] = src[src_idx];
}
}
static void pad_f32_cuda(const float * src, float * dst,
static void pad_f32_cuda(const float * src, size_t s00, size_t s01, size_t s02, size_t s03, float * dst,
const int lp0, const int rp0, const int lp1, const int rp1,
const int lp2, const int rp2, const int lp3, const int rp3,
const int ne0, const int ne1, const int ne2, const int ne3,
const bool circular, cudaStream_t stream) {
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
dim3 gridDim(num_blocks, ne1, ne2 * ne3);
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(src, dst,
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(src, s00, s01, s02, s03, dst,
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
ne0, ne1, ne2, ne3, circular);
}
@ -82,9 +79,10 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
float * dst_d = (float *) dst->data;
cudaStream_t stream = ctx.stream();
GGML_TENSOR_UNARY_OP_LOCALS;
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(ggml_is_contiguous(src0));
const int32_t lp0 = ((const int32_t *) (dst->op_params))[0];
const int32_t rp0 = ((const int32_t *) (dst->op_params))[1];
@ -96,7 +94,12 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const int32_t rp3 = ((const int32_t *) (dst->op_params))[7];
const int32_t circular = ((const int32_t *) (dst->op_params))[8];
pad_f32_cuda(src0_d, dst_d,
const size_t s00 = nb00 / ggml_type_size(src0->type);
const size_t s01 = nb01 / ggml_type_size(src0->type);
const size_t s02 = nb02 / ggml_type_size(src0->type);
const size_t s03 = nb03 / ggml_type_size(src0->type);
pad_f32_cuda(src0_d, s00, s01, s02, s03, dst_d,
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3],
(bool) circular, stream);

View File

@ -43,10 +43,15 @@ static __device__ void rope_yarn(
template <bool forward, bool has_ff, typename T, typename D>
static __global__ void rope_norm(const T * x,
D * dst,
const int ne0,
const int ne1,
const int ne00,
const int ne01,
const int ne02,
const int s01,
const int s02,
const int s03,
const int s1,
const int s2,
const int s3,
const int n_dims,
const int32_t * pos,
const float freq_scale,
@ -59,23 +64,23 @@ static __global__ void rope_norm(const T * x,
const int set_rows_stride) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) {
if (i0 >= ne00) {
return;
}
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;
int idst = row_dst * ne0 + i0;
const int ix = channel_x*s2 + row_x*s1 + i0;
const uint32_t i3 = row_dst / (ne01 * ne02);
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
int idst = i0 + i1 * s1 + i2 * s2 + i3 * s3;
const int ix = i0 + i1 * s01 + i2 * s02 + i3 * s03;
// Fusion optimization: ROPE + VIEW + SET_ROWS.
// The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
if (set_rows_stride != 0) {
idst = row_x * ne0 + i0;
idst += row_indices[channel_x] * set_rows_stride;
idst = i1 * s1 + i0;
idst += row_indices[i2] * set_rows_stride;
}
const auto & store_coaelsced = [&](float x0, float x1) {
@ -92,7 +97,7 @@ static __global__ void rope_norm(const T * x,
return;
}
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@ -110,10 +115,15 @@ static __global__ void rope_norm(const T * x,
template <bool forward, bool has_ff, typename T, typename D>
static __global__ void rope_neox(const T * x,
D * dst,
const int ne0,
const int ne1,
const int ne00,
const int ne01,
const int ne02,
const int s01,
const int s02,
const int s03,
const int s1,
const int s2,
const int s3,
const int n_dims,
const int32_t * pos,
const float freq_scale,
@ -126,23 +136,24 @@ static __global__ void rope_neox(const T * x,
const int set_rows_stride) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) {
if (i0 >= ne00) {
return;
}
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;
const uint32_t i3 = row_dst / (ne01 * ne02);
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
int idst = row_dst * ne0 + i0 / 2;
const int ix = channel_x*s2 + row_x*s1 + i0/2;
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
// Fusion optimization: ROPE + VIEW + SET_ROWS.
// The rope output is viewed as a 1D tensor and offset based on a row index in row_indices.
if (set_rows_stride != 0) {
idst = row_x * ne0 + i0 / 2;
idst += row_indices[channel_x] * set_rows_stride;
idst = i1 * s1 + i0 / 2;
idst += row_indices[i2] * set_rows_stride;
}
if (i0 >= n_dims) {
@ -152,7 +163,7 @@ static __global__ void rope_neox(const T * x,
return;
}
const float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
const float theta_base = pos[i2]*powf(theta_scale, i0/2.0f);
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@ -168,24 +179,42 @@ static __global__ void rope_neox(const T * x,
dst[idst + n_dims / 2] = ggml_cuda_cast<D>(x0 * sin_theta + x1 * cos_theta);
}
template<bool forward, bool has_ff, typename T>
static __global__ void rope_multi(
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2,
const int n_dims, const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float theta_scale, const float * freq_factors, const mrope_sections sections, const bool is_imrope) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
template <bool forward, bool has_ff, typename T>
static __global__ void rope_multi(const T * x,
T * dst,
const int ne00,
const int ne01,
const int ne02,
const int s01,
const int s02,
const int s03,
const int s1,
const int s2,
const int s3,
const int n_dims,
const int32_t * pos,
const float freq_scale,
const float ext_factor,
const float attn_factor,
const rope_corr_dims corr_dims,
const float theta_scale,
const float * freq_factors,
const mrope_sections sections,
const bool is_imrope) {
const int i0 = 2 * (blockDim.y * blockIdx.y + threadIdx.y);
if (i0 >= ne0) {
if (i0 >= ne00) {
return;
}
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;
const uint32_t i3 = row_dst / (ne01 * ne02);
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
const int idst = row_dst*ne0 + i0/2;
const int ix = channel_x*s2 + row_x*s1 + i0/2;
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
if (i0 >= n_dims) {
dst[idst + i0/2 + 0] = x[ix + i0/2 + 0];
@ -200,27 +229,24 @@ static __global__ void rope_multi(
float theta_base = 0.0;
if (is_imrope) {
if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
if (sector % 3 == 1 && sector < 3 * sections.v[1]) { // h
theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f);
} else if (sector % 3 == 2 && sector < 3 * sections.v[2]) { // w
theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f);
} else if (sector % 3 == 0 && sector < 3 * sections.v[0]) { // t
theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f);
} else {
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f);
}
} else {
if (sector < sections.v[0]) {
theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sections.v[0] && sector < sec_w) {
theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f);
}
else if (sector >= sec_w + sections.v[2]) {
theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f);
theta_base = pos[i2] * powf(theta_scale, i0 / 2.0f);
} else if (sector >= sections.v[0] && sector < sec_w) {
theta_base = pos[i2 + ne02 * 1] * powf(theta_scale, i0 / 2.0f);
} else if (sector >= sec_w && sector < sec_w + sections.v[2]) {
theta_base = pos[i2 + ne02 * 2] * powf(theta_scale, i0 / 2.0f);
} else if (sector >= sec_w + sections.v[2]) {
theta_base = pos[i2 + ne02 * 3] * powf(theta_scale, i0 / 2.0f);
}
}
@ -238,37 +264,53 @@ static __global__ void rope_multi(
dst[idst + n_dims/2] = x0*sin_theta + x1*cos_theta;
}
template<bool forward, bool has_ff, typename T>
static __global__ void rope_vision(
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims,
const int32_t * pos, const float freq_scale, const float ext_factor, const float attn_factor, const rope_corr_dims corr_dims,
const float theta_scale, const float * freq_factors, const mrope_sections sections) {
template <bool forward, bool has_ff, typename T>
static __global__ void rope_vision(const T * x,
T * dst,
const int ne00,
const int ne01,
const int ne02,
const int s01,
const int s02,
const int s03,
const int s1,
const int s2,
const int s3,
const int n_dims,
const int32_t * pos,
const float freq_scale,
const float ext_factor,
const float attn_factor,
const rope_corr_dims corr_dims,
const float theta_scale,
const float * freq_factors,
const mrope_sections sections) {
const int i0 = 2*(blockDim.y*blockIdx.y + threadIdx.y);
if (i0 >= ne0) {
if (i0 >= ne00) {
return;
}
const int row_dst = blockDim.x*blockIdx.x + threadIdx.x;
const int row_x = row_dst % ne1;
const int channel_x = row_dst / ne1;
const uint32_t i3 = row_dst / (ne01 * ne02);
const uint32_t i2 = (row_dst - i3 * ne01 * ne02) / ne01;
const uint32_t i1 = row_dst - i3 * ne01 * ne02 - i2 * ne01;
const int idst = row_dst*ne0 + i0/2;
const int ix = channel_x*s2 + row_x*s1 + i0/2;
int idst = i0 / 2 + i1 * s1 + i2 * s2 + i3 * s3;
const int ix = i0 / 2 + i1 * s01 + i2 * s02 + i3 * s03;
const int sect_dims = sections.v[0] + sections.v[1];
const int sec_w = sections.v[1] + sections.v[0];
const int sector = (i0 / 2) % sect_dims;
const int sec_w = sections.v[1] + sections.v[0];
const int sector = (i0 / 2) % sect_dims;
float theta_base = 0.0;
if (sector < sections.v[0]) {
const int p = sector;
theta_base = pos[channel_x]*powf(theta_scale, p);
}
else if (sector >= sections.v[0] && sector < sec_w) {
theta_base = pos[i2] * powf(theta_scale, p);
} else if (sector >= sections.v[0] && sector < sec_w) {
const int p = sector - sections.v[0];
theta_base = pos[channel_x + ne2]*powf(theta_scale, p);
theta_base = pos[i2 + ne02] * powf(theta_scale, p);
}
const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f;
@ -288,10 +330,15 @@ static __global__ void rope_vision(
template <bool forward, typename T, typename D>
static void rope_norm_cuda(const T * x,
D * dst,
const int ne0,
const int ne1,
const int ne00,
const int ne01,
const int ne02,
const int s01,
const int s02,
const int s03,
const int s1,
const int s2,
const int s3,
const int n_dims,
const int nr,
const int32_t * pos,
@ -304,31 +351,36 @@ static void rope_norm_cuda(const T * x,
const int64_t * row_indices,
const int set_rows_stride,
cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
GGML_ASSERT(ne00 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nr, n_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float theta_scale = powf(freq_base, -2.0f / n_dims);
if (freq_factors == nullptr) {
rope_norm<forward, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
freq_factors, row_indices, set_rows_stride);
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
} else {
rope_norm<forward, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
freq_factors, row_indices, set_rows_stride);
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
}
}
template <bool forward, typename T, typename D>
static void rope_neox_cuda(const T * x,
D * dst,
const int ne0,
const int ne1,
const int ne00,
const int ne01,
const int ne02,
const int s01,
const int s02,
const int s03,
const int s1,
const int s2,
const int s3,
const int n_dims,
const int nr,
const int32_t * pos,
@ -341,55 +393,92 @@ static void rope_neox_cuda(const T * x,
const int64_t * row_indices,
const int set_rows_stride,
cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
GGML_ASSERT(ne00 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nr, n_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float theta_scale = powf(freq_base, -2.0f / n_dims);
if (freq_factors == nullptr) {
rope_neox<forward, false><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
freq_factors, row_indices, set_rows_stride);
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
} else {
rope_neox<forward, true><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, s1, s2, n_dims, pos, freq_scale, ext_factor, attn_factor, corr_dims, theta_scale,
freq_factors, row_indices, set_rows_stride);
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, row_indices, set_rows_stride);
}
}
template<bool forward, typename T>
static void rope_multi_cuda(
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, const bool is_imrope, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
template <bool forward, typename T>
static void rope_multi_cuda(const T * x,
T * dst,
const int ne00,
const int ne01,
const int ne02,
const int s01,
const int s02,
const int s03,
const int s1,
const int s2,
const int s3,
const int n_dims,
const int nr,
const int32_t * pos,
const float freq_scale,
const float freq_base,
const float ext_factor,
const float attn_factor,
const rope_corr_dims corr_dims,
const float * freq_factors,
const mrope_sections sections,
const bool is_imrope,
cudaStream_t stream) {
GGML_ASSERT(ne00 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nr, n_blocks_x, 1);
const float theta_scale = powf(freq_base, -2.0f/n_dims);
const float theta_scale = powf(freq_base, -2.0f / n_dims);
if (freq_factors == nullptr) {
rope_multi<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
} else {
rope_multi<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections, is_imrope);
}
}
template<bool forward, typename T>
static void rope_vision_cuda(
const T * x, T * dst, const int ne0, const int ne1, const int ne2, const int s1, const int s2, const int n_dims, const int nr,
const int32_t * pos, const float freq_scale, const float freq_base, const float ext_factor, const float attn_factor,
const rope_corr_dims corr_dims, const float * freq_factors, const mrope_sections sections, cudaStream_t stream) {
GGML_ASSERT(ne0 % 2 == 0);
template <bool forward, typename T>
static void rope_vision_cuda(const T * x,
T * dst,
const int ne00,
const int ne01,
const int ne02,
const int s01,
const int s02,
const int s03,
const int s1,
const int s2,
const int s3,
const int n_dims,
const int nr,
const int32_t * pos,
const float freq_scale,
const float freq_base,
const float ext_factor,
const float attn_factor,
const rope_corr_dims corr_dims,
const float * freq_factors,
const mrope_sections sections,
cudaStream_t stream) {
GGML_ASSERT(ne00 % 2 == 0);
const dim3 block_dims(1, CUDA_ROPE_BLOCK_SIZE, 1);
const int n_blocks_x = (ne0 + 2*CUDA_ROPE_BLOCK_SIZE - 1) / (2*CUDA_ROPE_BLOCK_SIZE);
const int n_blocks_x = (ne00 + 2 * CUDA_ROPE_BLOCK_SIZE - 1) / (2 * CUDA_ROPE_BLOCK_SIZE);
const dim3 block_nums(nr, n_blocks_x, 1);
// break down (head_dim, heads, seq) into (CUDA_ROPE_BLOCK_SIZE, x, heads * seq)
// where x ~= ceil(head_dim / CUDA_ROPE_BLOCK_SIZE);
@ -398,11 +487,11 @@ static void rope_vision_cuda(
if (freq_factors == nullptr) {
rope_vision<forward, false, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections);
} else {
rope_vision<forward, true, T><<<block_nums, block_dims, 0, stream>>>(
x, dst, ne0, ne1, ne2, s1, s2, n_dims, pos, freq_scale, ext_factor,
x, dst, ne00, ne01, ne02, s01, s02, s03, s1, s2, s3, n_dims, pos, freq_scale, ext_factor,
attn_factor, corr_dims, theta_scale, freq_factors, sections);
}
}
@ -445,6 +534,11 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
const size_t s01 = src0->nb[1] / ggml_type_size(src0->type);
const size_t s02 = src0->nb[2] / ggml_type_size(src0->type);
const size_t s03 = src0->nb[3] / ggml_type_size(src0->type);
const size_t s1 = dst->nb[1] / ggml_type_size(dst->type);
const size_t s2 = dst->nb[2] / ggml_type_size(dst->type);
const size_t s3 = dst->nb[3] / ggml_type_size(dst->type);
//const int n_past = ((int32_t *) dst->op_params)[0];
const int n_dims = ((int32_t *) dst->op_params)[1];
@ -495,57 +589,63 @@ void ggml_cuda_op_rope_impl(ggml_backend_cuda_context & ctx,
// compute
if (is_neox) {
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
rope_neox_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, row_indices, set_rows_stride, stream);
rope_neox_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02,
s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
rope_neox_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, row_indices, set_rows_stride, stream);
rope_neox_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
rope_neox_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, row_indices, set_rows_stride, stream);
rope_neox_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
set_rows_stride, stream);
} else {
GGML_ABORT("fatal error");
}
} else if (is_mrope && !is_vision) {
if (src0->type == GGML_TYPE_F32) {
rope_multi_cuda<forward>(
(const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
rope_multi_cuda<forward>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
corr_dims, freq_factors, sections, is_imrope, stream);
} else if (src0->type == GGML_TYPE_F16) {
rope_multi_cuda<forward>(
(const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, is_imrope, stream);
rope_multi_cuda<forward>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
corr_dims, freq_factors, sections, is_imrope, stream);
} else {
GGML_ABORT("fatal error");
}
} else if (is_vision) {
if (src0->type == GGML_TYPE_F32) {
rope_vision_cuda<forward>(
(const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
rope_vision_cuda<forward>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
corr_dims, freq_factors, sections, stream);
} else if (src0->type == GGML_TYPE_F16) {
rope_vision_cuda<forward>(
(const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, n_dims, nr, pos, freq_scale,
freq_base, ext_factor, attn_factor, corr_dims, freq_factors, sections, stream);
rope_vision_cuda<forward>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02, s03, s1,
s2, s3, n_dims, nr, pos, freq_scale, freq_base, ext_factor, attn_factor,
corr_dims, freq_factors, sections, stream);
} else {
GGML_ABORT("fatal error");
}
} else {
if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F32) {
rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, s01, s02, n_dims,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, row_indices, set_rows_stride, stream);
rope_norm_cuda<forward, float, float>((const float *) src0_d, (float *) dst_d, ne00, ne01, ne02, s01, s02,
s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F32 && dst_type == GGML_TYPE_F16) {
rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims,
nr, pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, row_indices, set_rows_stride, stream);
rope_norm_cuda<forward, float, half>((const float *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
set_rows_stride, stream);
} else if (src0->type == GGML_TYPE_F16 && dst_type == GGML_TYPE_F16) {
rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, s01, s02, n_dims, nr,
pos, freq_scale, freq_base, ext_factor, attn_factor, corr_dims,
freq_factors, row_indices, set_rows_stride, stream);
rope_norm_cuda<forward, half, half>((const half *) src0_d, (half *) dst_d, ne00, ne01, ne02, s01, s02,
s03, s1, s2, s3, n_dims, nr, pos, freq_scale, freq_base,
ext_factor, attn_factor, corr_dims, freq_factors, row_indices,
set_rows_stride, stream);
} else {
GGML_ABORT("fatal error");
}

View File

@ -1935,11 +1935,6 @@ static bool ggml_hexagon_supported_binary(const struct ggml_hexagon_session * se
return false;
}
// TODO: add support for non-contigiuos tensors
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
return false;
}
return true;
}
@ -1991,6 +1986,25 @@ static bool ggml_hexagon_supported_unary(const struct ggml_hexagon_session * ses
return true;
}
static bool ggml_hexagon_supported_sum_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
const struct ggml_tensor * src0 = op->src[0];
const struct ggml_tensor * dst = op;
if (!hex_supported_src0_type(src0->type)) {
return false;
}
if (!hex_supported_dst_type(dst->type)) {
return false;
}
// TODO: add support for non-contigiuos tensors
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(dst)) {
return false;
}
return true;
}
static bool ggml_hexagon_supported_activations(const struct ggml_hexagon_session * sess,
const struct ggml_tensor * op) {
const struct ggml_tensor * src0 = op->src[0];
@ -2111,6 +2125,26 @@ static bool ggml_hexagon_supported_get_rows(const struct ggml_hexagon_session *
return true;
}
static bool ggml_hexagon_supported_argsort(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
const struct ggml_tensor * src0 = op->src[0]; // values
const struct ggml_tensor * dst = op; // indices
if (src0->type != GGML_TYPE_F32) {
return false;
}
if (dst->type != GGML_TYPE_I32) {
return false;
}
if (src0->ne[0] > (16*1024)) {
// reject tensors with huge rows for now
return false;
}
return true;
}
static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
const int32_t * op_params = &op->op_params[0];
@ -2278,6 +2312,9 @@ static inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bu
case GGML_OP_SUB:
req->op = HTP_OP_SUB;
break;
case GGML_OP_DIV:
req->op = HTP_OP_DIV;
break;
default:
GGML_ABORT("ggml-hex: binary : unsupported op: %d\n", t->op);
break;
@ -2316,6 +2353,17 @@ static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer *
return n_bufs;
}
static inline size_t init_argsort_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
req->op = HTP_OP_ARGSORT;
memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
size_t n_bufs = 0;
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
return n_bufs;
}
template <bool _is_src0_constant>
static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
switch (t->op) {
@ -2370,6 +2418,16 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
supported = true;
break;
case GGML_OP_SQR:
req->op = HTP_OP_SQR;
supported = true;
break;
case GGML_OP_SQRT:
req->op = HTP_OP_SQRT;
supported = true;
break;
case GGML_OP_UNARY:
if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) {
req->op = HTP_OP_UNARY_SILU;
@ -2387,6 +2445,9 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
} else if (ggml_get_glu_op(t) == GGML_GLU_OP_SWIGLU_OAI) {
req->op = HTP_OP_GLU_SWIGLU_OAI;
supported = true;
} else if (ggml_get_glu_op(t) == GGML_GLU_OP_GEGLU) {
req->op = HTP_OP_GLU_GEGLU;
supported = true;
}
break;
@ -2411,6 +2472,17 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
return n_bufs;
}
static inline size_t init_sum_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
req->op = HTP_OP_SUM_ROWS;
size_t n_bufs = 0;
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
return n_bufs;
}
static inline size_t init_rope_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
req->op = HTP_OP_ROPE;
@ -2519,6 +2591,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
case GGML_OP_MUL:
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_DIV:
ggml_hexagon_dispatch_op<init_binary_req<false>>(sess, node, flags);
break;
case GGML_OP_ADD_ID:
@ -2528,6 +2601,13 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
case GGML_OP_SCALE:
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
break;
case GGML_OP_SQR:
case GGML_OP_SQRT:
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
break;
case GGML_OP_SUM_ROWS:
ggml_hexagon_dispatch_op<init_sum_rows_req>(sess, node, flags);
break;
case GGML_OP_UNARY:
if ((ggml_get_unary_op(node) == GGML_UNARY_OP_SILU) ||
(ggml_get_unary_op(node) == GGML_UNARY_OP_GELU)) {
@ -2536,7 +2616,8 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
break;
case GGML_OP_GLU:
if ((ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU) ||
(ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI)) {
(ggml_get_glu_op(node) == GGML_GLU_OP_SWIGLU_OAI) ||
(ggml_get_glu_op(node) == GGML_GLU_OP_GEGLU)) {
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
}
break;
@ -2564,6 +2645,10 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
ggml_hexagon_dispatch_op<init_cpy_req>(sess, node, flags);
break;
case GGML_OP_ARGSORT:
ggml_hexagon_dispatch_op<init_argsort_req>(sess, node, flags);
break;
default:
GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
}
@ -2916,6 +3001,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
case GGML_OP_MUL:
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_DIV:
supp = ggml_hexagon_supported_binary(sess, op);
break;
@ -2928,6 +3014,15 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
supp = ggml_hexagon_supported_unary(sess, op);
break;
case GGML_OP_SQR:
case GGML_OP_SQRT:
supp = ggml_hexagon_supported_unary(sess, op);
break;
case GGML_OP_SUM_ROWS:
supp = ggml_hexagon_supported_sum_rows(sess, op);
break;
case GGML_OP_SOFT_MAX:
supp = ggml_hexagon_supported_softmax(sess, op);
break;
@ -2943,7 +3038,7 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
case GGML_OP_GLU:
{
const auto glu_op = ggml_get_glu_op(op);
if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI)) {
if ((glu_op == GGML_GLU_OP_SWIGLU) || (glu_op == GGML_GLU_OP_SWIGLU_OAI) || (glu_op == GGML_GLU_OP_GEGLU)) {
supp = ggml_hexagon_supported_activations(sess, op);
}
break;
@ -2968,6 +3063,10 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
supp = ggml_hexagon_supported_cpy(sess, op);
break;
case GGML_OP_ARGSORT:
supp = ggml_hexagon_supported_argsort(sess, op);
break;
default:
break;
}

View File

@ -6,6 +6,7 @@ include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
include_directories(
${HEXAGON_SDK_ROOT}/incs
${HEXAGON_SDK_ROOT}/incs/stddef
${CMAKE_CURRENT_SOURCE_DIR}/../../../include
${CMAKE_CURRENT_SOURCE_DIR}/../..
${CMAKE_CURRENT_SOURCE_DIR}/..
${CMAKE_CURRENT_SOURCE_DIR}
@ -21,6 +22,7 @@ add_library(${HTP_LIB} SHARED
matmul-ops.c
binary-ops.c
unary-ops.c
sum-rows-ops.c
softmax-ops.c
act-ops.c
rope-ops.c
@ -28,6 +30,7 @@ add_library(${HTP_LIB} SHARED
set-rows-ops.c
get-rows-ops.c
cpy-ops.c
argsort-ops.c
)
target_compile_definitions(${HTP_LIB} PRIVATE

View File

@ -410,7 +410,7 @@ static void unary_gelu_f32_per_thread(const struct htp_tensor * src0,
// gelu = x * sigmoid(1.702 * x) // current implementation
hvx_mul_scalar_f32((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (float) 1.702, ne0);
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
hvx_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
}
dma_queue_push_vtcm_to_ddr(dma_queue,
@ -516,7 +516,7 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
// silu = x * sigmoid(x)
hvx_sigmoid_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, ne0);
hvx_mul_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
hvx_mul_f32_aaa((uint8_t *) dst_spad_ptr, (const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr, ne0);
}
dma_queue_push_vtcm_to_ddr(dma_queue,
@ -541,6 +541,143 @@ static void unary_silu_f32_per_thread(const struct htp_tensor * src0,
ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
static const float GELU_COEF_A = 0.044715f;
static const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f;
static void glu_geglu_f32_per_thread(const struct htp_tensor * src0,
const struct htp_tensor * src1,
struct htp_tensor * dst,
const int32_t * op_params,
struct htp_spad * src0_spad,
struct htp_spad * src1_spad,
struct htp_spad * dst_spad,
uint32_t nth,
uint32_t ith,
uint32_t src0_nrows_per_thread,
dma_queue * dma_queue) {
htp_act_preamble3;
size_t src0_row_size = nb01;
size_t src1_row_size = nb11;
size_t dst_row_size = nb1;
uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count();
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
// no work for this thread
if (src0_start_row >= src0_end_row) {
return;
}
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
uint8_t * restrict data_dst = (uint8_t *) dst->data;
const bool src1_valid = src1->ne[0];
const int nc = (src1_valid) ? ne00 : ne00 / 2;
if (!src1_valid) {
const int32_t swapped = op_params[1];
data_src1 = data_src0;
src1_row_size = src0_row_size;
const size_t nc_in_bytes = nc * SIZEOF_FP32;
data_src0 += swapped ? nc_in_bytes : 0;
data_src1 += swapped ? 0 : nc_in_bytes;
}
const size_t src0_row_size_aligned = hex_round_up(src0_row_size, VLEN);
const size_t src1_row_size_aligned = hex_round_up(src1_row_size, VLEN);
const size_t dst_row_size_aligned = hex_round_up(dst_row_size, VLEN);
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread);
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread);
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread);
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
if (BLOCK == 0) {
FARF(ERROR,
"geglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
src0_spad->size_per_thread, src0_row_size_aligned);
return;
}
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
// Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
dma_queue_push_vtcm_to_ddr(dma_queue,
dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
dst_row_size, dst_row_size_aligned, 0);
dma_queue_push_ddr_to_vtcm(dma_queue,
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)),
src0_row_size_aligned, src0_row_size, block_size);
dma_queue_push_ddr_to_vtcm(dma_queue,
dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)),
src1_row_size_aligned, src1_row_size, block_size);
}
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
float * src1_spad = (float *) dma_queue_pop(dma_queue).dst;
for (uint32_t ib = 0; ib < block_size; ib++) {
const uint8_t * src0_spad_ptr = (const uint8_t *)(src0_spad + ib * (src0_row_size_aligned / sizeof(float)));
const uint8_t * src1_spad_ptr = (const uint8_t *)(src1_spad + ib * (src1_row_size_aligned / sizeof(float)));
uint8_t * dst_spad_ptr = (uint8_t *)(dst_spad + ib * (dst_row_size_aligned / sizeof(float)));
// geglu tanh implementation
// geglu(x, g) = gelu(x) * g
// gelu(x) = 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)))
hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, src0_spad_ptr, nc); // res = x*x
hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, GELU_COEF_A, nc); // res = res * GELU_COEF_A
hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 1.0f, nc); // res = res + 1.0f
hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc); // res = res * x
hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, SQRT_2_OVER_PI, nc); // res = result * SQRT_2_OVER_PI
hvx_tanh_f32_aa((uint8_t *) dst_spad_ptr, (const uint8_t *) dst_spad_ptr, nc); // res = tanh(res)
hvx_add_scalar_f32_aa(dst_spad_ptr, (const uint8_t*)dst_spad_ptr, 1.0f, nc); // res = res + 1.0f
hvx_mul_f32_aaa(dst_spad_ptr, src0_spad_ptr, (const uint8_t *)dst_spad_ptr, nc); // res = res * x
hvx_mul_scalar_f32_aa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, 0.5f, nc); // res = res + 0.5f
hvx_mul_f32_aaa(dst_spad_ptr, (const uint8_t *)dst_spad_ptr, src1_spad_ptr, nc); // res = res * g
}
dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size,
dst_row_size_aligned, block_size);
// prefetch N+2 loop iteration if any
const uint32_t pref_block = (ir + BLOCK * 2);
if (pref_block < src0_end_row) {
const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),
src0_row_size_aligned, src0_row_size, pref_block_size);
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)),
src1_row_size_aligned, src1_row_size, pref_block_size);
}
}
dma_queue_flush(dma_queue);
t2 = HAP_perf_get_qtimer_count();
FARF(HIGH, "geglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth,
ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
}
static void unary_silu_f32(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = (struct htp_ops_context *) data;
unary_silu_f32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
@ -559,6 +696,12 @@ static void glu_swiglu_oai_f32(unsigned int n, unsigned int i, void * data) {
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
}
static void glu_geglu_f32(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = (struct htp_ops_context *) data;
glu_geglu_f32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]);
}
static int execute_op_activations_f32(struct htp_ops_context * octx) {
int err = HTP_STATUS_OK;
@ -593,6 +736,11 @@ static int execute_op_activations_f32(struct htp_ops_context * octx) {
act_op_func = unary_gelu_f32;
op_type = "gelu-f32";
break;
case HTP_OP_GLU_GEGLU:
act_op_func = glu_geglu_f32;
op_type = "geglu-f32";
break;
default:
FARF(ERROR, "Unsupported activations Op %u\n", octx->op);
return HTP_STATUS_NO_SUPPORT;

View File

@ -0,0 +1,281 @@
#include <string.h>
#include <stdlib.h>
#include <math.h>
#include <HAP_farf.h>
#include <HAP_perf.h>
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "ggml.h"
#include "hvx-utils.h"
#include "hex-dma.h"
#include "htp-ctx.h"
#include "htp-msg.h"
#include "htp-ops.h"
#ifndef MIN
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#endif
struct htp_argsort_context {
struct htp_ops_context * octx;
uint32_t nrows_per_thread;
};
static inline bool all_greater_f32(HVX_Vector x, HVX_Vector y)
{
const HVX_Vector one = Q6_V_vsplat_R(1);
const HVX_Vector zero = Q6_V_vzero();
HVX_VectorPred pred = Q6_Q_vcmp_gt_VsfVsf(x, y);
HVX_Vector matches = Q6_V_vmux_QVV(pred, one, zero);
HVX_Vector sum = hvx_vec_reduce_sum_i32(matches);
return hvx_vec_get_i32(sum) == 32;
}
// Sorts values and mirrors swaps to indices.
static void quicksort_values_indices_asc(float * values, int32_t * indices, int left, int right) {
if (left >= right) return;
int pivot_idx = (left + right) / 2;
float pivot = values[pivot_idx];
int i = left;
int j = right;
HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot);
while (i <= j) {
// Vectorized scan for i
while (i <= j) {
// Check if we have at least one full vector
if (i + 32 <= j) {
HVX_Vector vals_vec = *(HVX_UVector *)(values + i);
if (all_greater_f32(pivot_vec, vals_vec)) {
// If all elements are < pivot, we can skip this whole block
i += 32;
continue;
}
}
// Scalar fallback / cleanup
if (values[i] < pivot) {
i++;
} else {
break;
}
}
// Vectorized scan for j
while (i <= j) {
if (j - 32 >= i) {
// Load 32 elements ending at j.
// Since we want `values[j] > pivot`, let's load from j-31 to j.
HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31);
if (all_greater_f32(vals_vec, pivot_vec)) {
j -= 32;
continue;
}
}
if (values[j] > pivot) {
j--;
} else {
break;
}
}
if (i <= j) {
float tmp_val = values[i];
values[i] = values[j];
values[j] = tmp_val;
int32_t tmp_idx = indices[i];
indices[i] = indices[j];
indices[j] = tmp_idx;
i++;
j--;
}
}
if (left < j) quicksort_values_indices_asc(values, indices, left, j);
if (i < right) quicksort_values_indices_asc(values, indices, i, right);
}
static void quicksort_values_indices_desc(float * values, int32_t * indices, int left, int right) {
if (left >= right) return;
int pivot_idx = (left + right) / 2;
float pivot = values[pivot_idx];
int i = left;
int j = right;
HVX_Vector pivot_vec = hvx_vec_splat_f32(pivot);
while (i <= j) {
// Vectorized scan for i (values[i] > pivot)
while (i <= j) {
if (i + 32 <= j) {
HVX_Vector vals_vec = *(HVX_UVector *)(values + i);
if (all_greater_f32(vals_vec, pivot_vec)) {
i += 32;
continue;
}
}
if (values[i] > pivot) {
i++;
} else {
break;
}
}
// Vectorized scan for j (values[j] < pivot)
while (i <= j) {
if (j - 32 >= i) {
HVX_Vector vals_vec = *(HVX_UVector *)(values + j - 31);
if (all_greater_f32(pivot_vec, vals_vec)) {
j -= 32;
continue;
}
}
if (values[j] < pivot) {
j--;
} else {
break;
}
}
if (i <= j) {
float tmp_val = values[i];
values[i] = values[j];
values[j] = tmp_val;
int32_t tmp_idx = indices[i];
indices[i] = indices[j];
indices[j] = tmp_idx;
i++;
j--;
}
}
if (left < j) quicksort_values_indices_desc(values, indices, left, j);
if (i < right) quicksort_values_indices_desc(values, indices, i, right);
}
static void htp_argsort_f32(unsigned int n, unsigned int i, void * data) {
struct htp_argsort_context * actx = (struct htp_argsort_context *)data;
struct htp_ops_context * octx = actx->octx;
// Unpack context
const struct htp_tensor * src0 = &octx->src0;
const struct htp_tensor * dst = &octx->dst;
// Scratchpad memory
uint8_t * spad = octx->src0_spad.data + octx->src0_spad.size_per_thread * i;
// Dimensions
uint32_t ne00 = src0->ne[0];
uint32_t ne01 = src0->ne[1];
uint32_t ne02 = src0->ne[2];
uint32_t ne03 = src0->ne[3];
uint32_t nb01 = src0->nb[1];
//uint32_t nb02 = src0->nb[2];
//uint32_t nb03 = src0->nb[3];
uint32_t nb1 = dst->nb[1];
//uint32_t nb2 = dst->nb[2];
//uint32_t nb3 = dst->nb[3];
// Sort order
enum ggml_sort_order order = (enum ggml_sort_order) octx->op_params[0];
// Rows to process
uint32_t total_rows = ne01 * ne02 * ne03;
uint32_t rows_per_thread = actx->nrows_per_thread;
uint32_t start_row = rows_per_thread * i;
uint32_t end_row = MIN(start_row + rows_per_thread, total_rows);
// Scratchpad layout:
// We need space for one row of float data (values) and one row of int32 indices.
// values: ne00 * sizeof(float)
// indices: ne00 * sizeof(int32_t)
// Padded to 128 bytes.
size_t values_size = hex_round_up(ne00 * sizeof(float), 128);
float * values_buf = (float *) spad;
int32_t * indices_buf = (int32_t *) (spad + values_size);
for (uint32_t r = start_row; r < end_row; r++) {
uint32_t src_offset = r * nb01;
uint32_t dst_offset = r * nb1;
uint8_t * src_ptr = (uint8_t *) src0->data + src_offset;
uint8_t * dst_ptr = (uint8_t *) dst->data + dst_offset;
hex_l2fetch(src_ptr, ne00 * sizeof(float), ne00 * sizeof(float), 1);
hvx_copy_f32_au((uint8_t*)values_buf, src_ptr, ne00);
// Initialize indices
for (uint32_t j = 0; j < ne00; j++) {
indices_buf[j] = j;
}
// Sort values and mirror swaps to indices
if (order == GGML_SORT_ORDER_ASC) {
quicksort_values_indices_asc(values_buf, indices_buf, 0, ne00 - 1);
} else {
quicksort_values_indices_desc(values_buf, indices_buf, 0, ne00 - 1);
}
// Copy indices back to DDR
hvx_copy_f32_ua(dst_ptr, (const uint8_t *) indices_buf, ne00);
}
}
int op_argsort(struct htp_ops_context * octx) {
// Check supported types
if (octx->src0.type != HTP_TYPE_F32) {
return HTP_STATUS_NO_SUPPORT;
}
// Allocate scratchpad
// We need 1 row of float + 1 row of int32 per thread.
uint32_t ne00 = octx->src0.ne[0];
size_t values_size = hex_round_up(ne00 * sizeof(float), 128);
size_t indices_size = hex_round_up(ne00 * sizeof(int32_t), 128);
size_t spad_per_thread = values_size + indices_size;
// Make sure we round up to 256 for alignment requirements
spad_per_thread = hex_round_up(spad_per_thread, 256);
size_t total_spad_size = spad_per_thread * octx->n_threads;
if (octx->ctx->vtcm_size < total_spad_size) {
FARF(ERROR, "argsort: VTCM size too small. Needed %zu, have %zu", total_spad_size, octx->ctx->vtcm_size);
return HTP_STATUS_VTCM_TOO_SMALL;
}
octx->src0_spad.data = octx->ctx->vtcm_base;
octx->src0_spad.size = total_spad_size;
octx->src0_spad.size_per_thread = spad_per_thread;
FARF(HIGH, "argsort: %ux%ux%ux%u -> %ux%ux%ux%u (0x%x, 0x%x)",
octx->src0.ne[0], octx->src0.ne[1], octx->src0.ne[2], octx->src0.ne[3],
octx->dst.ne[0], octx->dst.ne[1], octx->dst.ne[2], octx->dst.ne[3],
octx->src0.data, octx->dst.data);
uint32_t total_rows = octx->src0.ne[1] * octx->src0.ne[2] * octx->src0.ne[3];
uint32_t n_jobs = MIN(total_rows, octx->n_threads);
struct htp_argsort_context actx;
actx.octx = octx;
actx.nrows_per_thread = (total_rows + n_jobs - 1) / n_jobs;
// Run jobs
worker_pool_run_func(octx->ctx->worker_pool, htp_argsort_f32, &actx, n_jobs);
return HTP_STATUS_OK;
}

File diff suppressed because it is too large Load Diff

View File

@ -42,32 +42,36 @@ enum htp_data_type {
HTP_TYPE_COUNT
};
// These values are manually translated over to HTP
// !!!! DO NOT ALTER THE ORDER OF THE FIRST FOUR ENUMS !!!!
// Do not reorder first 4 (used as an index)
enum htp_op {
HTP_OP_MUL = 0,
HTP_OP_ADD = 1,
HTP_OP_SUB = 2,
HTP_OP_DIV = 3,
HTP_OP_MUL_MAT = 4,
HTP_OP_MUL_MAT_ID = 5,
HTP_OP_RMS_NORM = 6,
HTP_OP_UNARY_SILU = 7,
HTP_OP_UNARY_GELU = 8,
HTP_OP_GLU_SWIGLU = 9,
HTP_OP_GLU_SWIGLU_OAI = 10,
HTP_OP_SOFTMAX = 11,
HTP_OP_ADD_ID = 12,
HTP_OP_ROPE = 13,
HTP_OP_FLASH_ATTN_EXT = 14,
HTP_OP_SET_ROWS = 15,
HTP_OP_SCALE = 16,
HTP_OP_GET_ROWS = 17,
HTP_OP_CPY = 18,
HTP_OP_MUL = 0,
HTP_OP_ADD = 1,
HTP_OP_SUB = 2,
HTP_OP_DIV = 3,
HTP_OP_MUL_MAT,
HTP_OP_MUL_MAT_ID,
HTP_OP_RMS_NORM,
HTP_OP_UNARY_SILU,
HTP_OP_UNARY_GELU,
HTP_OP_GLU_SWIGLU,
HTP_OP_GLU_SWIGLU_OAI,
HTP_OP_GLU_GEGLU,
HTP_OP_SOFTMAX,
HTP_OP_ADD_ID,
HTP_OP_ROPE,
HTP_OP_FLASH_ATTN_EXT,
HTP_OP_SET_ROWS,
HTP_OP_GET_ROWS,
HTP_OP_SCALE,
HTP_OP_CPY,
HTP_OP_ARGSORT,
HTP_OP_SQR,
HTP_OP_SQRT,
HTP_OP_SUM_ROWS,
INVALID
};
static inline size_t htp_type_block_size(uint32_t t) {
static inline size_t htp_t_block_size(uint32_t t) {
switch (t) {
case HTP_TYPE_F32:
return 1;
@ -103,22 +107,6 @@ static inline size_t htp_type_nbytes(uint32_t t) {
return 0;
}
static const char * htp_type_name(uint32_t t) {
switch (t) {
case HTP_TYPE_F32:
return "fp32";
case HTP_TYPE_F16:
return "fp16";
case HTP_TYPE_Q4_0:
return "q4_0";
case HTP_TYPE_Q8_0:
return "q8_0";
case HTP_TYPE_MXFP4:
return "mxfp4";
}
return 0;
}
// Internal types
#define QK_Q4_0x4x2 256 // 4x Q4_0 blocks packed with next 4x Q4_0 blocks (size in bytes 128)
#define QK_Q8_0x4x2 256 // 4x Q8_0 blocks concat with next 4x Q8_0 blocks

View File

@ -64,25 +64,12 @@ struct htp_ops_context {
struct fastdiv_values broadcast_rv2;
struct fastdiv_values broadcast_rv3;
struct fastdiv_values mm_div_ne12_ne1; // fastdiv values for ne12 * ne1
struct fastdiv_values mm_div_ne1; // fastdiv values for ne1
struct fastdiv_values mm_div_r2; // fastdiv values for ne12 / ne02
struct fastdiv_values mm_div_r3; // fastdiv values for ne13 / ne03
struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12
struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11
struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10
struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11
struct fastdiv_values cpy_div_ne01; // fastdiv values for ne01
struct fastdiv_values cpy_div_ne02; // fastdiv values for ne02
struct fastdiv_values cpy_div_ne03; // fastdiv values for ne03
struct fastdiv_values cpy_rshp_div_n0; // fastdiv values for ne00
struct fastdiv_values cpy_rshp_div_n1n0; // fastdiv values for ne00*ne01
struct fastdiv_values cpy_rshp_div_n2n1n0; // fastdiv values for ne00*ne01*ne02
uint32_t flags;
};
@ -90,6 +77,7 @@ int op_matmul(struct htp_ops_context * octx);
int op_matmul_id(struct htp_ops_context * octx);
int op_binary(struct htp_ops_context * octx);
int op_unary(struct htp_ops_context * octx);
int op_sum_rows(struct htp_ops_context * octx);
int op_activations(struct htp_ops_context * octx);
int op_softmax(struct htp_ops_context * octx);
int op_add_id(struct htp_ops_context * octx);
@ -98,5 +86,6 @@ int op_flash_attn_ext(struct htp_ops_context * octx);
int op_set_rows(struct htp_ops_context * octx);
int op_get_rows(struct htp_ops_context * octx);
int op_cpy(struct htp_ops_context * octx);
int op_argsort(struct htp_ops_context * octx);
#endif /* HTP_OPS_H */

View File

@ -46,127 +46,76 @@
#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
#endif
// ADD variants
// Generic macro to define alignment permutations for an op
#define DEFINE_HVX_BINARY_OP_VARIANTS(OP_NAME, OP_MACRO) \
static inline void OP_NAME##_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) dst % 128 == 0); \
assert((uintptr_t) src0 % 128 == 0); \
assert((uintptr_t) src1 % 128 == 0); \
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \
} \
static inline void OP_NAME##_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) dst % 128 == 0); \
assert((uintptr_t) src0 % 128 == 0); \
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \
} \
static inline void OP_NAME##_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) dst % 128 == 0); \
assert((uintptr_t) src1 % 128 == 0); \
hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a, OP_MACRO); \
} \
static inline void OP_NAME##_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) dst % 128 == 0); \
hvx_arith_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a, OP_MACRO); \
} \
static inline void OP_NAME##_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) src0 % 128 == 0); \
assert((uintptr_t) src1 % 128 == 0); \
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \
} \
static inline void OP_NAME##_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) src0 % 128 == 0); \
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \
} \
static inline void OP_NAME##_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
assert((uintptr_t) src1 % 128 == 0); \
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u, OP_MACRO); \
} \
static inline void OP_NAME##_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) { \
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, OP_MACRO); \
} \
static inline void hvx_add_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src0 % 128 == 0);
assert((unsigned long) src1 % 128 == 0);
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_ADD);
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_add_f32, HVX_OP_ADD)
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_sub_f32, HVX_OP_SUB)
DEFINE_HVX_BINARY_OP_VARIANTS(hvx_mul_f32, HVX_OP_MUL)
// Dispatcher logic
#define HVX_BINARY_DISPATCHER(OP_NAME) \
static inline void OP_NAME(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) { \
if (hex_is_aligned((void *) dst, 128)) { \
if (hex_is_aligned((void *) src0, 128)) { \
if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aaa(dst, src0, src1, num_elems); \
else OP_NAME##_aau(dst, src0, src1, num_elems); \
} else { \
if (hex_is_aligned((void *) src1, 128)) OP_NAME##_aua(dst, src0, src1, num_elems); \
else OP_NAME##_auu(dst, src0, src1, num_elems); \
} \
} else { \
if (hex_is_aligned((void *) src0, 128)) { \
if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uaa(dst, src0, src1, num_elems); \
else OP_NAME##_uau(dst, src0, src1, num_elems); \
} else { \
if (hex_is_aligned((void *) src1, 128)) OP_NAME##_uua(dst, src0, src1, num_elems); \
else OP_NAME##_uuu(dst, src0, src1, num_elems); \
} \
} \
}
static inline void hvx_add_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src0 % 128 == 0);
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_ADD);
}
static inline void hvx_add_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) src0 % 128 == 0);
assert((unsigned long) src1 % 128 == 0);
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_ADD);
}
static inline void hvx_add_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_ADD);
}
// SUB variants
static inline void hvx_sub_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src0 % 128 == 0);
assert((unsigned long) src1 % 128 == 0);
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_SUB);
}
static inline void hvx_sub_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src0 % 128 == 0);
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_SUB);
}
static inline void hvx_sub_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) src0 % 128 == 0);
assert((unsigned long) src1 % 128 == 0);
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_SUB);
}
static inline void hvx_sub_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_SUB);
}
// MUL variants
static inline void hvx_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src0 % 128 == 0);
assert((unsigned long) src1 % 128 == 0);
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a, HVX_OP_MUL);
}
static inline void hvx_mul_f32_au(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src0 % 128 == 0);
hvx_arith_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a, HVX_OP_MUL);
}
static inline void hvx_mul_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((unsigned long) src0 % 128 == 0);
assert((unsigned long) src1 % 128 == 0);
hvx_arith_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u, HVX_OP_MUL);
}
static inline void hvx_mul_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
hvx_arith_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u, HVX_OP_MUL);
}
// Dispatchers
static inline void hvx_add_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) {
if (hex_is_aligned((void *) src1, 128)) {
hvx_add_f32_aa(dst, src0, src1, num_elems);
} else {
hvx_add_f32_au(dst, src0, src1, num_elems);
}
} else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) {
hvx_add_f32_ua(dst, src0, src1, num_elems);
} else {
hvx_add_f32_uu(dst, src0, src1, num_elems);
}
}
static inline void hvx_sub_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) {
if (hex_is_aligned((void *) src1, 128)) {
hvx_sub_f32_aa(dst, src0, src1, num_elems);
} else {
hvx_sub_f32_au(dst, src0, src1, num_elems);
}
} else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) {
hvx_sub_f32_ua(dst, src0, src1, num_elems);
} else {
hvx_sub_f32_uu(dst, src0, src1, num_elems);
}
}
static inline void hvx_mul_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
if (hex_is_aligned((void *) dst, 128) && hex_is_aligned((void *) src0, 128)) {
if (hex_is_aligned((void *) src1, 128)) {
hvx_mul_f32_aa(dst, src0, src1, num_elems);
} else {
hvx_mul_f32_au(dst, src0, src1, num_elems);
}
} else if (hex_is_aligned((void *) src0, 128) && hex_is_aligned((void *) src1, 128)) {
hvx_mul_f32_ua(dst, src0, src1, num_elems);
} else {
hvx_mul_f32_uu(dst, src0, src1, num_elems);
}
}
HVX_BINARY_DISPATCHER(hvx_add_f32)
HVX_BINARY_DISPATCHER(hvx_sub_f32)
HVX_BINARY_DISPATCHER(hvx_mul_f32)
// Mul-Mul Optimized
static inline void hvx_mul_mul_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint8_t * restrict src2, const uint32_t num_elems) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src0 % 128 == 0);
@ -443,6 +392,68 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t *
}
}
//
// Square
//
#define hvx_sqr_loop_body(dst_type, src_type, vec_store) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
src_type * restrict vsrc = (src_type *) src; \
\
const uint32_t elem_size = sizeof(float); \
const uint32_t epv = 128 / elem_size; \
const uint32_t nvec = n / epv; \
const uint32_t nloe = n % epv; \
\
uint32_t i = 0; \
\
_Pragma("unroll(4)") \
for (; i < nvec; i++) { \
vdst[i] = HVX_OP_MUL(vsrc[i], vsrc[i]); \
} \
if (nloe) { \
HVX_Vector v = HVX_OP_MUL(vsrc[i], vsrc[i]); \
vec_store((void *) &vdst[i], nloe * elem_size, v); \
} \
} while(0)
static inline void hvx_sqr_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
}
static inline void hvx_sqr_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
hvx_sqr_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
}
static inline void hvx_sqr_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) src % 128 == 0);
hvx_sqr_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
}
static inline void hvx_sqr_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
hvx_sqr_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
}
static inline void hvx_sqr_f32(uint8_t * restrict dst, const uint8_t * restrict src, const uint32_t num_elems) {
if (hex_is_aligned((void *) dst, 128)) {
if (hex_is_aligned((void *) src, 128)) {
hvx_sqr_f32_aa(dst, src, num_elems);
} else {
hvx_sqr_f32_au(dst, src, num_elems);
}
} else {
if (hex_is_aligned((void *) src, 128)) {
hvx_sqr_f32_ua(dst, src, num_elems);
} else {
hvx_sqr_f32_uu(dst, src, num_elems);
}
}
}
#undef HVX_OP_ADD
#undef HVX_OP_SUB
#undef HVX_OP_MUL
@ -453,5 +464,7 @@ static inline void hvx_clamp_scalar_f32(uint8_t * restrict dst, const uint8_t *
#undef hvx_scalar_loop_body
#undef HVX_OP_MIN_SCALAR
#undef HVX_OP_CLAMP_SCALAR
#undef DEFINE_HVX_BINARY_OP_VARIANTS
#undef HVX_BINARY_DISPATCHER
#endif // HVX_ARITH_H

View File

@ -66,6 +66,12 @@ static inline float hvx_vec_get_f32(HVX_Vector v) {
return x;
}
static inline int32_t hvx_vec_get_i32(HVX_Vector v) {
int32_t __attribute__((aligned(128))) x;
hvx_vec_store_a(&x, 4, v);
return x;
}
static inline HVX_Vector hvx_vec_abs_f16(HVX_Vector v) {
// abs by clearing the fp16 sign bit
HVX_Vector mask = Q6_Vh_vsplat_R(0x7fff);

View File

@ -136,8 +136,6 @@ static inline void hvx_copy_f32_uu(uint8_t * restrict dst, const uint8_t * restr
dst_type * restrict vdst = (dst_type *) dst; \
src_type * restrict vsrc = (src_type *) src; \
\
const HVX_Vector zero = Q6_V_vsplat_R(0); \
\
const uint32_t elem_size = sizeof(__fp16); \
const uint32_t epv = 128 / elem_size; \
const uint32_t nvec = n / epv; \

View File

@ -0,0 +1,116 @@
#ifndef HVX_DIV_H
#define HVX_DIV_H
#include <HAP_farf.h>
#include <math.h>
#include <string.h>
#include <assert.h>
#include <stddef.h>
#include <stdint.h>
#include "hvx-base.h"
#include "hex-utils.h"
#include "hvx-inverse.h"
#include "hvx-arith.h"
#if __HVX_ARCH__ < 79
#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
#else
#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
#endif
#define hvx_div_f32_loop_body(dst_type, src0_type, src1_type, vec_store) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
src0_type * restrict vsrc0 = (src0_type *) src0; \
src1_type * restrict vsrc1 = (src1_type *) src1; \
\
const HVX_Vector nan_inf_mask = Q6_V_vsplat_R(0x7f800000); \
\
const uint32_t nvec = n / VLEN_FP32; \
const uint32_t nloe = n % VLEN_FP32; \
\
uint32_t i = 0; \
\
_Pragma("unroll(4)") \
for (; i < nvec; i++) { \
HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \
HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \
vdst[i] = res; \
} \
if (nloe) { \
HVX_Vector inv_src1 = hvx_vec_inverse_f32_guard(vsrc1[i], nan_inf_mask); \
HVX_Vector res = HVX_OP_MUL(vsrc0[i], inv_src1); \
vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, res); \
} \
} while(0)
// 3-letter suffix variants
static inline void hvx_div_f32_aaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((uintptr_t) dst % 128 == 0);
assert((uintptr_t) src0 % 128 == 0);
assert((uintptr_t) src1 % 128 == 0);
hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_Vector, hvx_vec_store_a);
}
static inline void hvx_div_f32_aau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((uintptr_t) dst % 128 == 0);
assert((uintptr_t) src0 % 128 == 0);
hvx_div_f32_loop_body(HVX_Vector, HVX_Vector, HVX_UVector, hvx_vec_store_a);
}
static inline void hvx_div_f32_aua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((uintptr_t) dst % 128 == 0);
assert((uintptr_t) src1 % 128 == 0);
hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_Vector, hvx_vec_store_a);
}
static inline void hvx_div_f32_auu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((uintptr_t) dst % 128 == 0);
hvx_div_f32_loop_body(HVX_Vector, HVX_UVector, HVX_UVector, hvx_vec_store_a);
}
static inline void hvx_div_f32_uaa(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((uintptr_t) src0 % 128 == 0);
assert((uintptr_t) src1 % 128 == 0);
hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_Vector, hvx_vec_store_u);
}
static inline void hvx_div_f32_uau(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((uintptr_t) src0 % 128 == 0);
hvx_div_f32_loop_body(HVX_UVector, HVX_Vector, HVX_UVector, hvx_vec_store_u);
}
static inline void hvx_div_f32_uua(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
assert((uintptr_t) src1 % 128 == 0);
hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_Vector, hvx_vec_store_u);
}
static inline void hvx_div_f32_uuu(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, uint32_t n) {
hvx_div_f32_loop_body(HVX_UVector, HVX_UVector, HVX_UVector, hvx_vec_store_u);
}
static inline void hvx_div_f32(uint8_t * restrict dst, const uint8_t * restrict src0, const uint8_t * restrict src1, const uint32_t num_elems) {
if (hex_is_aligned((void *) dst, 128)) {
if (hex_is_aligned((void *) src0, 128)) {
if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aaa(dst, src0, src1, num_elems);
else hvx_div_f32_aau(dst, src0, src1, num_elems);
} else {
if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_aua(dst, src0, src1, num_elems);
else hvx_div_f32_auu(dst, src0, src1, num_elems);
}
} else {
if (hex_is_aligned((void *) src0, 128)) {
if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uaa(dst, src0, src1, num_elems);
else hvx_div_f32_uau(dst, src0, src1, num_elems);
} else {
if (hex_is_aligned((void *) src1, 128)) hvx_div_f32_uua(dst, src0, src1, num_elems);
else hvx_div_f32_uuu(dst, src0, src1, num_elems);
}
}
}
#undef HVX_OP_MUL
#endif // HVX_DIV_H

View File

@ -91,6 +91,27 @@ static inline HVX_Vector hvx_vec_tanh_f32(HVX_Vector x) {
} \
} while(0)
#define hvx_tanh_loop_body(dst_type, src_type, vec_store) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
src_type * restrict vsrc = (src_type *) src; \
\
const uint32_t epv = 128 / sizeof(float); \
const uint32_t nvec = n / epv; \
const uint32_t nloe = n % epv; \
\
uint32_t i = 0; \
\
_Pragma("unroll(4)") \
for (; i < nvec; i++) { \
vdst[i] = hvx_vec_tanh_f32(vsrc[i]); \
} \
if (nloe) { \
HVX_Vector tmp = hvx_vec_tanh_f32(vsrc[i]); \
vec_store((void *) &vdst[i], nloe * sizeof(float), tmp); \
} \
} while(0)
static inline void hvx_sigmoid_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
@ -111,4 +132,10 @@ static inline void hvx_sigmoid_f32_uu(uint8_t * restrict dst, const uint8_t * re
hvx_sigmoid_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
}
static inline void hvx_tanh_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
hvx_tanh_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
}
#endif /* HVX_SIGMOID_H */

View File

@ -12,11 +12,17 @@
#define RSQRT_ONE_HALF 0x3f000000 // 0.5
#define RSQRT_THREE_HALVES 0x3fc00000 // 1.5
#if __HVX_ARCH__ < 79
#define HVX_OP_MUL(a, b) Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(a, b))
#else
#define HVX_OP_MUL(a, b) Q6_Vsf_vmpy_VsfVsf(a, b)
#endif
static inline HVX_Vector hvx_vec_rsqrt_f32(HVX_Vector in_vec) {
//Algorithm :
// x2 = input*0.5
// y = * (long *) &input
// y = 0x5f3759df - (y>>2)
// y = 0x5f3759df - (y>>1)
// y = y*(threehalfs - x2*y*y)
HVX_Vector rsqrtconst = Q6_V_vsplat_R(RSQRT_CONST);
@ -57,4 +63,64 @@ static inline HVX_Vector hvx_vec_rsqrt_f32(HVX_Vector in_vec) {
return Q6_Vsf_equals_Vqf32(temp);
}
// Compute sqrt(x) as x*inv_sqrt(x)
#define hvx_sqrt_f32_loop_body(dst_type, src_type, vec_store) \
do { \
dst_type * restrict vdst = (dst_type *) dst; \
src_type * restrict vsrc = (src_type *) src; \
\
const uint32_t nvec = n / VLEN_FP32; \
const uint32_t nloe = n % VLEN_FP32; \
\
uint32_t i = 0; \
\
_Pragma("unroll(4)") \
for (; i < nvec; i++) { \
HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]); \
HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]); \
vdst[i] = sqrt_res; \
} \
if (nloe) { \
HVX_Vector inv_sqrt = hvx_vec_rsqrt_f32(vsrc[i]); \
HVX_Vector sqrt_res = HVX_OP_MUL(inv_sqrt, vsrc[i]); \
vec_store((void *) &vdst[i], nloe * SIZEOF_FP32, sqrt_res); \
} \
} while(0)
static inline void hvx_sqrt_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
assert((unsigned long) src % 128 == 0);
hvx_sqrt_f32_loop_body(HVX_Vector, HVX_Vector, hvx_vec_store_a);
}
static inline void hvx_sqrt_f32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) dst % 128 == 0);
hvx_sqrt_f32_loop_body(HVX_Vector, HVX_UVector, hvx_vec_store_a);
}
static inline void hvx_sqrt_f32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
assert((unsigned long) src % 128 == 0);
hvx_sqrt_f32_loop_body(HVX_UVector, HVX_Vector, hvx_vec_store_u);
}
static inline void hvx_sqrt_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
hvx_sqrt_f32_loop_body(HVX_UVector, HVX_UVector, hvx_vec_store_u);
}
static inline void hvx_sqrt_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int num_elems) {
if ((unsigned long) dst % 128 == 0) {
if ((unsigned long) src % 128 == 0) {
hvx_sqrt_f32_aa(dst, src, num_elems);
} else {
hvx_sqrt_f32_au(dst, src, num_elems);
}
} else {
if ((unsigned long) src % 128 == 0) {
hvx_sqrt_f32_ua(dst, src, num_elems);
} else {
hvx_sqrt_f32_uu(dst, src, num_elems);
}
}
}
#endif /* HVX_SQRT_H */

View File

@ -12,6 +12,7 @@
#include "hvx-sigmoid.h"
#include "hvx-sqrt.h"
#include "hvx-arith.h"
#include "hvx-div.h"
#include "hvx-base.h"
#endif /* HVX_UTILS_H */

View File

@ -189,7 +189,7 @@ static int vtcm_release_callback(unsigned int rctx, void * state) {
// otherwise we'll release it once we're done with the current Op.
if (ctx->vtcm_inuse) {
ctx->vtcm_needs_release = false;
ctx->vtcm_needs_release = true;
return 0;
}
@ -440,6 +440,45 @@ static void proc_matmul_req(struct htp_context * ctx,
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
static void proc_argsort_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
struct dspqueue_buffer rsp_bufs[1];
// We had written to the output buffer, we'd also need to flush it
rsp_bufs[0].fd = bufs[1].fd;
rsp_bufs[0].ptr = bufs[1].ptr;
rsp_bufs[0].offset = bufs[1].offset;
rsp_bufs[0].size = bufs[1].size;
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
// Setup Op context
struct htp_ops_context octx = { 0 };
octx.ctx = ctx;
octx.src0 = req->src0;
octx.dst = req->dst;
octx.flags = req->flags;
octx.op = req->op;
memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
// Update data pointers
octx.src0.data = (uint32_t) bufs[0].ptr;
octx.dst.data = (uint32_t) bufs[1].ptr;
octx.n_threads = ctx->n_threads;
struct profile_data prof;
profile_start(&prof);
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
rsp_status = op_argsort(&octx);
vtcm_release(ctx);
}
profile_stop(&prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
static void proc_cpy_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
struct dspqueue_buffer rsp_bufs[1];
@ -679,6 +718,45 @@ static void proc_unary_req(struct htp_context * ctx, struct htp_general_req * re
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
static void proc_sum_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
// We had written to the output buffer, we'd also need to flush it
rsp_bufs[0].fd = bufs[1].fd;
rsp_bufs[0].ptr = bufs[1].ptr;
rsp_bufs[0].offset = bufs[1].offset;
rsp_bufs[0].size = bufs[1].size;
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
// Setup Op context
struct htp_ops_context octx = { 0 };
octx.ctx = ctx;
octx.src0 = req->src0;
octx.dst = req->dst;
octx.flags = req->flags;
octx.op = req->op;
memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
// Update data pointers
octx.src0.data = (uint32_t) bufs[0].ptr;
octx.dst.data = (uint32_t) bufs[1].ptr;
octx.n_threads = ctx->n_threads;
struct profile_data prof;
profile_start(&prof);
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
rsp_status = op_sum_rows(&octx);
vtcm_release(ctx);
}
profile_stop(&prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
static void proc_activations_req(struct htp_context * ctx,
struct htp_general_req * req,
struct dspqueue_buffer * bufs,
@ -951,6 +1029,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
case HTP_OP_MUL:
case HTP_OP_ADD:
case HTP_OP_SUB:
case HTP_OP_DIV:
if (n_bufs != 3) {
FARF(ERROR, "Bad binary-req buffer list");
continue;
@ -968,6 +1047,25 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
proc_unary_req(ctx, &req, bufs);
break;
case HTP_OP_SQR:
case HTP_OP_SQRT:
if (n_bufs != 2) {
FARF(ERROR, "Bad unary-req buffer list");
continue;
}
proc_unary_req(ctx, &req, bufs);
break;
case HTP_OP_SUM_ROWS:
if (n_bufs != 2) {
FARF(ERROR, "Bad unary-req buffer list");
continue;
}
proc_sum_rows_req(ctx, &req, bufs);
break;
case HTP_OP_UNARY_SILU:
case HTP_OP_UNARY_GELU:
if (n_bufs != 2) {
@ -980,6 +1078,7 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
case HTP_OP_GLU_SWIGLU:
case HTP_OP_GLU_SWIGLU_OAI:
case HTP_OP_SOFTMAX:
case HTP_OP_GLU_GEGLU:
if ((n_bufs != 2) && (n_bufs != 3)) {
FARF(ERROR, "Bad act-req buffer list");
continue;
@ -1035,6 +1134,14 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
proc_cpy_req(ctx, &req, bufs);
break;
case HTP_OP_ARGSORT:
if (n_bufs != 2) {
FARF(ERROR, "Bad argsort-req buffer list");
continue;
}
proc_argsort_req(ctx, &req, bufs);
break;
default:
FARF(ERROR, "Unknown Op %u", req.op);
break;

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,115 @@
#pragma clang diagnostic ignored "-Wunused-variable"
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
#include <HAP_farf.h>
#include <HAP_perf.h>
#include <string.h>
#include <math.h>
#include "hex-dma.h"
#include "hvx-utils.h"
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "htp-ctx.h"
#include "htp-msg.h"
#include "htp-ops.h"
#define sum_rows_preamble \
struct htp_tensor *src0 = &octx->src0;\
struct htp_tensor *dst = &octx->dst; \
\
const uint32_t ne00 = src0->ne[0]; \
const uint32_t ne01 = src0->ne[1]; \
const uint32_t ne02 = src0->ne[2]; \
const uint32_t ne03 = src0->ne[3]; \
\
const uint32_t nb00 = src0->nb[0]; \
const uint32_t nb01 = src0->nb[1]; \
const uint32_t nb02 = src0->nb[2]; \
const uint32_t nb03 = src0->nb[3]; \
\
const uint32_t ne0 = dst->ne[0]; \
const uint32_t ne1 = dst->ne[1]; \
const uint32_t ne2 = dst->ne[2]; \
const uint32_t ne3 = dst->ne[3]; \
\
const uint32_t nb0 = dst->nb[0]; \
const uint32_t nb1 = dst->nb[1]; \
const uint32_t nb2 = dst->nb[2]; \
const uint32_t nb3 = dst->nb[3]; \
static int sum_rows_thread_f32(struct htp_ops_context * octx, const int nth, const int ith) {
sum_rows_preamble;
const uint32_t src0_nrows_per_thread = octx->src0_nrows_per_thread;
const size_t src0_row_size = nb01;
const size_t dst_row_size = nb1;
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
const uint32_t src0_start_row = src0_nrows_per_thread * ith;
const uint32_t src0_end_row = MIN(src0_start_row + src0_nrows_per_thread, src0_nrows);
// no work for this thread
if (src0_start_row >= src0_end_row) {
return HTP_STATUS_OK;
}
int opt_path = 0;
if ((0 == hex_is_aligned((void *) src0->data, VLEN)) && !(nb01 & (VLEN - 1))) {
opt_path = 1;
}
const uint8_t * restrict data_src = (const uint8_t *) src0->data;
uint8_t * restrict data_dst = (uint8_t *) dst->data;
const float * restrict src_th = (float *) (data_src + (src0_start_row * src0_row_size));
float * restrict dst_th = (float *) (data_dst + (src0_start_row * dst_row_size));
for (uint32_t ir = 0; ir < src0_nrows_per_thread; ir++) {
const float * restrict src_local = src_th + (ir * ne00);
if (ir + 1 < src0_nrows_per_thread) {
hex_l2fetch(src_local + ne00, src0_row_size, src0_row_size, 1);
}
if (1 == opt_path) {
dst_th[ir] = hvx_reduce_sum_f32_a((const uint8_t *) src_local, ne00);
} else {
dst_th[ir] = hvx_reduce_sum_f32((const uint8_t *) src_local, ne00);
}
}
return HTP_STATUS_OK;
}
static void sum_rows_work_f32(unsigned int n, unsigned int i, void *data) {
sum_rows_thread_f32((struct htp_ops_context *) data, n, i);
}
int op_sum_rows(struct htp_ops_context * octx) {
sum_rows_preamble;
if (octx->src0.type != HTP_TYPE_F32) {
return HTP_STATUS_NO_SUPPORT;
}
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
return HTP_STATUS_OK;
}
const int n_threads = octx->n_threads;
const uint32_t src0_nrows = ne01 * ne02 * ne03;
uint32_t n_jobs = MIN(n_threads, src0_nrows);
octx->src0_nrows_per_thread = (src0_nrows + n_jobs - 1) / n_jobs;
worker_pool_run_func(octx->ctx->worker_pool, sum_rows_work_f32, octx, n_jobs);
return HTP_STATUS_OK;
}

View File

@ -132,6 +132,56 @@ static void rms_norm_htp_f32(const float * restrict src,
}
}
static void sqr_htp_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
const uint32_t num_rows,
const uint32_t row_elems,
const size_t row_size,
int32_t * op_params,
int opt_path) {
for (uint32_t ir = 0; ir < num_rows; ir++) {
const float * restrict src_local = src + (ir * row_elems);
float * restrict dst_local = dst + (ir * row_elems);
if (ir + 1 < num_rows) {
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
}
if (1 == opt_path) {
hvx_sqr_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
} else {
hvx_sqr_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
}
}
}
static void sqrt_htp_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
const uint32_t num_rows,
const uint32_t row_elems,
const size_t row_size,
int32_t * op_params,
int opt_path) {
for (uint32_t ir = 0; ir < num_rows; ir++) {
const float * restrict src_local = src + (ir * row_elems);
float * restrict dst_local = dst + (ir * row_elems);
if (ir + 1 < num_rows) {
hex_l2fetch(src_local + row_elems, row_size, row_size, 1);
}
if (1 == opt_path) {
hvx_sqrt_f32_aa((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
} else {
hvx_sqrt_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems);
}
}
}
static void unary_job_f32_per_thread(const struct htp_tensor * src,
struct htp_tensor * dst,
uint8_t * spad,
@ -181,6 +231,12 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src,
case HTP_OP_SCALE:
scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
break;
case HTP_OP_SQR:
sqr_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
break;
case HTP_OP_SQRT:
sqrt_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
break;
default:
break;
@ -218,6 +274,14 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
unary_op_func = unary_job_dispatcher_f32;
op_type = "scale-f32";
break;
case HTP_OP_SQR:
unary_op_func = unary_job_dispatcher_f32;
op_type = "sqr-f32";
break;
case HTP_OP_SQRT:
unary_op_func = unary_job_dispatcher_f32;
op_type = "sqrt-f32";
break;
default:
FARF(ERROR, "Unsupported unary Op %u\n", octx->op);

View File

@ -264,15 +264,25 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
case GGML_OP_NORM:
case GGML_OP_RMS_NORM:
case GGML_OP_GROUP_NORM:
case GGML_OP_L2_NORM:
case GGML_OP_SUM_ROWS:
case GGML_OP_SSM_CONV:
case GGML_OP_SSM_SCAN:
case GGML_OP_CLAMP:
case GGML_OP_TRI:
case GGML_OP_DIAG:
case GGML_OP_MUL:
case GGML_OP_ADD:
case GGML_OP_DIV:
case GGML_OP_GLU:
case GGML_OP_SCALE:
case GGML_OP_UNARY:
case GGML_OP_GET_ROWS:
case GGML_OP_CPY:
case GGML_OP_SET_ROWS:
case GGML_OP_SET:
case GGML_OP_CPY:
case GGML_OP_CONT:
case GGML_OP_REPEAT:
return true;
default:
return ggml_op_is_empty(op);
@ -312,7 +322,7 @@ static std::vector<int> ggml_metal_graph_optimize_reorder(const std::vector<node
h_add(mrs1, node0);
// that many nodes forward to search for a concurrent node
constexpr int N_FORWARD = 8;
constexpr int N_FORWARD = 64;
for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) {
if (used[i1]) {

View File

@ -394,7 +394,7 @@ bool ggml_metal_cpy_tensor_async(ggml_metal_t ctx_src, ggml_metal_t ctx_dst, con
[encoder endEncoding];
ggml_metal_event_t ev_cpy = ggml_metal_get_ev_cpy(ctx_src);
ggml_metal_event_record(ctx_src, ev_cpy);
ggml_metal_event_encode_signal(ev_cpy, cmd_buf);
[cmd_buf commit];

View File

@ -212,61 +212,69 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_repeat(ggml_meta
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_unary(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
char base[256];
char name[256];
const int64_t n = ggml_nelements(op);
int op_num = -1;
const char * op_str = "undefined";
switch (op->op) {
case GGML_OP_SCALE: op_str = "scale"; break;
case GGML_OP_FILL: op_str = "fill"; break;
case GGML_OP_CLAMP: op_str = "clamp"; break;
case GGML_OP_SQR: op_str = "sqr"; break;
case GGML_OP_SQRT: op_str = "sqrt"; break;
case GGML_OP_SIN: op_str = "sin"; break;
case GGML_OP_COS: op_str = "cos"; break;
case GGML_OP_LOG: op_str = "log"; break;
case GGML_OP_LEAKY_RELU: op_str = "leaky_relu"; break;
case GGML_OP_SCALE: op_num = OP_UNARY_NUM_SCALE; break;
case GGML_OP_FILL: op_num = OP_UNARY_NUM_FILL; break;
case GGML_OP_CLAMP: op_num = OP_UNARY_NUM_CLAMP; break;
case GGML_OP_SQR: op_num = OP_UNARY_NUM_SQR; break;
case GGML_OP_SQRT: op_num = OP_UNARY_NUM_SQRT; break;
case GGML_OP_SIN: op_num = OP_UNARY_NUM_SIN; break;
case GGML_OP_COS: op_num = OP_UNARY_NUM_COS; break;
case GGML_OP_LOG: op_num = OP_UNARY_NUM_LOG; break;
case GGML_OP_LEAKY_RELU: op_num = OP_UNARY_NUM_LEAKY_RELU; break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_TANH: op_str = "tanh"; break;
case GGML_UNARY_OP_RELU: op_str = "relu"; break;
case GGML_UNARY_OP_SIGMOID: op_str = "sigmoid"; break;
case GGML_UNARY_OP_GELU: op_str = "gelu"; break;
case GGML_UNARY_OP_GELU_ERF: op_str = "gelu_erf"; break;
case GGML_UNARY_OP_GELU_QUICK: op_str = "gelu_quick"; break;
case GGML_UNARY_OP_SILU: op_str = "silu"; break;
case GGML_UNARY_OP_ELU: op_str = "elu"; break;
case GGML_UNARY_OP_NEG: op_str = "neg"; break;
case GGML_UNARY_OP_ABS: op_str = "abs"; break;
case GGML_UNARY_OP_SGN: op_str = "sgn"; break;
case GGML_UNARY_OP_STEP: op_str = "step"; break;
case GGML_UNARY_OP_HARDSWISH: op_str = "hardswish"; break;
case GGML_UNARY_OP_HARDSIGMOID: op_str = "hardsigmoid"; break;
case GGML_UNARY_OP_EXP: op_str = "exp"; break;
case GGML_UNARY_OP_SOFTPLUS: op_str = "softplus"; break;
case GGML_UNARY_OP_EXPM1: op_str = "expm1"; break;
case GGML_UNARY_OP_TANH: op_num = OP_UNARY_NUM_TANH; break;
case GGML_UNARY_OP_RELU: op_num = OP_UNARY_NUM_RELU; break;
case GGML_UNARY_OP_SIGMOID: op_num = OP_UNARY_NUM_SIGMOID; break;
case GGML_UNARY_OP_GELU: op_num = OP_UNARY_NUM_GELU; break;
case GGML_UNARY_OP_GELU_ERF: op_num = OP_UNARY_NUM_GELU_ERF; break;
case GGML_UNARY_OP_GELU_QUICK: op_num = OP_UNARY_NUM_GELU_QUICK; break;
case GGML_UNARY_OP_SILU: op_num = OP_UNARY_NUM_SILU; break;
case GGML_UNARY_OP_ELU: op_num = OP_UNARY_NUM_ELU; break;
case GGML_UNARY_OP_NEG: op_num = OP_UNARY_NUM_NEG; break;
case GGML_UNARY_OP_ABS: op_num = OP_UNARY_NUM_ABS; break;
case GGML_UNARY_OP_SGN: op_num = OP_UNARY_NUM_SGN; break;
case GGML_UNARY_OP_STEP: op_num = OP_UNARY_NUM_STEP; break;
case GGML_UNARY_OP_HARDSWISH: op_num = OP_UNARY_NUM_HARDSWISH; break;
case GGML_UNARY_OP_HARDSIGMOID: op_num = OP_UNARY_NUM_HARDSIGMOID; break;
case GGML_UNARY_OP_EXP: op_num = OP_UNARY_NUM_EXP; break;
case GGML_UNARY_OP_SOFTPLUS: op_num = OP_UNARY_NUM_SOFTPLUS; break;
case GGML_UNARY_OP_EXPM1: op_num = OP_UNARY_NUM_EXPM1; break;
default: GGML_ABORT("fatal error");
} break;
default: GGML_ABORT("fatal error");
};
const char * suffix = "";
if (n % 4 == 0) {
suffix = "_4";
}
const char * t0_str = ggml_type_name(op->src[0]->type);
const char * t_str = ggml_type_name(op->type);
snprintf(base, 256, "kernel_%s_%s%s", op_str, ggml_type_name(op->src[0]->type), suffix);
snprintf(name, 256, "%s", base);
const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
const bool is_cnt = ggml_is_contiguous(op->src[0]) && ggml_nelements(op) < 32768;
snprintf(base, 256, "kernel_unary_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
snprintf(name, 256, "%s_op=%d_cnt=%d", base, op_num, is_cnt);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_int16(cv, op_num, FC_UNARY + 0);
ggml_metal_cv_set_bool (cv, is_cnt, FC_UNARY + 1);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
}
res.c4 = is_c4;
res.cnt = is_cnt;
return res;
}
@ -320,31 +328,46 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum(ggml_metal_l
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type));
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
char base[256];
char name[256];
const char * op_str = "undefined";
int op_num = -1;
switch (op->op) {
case GGML_OP_SUM_ROWS:
op_str = "sum_rows"; break;
case GGML_OP_MEAN:
op_str = "mean"; break;
case GGML_OP_SUM_ROWS: op_num = OP_SUM_ROWS_NUM_SUM_ROWS; break;
case GGML_OP_MEAN: op_num = OP_SUM_ROWS_NUM_MEAN; break;
default: GGML_ABORT("fatal error");
};
snprintf(base, 256, "kernel_%s_%s", op_str, ggml_type_name(op->src[0]->type));
const char * t0_str = ggml_type_name(op->src[0]->type);
const char * t_str = ggml_type_name(op->type);
snprintf(name, 256, "%s", base);
const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
snprintf(base, 256, "kernel_sum_rows_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
snprintf(name, 256, "%s_op=%d", base, op_num);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_int16(cv, op_num, FC_SUM_ROWS + 0);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
}
res.smem = 32*sizeof(float);
if (is_c4) {
res.smem *= 4;
}
res.c4 = is_c4;
return res;
}
@ -1392,34 +1415,78 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_v
GGML_UNUSED(op);
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(
ggml_metal_library_t lib,
ggml_op op,
int32_t n_fuse,
bool row) {
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) {
char base[256];
char name[256];
const char * op_str = "undefined";
switch (op) {
case GGML_OP_ADD: op_str = "add"; break;
case GGML_OP_SUB: op_str = "sub"; break;
case GGML_OP_MUL: op_str = "mul"; break;
case GGML_OP_DIV: op_str = "div"; break;
int op_num = -1;
switch (op->op) {
case GGML_OP_ADD: op_num = 0; break;
case GGML_OP_SUB: op_num = 1; break;
case GGML_OP_MUL: op_num = 2; break;
case GGML_OP_DIV: op_num = 3; break;
default: GGML_ABORT("fatal error");
};
if (row) {
snprintf(base, 256, "kernel_%s_row_c4_fuse_%d", op_str, n_fuse);
} else {
snprintf(base, 256, "kernel_%s_fuse_%d", op_str, n_fuse);
}
const char * t0_str = ggml_type_name(op->src[0]->type);
const char * t1_str = ggml_type_name(op->src[1]->type);
const char * t_str = ggml_type_name(op->type);
snprintf(name, 256, "%s", base);
const bool is_c4 = (op->src[0]->ne[0] % 4 == 0) && (op->src[1]->ne[0] % 4 == 0);
const bool is_rb = ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]) && (ggml_nrows(op->src[1]) == 1) && ggml_nelements(op) < 65536;
snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s%s", t0_str, t1_str, t_str, is_c4 ? "_4" : "");
snprintf(name, 256, "%s_op=%d_nf=%d_rb=%d", base, op_num, n_fuse, is_rb);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);
ggml_metal_cv_set_int16(cv, n_fuse, FC_BIN + 1);
ggml_metal_cv_set_bool (cv, is_rb, FC_BIN + 2);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
}
res.c4 = is_c4;
res.cnt = is_rb;
return res;
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one(ggml_metal_library_t lib, ggml_op op) {
char base[256];
char name[256];
int op_num = -1;
switch (op) {
case GGML_OP_ADD: op_num = 0; break;
case GGML_OP_SUB: op_num = 1; break;
case GGML_OP_MUL: op_num = 2; break;
case GGML_OP_DIV: op_num = 3; break;
default: GGML_ABORT("fatal error");
};
snprintf(base, 256, "kernel_bin_fuse_%s_%s_%s", "f32", "f32", "f32");
snprintf(name, 256, "%s_op=%d_nf=%d", base, op_num, 1);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_int16(cv, op_num, FC_BIN + 0);
ggml_metal_cv_set_int16(cv, 1, FC_BIN + 1);
ggml_metal_cv_set_bool (cv, false, FC_BIN + 2);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
}
return res;
@ -1428,13 +1495,15 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin(
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_L2_NORM);
GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
char base[256];
char name[256];
snprintf(base, 256, "kernel_l2_norm_f32");
const bool is_c4 = op->src[0]->ne[0] % 4 == 0;
const char * t0_str = ggml_type_name(op->src[0]->type);
const char * t_str = ggml_type_name(op->type);
snprintf(base, 256, "kernel_l2_norm_%s_%s%s", t0_str, t_str, is_c4 ? "_4" : "");
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
@ -1442,6 +1511,7 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm(ggml_met
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
}
res.c4 = is_c4;
res.smem = 32*sizeof(float);
return res;

View File

@ -53,6 +53,9 @@ struct ggml_metal_pipeline_with_params {
int nr1;
size_t smem;
bool c4;
bool cnt;
};
int ggml_metal_pipeline_max_theads_per_threadgroup(struct ggml_metal_pipeline_with_params pipeline);
@ -134,7 +137,8 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_argsort_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_top_k_merge (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, enum ggml_op op, int32_t n_fuse, bool row);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse );
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_bin_one (ggml_metal_library_t lib, enum ggml_op op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_l2_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_group_norm (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_norm (ggml_metal_library_t lib, const struct ggml_tensor * op, int32_t n_fuse);

View File

@ -346,10 +346,12 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_meta
struct ggml_metal_pipeline_with_params res = {
/*.pipeline =*/ nil,
/*.nsg =*/ 0,
/*.nr0 =*/ 0,
/*.nr1 =*/ 0,
/*.nsg =*/ 0,
/*.smem =*/ 0,
/*.c4 =*/ false,
/*.cnt =*/ false,
};
res.pipeline = ggml_metal_pipelines_get(lib->pipelines, name);
@ -362,10 +364,12 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline(ggml_meta
struct ggml_metal_pipeline_with_params ggml_metal_library_compile_pipeline(ggml_metal_library_t lib, const char * base, const char * name, ggml_metal_cv_t cv) {
struct ggml_metal_pipeline_with_params res = {
/*.pipeline =*/ nil,
/*.nsg =*/ 0,
/*.nr0 =*/ 0,
/*.nr1 =*/ 0,
/*.nsg =*/ 0,
/*.smem =*/ 0,
/*.c4 =*/ false,
/*.cnt =*/ false,
};
[lib->lock lock];
@ -1007,6 +1011,15 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
}
switch (op->op) {
case GGML_OP_SCALE:
case GGML_OP_FILL:
case GGML_OP_CLAMP:
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_LOG:
return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_TANH:
@ -1026,7 +1039,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_SOFTPLUS:
case GGML_UNARY_OP_EXPM1:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
return ggml_is_contiguous_rows(op->src[0]) && (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
default:
return false;
}
@ -1054,11 +1067,9 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_ADD_ID:
return op->src[0]->type == GGML_TYPE_F32;
return ggml_is_contiguous_rows(op->src[0]) && ggml_is_contiguous_rows(op->src[1]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_ACC:
case GGML_OP_REPEAT:
case GGML_OP_SCALE:
case GGML_OP_FILL:
case GGML_OP_CONV_TRANSPOSE_1D:
return true;
case GGML_OP_CONV_TRANSPOSE_2D:
@ -1066,14 +1077,6 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
(op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_F32) &&
op->src[1]->type == GGML_TYPE_F32 &&
op->type == GGML_TYPE_F32;
case GGML_OP_CLAMP:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_LOG:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SUM:
return has_simdgroup_reduction && ggml_is_contiguous(op->src[0]);
case GGML_OP_TRI:
@ -1083,9 +1086,8 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_MEAN:
case GGML_OP_SOFT_MAX:
case GGML_OP_GROUP_NORM:
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_L2_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_COUNT_EQUAL:
return has_simdgroup_reduction &&
op->src[0]->type == GGML_TYPE_I32 &&
@ -1157,6 +1159,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_MUL_MAT:
case GGML_OP_MUL_MAT_ID:
return has_simdgroup_reduction;
case GGML_OP_SET:
case GGML_OP_CPY:
case GGML_OP_DUP:
case GGML_OP_CONT:

View File

@ -80,6 +80,9 @@
#define FC_SSM_CONV 900
#define FC_SOLVE_TRI 1000
#define FC_COUNT_EQUAL 1100
#define FC_UNARY 1200
#define FC_BIN 1300
#define FC_SUM_ROWS 1400
// op-specific constants
#define OP_FLASH_ATTN_EXT_NQPSG 8
@ -88,6 +91,37 @@
#define OP_FLASH_ATTN_EXT_VEC_NQPSG 1
#define OP_FLASH_ATTN_EXT_VEC_NCPSG 32
#define OP_UNARY_NUM_SCALE 10
#define OP_UNARY_NUM_FILL 11
#define OP_UNARY_NUM_CLAMP 12
#define OP_UNARY_NUM_SQR 13
#define OP_UNARY_NUM_SQRT 14
#define OP_UNARY_NUM_SIN 15
#define OP_UNARY_NUM_COS 16
#define OP_UNARY_NUM_LOG 17
#define OP_UNARY_NUM_LEAKY_RELU 18
#define OP_UNARY_NUM_TANH 100
#define OP_UNARY_NUM_RELU 101
#define OP_UNARY_NUM_SIGMOID 102
#define OP_UNARY_NUM_GELU 103
#define OP_UNARY_NUM_GELU_ERF 104
#define OP_UNARY_NUM_GELU_QUICK 105
#define OP_UNARY_NUM_SILU 106
#define OP_UNARY_NUM_ELU 107
#define OP_UNARY_NUM_NEG 108
#define OP_UNARY_NUM_ABS 109
#define OP_UNARY_NUM_SGN 110
#define OP_UNARY_NUM_STEP 111
#define OP_UNARY_NUM_HARDSWISH 112
#define OP_UNARY_NUM_HARDSIGMOID 113
#define OP_UNARY_NUM_EXP 114
#define OP_UNARY_NUM_SOFTPLUS 115
#define OP_UNARY_NUM_EXPM1 116
#define OP_SUM_ROWS_NUM_SUM_ROWS 10
#define OP_SUM_ROWS_NUM_MEAN 11
// kernel argument structs
//
// - element counters (e.g. ne00) typically use int32_t to reduce register usage
@ -123,6 +157,31 @@ typedef struct {
int32_t dim;
} ggml_metal_kargs_concat;
typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
float slope;
float scale;
float bias;
float val;
float min;
float max;
} ggml_metal_kargs_unary;
typedef struct {
int32_t ne00;
int32_t ne01;
@ -180,20 +239,6 @@ typedef struct {
uint64_t nb3;
} ggml_metal_kargs_repeat;
typedef struct {
float scale;
float bias;
} ggml_metal_kargs_scale;
typedef struct {
float val;
} ggml_metal_kargs_fill;
typedef struct {
float min;
float max;
} ggml_metal_kargs_clamp;
typedef struct {
int64_t nk0;
int64_t ne00;
@ -497,8 +542,21 @@ typedef struct {
typedef struct {
int32_t ne00;
int32_t ne00_4;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne0;
int32_t ne1;
int32_t ne2;
int32_t ne3;
uint64_t nb0;
uint64_t nb1;
uint64_t nb2;
uint64_t nb3;
float eps;
} ggml_metal_kargs_l2_norm;
@ -880,10 +938,6 @@ typedef struct {
int max_period;
} ggml_metal_kargs_timestep_embedding;
typedef struct {
float slope;
} ggml_metal_kargs_leaky_relu;
typedef struct {
int32_t ne00;
int32_t ne01;

View File

@ -287,17 +287,9 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
n_fuse = ggml_metal_op_acc(ctx, idx);
} break;
case GGML_OP_SCALE:
{
n_fuse = ggml_metal_op_scale(ctx, idx);
} break;
case GGML_OP_FILL:
{
n_fuse = ggml_metal_op_fill(ctx, idx);
} break;
case GGML_OP_CLAMP:
{
n_fuse = ggml_metal_op_clamp(ctx, idx);
} break;
case GGML_OP_LEAKY_RELU:
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_SIN:
@ -426,10 +418,6 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_top_k(ctx, idx);
} break;
case GGML_OP_LEAKY_RELU:
{
n_fuse = ggml_metal_op_leaky_relu(ctx, idx);
} break;
case GGML_OP_TRI:
{
n_fuse = ggml_metal_op_tri(ctx, idx);
@ -438,6 +426,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{
n_fuse = ggml_metal_op_flash_attn_ext(ctx, idx);
} break;
case GGML_OP_SET:
{
n_fuse = ggml_metal_op_set(ctx, idx);
} break;
case GGML_OP_DUP:
case GGML_OP_CPY:
case GGML_OP_CONT:
@ -707,7 +699,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
/*.o1 =*/ { 0 },
};
auto pipeline = ggml_metal_library_get_pipeline_bin(lib, GGML_OP_ADD, 1, false);
auto pipeline = ggml_metal_library_get_pipeline_bin_one(lib, GGML_OP_ADD);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
@ -722,119 +714,6 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_scale(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
float scale;
float bias;
memcpy(&scale, ((const int32_t *) op->op_params) + 0, sizeof(float));
memcpy(&bias, ((const int32_t *) op->op_params) + 1, sizeof(float));
ggml_metal_kargs_scale args = {
/*.scale =*/ scale,
/*.bias =*/ bias,
};
int64_t n = ggml_nelements(op);
if (n % 4 == 0) {
n /= 4;
}
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
return 1;
}
int ggml_metal_op_fill(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
const float val = ggml_get_op_params_f32(op, 0);
ggml_metal_kargs_fill args = {
/*.val =*/ val
};
int64_t n = ggml_nelements(op);
if (n % 4 == 0) {
n /= 4;
}
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
return 1;
}
int ggml_metal_op_clamp(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
float min;
float max;
memcpy(&min, ((const int32_t *) op->op_params) + 0, sizeof(float));
memcpy(&max, ((const int32_t *) op->op_params) + 1, sizeof(float));
ggml_metal_kargs_clamp args = {
/*.min =*/ min,
/*.max =*/ max,
};
int64_t n = ggml_nelements(op);
if (n % 4 == 0) {
n /= 4;
}
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
return 1;
}
int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
@ -846,19 +725,79 @@ int ggml_metal_op_unary(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
int64_t n = ggml_nelements(op);
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
if (n % 4 == 0) {
n /= 4;
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
ggml_metal_kargs_unary args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.slope =*/ 0.0,
/*.scale =*/ 0.0,
/*.bias =*/ 0.0,
/*.val =*/ 0.0,
/*.min =*/ 0.0,
/*.max =*/ 0.0,
};
if (op->op == GGML_OP_LEAKY_RELU) {
args.slope = ggml_get_op_params_f32(op, 0);
}
if (op->op == GGML_OP_SCALE) {
args.scale = ggml_get_op_params_f32(op, 0);
args.bias = ggml_get_op_params_f32(op, 1);
}
if (op->op == GGML_OP_FILL) {
args.val = ggml_get_op_params_f32(op, 0);
}
if (op->op == GGML_OP_CLAMP) {
args.min = ggml_get_op_params_f32(op, 0);
args.max = ggml_get_op_params_f32(op, 1);
}
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 1);
if (pipeline.c4) {
args.ne00 = ne00/4;
args.ne0 = ne0/4;
}
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
if (pipeline.cnt) {
const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
} else {
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
const int nth = MIN(args.ne00, nth_max);
const int nk0 = (args.ne00 + nth - 1)/nth;
ggml_metal_encoder_dispatch_threadgroups(enc, nk0*ne01, ne02, ne03, nth, 1, 1);
}
return 1;
}
@ -969,6 +908,11 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
ggml_metal_kargs_sum_rows args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
@ -990,21 +934,26 @@ int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) {
auto pipeline = ggml_metal_library_get_pipeline_sum_rows(lib, op);
if (pipeline.c4) {
args.ne00 = ne00/4;
args.ne0 = ne0/4;
}
int nth = 32; // SIMD width
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
while (nth < args.ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2;
}
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
nth = std::min(nth, ne00);
nth = std::min(nth, (int) args.ne00);
const size_t smem = pipeline.smem;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
@ -1664,6 +1613,134 @@ int ggml_metal_op_solve_tri(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_set(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
const size_t pnb1 = ((const int32_t *) op->op_params)[0];
const size_t pnb2 = ((const int32_t *) op->op_params)[1];
const size_t pnb3 = ((const int32_t *) op->op_params)[2];
const size_t offs = ((const int32_t *) op->op_params)[3];
const bool inplace = (bool) ((const int32_t *) op->op_params)[4];
if (!inplace) {
// run a separete kernel to cpy src->dst
// not sure how to avoid this
// TODO: make a simpler cpy_bytes kernel
//const id<MTLComputePipelineState> pipeline = ctx->pipelines[GGML_METAL_PIPELINE_TYPE_CPY_F32_F32].obj;
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type);
ggml_metal_kargs_cpy args = {
/*.nk0 =*/ ne00,
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
};
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne00);
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
ggml_metal_op_concurrency_reset(ctx);
}
auto pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[1]->type, op->type);
GGML_ASSERT(ne10 % ggml_blck_size(op->src[1]->type) == 0);
int64_t nk0 = ne10;
if (ggml_is_quantized(op->src[1]->type)) {
nk0 = ne10/16;
} else if (ggml_is_quantized(op->type)) {
nk0 = ne10/ggml_blck_size(op->type);
}
int nth = std::min<int>(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
// when rows are small, we can batch them together in a single threadgroup
int nrptg = 1;
// TODO: relax this constraint in the future
if (ggml_blck_size(op->src[1]->type) == 1 && ggml_blck_size(op->type) == 1) {
if (nth > nk0) {
nrptg = (nth + nk0 - 1)/nk0;
nth = nk0;
if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nrptg--;
}
}
}
nth = std::min<int>(nth, nk0);
ggml_metal_kargs_cpy args = {
/*.nk0 =*/ nk0,
/*.ne00 =*/ ne10,
/*.ne01 =*/ ne11,
/*.ne02 =*/ ne12,
/*.ne03 =*/ ne13,
/*.nb00 =*/ nb10,
/*.nb01 =*/ nb11,
/*.nb02 =*/ nb12,
/*.nb03 =*/ nb13,
/*.ne0 =*/ ne10,
/*.ne1 =*/ ne11,
/*.ne2 =*/ ne12,
/*.ne3 =*/ ne13,
/*.nb0 =*/ ggml_element_size(op),
/*.nb1 =*/ pnb1,
/*.nb2 =*/ pnb2,
/*.nb3 =*/ pnb3,
};
const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1;
bid_dst.offs += offs;
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, bid_src1, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne11 + nrptg - 1)/nrptg, ne12, ne13, nth, nrptg, 1);
return 1;
}
int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
@ -2895,8 +2972,6 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
GGML_ASSERT(ggml_is_contiguous_rows(op->src[1]));
bool bcast_row = false;
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
@ -2990,18 +3065,7 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
struct ggml_metal_pipeline_with_params pipeline;
if (ggml_nelements(op->src[1]) == ne10 && ggml_is_contiguous(op->src[1]) && ne00 % 4 == 0 && ne10 % 4 == 0) {
GGML_ASSERT(ggml_is_contiguous(op->src[0]));
// src1 is a row
GGML_ASSERT(ne11 == 1);
pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, true);
bcast_row = true;
} else {
pipeline = ggml_metal_library_get_pipeline_bin(lib, op->op, n_fuse, false);
}
pipeline = ggml_metal_library_get_pipeline_bin(lib, op, n_fuse);
if (n_fuse > 1) {
bid_dst = ggml_metal_get_buffer_id(ctx->node(idx + n_fuse - 1));
@ -3015,20 +3079,28 @@ int ggml_metal_op_bin(ggml_metal_op_t ctx, int idx) {
}
}
if (pipeline.c4) {
args.ne00 = ne00/4;
args.ne10 = ne10/4;
args.ne0 = ne0/4;
}
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_src1, 2);
ggml_metal_encoder_set_buffer (enc, bid_dst, 3);
if (bcast_row) {
const int64_t n = ggml_nelements(op)/4;
if (pipeline.cnt) {
const int n = pipeline.c4 ? ggml_nelements(op)/4 : ggml_nelements(op);
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
} else {
int nth = 32;
const int nth_max = MIN(256, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
while (16*nth < ne0 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
int nth = 1;
while (2*nth < args.ne0 && nth < nth_max) {
nth *= 2;
}
@ -3049,39 +3121,59 @@ int ggml_metal_op_l2_norm(ggml_metal_op_t ctx, int idx) {
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]);
ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op);
float eps;
memcpy(&eps, op->op_params, sizeof(float));
int nth = 32; // SIMD width
ggml_metal_kargs_l2_norm args = {
/*.ne00 =*/ ne00,
/*.ne00_4 =*/ ne00/4,
/*.nb01 =*/ nb01,
/*.eps =*/ eps,
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.ne3 =*/ ne3,
/*.nb0 =*/ nb0,
/*.nb1 =*/ nb1,
/*.nb2 =*/ nb2,
/*.nb3 =*/ nb3,
/*.eps =*/ eps,
};
auto pipeline = ggml_metal_library_get_pipeline_l2_norm(lib, op);
while (nth < ne00/4 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
if (pipeline.c4) {
args.ne00 = ne00/4;
args.ne0 = ne0/4;
}
int nth = 32; // SIMD width
while (nth < ne00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) {
nth *= 2;
}
nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
nth = std::min(nth, ne00/4);
const size_t smem = pipeline.smem;
const int64_t nrows = ggml_nrows(op->src[0]);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_set_buffer (enc, bid_src0, 1);
ggml_metal_encoder_set_buffer (enc, bid_dst, 2);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, nrows, 1, 1, nth, 1, 1);
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
return 1;
}
@ -4089,42 +4181,6 @@ int ggml_metal_op_top_k(ggml_metal_op_t ctx, int idx) {
return 1;
}
int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
GGML_TENSOR_LOCALS(uint64_t, nb, op, nb);
float slope;
memcpy(&slope, op->op_params, sizeof(float));
ggml_metal_kargs_leaky_relu args = {
/*.slope =*/ slope
};
auto pipeline = ggml_metal_library_get_pipeline_unary(lib, op);
int64_t n = ggml_nelements(op);
if (n % 4 == 0) {
n /= 4;
}
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, 1, 1, 1);
return 1;
}
int ggml_metal_op_tri(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);

View File

@ -46,9 +46,6 @@ size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op);
int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_repeat (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_acc (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_fill (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx);
@ -62,6 +59,7 @@ int ggml_metal_op_ssm_conv (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_ssm_scan (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_rwkv (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_solve_tri (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_set (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_cpy (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_pool_1d (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_pool_2d (ggml_metal_op_t ctx, int idx);
@ -86,7 +84,6 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_top_k (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);

File diff suppressed because it is too large Load Diff

View File

@ -85,6 +85,9 @@ set(GGML_OPENCL_KERNELS
mul_mv_q4_0_f32_8x_flat
mul_mv_q4_0_f32_1d_8x_flat
mul_mv_q4_0_f32_1d_16x_flat
mul_mv_q4_1_f32
mul_mv_q4_1_f32_flat
mul_mv_q4_k_f32
mul_mv_q6_k_f32
mul_mv_q6_k_f32_flat
mul_mv_q8_0_f32
@ -100,7 +103,10 @@ set(GGML_OPENCL_KERNELS
gemv_moe_mxfp4_f32
mul_mm_f32_f32_l4_lm
mul_mm_f16_f32_l4_lm
mul_mm_q4_0_f32_l4_lm
mul_mm_q4_1_f32_l4_lm
mul_mm_q8_0_f32_l4_lm
mul_mm_q6_k_f32_l4_lm
mul_mm_q8_0_f32_8x4
gemv_noshuffle_general_q8_0_f32
mul

View File

@ -525,6 +525,7 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_mul_mm_f16_f32_kq;
cl_kernel kernel_mul_mat_q4_0_f32, kernel_mul_mat_q4_0_f32_v;
cl_kernel kernel_convert_block_q4_0, kernel_restore_block_q4_0;
cl_kernel kernel_convert_block_q4_1, kernel_restore_block_q4_1;
cl_kernel kernel_convert_block_mxfp4, kernel_convert_block_mxfp4_trans, kernel_restore_block_mxfp4, kernel_restore_block_mxfp4_trans;
cl_kernel kernel_convert_block_q8_0, kernel_restore_block_q8_0, kernel_restore_block_q8_0_trans;
cl_kernel kernel_mul_mat_q4_0_f32_8x_flat;
@ -532,6 +533,9 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_restore_block_q4_0_noshuffle;
cl_kernel kernel_convert_block_q6_K, kernel_restore_block_q6_K;
cl_kernel kernel_mul_mat_q4_0_f32_1d_8x_flat, kernel_mul_mat_q4_0_f32_1d_16x_flat;
cl_kernel kernel_mul_mv_q4_1_f32;
cl_kernel kernel_mul_mv_q4_1_f32_flat;
cl_kernel kernel_mul_mv_q4_K_f32;
cl_kernel kernel_mul_mv_q6_K_f32;
cl_kernel kernel_mul_mv_q6_K_f32_flat;
cl_kernel kernel_mul_mv_mxfp4_f32, kernel_mul_mv_mxfp4_f32_flat;
@ -563,7 +567,10 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_mul_mv_id_mxfp4_f32_flat;
cl_kernel kernel_mul_mm_f32_f32_l4_lm;
cl_kernel kernel_mul_mm_f16_f32_l4_lm;
cl_kernel kernel_mul_mm_q4_0_f32_l4_lm;
cl_kernel kernel_mul_mm_q4_1_f32_l4_lm;
cl_kernel kernel_mul_mm_q8_0_f32_l4_lm;
cl_kernel kernel_mul_mm_q6_k_f32_l4_lm;
std::vector<ProfilingInfo> profiling_info;
@ -886,6 +893,8 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
CL_CHECK((backend_ctx->kernel_restore_block_q4_0_noshuffle = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0_noshuffle", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_0", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_q4_0 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_0", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_q4_1", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_q4_1 = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_q4_1", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_mxfp4 = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4", &err), err));
CL_CHECK((backend_ctx->kernel_convert_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_convert_block_mxfp4_trans", &err), err));
CL_CHECK((backend_ctx->kernel_restore_block_mxfp4_trans = clCreateKernel(backend_ctx->program_cvt, "kernel_restore_block_mxfp4_trans", &err), err));
@ -1117,6 +1126,57 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT(".");
}
// mul_mv_q4_1_f32
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "mul_mv_q4_1_f32.cl.h"
};
#else
const std::string kernel_src = read_file("mul_mv_q4_1_f32.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_mul_mv_q4_1_f32 = clCreateKernel(prog, "kernel_mul_mv_q4_1_f32", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// mul_mv_q4_1_f32_flat
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "mul_mv_q4_1_f32_flat.cl.h"
};
#else
const std::string kernel_src = read_file("mul_mv_q4_1_f32_flat.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_mul_mv_q4_1_f32_flat = clCreateKernel(prog, "kernel_mul_mv_q4_1_f32_flat", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// mul_mv_q4_k_f32
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "mul_mv_q4_k_f32.cl.h"
};
#else
const std::string kernel_src = read_file("mul_mv_q4_k_f32.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_mul_mv_q4_K_f32 = clCreateKernel(prog, "kernel_mul_mv_q4_K_f32", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// mul_mv_q6_k_f32
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@ -1342,6 +1402,38 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT(".");
}
// mul_mm_q4_0_f32_l4_lm
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "mul_mm_q4_0_f32_l4_lm.cl.h"
};
#else
const std::string kernel_src = read_file("mul_mm_q4_0_f32_l4_lm.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_mul_mm_q4_0_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q4_0_f32_l4_lm", &err), err));
GGML_LOG_CONT(".");
}
// mul_mm_q4_1_f32_l4_lm
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "mul_mm_q4_1_f32_l4_lm.cl.h"
};
#else
const std::string kernel_src = read_file("mul_mm_q4_1_f32_l4_lm.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_mul_mm_q4_1_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q4_1_f32_l4_lm", &err), err));
GGML_LOG_CONT(".");
}
// mul_mm_q8_0_f32_l4_lm
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@ -1358,6 +1450,23 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT(".");
}
// mul_mm_q6_k_f32_l4_lm
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "mul_mm_q6_k_f32_l4_lm.cl.h"
};
#else
const std::string kernel_src = read_file("mul_mm_q6_k_f32_l4_lm.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_mul_mm_q6_k_f32_l4_lm = clCreateKernel(prog, "kernel_mul_mm_q6_k_f32_l4_lm", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// mul_mm_f16_f32_kq_kqv
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@ -2887,6 +2996,59 @@ struct ggml_tensor_extra_cl_q4_0 {
}
};
struct ggml_tensor_extra_cl_q4_1 {
// Quantized values.
cl_mem q = nullptr;
// Quantized values in image1d_buffer_t.
cl_mem q_img = nullptr;
// Scales.
cl_mem d = nullptr;
// Scales in image1d_buffer_t.
cl_mem d_img = nullptr;
// Min
cl_mem m = nullptr;
// Min in image1d_buffer_t.
cl_mem m_img = nullptr;
// Size of quantized values.
size_t size_q = 0;
// Size of scales.
size_t size_d = 0;
// Size of min values.
size_t size_m = 0;
~ggml_tensor_extra_cl_q4_1() {
reset();
}
void reset() {
// q and d are subbuffers into the bigger buffer allocated in ggml_backend_buffer.
// They must be properly released so that the original buffer can be
// properly released to avoid memory leak.
if (q != nullptr) {
CL_CHECK(clReleaseMemObject(q));
q = nullptr;
}
if (d != nullptr) {
CL_CHECK(clReleaseMemObject(d));
d = nullptr;
}
if (m != nullptr) {
CL_CHECK(clReleaseMemObject(m));
m = nullptr;
}
// Currently, q_img and d_img are only initialized when SMALL_ALLOC is
// enabled. They point to the images in ggml_backend_opencl_buffer_context.
// So, there is no need to release them here.
// TODO: initialize them for non SMALL_PATH path, or remove them.
q_img = nullptr;
d_img = nullptr;
m_img = nullptr;
size_q = 0;
size_d = 0;
size_m = 0;
}
};
struct ggml_tensor_extra_cl_mxfp4 {
// Quantized values.
cl_mem q = nullptr;
@ -3363,7 +3525,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
return true;
} else if (op->src[0]->type == GGML_TYPE_F32) {
return op->src[1]->type == GGML_TYPE_F32;
} else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_MXFP4 ||
} else if (op->src[0]->type == GGML_TYPE_Q4_0 || op->src[0]->type == GGML_TYPE_Q4_1 ||
op->src[0]->type == GGML_TYPE_MXFP4 ||
op->src[0]->type == GGML_TYPE_Q4_K ||
op->src[0]->type == GGML_TYPE_Q6_K) {
return op->src[1]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1]);
} else if (op->src[0]->type == GGML_TYPE_Q8_0) {
@ -3592,6 +3756,21 @@ struct ggml_backend_opencl_buffer_context {
return extra;
}
ggml_tensor_extra_cl_q4_1 * ggml_opencl_alloc_temp_tensor_extra_q4_1() {
ggml_tensor_extra_cl_q4_1 * extra;
if (temp_tensor_extras_q4_1.empty()) {
extra = new ggml_tensor_extra_cl_q4_1();
} else {
extra = temp_tensor_extras_q4_1.back();
temp_tensor_extras_q4_1.pop_back();
}
temp_tensor_extras_q4_1_in_use.push_back(extra);
extra->reset();
return extra;
}
ggml_tensor_extra_cl_mxfp4 * ggml_opencl_alloc_temp_tensor_extra_mxfp4() {
ggml_tensor_extra_cl_mxfp4 * extra;
if (temp_tensor_extras_mxfp4.empty()) {
@ -3648,6 +3827,11 @@ struct ggml_backend_opencl_buffer_context {
}
temp_tensor_extras_q4_0_in_use.clear();
for (ggml_tensor_extra_cl_q4_1 * e : temp_tensor_extras_q4_1_in_use) {
temp_tensor_extras_q4_1.push_back(e);
}
temp_tensor_extras_q4_1_in_use.clear();
for (ggml_tensor_extra_cl_mxfp4 * e : temp_tensor_extras_mxfp4_in_use) {
temp_tensor_extras_mxfp4.push_back(e);
}
@ -3673,6 +3857,8 @@ struct ggml_backend_opencl_buffer_context {
std::vector<ggml_tensor_extra_cl *> temp_tensor_extras_in_use;
std::vector<ggml_tensor_extra_cl_q4_0 *> temp_tensor_extras_q4_0;
std::vector<ggml_tensor_extra_cl_q4_0 *> temp_tensor_extras_q4_0_in_use;
std::vector<ggml_tensor_extra_cl_q4_1 *> temp_tensor_extras_q4_1;
std::vector<ggml_tensor_extra_cl_q4_1 *> temp_tensor_extras_q4_1_in_use;
std::vector<ggml_tensor_extra_cl_mxfp4 *> temp_tensor_extras_mxfp4;
std::vector<ggml_tensor_extra_cl_mxfp4 *> temp_tensor_extras_mxfp4_in_use;
std::vector<ggml_tensor_extra_cl_q8_0 *> temp_tensor_extras_q8_0;
@ -4042,6 +4228,75 @@ static void ggml_backend_opencl_buffer_set_tensor(ggml_backend_buffer_t buffer,
return;
}
if (tensor->type == GGML_TYPE_Q4_1) {
ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;
GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized");
// Allocate the new extra and create aliases from the original.
ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context;
ggml_tensor_extra_cl_q4_1 * extra = ctx->ggml_opencl_alloc_temp_tensor_extra_q4_1();
size_t size_d = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t);
size_t size_m = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*sizeof(ggml_fp16_t);
size_t size_q = ggml_nelements(tensor)/ggml_blck_size(tensor->type)*ggml_blck_size(tensor->type)/2;
GGML_ASSERT(size_d + size_m + size_q == ggml_nbytes(tensor) && "Incorrect tensor size");
cl_int err;
cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
ggml_nbytes(tensor), NULL, &err);
CL_CHECK(err);
CL_CHECK(clEnqueueWriteBuffer(
queue, data_device, CL_TRUE, 0,
ggml_nbytes(tensor), data, 0, NULL, NULL));
cl_buffer_region region;
// The original tensor memory is divided into scales and quants, i.e.,
// we first store scales, mins, then quants.
// Create subbuffer for scales.
region.origin = align_to(extra_orig->offset + tensor->view_offs + offset, backend_ctx->alignment);
region.size = size_d;
extra->d = clCreateSubBuffer(
extra_orig->data_device, CL_MEM_READ_WRITE,
CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
auto previous_origin = region.origin;
// Create subbuffer for mins.
region.origin = align_to(previous_origin + size_d, backend_ctx->alignment);
region.size = size_m;
extra->m = clCreateSubBuffer(
extra_orig->data_device, CL_MEM_READ_WRITE,
CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
previous_origin = region.origin;
// Create subbuffer for quants.
region.origin = align_to(previous_origin + size_m, backend_ctx->alignment);
region.size = size_q;
extra->q = clCreateSubBuffer(
extra_orig->data_device, CL_MEM_READ_WRITE,
CL_BUFFER_CREATE_TYPE_REGION, &region, &err);
CL_CHECK(err);
cl_kernel kernel = backend_ctx->kernel_convert_block_q4_1;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->q));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->d));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra->m));
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
size_t local_work_size[] = {64, 1, 1};
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL, global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clWaitForEvents(1, &evt));
CL_CHECK(clReleaseMemObject(data_device));
tensor->extra = extra;
return;
}
if (tensor->type == GGML_TYPE_MXFP4) {
ggml_tensor_extra_cl * extra_orig = (ggml_tensor_extra_cl *)tensor->extra;
GGML_ASSERT(extra_orig && "Tesnors in OpenCL backend should have been allocated and initialized");
@ -4544,7 +4799,35 @@ static void ggml_backend_opencl_buffer_get_tensor(ggml_backend_buffer_t buffer,
size, data, 0, NULL, NULL));
CL_CHECK(clReleaseMemObject(data_device));
return;
} else if (tensor->type == GGML_TYPE_MXFP4) {
}
if (tensor->type == GGML_TYPE_Q4_1) {
ggml_tensor_extra_cl_q4_1 * extra = (ggml_tensor_extra_cl_q4_1 *)tensor->extra;
cl_int err;
cl_mem data_device = clCreateBuffer(context, CL_MEM_READ_WRITE,
ggml_nbytes(tensor), NULL, &err);
CL_CHECK(err);
cl_kernel kernel = backend_ctx->kernel_restore_block_q4_1;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra->q));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra->d));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra->m));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &data_device));
size_t global_work_size[] = {(size_t)ggml_nelements(tensor)/ggml_blck_size(tensor->type), 1, 1};
size_t local_work_size[] = {1, 1, 1};
cl_event evt;
CL_CHECK(clEnqueueNDRangeKernel(queue, kernel, 3, NULL,
global_work_size, local_work_size, 0, NULL, &evt));
CL_CHECK(clWaitForEvents(1, &evt));
CL_CHECK(clEnqueueReadBuffer(
queue, data_device, CL_TRUE, offset,
size, data, 0, NULL, NULL));
CL_CHECK(clReleaseMemObject(data_device));
return;
}
if (tensor->type == GGML_TYPE_MXFP4) {
ggml_tensor_extra_cl_mxfp4 * extra = (ggml_tensor_extra_cl_mxfp4 *)tensor->extra;
cl_int err;
@ -8372,6 +8655,7 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
#ifdef GGML_OPENCL_SOA_Q
ggml_tensor_extra_cl_q4_0 * extra0_q4_0 = (ggml_tensor_extra_cl_q4_0 *)src0->extra;
ggml_tensor_extra_cl_q4_1 * extra0_q4_1 = (ggml_tensor_extra_cl_q4_1 *)src0->extra;
ggml_tensor_extra_cl_mxfp4 * extra0_mxfp4 = (ggml_tensor_extra_cl_mxfp4 *)src0->extra;
ggml_tensor_extra_cl_q8_0 * extra0_q8_0 = (ggml_tensor_extra_cl_q8_0 *)src0->extra;
ggml_tensor_extra_cl_q6_K * extra0_q6_K = (ggml_tensor_extra_cl_q6_K *)src0->extra;
@ -8885,6 +9169,91 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
return;
}
case GGML_TYPE_Q4_0: {
if (ne11 < 32) {
break;
}
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {
break;
}
kernel = backend_ctx->kernel_mul_mm_q4_0_f32_l4_lm;
nth0 = 128; // calculated as (BM*BN)/(TM*TN)
int batch_stride_a = ne00*ne01;
int batch_stride_b = ne10*ne11;
int batch_stride_d = ne0*ne1;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_0->q));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_0->d));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne11));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne10)); // stride_a
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_b
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne01)); // stride_d
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &batch_stride_a));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_b));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_d));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3));
// 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
size_t local_work_size[] = {(size_t)nth0, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
return;
}
case GGML_TYPE_Q4_1: {
if (ne11 < 32) {
break;
}
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {
break;
}
kernel = backend_ctx->kernel_mul_mm_q4_1_f32_l4_lm;
nth0 = 128; // calculated as (BM*BN)/(TM*TN)
int batch_stride_a = ne00*ne01;
int batch_stride_b = ne10*ne11;
int batch_stride_d = ne0*ne1;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_1->q));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_1->d));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_1->m));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne11));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne10)); // stride_a
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); // stride_b
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne01)); // stride_d
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &batch_stride_a));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_b));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_d));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r2));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r3));
// 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
size_t local_work_size[] = {(size_t)nth0, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
return;
}
case GGML_TYPE_Q8_0: {
if (ne11 < 32) {
break;
@ -8927,6 +9296,50 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
return;
}
case GGML_TYPE_Q6_K: {
if (ne11 < 32) {
break;
}
if (!ggml_is_contiguous(src0) || !ggml_is_contiguous(src1)) {
break;
}
kernel = backend_ctx->kernel_mul_mm_q6_k_f32_l4_lm;
nth0 = 128; // calculated as (BM*BN)/(TM*TN)
int batch_stride_a = ne00*ne01;
int batch_stride_b = ne10*ne11;
int batch_stride_d = ne0*ne1;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q6_K->ql));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q6_K->qh));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q6_K->s));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra0_q6_K->d));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne11));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne10)); // stride_a
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &ne10)); // stride_b
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne01)); // stride_d
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &batch_stride_a));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &batch_stride_b));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &batch_stride_d));
CL_CHECK(clSetKernelArg(kernel, 19, sizeof(int), &r2));
CL_CHECK(clSetKernelArg(kernel, 20, sizeof(int), &r3));
// 64 is block tile size BM and BN - change here when BM and BN in the kernel are changed.
size_t global_work_size[] = {(size_t)(CEIL_DIV(ne01, 64)*nth0), (size_t)(CEIL_DIV(ne11, 64)), (size_t)ne12*ne13};
size_t local_work_size[] = {(size_t)nth0, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
return;
}
default:
break;
}
@ -9181,7 +9594,71 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3));
#endif // GGML_OPENCL_SOA_Q
break;
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q4_1: {
#ifdef GGML_OPENCL_SOA_Q
if (backend_ctx->gpu_family == INTEL) {
nth0 = 16;
nth1 = 1;
ndst = 4;
} else if (backend_ctx->gpu_family == ADRENO) {
nth0 = 64;
nth1 = 1;
ndst = 4;
} else {
GGML_ASSERT(false && "TODO: Unknown GPU");
}
kernel = backend_ctx->kernel_mul_mv_q4_1_f32_flat;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0_q4_1->q));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_mem), &extra0_q4_1->d));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra0_q4_1->m));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne10));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne0));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &ne1));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r2));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &r3));
#else
if (backend_ctx->gpu_family == INTEL) {
nth0 = 16;
nth1 = 1;
ndst = 4;
} else if (backend_ctx->gpu_family == ADRENO) {
nth0 = 64;
nth1 = 1;
ndst = 4;
} else {
GGML_ASSERT(false && "TODO: Unknown GPU");
}
kernel = backend_ctx->kernel_mul_mv_q4_1_f32;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne0));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(int), &ne1));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(int), &r2));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(int), &r3));
#endif // GGML_OPENCL_SOA_Q
break;
}
case GGML_TYPE_Q8_0: {
#ifdef GGML_OPENCL_SOA_Q
kernel = backend_ctx->kernel_mul_mv_q8_0_f32_flat;
@ -9262,7 +9739,42 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
}
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q4_K: {
kernel = backend_ctx->kernel_mul_mv_q4_K_f32;
if (backend_ctx->gpu_family == INTEL) {
nth0 = 16;
nth1 = 1;
ndst = 4;
} else if (backend_ctx->gpu_family == ADRENO) {
nth0 = 64;
nth1 = 1;
ndst = 4;
} else {
GGML_ASSERT(false && "TODO: Unknown GPU");
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(int), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(int), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(int), &ne12));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb12));
CL_CHECK(clSetKernelArg(kernel, 14, sizeof(cl_ulong), &nb13));
CL_CHECK(clSetKernelArg(kernel, 15, sizeof(int), &ne0));
CL_CHECK(clSetKernelArg(kernel, 16, sizeof(int), &ne1));
CL_CHECK(clSetKernelArg(kernel, 17, sizeof(int), &r2));
CL_CHECK(clSetKernelArg(kernel, 18, sizeof(int), &r3));
break;
}
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
#ifdef GGML_OPENCL_SOA_Q
@ -9424,7 +9936,10 @@ static void ggml_cl_mul_mat(ggml_backend_t backend, const ggml_tensor * src0, co
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
} else if (src0t == GGML_TYPE_Q4_K) {
GGML_ASSERT(false && "not implemented");
size_t global_work_size[] = {(size_t)(ne01+ndst*nth1-1)/(ndst*nth1)*nth0, (size_t)ne11*nth1, (size_t)ne12*ne13};
size_t local_work_size[] = {(size_t)nth0, (size_t)nth1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
} else if (src0t == GGML_TYPE_Q3_K) {
GGML_ASSERT(false && "not implemented");
} else if (src0t == GGML_TYPE_Q5_K) {

View File

@ -46,6 +46,15 @@ struct block_q4_0
uint8_t qs[QK4_0 / 2];
};
//------------------------------------------------------------------------------
// block_q4_1
//------------------------------------------------------------------------------
struct block_q4_1 {
half d; // delta
half m; // min
uchar qs[QK4_1 / 2]; // nibbles / quants
};
//------------------------------------------------------------------------------
// block_q6_K
//------------------------------------------------------------------------------
@ -148,6 +157,48 @@ kernel void kernel_restore_block_q4_0_noshuffle(
}
}
//------------------------------------------------------------------------------
// kernel_convert_block_q4_1
// Convert the block_q4_1 format to 2 separate arrays (AOS -> SOA).
// This kernel does not deshuffle the bits.
//------------------------------------------------------------------------------
kernel void kernel_convert_block_q4_1(
global struct block_q4_1 * src0,
global uchar * dst_q,
global half * dst_d,
global half * dst_m
) {
global struct block_q4_1 * b = (global struct block_q4_1 *) src0 + get_global_id(0);
global uchar * q = (global uchar *) dst_q + QK4_1/2*get_global_id(0);
global half * d = (global half *) dst_d + get_global_id(0);
global half * m = (global half *) dst_m + get_global_id(0);
*d = b->d;
*m = b->m;
for (int i = 0; i < QK4_1/2; ++i) {
q[i] = b->qs[i];
}
}
kernel void kernel_restore_block_q4_1(
global uchar * src_q,
global half * src_d,
global half * src_m,
global struct block_q4_1 * dst
) {
global struct block_q4_1 * b = (global struct block_q4_1 *) dst + get_global_id(0);
global uchar * q = (global uchar *) src_q + QK4_1/2*get_global_id(0);
global half * d = (global half *) src_d + get_global_id(0);
global half * m = (global half *) src_m + get_global_id(0);
b->d = *d;
b->m = *m;
for (int i = 0; i < QK4_1/2; ++i) {
b->qs[i] = q[i];
}
}
//------------------------------------------------------------------------------
// block_mxfp4
//------------------------------------------------------------------------------

View File

@ -0,0 +1,163 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define LOAD_VEC_A 8
#define LOAD_VEC_B 4
#define BM 64
#define BN 64
#define BK 32
#define TM 4
#define TN 8
kernel void kernel_mul_mm_q4_0_f32_l4_lm(
global uchar4 * src0_q,
global half * src0_d,
global float4 * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne11,
int ne12,
int stride_a,
int stride_b,
int stride_d,
int batch_stride_a,
int batch_stride_b,
int batch_stride_d,
int r2,
int r3
) {
src1 = (global float4*)((global char*)src1 + offset1);
dst = (global float *)((global char*)dst + offsetd);
local float buf_a[BM * BK];
local float buf_b[BN * BK];
const int batch_idx = get_global_id(2);
const int i13 = batch_idx / ne12;
const int i12 = batch_idx % ne12;
const int i03 = i13 / r3;
const int i02 = i12 / r2;
const int batch_idx_a = i03 * ne02 + i02;
const int ir = get_group_id(0);
const int ic = get_group_id(1);
const int tid = get_local_id(0);
const int th_r = tid % (BM / TM);
const int th_c = tid / (BM / TM);
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
float sums[TM * TN];
float cache_a[TM];
float cache_b[TN];
for (int i = 0; i < TM * TN; i++) {
sums[i] = 0.0f;
}
for (int block = 0; block < ne00; block += BK) {
for (int l = 0; l < BM; l += loadstride_a) {
if (ir*BM + loadc_a + l < ne01) {
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
int ib = idx / 4;
int iqs = idx % 4;
float d = (float)src0_d[ib];
global uchar4 * qs = src0_q + ib*4 + iqs;
uchar4 q = *qs;
float4 v1 = (convert_float4((uchar4)((q.s0 )&0x0F, (q.s1 )&0x0F, (q.s2 )&0x0F, (q.s3 )&0x0F)) - 8.0f)*d;
float4 v2 = (convert_float4((uchar4)((q.s0>>4)&0x0F, (q.s1>>4)&0x0F, (q.s2>>4)&0x0F, (q.s3>>4)&0x0F)) - 8.0f)*d;
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0;
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1;
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2;
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3;
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0;
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1;
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2;
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3;
} else {
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f;
}
}
for (int l = 0; l < BN; l += loadstride_b) {
if (ic*BN + loadc_b + l < ne11) {
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
} else {
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
pos_a += BK / LOAD_VEC_A;
pos_b += BK / LOAD_VEC_B;
for (int i = 0; i < BK; i++) {
for (int j = 0; j < TM; j++) {
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
}
for (int j = 0; j < TN; j++) {
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
}
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
const int sums_idx = cc*TM + cr;
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
}
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
const int dr = ir * BM + th_r * TM;
const int dc = ic * BN + th_c * TN;
const int offsets = batch_idx * batch_stride_d;
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
if (dr + cr < ne01 && dc + cc < ne11) {
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
}
}
}
}

View File

@ -0,0 +1,165 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define LOAD_VEC_A 8
#define LOAD_VEC_B 4
#define BM 64
#define BN 64
#define BK 32
#define TM 4
#define TN 8
kernel void kernel_mul_mm_q4_1_f32_l4_lm(
global uchar4 * src0_q,
global half * src0_d,
global half * src0_m,
global float4 * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne11,
int ne12,
int stride_a,
int stride_b,
int stride_d,
int batch_stride_a,
int batch_stride_b,
int batch_stride_d,
int r2,
int r3
) {
src1 = (global float4*)((global char*)src1 + offset1);
dst = (global float *)((global char*)dst + offsetd);
local float buf_a[BM * BK];
local float buf_b[BN * BK];
const int batch_idx = get_global_id(2);
const int i13 = batch_idx / ne12;
const int i12 = batch_idx % ne12;
const int i03 = i13 / r3;
const int i02 = i12 / r2;
const int batch_idx_a = i03 * ne02 + i02;
const int ir = get_group_id(0);
const int ic = get_group_id(1);
const int tid = get_local_id(0);
const int th_r = tid % (BM / TM);
const int th_c = tid / (BM / TM);
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
float sums[TM * TN];
float cache_a[TM];
float cache_b[TN];
for (int i = 0; i < TM * TN; i++) {
sums[i] = 0.0f;
}
for (int block = 0; block < ne00; block += BK) {
for (int l = 0; l < BM; l += loadstride_a) {
if (ir*BM + loadc_a + l < ne01) {
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
int ib = idx / 4;
int iqs = idx % 4;
float d = (float)src0_d[ib];
float m = (float)src0_m[ib];
global uchar4 * qs = src0_q + ib*4 + iqs;
uchar4 q = *qs;
float4 v1 = (convert_float4((uchar4)((q.s0 )&0x0F, (q.s1 )&0x0F, (q.s2 )&0x0F, (q.s3 )&0x0F)))*d + m;
float4 v2 = (convert_float4((uchar4)((q.s0>>4)&0x0F, (q.s1>>4)&0x0F, (q.s2>>4)&0x0F, (q.s3>>4)&0x0F)))*d + m;
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = v1.s0;
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = v1.s1;
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = v1.s2;
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = v1.s3;
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = v2.s0;
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = v2.s1;
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = v2.s2;
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = v2.s3;
} else {
buf_a[(loadr_a * 4 + 0) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 1) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 2) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 3) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 16) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 17) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 18) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * 4 + 19) * BM + loadc_a + l] = 0.0f;
}
}
for (int l = 0; l < BN; l += loadstride_b) {
if (ic*BN + loadc_b + l < ne11) {
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
} else {
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
pos_a += BK / LOAD_VEC_A;
pos_b += BK / LOAD_VEC_B;
for (int i = 0; i < BK; i++) {
for (int j = 0; j < TM; j++) {
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
}
for (int j = 0; j < TN; j++) {
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
}
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
const int sums_idx = cc*TM + cr;
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
}
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
const int dr = ir * BM + th_r * TM;
const int dc = ic * BN + th_c * TN;
const int offsets = batch_idx * batch_stride_d;
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
if (dr + cr < ne01 && dc + cc < ne11) {
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
}
}
}
}

View File

@ -0,0 +1,158 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#define LOAD_VEC_A 2
#define LOAD_VEC_B 4
#define BM 64
#define BN 64
#define BK 32
#define TM 4
#define TN 8
kernel void kernel_mul_mm_q6_k_f32_l4_lm(
global uchar * src0_ql,
global uchar * src0_qh,
global char * src0_s,
global half * src0_d,
global float4 * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne11,
int ne12,
int stride_a,
int stride_b,
int stride_d,
int batch_stride_a,
int batch_stride_b,
int batch_stride_d,
int r2,
int r3
) {
src1 = (global float4*)((global char*)src1 + offset1);
dst = (global float *)((global char*)dst + offsetd);
local float buf_a[BM * BK];
local float buf_b[BN * BK];
const int batch_idx = get_global_id(2);
const int i13 = batch_idx / ne12;
const int i12 = batch_idx % ne12;
const int i03 = i13 / r3;
const int i02 = i12 / r2;
const int batch_idx_a = i03 * ne02 + i02;
const int ir = get_group_id(0);
const int ic = get_group_id(1);
const int tid = get_local_id(0);
const int th_r = tid % (BM / TM);
const int th_c = tid / (BM / TM);
const int loadr_a = get_local_id(0) % (BK / LOAD_VEC_A);
const int loadc_a = get_local_id(0) / (BK / LOAD_VEC_A);
const int loadr_b = get_local_id(0) % (BK / LOAD_VEC_B);
const int loadc_b = get_local_id(0) / (BK / LOAD_VEC_B);
const int loadstride_a = get_local_size(0) * LOAD_VEC_A / BK;
const int loadstride_b = get_local_size(0) * LOAD_VEC_B / BK;
int pos_a = (batch_idx_a * batch_stride_a + ir * BM * stride_a) / LOAD_VEC_A;
int pos_b = (batch_idx * batch_stride_b + ic * BN * stride_b) / LOAD_VEC_B;
float sums[TM * TN];
float cache_a[TM];
float cache_b[TN];
for (int i = 0; i < TM * TN; i++) {
sums[i] = 0.0f;
}
for (int block = 0; block < ne00; block += BK) {
for (int l = 0; l < BM; l += loadstride_a) {
if (ir*BM + loadc_a + l < ne01) {
int idx = pos_a + (loadc_a + l) * stride_a / LOAD_VEC_A + loadr_a;
int ib = idx / 128; // 2 values per idx
int iqs = idx % 128; // 0..127
int n = iqs / 64; // 0,1
int b = (iqs % 64) / 32; // 0,1
int is_b = (iqs % 16) / 8; // 0,1
int qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
int is = 8 * n + qhshift + is_b; // 0..15
int qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126
int qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62
float dscale = (float)src0_d[ib] * (float)src0_s[ib*16 + is];
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 0] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 0] >> qhshift) & 3) << 4)) - 32);
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = dscale * convert_float(convert_char(((src0_ql[128*ib + qsi + 1] >> (b * 4)) & 0xF) | (((src0_qh[64*ib + qhi + 1] >> qhshift) & 3) << 4)) - 32);
} else {
buf_a[(loadr_a * LOAD_VEC_A + 0) * BM + loadc_a + l] = 0.0f;
buf_a[(loadr_a * LOAD_VEC_A + 1) * BM + loadc_a + l] = 0.0f;
}
}
for (int l = 0; l < BN; l += loadstride_b) {
if (ic*BN + loadc_b + l < ne11) {
int idx = pos_b + (loadc_b + l) * stride_b / LOAD_VEC_B + loadr_b;
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = src1[idx].s0;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = src1[idx].s1;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = src1[idx].s2;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = src1[idx].s3;
} else {
buf_b[(loadr_b * LOAD_VEC_B + 0) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 1) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 2) * BN + loadc_b + l] = 0.0f;
buf_b[(loadr_b * LOAD_VEC_B + 3) * BN + loadc_b + l] = 0.0f;
}
}
barrier(CLK_LOCAL_MEM_FENCE);
pos_a += BK / LOAD_VEC_A;
pos_b += BK / LOAD_VEC_B;
for (int i = 0; i < BK; i++) {
for (int j = 0; j < TM; j++) {
cache_a[j] = buf_a[(i) * BM + th_r * TM + j];
}
for (int j = 0; j < TN; j++) {
cache_b[j] = buf_b[(i) * BN + th_c * TN + j];
}
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
const int sums_idx = cc*TM + cr;
sums[sums_idx] = mad(cache_a[cr], cache_b[cc], sums[sums_idx]);
}
}
}
barrier(CLK_LOCAL_MEM_FENCE);
}
const int dr = ir * BM + th_r * TM;
const int dc = ic * BN + th_c * TN;
const int offsets = batch_idx * batch_stride_d;
for (int cc = 0; cc < TN; cc++) {
for (int cr = 0; cr < TM; cr++) {
if (dr + cr < ne01 && dc + cc < ne11) {
dst[offsets + (dc + cc) * stride_d + dr + cr] = sums[cc * TM + cr];
}
}
}
}

View File

@ -0,0 +1,219 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#define QK4_1 32
struct block_q4_1 {
half d; // delta
half m; // min
uchar qs[QK4_1 / 2]; // nibbles / quants
};
inline float block_q4_1_dot_y(
global const struct block_q4_1 * qb_curr,
float sumy,
float16 yl,
int il
) {
float d = qb_curr->d;
float m = qb_curr->m;
float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f);
global const ushort * qs = ((global const ushort *) qb_curr + 2 + il/2);
acc.s0 += yl.s0 * (qs[0] & 0x000F);
acc.s0 += yl.s1 * (qs[0] & 0x0F00);
acc.s0 += yl.s8 * (qs[0] & 0x00F0);
acc.s3 += yl.s9 * (qs[0] & 0xF000);
acc.s0 += yl.s2 * (qs[1] & 0x000F);
acc.s1 += yl.s3 * (qs[1] & 0x0F00);
acc.s2 += yl.sa * (qs[1] & 0x00F0);
acc.s3 += yl.sb * (qs[1] & 0xF000);
acc.s0 += yl.s4 * (qs[2] & 0x000F);
acc.s1 += yl.s5 * (qs[2] & 0x0F00);
acc.s2 += yl.sc * (qs[2] & 0x00F0);
acc.s3 += yl.sd * (qs[2] & 0xF000);
acc.s0 += yl.s6 * (qs[3] & 0x000F);
acc.s1 += yl.s7 * (qs[3] & 0x0F00);
acc.s2 += yl.se * (qs[3] & 0x00F0);
acc.s3 += yl.sf * (qs[3] & 0xF000);
return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3) + sumy * m;
}
#undef N_DST
#undef N_SIMDGROUP
#undef N_SIMDWIDTH
#ifdef INTEL_GPU
#define N_DST 4 // each subgroup works on 4 rows
#define N_SIMDGROUP 1 // number of subgroups in a thread group
#define N_SIMDWIDTH 16 // assuming subgroup size is 16
#elif defined (ADRENO_GPU)
#define N_DST 4
#define N_SIMDGROUP 1
#define N_SIMDWIDTH 64
#endif
inline void mul_vec_q_n_f32(
global void * src0,
global float * src1,
global float * dst,
int ne00,
int ne01,
int ne02,
int ne10,
int ne12,
int ne0,
int ne1,
int r2,
int r3
) {
const ulong nb = ne00/QK4_1;
int r0 = get_group_id(0);
int r1 = get_group_id(1);
int im = get_group_id(2);
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
int i12 = im%ne12;
int i13 = im/ne12;
ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
global struct block_q4_1 * x = (global struct block_q4_1 *) src0 + offset0;
global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
float16 yl;
float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f);
int ix = get_sub_group_local_id()/2;
int il = 8*(get_sub_group_local_id()%2);
global float * yb = y + ix * QK4_1 + il;
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
float sumy = 0;
sumy += yb[0];
sumy += yb[1];
sumy += yb[2];
sumy += yb[3];
sumy += yb[4];
sumy += yb[5];
sumy += yb[6];
sumy += yb[7];
sumy += yb[16];
sumy += yb[17];
sumy += yb[18];
sumy += yb[19];
sumy += yb[20];
sumy += yb[21];
sumy += yb[22];
sumy += yb[23];
yl.s0 = yb[0];
yl.s1 = yb[1]/256.f;
yl.s2 = yb[2];
yl.s3 = yb[3]/256.f;
yl.s4 = yb[4];
yl.s5 = yb[5]/256.f;
yl.s6 = yb[6];
yl.s7 = yb[7]/256.f;
yl.s8 = yb[16]/16.f;
yl.s9 = yb[17]/4096.f;
yl.sa = yb[18]/16.f;
yl.sb = yb[19]/4096.f;
yl.sc = yb[20]/16.f;
yl.sd = yb[21]/4096.f;
yl.se = yb[22]/16.f;
yl.sf = yb[23]/4096.f;
sumf.s0 += block_q4_1_dot_y(x+ib+0*nb, sumy, yl, il);
sumf.s1 += block_q4_1_dot_y(x+ib+1*nb, sumy, yl, il);
sumf.s2 += block_q4_1_dot_y(x+ib+2*nb, sumy, yl, il);
sumf.s3 += block_q4_1_dot_y(x+ib+3*nb, sumy, yl, il);
yb += QK4_1 * (N_SIMDWIDTH/2);
}
float4 tot = (float4)(
sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3)
);
if (get_sub_group_local_id() == 0) {
if (first_row + 0 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
}
if (first_row + 1 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
}
if (first_row + 2 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
}
if (first_row + 3 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
}
}
}
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_16
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mv_q4_1_f32(
global void * src0,
ulong offset0,
global float * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne10,
int ne12,
int ne0,
int ne1,
int r2,
int r3
) {
src0 = (global void*)((global char*)src0 + offset0);
src1 = (global float*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
mul_vec_q_n_f32(src0, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
}

View File

@ -0,0 +1,229 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
#ifdef cl_intel_subgroups
#pragma OPENCL EXTENSION cl_intel_subgroups : enable
#else
#pragma OPENCL EXTENSION cl_khr_subgroups : enable
#endif
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
#define QK4_1 32
struct block_q4_1 {
half d; // delta
half m; // min
uchar qs[QK4_1 / 2]; // nibbles / quants
};
inline float block_q4_1_dot_y_flat(
global const uchar * x,
global const half * dh,
global const half * mh,
float sumy,
float16 yl,
int il
) {
float d = *dh;
float m = *mh;
global const ushort * qs = ((global const ushort *) x + il/2);
float4 acc = (float4)(0.0f, 0.0f, 0.0f, 0.0f);
acc.s0 += yl.s0 * (qs[0] & 0x000F);
acc.s0 += yl.s1 * (qs[0] & 0x0F00);
acc.s0 += yl.s8 * (qs[0] & 0x00F0);
acc.s3 += yl.s9 * (qs[0] & 0xF000);
acc.s0 += yl.s2 * (qs[1] & 0x000F);
acc.s1 += yl.s3 * (qs[1] & 0x0F00);
acc.s2 += yl.sa * (qs[1] & 0x00F0);
acc.s3 += yl.sb * (qs[1] & 0xF000);
acc.s0 += yl.s4 * (qs[2] & 0x000F);
acc.s1 += yl.s5 * (qs[2] & 0x0F00);
acc.s2 += yl.sc * (qs[2] & 0x00F0);
acc.s3 += yl.sd * (qs[2] & 0xF000);
acc.s0 += yl.s6 * (qs[3] & 0x000F);
acc.s1 += yl.s7 * (qs[3] & 0x0F00);
acc.s2 += yl.se * (qs[3] & 0x00F0);
acc.s3 += yl.sf * (qs[3] & 0xF000);
return d * (acc.s0 + acc.s1 + acc.s2 + acc.s3) + sumy * m;
}
#undef N_DST
#undef N_SIMDGROUP
#undef N_SIMDWIDTH
#ifdef INTEL_GPU
#define N_DST 4 // each subgroup works on 4 rows
#define N_SIMDGROUP 1 // number of subgroups in a thread group
#define N_SIMDWIDTH 16 // assuming subgroup size is 16
#elif defined (ADRENO_GPU)
#define N_DST 4
#define N_SIMDGROUP 1
#define N_SIMDWIDTH 64
#endif
inline void mul_vec_q_n_f32_flat(
global void * src0_q,
global void * src0_d,
global void * src0_m,
global float * src1,
global float * dst,
int ne00,
int ne01,
int ne02,
int ne10,
int ne12,
int ne0,
int ne1,
int r2,
int r3
) {
const ulong nb = ne00/QK4_1;
int r0 = get_group_id(0);
int r1 = get_group_id(1);
int im = get_group_id(2);
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
int i12 = im%ne12;
int i13 = im/ne12;
ulong offset0 = first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02);
// The number of scales/mins is the same as the number of blocks.
ulong offset0_dm = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02));
// Each block contains QK4_1/2 uchars, hence offset for qs is as follows.
ulong offset0_q = (first_row * nb + (i12/r2)*(nb*ne01) + (i13/r3)*(nb*ne01*ne02)) * QK4_1/2;
global uchar * x = (global uchar *) src0_q + offset0_q;
global half * d = (global half *) src0_d + offset0_dm;
global half * m = (global half *) src0_m + offset0_dm;
global float * y = (global float *) src1 + r1*ne10 + im*ne00*ne1;
float16 yl;
float4 sumf = (float4)(0.f, 0.f, 0.f, 0.f);
int ix = get_sub_group_local_id()/2;
int il = 8*(get_sub_group_local_id()%2);
global float * yb = y + ix * QK4_1 + il;
for (int ib = ix; ib < nb; ib += N_SIMDWIDTH/2) {
float sumy = 0;
sumy += yb[0];
sumy += yb[1];
sumy += yb[2];
sumy += yb[3];
sumy += yb[4];
sumy += yb[5];
sumy += yb[6];
sumy += yb[7];
sumy += yb[16];
sumy += yb[17];
sumy += yb[18];
sumy += yb[19];
sumy += yb[20];
sumy += yb[21];
sumy += yb[22];
sumy += yb[23];
yl.s0 = yb[0];
yl.s1 = yb[1]/256.f;
yl.s2 = yb[2];
yl.s3 = yb[3]/256.f;
yl.s4 = yb[4];
yl.s5 = yb[5]/256.f;
yl.s6 = yb[6];
yl.s7 = yb[7]/256.f;
yl.s8 = yb[16]/16.f;
yl.s9 = yb[17]/4096.f;
yl.sa = yb[18]/16.f;
yl.sb = yb[19]/4096.f;
yl.sc = yb[20]/16.f;
yl.sd = yb[21]/4096.f;
yl.se = yb[22]/16.f;
yl.sf = yb[23]/4096.f;
sumf.s0 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 0*nb*QK4_1/2, d + ib + 0*nb, m + ib + 0*nb, sumy, yl, il);
sumf.s1 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 1*nb*QK4_1/2, d + ib + 1*nb, m + ib + 1*nb, sumy, yl, il);
sumf.s2 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 2*nb*QK4_1/2, d + ib + 2*nb, m + ib + 2*nb, sumy, yl, il);
sumf.s3 += block_q4_1_dot_y_flat(x + ib*QK4_1/2 + 3*nb*QK4_1/2, d + ib + 3*nb, m + ib + 3*nb, sumy, yl, il);
yb += QK4_1 * (N_SIMDWIDTH/2);
}
float4 tot = (float4)(
sub_group_reduce_add(sumf.s0), sub_group_reduce_add(sumf.s1),
sub_group_reduce_add(sumf.s2), sub_group_reduce_add(sumf.s3)
);
if (get_sub_group_local_id() == 0) {
if (first_row + 0 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 0] = tot.s0;
}
if (first_row + 1 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 1] = tot.s1;
}
if (first_row + 2 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 2] = tot.s2;
}
if (first_row + 3 < ne01) {
dst[r1*ne0 + im*ne0*ne1 + first_row + 3] = tot.s3;
}
}
}
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_16
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mv_q4_1_f32_flat(
global void * src0_q,
global void * src0_d,
global void * src0_m,
global float * src1,
ulong offset1,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne10,
int ne12,
int ne0,
int ne1,
int r2,
int r3
) {
src1 = (global float*)((global char*)src1 + offset1);
dst = (global float*)((global char*)dst + offsetd);
mul_vec_q_n_f32_flat(src0_q, src0_d, src0_m, src1, dst, ne00, ne01, ne02, ne10, ne12, ne0, ne1, r2, r3);
}

View File

@ -0,0 +1,180 @@
#ifdef cl_intel_required_subgroup_size
#pragma OPENCL EXTENSION cl_intel_required_subgroup_size : enable
#define INTEL_GPU 1
#define REQD_SUBGROUP_SIZE_16 __attribute__((intel_reqd_sub_group_size(16)))
#define REQD_SUBGROUP_SIZE_32 __attribute__((intel_reqd_sub_group_size(32)))
#elif defined(cl_qcom_reqd_sub_group_size)
#pragma OPENCL EXTENSION cl_qcom_reqd_sub_group_size : enable
#define ADRENO_GPU 1
#define REQD_SUBGROUP_SIZE_64 __attribute__((qcom_reqd_sub_group_size("half")))
#define REQD_SUBGROUP_SIZE_128 __attribute__((qcom_reqd_sub_group_size("full")))
#endif
//------------------------------------------------------------------------------
// block_q4_K
//------------------------------------------------------------------------------
#define QK_K 256
#define K_SCALE_SIZE 12
// 8 blocks of 32 elements each
// weight is represented as x = a * q + b
typedef struct {
half d; // super-block scale for quantized scales
half dmin; // super-block scale for quantized mins
uchar scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
uchar qs[QK_K/2]; // 4-bit quants
} block_q4_K;
#undef N_DST
#undef N_SIMDGROUP
#undef N_SIMDWIDTH
#ifdef INTEL_GPU
#define N_DST 4 // number of rows each SIMD group works on
#define N_SIMDGROUP 1 // number of SIMD groups in a thread group
#define N_SIMDWIDTH 16 // SIMD group size
#elif defined (ADRENO_GPU)
#define N_DST 4
#define N_SIMDGROUP 1
#define N_SIMDWIDTH 64
#endif
#undef BLOCK_STRIDE
// number of (super) blocks each subgroup processes
// each thread in a subgroup processes a block (32 weights)
#define BLOCK_STRIDE (N_SIMDWIDTH/8)
#ifdef INTEL_GPU
REQD_SUBGROUP_SIZE_16
#elif defined (ADRENO_GPU)
REQD_SUBGROUP_SIZE_64
#endif
kernel void kernel_mul_mv_q4_K_f32(
global char * src0,
int offset0,
global char * src1,
int offset1,
global char * dst,
int offsetd,
int ne00,
int ne01,
ulong nb01,
ulong nb02,
ulong nb03,
int ne12,
ulong nb11,
ulong nb12,
ulong nb13,
int ne0,
int ne1,
int r2,
int r3
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
ushort kmask1 = 0x3f3f;
ushort kmask2 = 0x0f0f;
ushort kmask3 = 0xc0c0;
int ix = get_sub_group_local_id()/8; // super block index
int it = get_sub_group_local_id()%8; // block index (inside super block)
int iq = it/4; // 0 or 1 - first or second half of the super block
int ir = it%4; // 0...3 - block index in the half super block
int nb = ne00/QK_K;
int r0 = get_group_id(0);
int r1 = get_group_id(1);
int im = get_group_id(2);
int first_row = (r0 * N_SIMDGROUP + get_sub_group_id()) * N_DST;
int i12 = im%ne12;
int i13 = im/ne12;
int offset_src0 = first_row*nb01 + (i12/r2)*nb02 + (i13/r3)*nb03;
int offset_src1 = r1*nb11 + (i12 )*nb12 + (i13 )*nb13;
global block_q4_K * x = (global block_q4_K *) (src0 + offset_src0);
global float * y = (global float *) (src1 + offset_src1);
float yl[16];
float yh[16];
float sumf[N_DST] = {0.f};
float all_sum;
global float * y4 = y + ix * QK_K + 64 * iq + 8 * ir;
ushort sc16[4];
uchar * sc8 = (uchar *)sc16;
for (int ib = ix; ib < nb; ib += BLOCK_STRIDE) {
float4 sumy = {0.f, 0.f, 0.f, 0.f};
for (int i = 0; i < 8; ++i) {
yl[i+0] = y4[i+0];
sumy.s0 += yl[i+0];
yl[i+8] = y4[i+32];
sumy.s1 += yl[i+8];
yh[i+0] = y4[i+128];
sumy.s2 += yh[i+0];
yh[i+8] = y4[i+160];
sumy.s3 += yh[i+8];
}
global ushort * sc = (global ushort *)x[ib].scales + iq;
global ushort * q1 = (global ushort *)x[ib].qs + 16 * iq + 4 * ir;
global half * dh = &x[ib].d;
for (int row = 0; row < N_DST; row++) {
sc16[0] = sc[0] & kmask1;
sc16[1] = sc[2] & kmask1;
sc16[2] = ((sc[4] >> 0) & kmask2) | ((sc[0] & kmask3) >> 2);
sc16[3] = ((sc[4] >> 4) & kmask2) | ((sc[2] & kmask3) >> 2);
global ushort * q2 = q1 + 32;
float4 acc1 = {0.f, 0.f, 0.f, 0.f};
float4 acc2 = {0.f, 0.f, 0.f, 0.f};
for (int i = 0; i < 8; i += 2) {
acc1.s0 += yl[i+0] * (q1[i/2] & 0x000F);
acc1.s1 += yl[i+1] * (q1[i/2] & 0x0F00);
acc1.s2 += yl[i+8] * (q1[i/2] & 0x00F0);
acc1.s3 += yl[i+9] * (q1[i/2] & 0xF000);
acc2.s0 += yh[i+0] * (q2[i/2] & 0x000F);
acc2.s1 += yh[i+1] * (q2[i/2] & 0x0F00);
acc2.s2 += yh[i+8] * (q2[i/2] & 0x00F0);
acc2.s3 += yh[i+9] * (q2[i/2] & 0xF000);
}
float dall = dh[0];
float dmin = dh[1];
sumf[row] += dall * ((acc1.s0 + 1.f/256.f * acc1.s1) * sc8[0] +
(acc1.s2 + 1.f/256.f * acc1.s3) * sc8[1] * 1.f/16.f +
(acc2.s0 + 1.f/256.f * acc2.s1) * sc8[4] +
(acc2.s2 + 1.f/256.f * acc2.s3) * sc8[5] * 1.f/16.f) -
dmin * (sumy.s0 * sc8[2] + sumy.s1 * sc8[3] + sumy.s2 * sc8[6] + sumy.s3 * sc8[7]);
q1 += nb01/2;
sc += nb01/2;
dh += nb01/2;
}
y4 += BLOCK_STRIDE * QK_K;
}
global float * dst_f32 = (global float *) dst + im*ne0*ne1 + r1*ne0;
for (int row = 0; row < N_DST; ++row) {
all_sum = sub_group_reduce_add(sumf[row]);
if (first_row + row < ne01) {
if (get_sub_group_local_id() == 0) {
dst_f32[first_row + row] = all_sum;
}
}
}
}

View File

@ -836,16 +836,9 @@ static inline void ggml_sycl_op_floor(ggml_backend_sycl_context & ctx, ggml_tens
}
static inline void ggml_sycl_op_ceil(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {
ggml_sycl_detail::dispatch_ggml_sycl_op_unary(ctx, dst,
[](const auto* src, auto* dst_ptr, int k_elements, queue_ptr stream) {
const int num_blocks = ceil_div(k_elements, 256);
stream->parallel_for(
sycl::nd_range<1>(sycl::range<1>(num_blocks) * sycl::range<1>(256),
sycl::range<1>(256)),
[=](sycl::nd_item<1> item_ct1) {
unary_op_ceil_kernel(src, dst_ptr, k_elements, item_ct1);
});
});
ggml_sycl_detail::ggml_sycl_op_unary(ctx, dst, [](auto x) {
return op_ceil(x);
});
}
static inline void ggml_sycl_op_round(ggml_backend_sycl_context & ctx, ggml_tensor * dst) {

View File

@ -4591,9 +4591,9 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_SOFTPLUS:
case GGML_UNARY_OP_ELU:
case GGML_UNARY_OP_CEIL:
return true;
case GGML_UNARY_OP_FLOOR:
case GGML_UNARY_OP_CEIL:
case GGML_UNARY_OP_ROUND:
case GGML_UNARY_OP_TRUNC:
#if defined (GGML_SYCL_F16)

View File

@ -4,6 +4,7 @@
#include "ggml.h"
#include "pre_wgsl.hpp"
#include <memory>
#include <string>
#include <vector>
@ -18,9 +19,9 @@
#define GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE 512u
struct ggml_webgpu_processed_shader {
std::string wgsl;
std::string variant;
void * decisions;
std::string wgsl;
std::string variant;
std::shared_ptr<void> decisions;
};
// Same hash combine function as in boost
@ -192,13 +193,13 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_flash_attn_shader(
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
ggml_webgpu_flash_attn_shader_decisions * decisions = new ggml_webgpu_flash_attn_shader_decisions();
decisions->q_tile = q_tile;
decisions->kv_tile = kv_tile;
decisions->wg_size = wg_size;
result.decisions = decisions;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
auto decisions = std::make_shared<ggml_webgpu_flash_attn_shader_decisions>();
decisions->q_tile = q_tile;
decisions->kv_tile = kv_tile;
decisions->wg_size = wg_size;
result.decisions = decisions;
return result;
}
@ -270,11 +271,11 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_pad_shader(
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions();
decisions->wg_size = context.max_wg_size;
result.decisions = decisions;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
result.decisions = decisions;
return result;
}
@ -305,11 +306,11 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_shader(
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
ggml_webgpu_argsort_shader_decisions * decisions = new ggml_webgpu_argsort_shader_decisions();
decisions->wg_size = wg_size;
result.decisions = decisions;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
auto decisions = std::make_shared<ggml_webgpu_argsort_shader_decisions>();
decisions->wg_size = wg_size;
result.decisions = decisions;
return result;
}
@ -324,11 +325,11 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_argsort_merge_shader(
uint32_t wg_size = std::min(GGML_WEBGPU_ARGSORT_MERGE_MAX_WG_SIZE, context.max_wg_size);
defines.push_back(std::string("WG_SIZE=") + std::to_string(wg_size));
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
ggml_webgpu_argsort_shader_decisions * decisions = new ggml_webgpu_argsort_shader_decisions();
decisions->wg_size = wg_size;
result.decisions = decisions;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
auto decisions = std::make_shared<ggml_webgpu_argsort_shader_decisions>();
decisions->wg_size = wg_size;
result.decisions = decisions;
return result;
}
@ -391,11 +392,11 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_set_rows_shader(
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions();
decisions->wg_size = context.max_wg_size;
result.decisions = decisions;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
result.decisions = decisions;
return result;
}
@ -457,12 +458,81 @@ inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_unary_shader(
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
ggml_webgpu_generic_shader_decisions * decisions = new ggml_webgpu_generic_shader_decisions();
decisions->wg_size = context.max_wg_size;
result.decisions = decisions;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
result.decisions = decisions;
return result;
}
/** Binary **/
struct ggml_webgpu_binary_pipeline_key {
int type;
int op;
bool inplace;
bool overlap;
bool operator==(const ggml_webgpu_binary_pipeline_key & other) const {
return type == other.type && op == other.op && inplace == other.inplace && overlap == other.overlap;
}
};
struct ggml_webgpu_binary_pipeline_key_hash {
size_t operator()(const ggml_webgpu_binary_pipeline_key & key) const {
size_t seed = 0;
ggml_webgpu_hash_combine(seed, key.type);
ggml_webgpu_hash_combine(seed, key.op);
ggml_webgpu_hash_combine(seed, key.inplace);
ggml_webgpu_hash_combine(seed, key.overlap);
return seed;
}
};
struct ggml_webgpu_binary_shader_lib_context {
ggml_webgpu_binary_pipeline_key key;
uint32_t max_wg_size;
};
inline ggml_webgpu_processed_shader ggml_webgpu_preprocess_binary_shader(
pre_wgsl::Preprocessor & preprocessor,
const char * shader_src,
const ggml_webgpu_binary_shader_lib_context & context) {
std::vector<std::string> defines;
std::string op_name = ggml_op_name((ggml_op) context.key.op);
std::string variant = op_name;
defines.push_back(std::string("OP_") + op_name);
switch (context.key.type) {
case GGML_TYPE_F32:
defines.push_back("TYPE_F32");
variant += "_f32";
break;
case GGML_TYPE_F16:
defines.push_back("TYPE_F16");
variant += "_f16";
break;
default:
GGML_ABORT("Unsupported type for binary shader");
}
if (context.key.inplace) {
defines.push_back("INPLACE");
variant += "_inplace";
} else if (context.key.overlap) {
defines.push_back("OVERLAP");
variant += "_overlap";
}
defines.push_back(std::string("WG_SIZE=") + std::to_string(context.max_wg_size));
ggml_webgpu_processed_shader result;
result.wgsl = preprocessor.preprocess(shader_src, defines);
result.variant = variant;
auto decisions = std::make_shared<ggml_webgpu_generic_shader_decisions>();
decisions->wg_size = context.max_wg_size;
result.decisions = decisions;
return result;
}
#endif // GGML_WEBGPU_SHADER_LIB_HPP

View File

@ -186,11 +186,17 @@ struct webgpu_buf_pool {
void cleanup() {
std::lock_guard<std::mutex> lock(mutex);
for (auto & bufs : free) {
bufs.host_buf.Destroy();
bufs.dev_buf.Destroy();
if (bufs.host_buf) {
bufs.host_buf.Destroy();
}
if (bufs.dev_buf) {
bufs.dev_buf.Destroy();
}
}
free.clear();
}
~webgpu_buf_pool() { this->cleanup(); }
};
#ifdef GGML_WEBGPU_GPU_PROFILE
@ -252,13 +258,15 @@ struct webgpu_gpu_profile_buf_pool {
}
free.clear();
}
~webgpu_gpu_profile_buf_pool() { this->cleanup(); }
};
#endif
struct webgpu_pipeline {
wgpu::ComputePipeline pipeline;
std::string name;
void * context = nullptr;
std::shared_ptr<void> context = nullptr;
};
struct webgpu_command {
@ -319,6 +327,23 @@ struct webgpu_global_context_struct {
wgpu::Buffer debug_host_buf;
wgpu::Buffer debug_dev_buf;
#endif
~webgpu_global_context_struct() {
if (this->get_tensor_staging_buf) {
this->get_tensor_staging_buf.Destroy();
this->get_tensor_staging_buf = nullptr;
}
#ifdef GGML_WEBGPU_DEBUG
if (this->debug_host_buf) {
this->debug_host_buf.Destroy();
this->debug_host_buf = nullptr;
}
if (this->debug_dev_buf) {
this->debug_dev_buf.Destroy();
this->debug_dev_buf = nullptr;
}
#endif
}
};
typedef std::shared_ptr<webgpu_global_context_struct> webgpu_global_context;
@ -348,13 +373,12 @@ struct webgpu_context_struct {
std::unordered_map<ggml_webgpu_set_rows_pipeline_key, webgpu_pipeline, ggml_webgpu_set_rows_pipeline_key_hash>
set_rows_pipelines;
std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines; // src_type, vectorized
std::map<int, std::map<int, webgpu_pipeline>> get_rows_pipelines; // src_type, vectorized
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
std::map<int, std::map<int, webgpu_pipeline>> add_pipelines; // type, inplace
std::map<int, std::map<int, webgpu_pipeline>> sub_pipelines; // type, inplace
std::map<int, std::map<int, webgpu_pipeline>> mul_pipelines; // type, inplace
std::map<int, std::map<int, webgpu_pipeline>> div_pipelines; // type, inplace
std::map<int, std::map<int, webgpu_pipeline>> cpy_pipelines; // src_type, dst_type
std::unordered_map<ggml_webgpu_binary_pipeline_key, webgpu_pipeline, ggml_webgpu_binary_pipeline_key_hash>
binary_pipelines;
std::map<int, webgpu_pipeline> rms_norm_pipelines; // inplace
std::map<int, std::map<int, std::map<int, webgpu_pipeline>>> rope_pipelines; // type, ff, inplace
@ -745,7 +769,6 @@ static const char * ggml_backend_webgpu_name(ggml_backend_t backend) {
return ctx->name.c_str();
}
// TODO: implement proper cleanup
static void ggml_backend_webgpu_free(ggml_backend_t backend) {
ggml_backend_webgpu_context * ctx = (ggml_backend_webgpu_context *) backend->context;
WEBGPU_LOG_DEBUG("ggml_backend_webgpu_free(" << ctx->name << ")");
@ -789,9 +812,8 @@ static void ggml_backend_webgpu_free(ggml_backend_t backend) {
std::cout << "ggml_webgpu: gpu/cpu ratio: " << (total_cpu > 0.0 ? total_gpu / total_cpu : 0.0) << "\n";
#endif
#if !defined(GGML_WEBGPU_CPU_PROFILE) && !defined(GGML_WEBGPU_GPU_PROFILE)
GGML_UNUSED(ctx);
#endif
delete ctx;
delete backend;
}
static size_t ggml_webgpu_tensor_offset(const ggml_tensor * tensor) {
@ -823,6 +845,28 @@ static bool ggml_webgpu_tensor_equal(ggml_tensor * a, ggml_tensor * b) {
(ggml_webgpu_tensor_offset(a) == ggml_webgpu_tensor_offset(b));
}
// Used to determine if two tensors share the same buffer and their byte ranges overlap,
static bool ggml_webgpu_tensor_overlap(ggml_tensor * a, ggml_tensor * b) {
return (ggml_webgpu_tensor_buf(a).Get() == ggml_webgpu_tensor_buf(b).Get()) &&
ggml_webgpu_tensor_offset(a) < (ggml_webgpu_tensor_offset(b) + ggml_nbytes(b)) &&
ggml_webgpu_tensor_offset(b) < (ggml_webgpu_tensor_offset(a) + ggml_nbytes(a));
}
struct binary_overlap_flags {
bool inplace; // src0 == dst
bool overlap; // src1 == dst
};
static binary_overlap_flags ggml_webgpu_detect_binary_overlap(ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst) {
binary_overlap_flags flags = {};
flags.inplace = ggml_webgpu_tensor_equal(src0, dst);
flags.overlap = ggml_webgpu_tensor_overlap(src1, dst);
return flags;
}
static webgpu_command ggml_webgpu_cpy(webgpu_context & ctx, ggml_tensor * src, ggml_tensor * dst) {
uint32_t ne = (uint32_t) ggml_nelements(dst);
@ -875,8 +919,7 @@ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, g
ctx->pad_pipelines.emplace(pipeline_key, pipeline);
}
ggml_webgpu_generic_shader_decisions decisions =
*static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
const uint32_t ne = (uint32_t) ggml_nelements(dst);
@ -920,7 +963,7 @@ static webgpu_command ggml_webgpu_pad(webgpu_context & ctx, ggml_tensor * src, g
.size = ggml_webgpu_tensor_binding_size(ctx, dst) }
};
uint32_t wg_x = CEIL_DIV(ne, decisions.wg_size);
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
@ -954,8 +997,7 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
ctx->set_rows_pipelines.emplace(key, pipeline);
}
ggml_webgpu_generic_shader_decisions decisions =
*static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
std::optional<webgpu_pool_bufs> error_bufs = std::nullopt;
if (key.i64_idx) {
@ -1007,7 +1049,7 @@ static std::optional<webgpu_command> ggml_webgpu_set_rows(webgpu_context & ctx,
} else {
threads = src->ne[0] * src->ne[1] * src->ne[2] * src->ne[3];
}
uint32_t wg_x = CEIL_DIV(threads, decisions.wg_size);
uint32_t wg_x = CEIL_DIV(threads, decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x, 1,
error_bufs);
}
@ -1276,10 +1318,9 @@ static webgpu_command ggml_webgpu_flash_attn(webgpu_context & ctx,
ctx->flash_attn_pipelines.emplace(key, pipeline);
}
ggml_webgpu_flash_attn_shader_decisions decisions =
*static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context);
auto * decisions = static_cast<ggml_webgpu_flash_attn_shader_decisions *>(pipeline.context.get());
uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions.q_tile);
uint32_t wg_per_head = CEIL_DIV(Q->ne[1], decisions->q_tile);
uint32_t wg_x = wg_per_head * Q->ne[2] * Q->ne[3]; // wg per head * number of heads * number of batches
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
@ -1310,8 +1351,7 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s
ctx->unary_pipelines.emplace(pipeline_key, pipeline);
}
ggml_webgpu_generic_shader_decisions decisions =
*static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context);
auto * decisions = static_cast<ggml_webgpu_generic_shader_decisions *>(pipeline.context.get());
uint32_t ne = (uint32_t) ggml_nelements(dst);
@ -1371,18 +1411,45 @@ static webgpu_command ggml_webgpu_unary_op(webgpu_context & ctx, ggml_tensor * s
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
}
uint32_t wg_x = CEIL_DIV(ne, decisions.wg_size);
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst,
webgpu_pipeline & pipeline,
bool inplace) {
static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
ggml_tensor * src0,
ggml_tensor * src1,
ggml_tensor * dst) {
binary_overlap_flags flags = ggml_webgpu_detect_binary_overlap(src0, src1, dst);
ggml_webgpu_binary_pipeline_key pipeline_key = {
.type = dst->type,
.op = dst->op,
.inplace = flags.inplace,
.overlap = flags.overlap,
};
ggml_webgpu_binary_shader_lib_context shader_lib_ctx = {
.key = pipeline_key, .max_wg_size = ctx->global_ctx->capabilities.limits.maxComputeInvocationsPerWorkgroup
};
webgpu_pipeline pipeline;
auto it = ctx->binary_pipelines.find(pipeline_key);
if (it != ctx->binary_pipelines.end()) {
pipeline = it->second;
} else {
ggml_webgpu_processed_shader processed =
ggml_webgpu_preprocess_binary_shader(ctx->p, wgsl_binary, shader_lib_ctx);
pipeline =
ggml_webgpu_create_pipeline(ctx->global_ctx->device, processed.wgsl.c_str(), processed.variant.c_str());
pipeline.context = processed.decisions;
ctx->binary_pipelines.emplace(pipeline_key, pipeline);
}
auto * decisions = static_cast<ggml_webgpu_argsort_shader_decisions *>(pipeline.context.get());
uint32_t ne = (uint32_t) ggml_nelements(dst);
std::vector<uint32_t> params = {
(uint32_t) ggml_nelements(dst),
ne,
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src0) / ggml_type_size(src0->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, src1) / ggml_type_size(src1->type)),
(uint32_t) (ggml_webgpu_tensor_misalignment(ctx, dst) / ggml_type_size(dst->type)),
@ -1399,24 +1466,30 @@ static webgpu_command ggml_webgpu_binary_op(webgpu_context & ctx,
(uint32_t) src1->ne[3],
};
std::vector<wgpu::BindGroupEntry> entries = {
{ .binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0) },
{ .binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1) }
};
if (!inplace) {
std::vector<wgpu::BindGroupEntry> entries;
entries.push_back({
.binding = 0,
.buffer = ggml_webgpu_tensor_buf(src0),
.offset = ggml_webgpu_tensor_align_offset(ctx, src0),
.size = ggml_webgpu_tensor_binding_size(ctx, src0),
});
entries.push_back({
.binding = 1,
.buffer = ggml_webgpu_tensor_buf(src1),
.offset = ggml_webgpu_tensor_align_offset(ctx, src1),
.size = ggml_webgpu_tensor_binding_size(ctx, src1),
});
if (!flags.inplace && !flags.overlap) {
entries.push_back({ .binding = 2,
.buffer = ggml_webgpu_tensor_buf(dst),
.offset = ggml_webgpu_tensor_align_offset(ctx, dst),
.size = ggml_webgpu_tensor_binding_size(ctx, dst) });
}
uint32_t wg_x = CEIL_DIV(ggml_nelements(dst), WEBGPU_MAX_WG_SIZE);
uint32_t wg_x = CEIL_DIV(ne, decisions->wg_size);
return ggml_backend_webgpu_build(ctx->global_ctx, ctx->param_buf_pool, pipeline, params, entries, wg_x);
}
@ -1766,8 +1839,7 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr
argsort_pipeline.context = processed.decisions;
ctx->argsort_pipelines.emplace(order, argsort_pipeline);
}
ggml_webgpu_argsort_shader_decisions argsort_decisions =
*static_cast<ggml_webgpu_argsort_shader_decisions *>(argsort_pipeline.context);
auto * argsort_decisions = static_cast<ggml_webgpu_argsort_shader_decisions *>(argsort_pipeline.context.get());
webgpu_pipeline argsort_merge_pipeline;
it = ctx->argsort_merge_pipelines.find(order);
@ -1784,13 +1856,13 @@ static webgpu_command ggml_webgpu_argsort(webgpu_context & ctx, ggml_tensor * sr
const uint32_t src_ne0 = (uint32_t) src->ne[0];
const uint32_t nrows = (uint32_t) ggml_nrows(src);
const uint32_t npr = CEIL_DIV(src_ne0, argsort_decisions.wg_size);
const uint32_t npr = CEIL_DIV(src_ne0, argsort_decisions->wg_size);
const uint32_t block_size =
is_top_k ? std::min(argsort_decisions.wg_size, (uint32_t) dst->ne[0]) : argsort_decisions.wg_size;
is_top_k ? std::min(argsort_decisions->wg_size, (uint32_t) dst->ne[0]) : argsort_decisions->wg_size;
uint32_t out_ne0 = src_ne0;
if (is_top_k) {
if (npr > 1) {
const uint32_t last_tile = src_ne0 - (npr - 1) * argsort_decisions.wg_size;
const uint32_t last_tile = src_ne0 - (npr - 1) * argsort_decisions->wg_size;
out_ne0 = (npr - 1) * block_size + std::min(last_tile, block_size);
} else {
out_ne0 = block_size;
@ -2038,25 +2110,10 @@ static std::optional<webgpu_command> ggml_webgpu_encode_node(webgpu_context ctx,
return std::nullopt;
#endif
case GGML_OP_ADD:
{
int inplace = ggml_webgpu_tensor_equal(src0, node);
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->add_pipelines[node->type][inplace], inplace);
}
case GGML_OP_SUB:
{
int inplace = ggml_webgpu_tensor_equal(src0, node);
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->sub_pipelines[node->type][inplace], inplace);
}
case GGML_OP_MUL:
{
int inplace = ggml_webgpu_tensor_equal(src0, node);
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->mul_pipelines[node->type][inplace], inplace);
}
case GGML_OP_DIV:
{
int inplace = ggml_webgpu_tensor_equal(src0, node);
return ggml_webgpu_binary_op(ctx, src0, src1, node, ctx->div_pipelines[node->type][inplace], inplace);
}
return ggml_webgpu_binary_op(ctx, src0, src1, node);
case GGML_OP_RMS_NORM:
return ggml_webgpu_rms_norm(ctx, src0, node);
case GGML_OP_ROPE:
@ -2158,7 +2215,10 @@ static ggml_backend_i ggml_backend_webgpu_i = {
static void ggml_backend_webgpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
ggml_backend_webgpu_buffer_context * ctx = static_cast<ggml_backend_webgpu_buffer_context *>(buffer->context);
ctx->buffer.Destroy();
if (ctx != nullptr && ctx->buffer != nullptr) {
ctx->buffer.Destroy();
delete ctx;
}
}
// Returns the "fake" base pointer.
@ -2665,58 +2725,6 @@ static void ggml_webgpu_init_cpy_pipeline(webgpu_context & webgpu_ctx) {
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_cpy_f16_f16, "cpy_f16_f16", constants);
}
static void ggml_webgpu_init_add_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
webgpu_ctx->add_pipelines[GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f32, "add_f32", constants);
webgpu_ctx->add_pipelines[GGML_TYPE_F16][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f16, "add_f16", constants);
webgpu_ctx->add_pipelines[GGML_TYPE_F32][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f32_inplace, "add_f32_inplace", constants);
webgpu_ctx->add_pipelines[GGML_TYPE_F16][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_add_f16_inplace, "add_f16_inplace", constants);
}
static void ggml_webgpu_init_sub_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
webgpu_ctx->sub_pipelines[GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f32, "sub_f32", constants);
webgpu_ctx->sub_pipelines[GGML_TYPE_F16][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f16, "sub_f16", constants);
webgpu_ctx->sub_pipelines[GGML_TYPE_F32][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f32_inplace, "sub_f32_inplace", constants);
webgpu_ctx->sub_pipelines[GGML_TYPE_F16][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_sub_f16_inplace, "sub_f16_inplace", constants);
}
static void ggml_webgpu_init_mul_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
webgpu_ctx->mul_pipelines[GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f32, "mul_f32", constants);
webgpu_ctx->mul_pipelines[GGML_TYPE_F16][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f16, "mul_f16", constants);
webgpu_ctx->mul_pipelines[GGML_TYPE_F32][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f32_inplace, "mul_f32_inplace", constants);
webgpu_ctx->mul_pipelines[GGML_TYPE_F16][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_mul_f16_inplace, "mul_f16_inplace", constants);
}
static void ggml_webgpu_init_div_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_MAX_WG_SIZE);
webgpu_ctx->div_pipelines[GGML_TYPE_F32][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f32, "div_f32", constants);
webgpu_ctx->div_pipelines[GGML_TYPE_F16][0] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f16, "div_f16", constants);
webgpu_ctx->div_pipelines[GGML_TYPE_F32][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f32_inplace, "div_f32_inplace", constants);
webgpu_ctx->div_pipelines[GGML_TYPE_F16][1] =
ggml_webgpu_create_pipeline(webgpu_ctx->global_ctx->device, wgsl_div_f16_inplace, "div_f16_inplace", constants);
}
static void ggml_webgpu_init_rms_norm_pipeline(webgpu_context & webgpu_ctx) {
std::vector<wgpu::ConstantEntry> constants = ggml_webgpu_wg_size_entry(WEBGPU_ROW_SPLIT_WG_SIZE);
@ -2938,12 +2946,12 @@ static bool create_webgpu_device(ggml_backend_webgpu_reg_context * ctx) {
dev_desc.SetDeviceLostCallback(
wgpu::CallbackMode::AllowSpontaneous,
[](const wgpu::Device & device, wgpu::DeviceLostReason reason, wgpu::StringView message) {
if (reason == wgpu::DeviceLostReason::Destroyed) {
return;
}
GGML_UNUSED(device);
GGML_UNUSED(reason);
GGML_UNUSED(message);
//TODO: uncomment once proper free logic is in place
//GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
//std::string(message).c_str());
GGML_LOG_ERROR("ggml_webgpu: Device lost! Reason: %d, Message: %s\n", static_cast<int>(reason),
std::string(message).c_str());
});
dev_desc.SetUncapturedErrorCallback(
[](const wgpu::Device & device, wgpu::ErrorType reason, wgpu::StringView message) {
@ -3018,10 +3026,6 @@ static webgpu_context initialize_webgpu_context(ggml_backend_dev_t dev) {
ggml_webgpu_init_mul_mat_pipeline(webgpu_ctx);
ggml_webgpu_init_get_rows_pipeline(webgpu_ctx);
ggml_webgpu_init_cpy_pipeline(webgpu_ctx);
ggml_webgpu_init_add_pipeline(webgpu_ctx);
ggml_webgpu_init_sub_pipeline(webgpu_ctx);
ggml_webgpu_init_mul_pipeline(webgpu_ctx);
ggml_webgpu_init_div_pipeline(webgpu_ctx);
ggml_webgpu_init_rms_norm_pipeline(webgpu_ctx);
ggml_webgpu_init_rope_pipeline(webgpu_ctx);
ggml_webgpu_init_glu_pipeline(webgpu_ctx);
@ -3381,10 +3385,7 @@ static size_t ggml_backend_webgpu_reg_get_device_count(ggml_backend_reg_t reg) {
return ctx->device_count;
}
// TODO: Does this need to be thread safe? Is it only called once?
// TODO: move most logic to device_init function so backend can be freed/initialized properly
// Only one device is supported for now
static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t reg, size_t index) {
GGML_ASSERT(index == 0);
WEBGPU_LOG_DEBUG("ggml_backend_reg_get_device()");

View File

@ -1,188 +0,0 @@
#define(VARIANTS)
[
{
"SHADER_NAME": "add_f32",
"REPLS": {
"TYPE" : "f32",
"OP": "+"
},
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "add_f16",
"REPLS": {
"TYPE" : "f16",
"OP": "+"
},
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "add_f32_inplace",
"REPLS": {
"TYPE" : "f32",
"OP": "+"
},
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "add_f16_inplace",
"REPLS": {
"TYPE" : "f16",
"OP": "+"
},
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "mul_f32",
"REPLS": {
"TYPE" : "f32",
"OP": "*"
},
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "mul_f16",
"REPLS": {
"TYPE" : "f16",
"OP": "*"
},
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "mul_f32_inplace",
"REPLS": {
"TYPE" : "f32",
"OP": "*"
},
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "mul_f16_inplace",
"REPLS": {
"TYPE" : "f16",
"OP": "*"
},
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "sub_f32",
"REPLS": {
"TYPE" : "f32",
"OP": "-"
},
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "sub_f16",
"REPLS": {
"TYPE" : "f16",
"OP": "-"
},
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "sub_f32_inplace",
"REPLS": {
"TYPE" : "f32",
"OP": "-"
},
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "sub_f16_inplace",
"REPLS": {
"TYPE" : "f16",
"OP": "-"
},
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "div_f32",
"REPLS": {
"TYPE" : "f32",
"OP": "/"
},
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "div_f16",
"REPLS": {
"TYPE" : "f16",
"OP": "/"
},
"DECLS": ["NOT_INPLACE"]
},
{
"SHADER_NAME": "div_f32_inplace",
"REPLS": {
"TYPE" : "f32",
"OP": "/"
},
"DECLS": ["INPLACE"]
},
{
"SHADER_NAME": "div_f16_inplace",
"REPLS": {
"TYPE" : "f16",
"OP": "/"
},
"DECLS": ["INPLACE"]
}
]
#end(VARIANTS)
#define(DECLS)
#decl(NOT_INPLACE)
fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
dst[dst_i] = src0[src0_i] {{OP}} src1[src1_i];
}
@group(0) @binding(2)
var<storage, read_write> dst: array<{{TYPE}}>;
@group(0) @binding(3)
var<uniform> params: Params;
#enddecl(NOT_INPLACE)
#decl(INPLACE)
fn update(dst_i: u32, src0_i: u32, src1_i: u32) {
src0[dst_i] = src0[src0_i] {{OP}} src1[src1_i];
}
@group(0) @binding(2)
var<uniform> params: Params;
#enddecl(INPLACE)
#end(DECLS)
#define(SHADER)
enable f16;
#include "binary_head.tmpl"
@group(0) @binding(0)
var<storage, read_write> src0: array<{{TYPE}}>;
@group(0) @binding(1)
var<storage, read_write> src1: array<{{TYPE}}>;
DECLS
override wg_size: u32;
@compute @workgroup_size(wg_size)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x < params.ne) {
update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x));
}
}
#end(SHADER)

View File

@ -0,0 +1,107 @@
enable f16;
struct Params {
ne: u32,
// offsets in elements
offset_src0: u32,
offset_src1: u32,
offset_dst: u32,
stride_src1_0: u32,
stride_src1_1: u32,
stride_src1_2: u32,
stride_src1_3: u32,
a_ne0: u32,
a_ne1: u32,
a_ne2: u32,
b_ne0: u32,
b_ne1: u32,
b_ne2: u32,
b_ne3: u32,
};
fn src1_index(_i: u32) -> u32 {
var i = _i;
let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0);
let a_i2 = i / (params.a_ne1 * params.a_ne0);
i = i % (params.a_ne1 * params.a_ne0);
let a_i1 = i / params.a_ne0;
let a_i0 = i % params.a_ne0;
// handle repetition of b
// index loops back to the beginning and repeats after elements are exhausted = modulo
let b_i0 = a_i0 % params.b_ne0;
let b_i1 = a_i1 % params.b_ne1;
let b_i2 = a_i2 % params.b_ne2;
let b_i3 = a_i3 % params.b_ne3;
// compute index for position in b's flat array
return b_i0 * params.stride_src1_0 +
b_i1 * params.stride_src1_1 +
b_i2 * params.stride_src1_2 +
b_i3 * params.stride_src1_3;
}
#ifdef TYPE_F32
#define DataType f32
#endif
#ifdef TYPE_F16
#define DataType f16
#endif
@group(0) @binding(0)
var<storage, read_write> src0: array<DataType>;
@group(0) @binding(1)
var<storage, read_write> src1 : array<DataType>;
#ifdef INPLACE
@group(0) @binding(2)
var<uniform> params: Params;
#elif defined(OVERLAP)
@group(0) @binding(2)
var<uniform> params: Params;
#else
@group(0) @binding(2)
var<storage, read_write> dst: array<DataType>;
@group(0) @binding(3)
var<uniform> params: Params;
#endif
fn op(a: DataType, b: DataType) -> DataType {
#ifdef OP_ADD
return a + b;
#elif defined(OP_SUB)
return a - b;
#elif defined(OP_MUL)
return a * b;
#elif defined(OP_DIV)
return a / b;
#endif
}
fn update(dst_i: u32, src0_i: u32, src1_i: u32){
let result = op(src0[src0_i], src1[src1_i]);
#ifdef INPLACE
src0[dst_i] = result;
#elif defined(OVERLAP)
src1[dst_i] = result;
#else
dst[dst_i] = result;
#endif
}
@compute @workgroup_size(WG_SIZE)
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
if (gid.x < params.ne) {
update(params.offset_dst + gid.x, params.offset_src0 + gid.x, params.offset_src1 + src1_index(gid.x));
}
}

View File

@ -1,45 +0,0 @@
struct Params {
ne: u32,
// offsets in elements
offset_src0: u32,
offset_src1: u32,
offset_dst: u32,
stride_src1_0: u32,
stride_src1_1: u32,
stride_src1_2: u32,
stride_src1_3: u32,
a_ne0: u32,
a_ne1: u32,
a_ne2: u32,
b_ne0: u32,
b_ne1: u32,
b_ne2: u32,
b_ne3: u32,
};
fn src1_index(_i: u32) -> u32 {
var i = _i;
let a_i3 = i / (params.a_ne2 * params.a_ne1 * params.a_ne0);
i = i % (params.a_ne2 * params.a_ne1 * params.a_ne0);
let a_i2 = i / (params.a_ne1 * params.a_ne0);
i = i % (params.a_ne1 * params.a_ne0);
let a_i1 = i / params.a_ne0;
let a_i0 = i % params.a_ne0;
// handle repetition of b
// index loops back to the beginning and repeats after elements are exhausted = modulo
let b_i0 = a_i0 % params.b_ne0;
let b_i1 = a_i1 % params.b_ne1;
let b_i2 = a_i2 % params.b_ne2;
let b_i3 = a_i3 % params.b_ne3;
// compute index for position in b's flat array
return b_i0 * params.stride_src1_0 +
b_i1 * params.stride_src1_1 +
b_i2 * params.stride_src1_2 +
b_i3 * params.stride_src1_3;
}

View File

@ -5749,7 +5749,7 @@ static struct ggml_tensor * ggml_unary_impl(
struct ggml_tensor * a,
enum ggml_unary_op op,
bool inplace) {
GGML_ASSERT(ggml_is_contiguous_1(a));
GGML_ASSERT(ggml_is_contiguous_rows(a));
struct ggml_tensor * result = inplace ? ggml_view_tensor(ctx, a) : ggml_dup_tensor(ctx, a);

View File

@ -142,10 +142,13 @@ class Keys:
EMBEDDING_SCALE = "{arch}.embedding_scale"
TOKEN_SHIFT_COUNT = "{arch}.token_shift_count"
INTERLEAVE_MOE_LAYER_STEP = "{arch}.interleave_moe_layer_step"
FULL_ATTENTION_INTERVAL = "{arch}.full_attention_interval"
ACTIVATION_SPARSITY_SCALE = "{arch}.activation_sparsity_scale"
ALTUP_ACTIVE_IDX = "{arch}.altup.active_idx"
ALTUP_NUM_INPUTS = "{arch}.altup.num_inputs"
EMBD_LENGTH_PER_LAYER_INP = "{arch}.embedding_length_per_layer_input"
SWIGLU_CLAMP_EXP = "{arch}.swiglu_clamp_exp"
SWIGLU_CLAMP_SHEXP = "{arch}.swiglu_clamp_shexp"
DENSE_FEAT_IN_SIZE = "{arch}.{dense}_feat_in"
DENSE_FEAT_OUT_SIZE = "{arch}.{dense}_feat_out"
@ -179,20 +182,20 @@ class Keys:
TEMPERATURE_SCALE = "{arch}.attention.temperature_scale"
class Rope:
DIMENSION_COUNT = "{arch}.rope.dimension_count"
DIMENSION_SECTIONS = "{arch}.rope.dimension_sections"
FREQ_BASE = "{arch}.rope.freq_base"
FREQ_BASE_SWA = "{arch}.rope.freq_base_swa"
SCALING_TYPE = "{arch}.rope.scaling.type"
SCALING_FACTOR = "{arch}.rope.scaling.factor"
SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier"
SCALING_YARN_EXT_FACTOR = "{arch}.rope.scaling.yarn_ext_factor"
SCALING_YARN_ATTN_FACTOR = "{arch}.rope.scaling.yarn_attn_factor"
SCALING_YARN_BETA_FAST = "{arch}.rope.scaling.yarn_beta_fast"
SCALING_YARN_BETA_SLOW = "{arch}.rope.scaling.yarn_beta_slow"
DIMENSION_COUNT = "{arch}.rope.dimension_count"
DIMENSION_SECTIONS = "{arch}.rope.dimension_sections"
FREQ_BASE = "{arch}.rope.freq_base"
FREQ_BASE_SWA = "{arch}.rope.freq_base_swa"
SCALING_TYPE = "{arch}.rope.scaling.type"
SCALING_FACTOR = "{arch}.rope.scaling.factor"
SCALING_ATTN_FACTOR = "{arch}.rope.scaling.attn_factor"
SCALING_ORIG_CTX_LEN = "{arch}.rope.scaling.original_context_length"
SCALING_FINETUNED = "{arch}.rope.scaling.finetuned"
SCALING_YARN_LOG_MUL = "{arch}.rope.scaling.yarn_log_multiplier"
SCALING_YARN_EXT_FACTOR = "{arch}.rope.scaling.yarn_ext_factor"
SCALING_YARN_ATTN_FACTOR = "{arch}.rope.scaling.yarn_attn_factor"
SCALING_YARN_BETA_FAST = "{arch}.rope.scaling.yarn_beta_fast"
SCALING_YARN_BETA_SLOW = "{arch}.rope.scaling.yarn_beta_slow"
class Split:
LLM_KV_SPLIT_NO = "split.no"
@ -382,6 +385,8 @@ class MODEL_ARCH(IntEnum):
QWEN3NEXT = auto()
QWEN3VL = auto()
QWEN3VLMOE = auto()
QWEN35 = auto()
QWEN35MOE = auto()
PHI2 = auto()
PHI3 = auto()
PHIMOE = auto()
@ -462,6 +467,7 @@ class MODEL_ARCH(IntEnum):
PANGU_EMBED = auto()
MISTRAL3 = auto()
MIMO2 = auto()
STEP35 = auto()
LLAMA_EMBED = auto()
MAINCODER = auto()
KIMI_LINEAR = auto()
@ -554,13 +560,14 @@ class MODEL_TENSOR(IntEnum):
SSM_D = auto()
SSM_NORM = auto()
SSM_OUT = auto()
SSM_ALPHA = auto() # qwen3.5
SSM_BETA_ALPHA = auto() # qwen3next
SSM_CONV1D_Q = auto() # Kimi Linear
SSM_CONV1D_K = auto() # Kimi Linear
SSM_CONV1D_V = auto() # Kimi Linear
SSM_F_A = auto() # Kimi Linear
SSM_F_B = auto() # Kimi Linear
SSM_BETA = auto() # Kimi Linear
SSM_BETA = auto() # Kimi Linear qwen3.5
SSM_G_A = auto() # Kimi Linear
SSM_G_B = auto() # Kimi Linear
TIME_MIX_W0 = auto()
@ -811,6 +818,8 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.QWEN3NEXT: "qwen3next",
MODEL_ARCH.QWEN3VL: "qwen3vl",
MODEL_ARCH.QWEN3VLMOE: "qwen3vlmoe",
MODEL_ARCH.QWEN35: "qwen35",
MODEL_ARCH.QWEN35MOE: "qwen35moe",
MODEL_ARCH.PHI2: "phi2",
MODEL_ARCH.PHI3: "phi3",
MODEL_ARCH.PHIMOE: "phimoe",
@ -892,6 +901,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.PANGU_EMBED: "pangu-embedded",
MODEL_ARCH.MISTRAL3: "mistral3",
MODEL_ARCH.MIMO2: "mimo2",
MODEL_ARCH.STEP35: "step35",
MODEL_ARCH.LLAMA_EMBED: "llama-embed",
MODEL_ARCH.MAINCODER: "maincoder",
MODEL_ARCH.KIMI_LINEAR: "kimi-linear",
@ -981,13 +991,14 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm",
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
MODEL_TENSOR.SSM_ALPHA: "blk.{bid}.ssm_alpha", # qwen3.5
MODEL_TENSOR.SSM_BETA_ALPHA: "blk.{bid}.ssm_ba",
MODEL_TENSOR.SSM_CONV1D_Q: "blk.{bid}.ssm_conv1d_q", # Kimi Linear
MODEL_TENSOR.SSM_CONV1D_K: "blk.{bid}.ssm_conv1d_k", # Kimi Linear
MODEL_TENSOR.SSM_CONV1D_V: "blk.{bid}.ssm_conv1d_v", # Kimi Linear
MODEL_TENSOR.SSM_F_A: "blk.{bid}.ssm_f_a", # Kimi Linear
MODEL_TENSOR.SSM_F_B: "blk.{bid}.ssm_f_b", # Kimi Linear
MODEL_TENSOR.SSM_BETA: "blk.{bid}.ssm_beta", # Kimi Linear
MODEL_TENSOR.SSM_BETA: "blk.{bid}.ssm_beta", # Kimi Linear qwen3.5
MODEL_TENSOR.SSM_G_A: "blk.{bid}.ssm_g_a", # Kimi Linear
MODEL_TENSOR.SSM_G_B: "blk.{bid}.ssm_g_b", # Kimi Linear
MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0",
@ -1814,6 +1825,61 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
MODEL_ARCH.QWEN35: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.ATTN_GATE,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.SSM_A,
MODEL_TENSOR.SSM_CONV1D,
MODEL_TENSOR.SSM_DT,
MODEL_TENSOR.SSM_NORM,
MODEL_TENSOR.SSM_BETA,
MODEL_TENSOR.SSM_ALPHA,
MODEL_TENSOR.SSM_OUT
],
MODEL_ARCH.QWEN35MOE: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.ATTN_GATE,
MODEL_TENSOR.ATTN_QKV,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_INP_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.SSM_A,
MODEL_TENSOR.SSM_CONV1D,
MODEL_TENSOR.SSM_DT,
MODEL_TENSOR.SSM_NORM,
MODEL_TENSOR.SSM_BETA,
MODEL_TENSOR.SSM_ALPHA,
MODEL_TENSOR.SSM_OUT
],
MODEL_ARCH.PLAMO: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
@ -3364,6 +3430,32 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_EXP_PROBS_B,
],
MODEL_ARCH.STEP35: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ROPE_FREQS,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_GATE,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.FFN_NORM,
MODEL_TENSOR.FFN_GATE,
MODEL_TENSOR.FFN_DOWN,
MODEL_TENSOR.FFN_UP,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_UP_SHEXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_EXP_PROBS_B,
],
MODEL_ARCH.LLAMA_EMBED: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
@ -3674,6 +3766,7 @@ class VisionProjectorType:
VOXTRAL = "voxtral"
LFM2 = "lfm2"
KIMIVL = "kimivl"
KIMIK25 = "kimik25"
LIGHTONOCR = "lightonocr"
COGVLM = "cogvlm"
JANUS_PRO = "janus_pro"
@ -3753,12 +3846,12 @@ KEY_ATTENTION_LAYERNORM_EPS = Keys.Attention.LAYERNORM_EPS
KEY_ATTENTION_LAYERNORM_RMS_EPS = Keys.Attention.LAYERNORM_RMS_EPS
# RoPE
KEY_ROPE_DIMENSION_COUNT = Keys.Rope.DIMENSION_COUNT
KEY_ROPE_FREQ_BASE = Keys.Rope.FREQ_BASE
KEY_ROPE_SCALING_TYPE = Keys.Rope.SCALING_TYPE
KEY_ROPE_SCALING_FACTOR = Keys.Rope.SCALING_FACTOR
KEY_ROPE_SCALING_ORIG_CTX_LEN = Keys.Rope.SCALING_ORIG_CTX_LEN
KEY_ROPE_SCALING_FINETUNED = Keys.Rope.SCALING_FINETUNED
KEY_ROPE_DIMENSION_COUNT = Keys.Rope.DIMENSION_COUNT
KEY_ROPE_FREQ_BASE = Keys.Rope.FREQ_BASE
KEY_ROPE_SCALING_TYPE = Keys.Rope.SCALING_TYPE
KEY_ROPE_SCALING_FACTOR = Keys.Rope.SCALING_FACTOR
KEY_ROPE_SCALING_ORIG_CTX_LEN = Keys.Rope.SCALING_ORIG_CTX_LEN
KEY_ROPE_SCALING_FINETUNED = Keys.Rope.SCALING_FINETUNED
# SSM
KEY_SSM_CONV_KERNEL = Keys.SSM.CONV_KERNEL

View File

@ -708,6 +708,9 @@ class GGUFWriter:
def add_leading_dense_block_count(self, length: int) -> None:
self.add_uint32(Keys.LLM.LEADING_DENSE_BLOCK_COUNT.format(arch=self.arch), length)
def add_full_attention_interval(self, interval: int) -> None:
self.add_uint32(Keys.LLM.FULL_ATTENTION_INTERVAL.format(arch=self.arch), interval)
def add_feed_forward_length(self, length: int | Sequence[int]) -> None:
if isinstance(length, int):
self.add_uint32(Keys.LLM.FEED_FORWARD_LENGTH.format(arch=self.arch), length)
@ -824,6 +827,12 @@ class GGUFWriter:
def add_expert_gating_func(self, value: ExpertGatingFuncType) -> None:
self.add_uint32(Keys.LLM.EXPERT_GATING_FUNC.format(arch=self.arch), value.value)
def add_swiglu_clamp_exp(self, values: Sequence[float]) -> None:
self.add_array(Keys.LLM.SWIGLU_CLAMP_EXP.format(arch=self.arch), values)
def add_swiglu_clamp_shexp(self, values: Sequence[float]) -> None:
self.add_array(Keys.LLM.SWIGLU_CLAMP_SHEXP.format(arch=self.arch), values)
def add_expert_group_scale(self, value: float) -> None:
self.add_float32(Keys.LLM.EXPERT_GROUP_SCALE.format(arch=self.arch), value)

View File

@ -228,6 +228,7 @@ class TensorNameMap:
"transformer_encoder.{bid}.qkv", # neobert
"layers.{bid}.attn.Wqkv", # modern-bert
"model.layers.{bid}.self_attn.language_expert_query_key_value", # cogvlm
"model.layers.{bid}.linear_attn.in_proj_qkv", # qwen3.5
),
# Attention query
@ -359,6 +360,8 @@ class TensorNameMap:
MODEL_TENSOR.ATTN_GATE: (
"model.layers.{bid}.self_attn.gate_proj", # afmoe
"model.layers.{bid}.linear_attn.in_proj_z", # qwen3.5
"model.layers.{bid}.self_attn.g_proj", # step3.5 head-wise attention gate
),
# Feed-forward norm
@ -423,6 +426,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.router.gate", # afmoe
"layers.{bid}.gate", # mistral-large
"backbone.layers.{bid}.mixer.gate", # nemotron-h-moe
"model.layers.{bid}.moe.gate", # step3.5
),
MODEL_TENSOR.FFN_GATE_INP_SHEXP: (
@ -439,6 +443,7 @@ class TensorNameMap:
"backbone.layers.{bid}.mixer.gate.e_score_correction", # nemotron-h-moe
"model.layers.{bid}.mlp.e_score_correction", # exaone-moe
"model.layers.{bid}.block_sparse_moe.gate.e_score_correction", # kimi
"model.layers.{bid}.moe.router_bias", # step3.5 expert selection bias
),
# Feed-forward up
@ -493,6 +498,7 @@ class TensorNameMap:
"model.layers.{bid}.feed_forward.experts.up_proj", # llama4
"encoder.layers.{bid}.mlp.experts.mlp.w1", # nomic-bert-moe
"model.layers.{bid}.block_sparse_moe.experts.up", # smallthinker
"model.layers.{bid}.moe.up_proj", # step3.5
),
MODEL_TENSOR.FFN_UP_SHEXP: (
@ -504,6 +510,7 @@ class TensorNameMap:
"layers.{bid}.shared_experts.w3", # mistral-large
"backbone.layers.{bid}.mixer.shared_experts.up_proj", # nemotron-h-moe
"model.layers.{bid}.block_sparse_moe.shared_experts.up_proj", # kimi
"model.layers.{bid}.share_expert.up_proj", # step3.5
),
MODEL_TENSOR.FFN_UP_CHEXP: (
@ -543,6 +550,7 @@ class TensorNameMap:
"model.layers.{bid}.block_sparse_moe.experts.w1", # phimoe (merged)
"model.layers.{bid}.feed_forward.experts.gate_proj", # llama4
"model.layers.{bid}.block_sparse_moe.experts.gate", # smallthinker
"model.layers.{bid}.moe.gate_proj", # step3.5
),
MODEL_TENSOR.FFN_GATE_SHEXP: (
@ -552,6 +560,7 @@ class TensorNameMap:
"model.layers.{bid}.mlp.shared_mlp.gate_proj", # hunyuan
"layers.{bid}.shared_experts.w1", # mistral-large
"model.layers.{bid}.block_sparse_moe.shared_experts.gate_proj", # kimi
"model.layers.{bid}.share_expert.gate_proj", # step3.5
),
MODEL_TENSOR.FFN_GATE_CHEXP: (
@ -606,6 +615,7 @@ class TensorNameMap:
"model.layers.{bid}.feed_forward.experts.down_proj", # llama4
"encoder.layers.{bid}.mlp.experts.mlp.w2", # nomic-bert-moe
"model.layers.{bid}.block_sparse_moe.experts.down", # smallthinker
"model.layers.{bid}.moe.down_proj", # step3.5
),
MODEL_TENSOR.FFN_DOWN_SHEXP: (
@ -617,6 +627,7 @@ class TensorNameMap:
"layers.{bid}.shared_experts.w2", # mistral-large
"backbone.layers.{bid}.mixer.shared_experts.down_proj", # nemotron-h-moe
"model.layers.{bid}.block_sparse_moe.shared_experts.down_proj", # kimi
"model.layers.{bid}.share_expert.down_proj", # step3.5
),
MODEL_TENSOR.FFN_DOWN_CHEXP: (
@ -814,6 +825,10 @@ class TensorNameMap:
"model.layers.layers.{bid}.mixer.out_proj", # plamo2
),
MODEL_TENSOR.SSM_ALPHA: (
"model.layers.{bid}.linear_attn.in_proj_a", # qwen3.5
),
MODEL_TENSOR.SSM_BETA_ALPHA: (
"model.layers.{bid}.linear_attn.in_proj_ba", # qwen3next
),
@ -835,7 +850,8 @@ class TensorNameMap:
"model.layers.{bid}.self_attn.f_b_proj",
),
MODEL_TENSOR.SSM_BETA: (
"model.layers.{bid}.self_attn.b_proj",
"model.layers.{bid}.linear_attn.in_proj_b", # qwen3.5
"model.layers.{bid}.self_attn.b_proj", # Kimi Linear
),
MODEL_TENSOR.SSM_G_A: (
"model.layers.{bid}.self_attn.g_a_proj",
@ -1287,6 +1303,7 @@ class TensorNameMap:
MODEL_TENSOR.V_MMPROJ: (
"multi_modal_projector.linear_{bid}",
"mm_projector.proj.linear_{bid}", # Kimi-K2.5
"visual.merger.mlp.{bid}", # qwen2vl
"merger.mlp.{bid}",
),
@ -1348,6 +1365,7 @@ class TensorNameMap:
MODEL_TENSOR.V_ENC_ATTN_QKV: (
"visual.blocks.{bid}.attn.qkv", # qwen3vl
"model.vision.transformer.layers.{bid}.attention.query_key_value", # cogvlm
"vision_tower.encoder.blocks.{bid}.wqkv" # Kimi-K2.5
),
MODEL_TENSOR.V_ENC_ATTN_Q: (
@ -1522,6 +1540,7 @@ class TensorNameMap:
"multi_modal_projector.norm",
"multi_modal_projector.layer_norm",
"multi_modal_projector.pre_norm",
"mm_projector.pre_norm", # Kimi-K2.5
"pre_mm_projector_norm",
"model.vision.linear_proj.norm1", # cogvlm
"merger.ln_q",

View File

@ -23,7 +23,7 @@ numpy = ">=1.17"
tqdm = ">=4.27"
pyyaml = ">=5.1"
requests = ">=2.25"
sentencepiece = { version = ">=0.1.98,<=0.2.0", optional = true }
sentencepiece = { version = ">=0.1.98,<0.3.0", optional = true }
PySide6 = { version = "^6.9", python = ">=3.9,<3.14", optional = true }
[tool.poetry.dev-dependencies]

View File

@ -482,7 +482,7 @@ extern "C" {
enum llama_params_fit_status {
LLAMA_PARAMS_FIT_STATUS_SUCCESS = 0, // found allocations that are projected to fit
LLAMA_PARAMS_FIT_STATUS_FAILURE = 1, // could not find allocations that are projected to fit
LLAMA_PARAMS_FIT_STATUS_ERROR = 2, // a hard error occured, e.g. because no model could be found at the specified path
LLAMA_PARAMS_FIT_STATUS_ERROR = 2, // a hard error occurred, e.g. because no model could be found at the specified path
};
// fits mparams and cparams to free device memory (assumes system memory is unlimited)
@ -1150,9 +1150,9 @@ extern "C" {
//
/// Apply chat template. Inspired by hf apply_chat_template() on python.
/// Both "model" and "custom_template" are optional, but at least one is required. "custom_template" has higher precedence than "model"
///
/// NOTE: This function does not use a jinja parser. It only support a pre-defined list of template. See more: https://github.com/ggml-org/llama.cpp/wiki/Templates-supported-by-llama_chat_apply_template
/// @param tmpl A Jinja template to use for this chat. If this is nullptr, the models default chat template will be used instead.
/// @param tmpl A Jinja template to use for this chat.
/// @param chat Pointer to a list of multiple llama_chat_message
/// @param n_msg Number of llama_chat_message in this chat
/// @param add_ass Whether to end the prompt with the token(s) that indicate the start of an assistant message.

View File

@ -17,7 +17,7 @@ classifiers = [
[tool.poetry.dependencies]
python = ">=3.9"
numpy = "^1.25.0"
sentencepiece = ">=0.1.98,<=0.2.0"
sentencepiece = ">=0.1.98,<0.3.0"
transformers = ">=4.35.2,<5.0.0"
protobuf = ">=4.21.0,<5.0.0"
gguf = { path = "./gguf-py" }

View File

@ -1,5 +1,5 @@
numpy~=1.26.4
sentencepiece~=0.2.0
sentencepiece>=0.1.98,<0.3.0
transformers>=4.57.1,<5.0.0

Some files were not shown because too many files have changed in this diff Show More