Compare commits

..

No commits in common. "master" and "b7588" have entirely different histories.

183 changed files with 2993 additions and 12069 deletions

View File

@ -1098,7 +1098,6 @@ jobs:
save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }} save: ${{ github.event_name == 'push' && github.ref == 'refs/heads/master' }}
- name: Build with CMake - name: Build with CMake
# TODO: Remove GGML_CUDA_CUB_3DOT2 flag once CCCL 3.2 is bundled within CTK and that CTK version is used in this project
run: | run: |
cmake -S . -B build -G Ninja \ cmake -S . -B build -G Ninja \
-DLLAMA_CURL=OFF \ -DLLAMA_CURL=OFF \
@ -1108,8 +1107,7 @@ jobs:
-DCMAKE_CUDA_ARCHITECTURES=89-real \ -DCMAKE_CUDA_ARCHITECTURES=89-real \
-DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined \ -DCMAKE_EXE_LINKER_FLAGS=-Wl,--allow-shlib-undefined \
-DGGML_NATIVE=OFF \ -DGGML_NATIVE=OFF \
-DGGML_CUDA=ON \ -DGGML_CUDA=ON
-DGGML_CUDA_CUB_3DOT2=ON
cmake --build build cmake --build build
windows-2022-cmake-cuda: windows-2022-cmake-cuda:
@ -1145,7 +1143,6 @@ jobs:
- name: Build - name: Build
id: cmake_build id: cmake_build
shell: cmd shell: cmd
# TODO: Remove GGML_CUDA_CUB_3DOT2 flag once CCCL 3.2 is bundled within CTK and that CTK version is used in this project
run: | run: |
call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" x64 call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" x64
cmake -S . -B build -G "Ninja Multi-Config" ^ cmake -S . -B build -G "Ninja Multi-Config" ^
@ -1156,8 +1153,7 @@ jobs:
-DGGML_BACKEND_DL=ON ^ -DGGML_BACKEND_DL=ON ^
-DGGML_CPU_ALL_VARIANTS=ON ^ -DGGML_CPU_ALL_VARIANTS=ON ^
-DGGML_CUDA=ON ^ -DGGML_CUDA=ON ^
-DGGML_RPC=ON ^ -DGGML_RPC=ON
-DGGML_CUDA_CUB_3DOT2=ON
set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1 set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1
cmake --build build --config Release -j %NINJA_JOBS% -t ggml cmake --build build --config Release -j %NINJA_JOBS% -t ggml
cmake --build build --config Release cmake --build build --config Release
@ -1418,6 +1414,7 @@ jobs:
echo "FIXME: test on devices" echo "FIXME: test on devices"
openEuler-latest-cmake-cann: openEuler-latest-cmake-cann:
if: ${{ github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'Ascend NPU') }}
defaults: defaults:
run: run:
shell: bash -el {0} shell: bash -el {0}
@ -1753,7 +1750,7 @@ jobs:
sudo apt-get update sudo apt-get update
# Install necessary packages # Install necessary packages
sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential libssl-dev wget ccache git-lfs sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential libssl-dev wget ccache
# Set gcc-14 and g++-14 as the default compilers # Set gcc-14 and g++-14 as the default compilers
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100 sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100
@ -1765,8 +1762,6 @@ jobs:
rustup install stable rustup install stable
rustup default stable rustup default stable
git lfs install
- name: Clone - name: Clone
id: checkout id: checkout
uses: actions/checkout@v4 uses: actions/checkout@v4
@ -1852,7 +1847,7 @@ jobs:
sudo apt-get update sudo apt-get update
# Install necessary packages # Install necessary packages
sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential wget ccache git-lfs sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential wget ccache
# Set gcc-14 and g++-14 as the default compilers # Set gcc-14 and g++-14 as the default compilers
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100 sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100
@ -1864,8 +1859,6 @@ jobs:
rustup install stable rustup install stable
rustup default stable rustup default stable
git lfs install
- name: GCC version check - name: GCC version check
run: | run: |
gcc --version gcc --version
@ -1946,7 +1939,7 @@ jobs:
sudo apt-get update sudo apt-get update
# Install necessary packages # Install necessary packages
sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential wget ccache git-lfs sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential wget ccache
# Set gcc-14 and g++-14 as the default compilers # Set gcc-14 and g++-14 as the default compilers
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100 sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100
@ -1958,8 +1951,6 @@ jobs:
rustup install stable rustup install stable
rustup default stable rustup default stable
git lfs install
- name: GCC version check - name: GCC version check
run: | run: |
gcc --version gcc --version
@ -2020,7 +2011,7 @@ jobs:
sudo apt-get update sudo apt-get update
# Install necessary packages # Install necessary packages
sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential libssl-dev wget ccache git-lfs sudo apt-get install -y libatomic1 libtsan2 gcc-14 g++-14 rustup cmake build-essential libssl-dev wget ccache
# Set gcc-14 and g++-14 as the default compilers # Set gcc-14 and g++-14 as the default compilers
sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100 sudo update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-14 100
@ -2032,8 +2023,6 @@ jobs:
rustup install stable rustup install stable
rustup default stable rustup default stable
git lfs install
- name: GCC version check - name: GCC version check
run: | run: |
gcc --version gcc --version

View File

@ -420,7 +420,6 @@ jobs:
- name: Build - name: Build
id: cmake_build id: cmake_build
shell: cmd shell: cmd
# TODO: Remove GGML_CUDA_CUB_3DOT2 flag once CCCL 3.2 is bundled within CTK and that CTK version is used in this project
run: | run: |
call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" x64 call "C:\Program Files\Microsoft Visual Studio\2022\Enterprise\VC\Auxiliary\Build\vcvarsall.bat" x64
cmake -S . -B build -G "Ninja Multi-Config" ^ cmake -S . -B build -G "Ninja Multi-Config" ^
@ -428,8 +427,7 @@ jobs:
-DGGML_NATIVE=OFF ^ -DGGML_NATIVE=OFF ^
-DGGML_CPU=OFF ^ -DGGML_CPU=OFF ^
-DGGML_CUDA=ON ^ -DGGML_CUDA=ON ^
-DLLAMA_CURL=OFF ^ -DLLAMA_CURL=OFF
-DGGML_CUDA_CUB_3DOT2=ON
set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1 set /A NINJA_JOBS=%NUMBER_OF_PROCESSORS%-1
cmake --build build --config Release -j %NINJA_JOBS% --target ggml-cuda cmake --build build --config Release -j %NINJA_JOBS% --target ggml-cuda

View File

@ -41,10 +41,6 @@ jobs:
include: include:
- build_type: Release - build_type: Release
sanitizer: "" sanitizer: ""
extra_args: ""
- build_type: Release
sanitizer: ""
extra_args: "LLAMA_ARG_BACKEND_SAMPLING=1"
fail-fast: false # While -DLLAMA_SANITIZE_THREAD=ON is broken fail-fast: false # While -DLLAMA_SANITIZE_THREAD=ON is broken
steps: steps:
@ -69,12 +65,6 @@ jobs:
fetch-depth: 0 fetch-depth: 0
ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }} ref: ${{ github.event.inputs.sha || github.event.pull_request.head.sha || github.sha || github.head_ref || github.ref_name }}
- name: Build
id: cmake_build
run: |
cmake -B build -DLLAMA_CURL=OFF -DLLAMA_BUILD_BORINGSSL=ON
cmake --build build --config ${{ matrix.build_type }} -j ${env:NUMBER_OF_PROCESSORS} --target llama-server
- name: Python setup - name: Python setup
id: setup_python id: setup_python
uses: actions/setup-python@v5 uses: actions/setup-python@v5
@ -86,14 +76,6 @@ jobs:
run: | run: |
pip install -r tools/server/tests/requirements.txt pip install -r tools/server/tests/requirements.txt
- name: Tests
id: server_integration_tests
if: ${{ (!matrix.disabled_on_pr || !github.event.pull_request) && matrix.build_type == 'Release' }}
run: |
cd tools/server/tests
export ${{ matrix.extra_args }}
pytest -v -x -m "not slow"
server-windows: server-windows:
runs-on: windows-2022 runs-on: windows-2022

1
.gitignore vendored
View File

@ -130,7 +130,6 @@ poetry.toml
# Local scripts # Local scripts
/run-vim.sh /run-vim.sh
/run-chat.sh /run-chat.sh
/run-spec.sh
/.ccache/ /.ccache/
# IDE # IDE

View File

@ -52,8 +52,7 @@ if [ ! -z ${GG_BUILD_METAL} ]; then
fi fi
if [ ! -z ${GG_BUILD_CUDA} ]; then if [ ! -z ${GG_BUILD_CUDA} ]; then
# TODO: Remove GGML_CUDA_CUB_3DOT2 flag once CCCL 3.2 is bundled within CTK and that CTK version is used in this project CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_CUDA=ON"
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_CUDA=ON -DGGML_CUDA_CUB_3DOT2=ON"
if command -v nvidia-smi >/dev/null 2>&1; then if command -v nvidia-smi >/dev/null 2>&1; then
CUDA_ARCH=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader,nounits 2>/dev/null | head -1 | tr -d '.') CUDA_ARCH=$(nvidia-smi --query-gpu=compute_cap --format=csv,noheader,nounits 2>/dev/null | head -1 | tr -d '.')

View File

@ -854,54 +854,6 @@ bool common_arg_utils::is_autoy(const std::string & value) {
return value == "auto" || value == "-1"; return value == "auto" || value == "-1";
} }
// Simple CSV parser that handles quoted fields and escaped quotes
// example:
// input: value1,"value, with, commas","value with ""escaped"" quotes",value4
// output: [value1] [value, with, commas] [value with "escaped" quotes] [value4]
static std::vector<std::string> parse_csv_row(const std::string& input) {
std::vector<std::string> fields;
std::string field;
bool in_quotes = false;
for (size_t i = 0; i < input.length(); ++i) {
char ch = input[i];
if (ch == '"') {
if (!in_quotes) {
// start of quoted field (only valid if at beginning of field)
if (!field.empty()) {
// quote appeared in middle of unquoted field, treat as literal
field += '"';
} else {
in_quotes = true; // start
}
} else {
if (i + 1 < input.length() && input[i + 1] == '"') {
// escaped quote: ""
field += '"';
++i; // skip the next quote
} else {
in_quotes = false; // end
}
}
} else if (ch == ',') {
if (in_quotes) {
field += ',';
} else {
fields.push_back(std::move(field));
field.clear();
}
} else {
field += ch;
}
}
// Add the last field
fields.push_back(std::move(field));
return fields;
}
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) { common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
// per-example default params // per-example default params
// we define here to make sure it's included in llama-gen-docs // we define here to make sure it's included in llama-gen-docs
@ -1298,7 +1250,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--in-file"}, "FNAME", {"--in-file"}, "FNAME",
"an input file (use comma-separated values to specify multiple files)", "an input file (use comma-separated values to specify multiple files)",
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
for (const auto & item : parse_csv_row(value)) { for (const auto & item : string_split<std::string>(value, ',')) {
std::ifstream file(item); std::ifstream file(item);
if (!file) { if (!file) {
throw std::runtime_error(string_format("error: failed to open file '%s'\n", item.c_str())); throw std::runtime_error(string_format("error: failed to open file '%s'\n", item.c_str()));
@ -1445,7 +1397,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, bool value) { [](common_params & params, bool value) {
params.warmup = value; params.warmup = value;
} }
).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_DEBUG})); ).set_examples({LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_CLI, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MTMD, LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_PERPLEXITY}));
add_opt(common_arg( add_opt(common_arg(
{"--spm-infill"}, {"--spm-infill"},
string_format( string_format(
@ -1743,13 +1695,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.sampling.grammar = json_schema_to_grammar(json::parse(schema)); params.sampling.grammar = json_schema_to_grammar(json::parse(schema));
} }
).set_sparam()); ).set_sparam());
add_opt(common_arg(
{"-bs", "--backend-sampling"},
"enable backend sampling (experimental) (default: disabled)",
[](common_params & params) {
params.sampling.backend_sampling = true;
}
).set_sparam().set_env("LLAMA_ARG_BACKEND_SAMPLING"));
add_opt(common_arg( add_opt(common_arg(
{"--pooling"}, "{none,mean,cls,last,rank}", {"--pooling"}, "{none,mean,cls,last,rank}",
"pooling type for embeddings, use model default if unspecified", "pooling type for embeddings, use model default if unspecified",
@ -1761,7 +1706,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
else if (value == "rank") { params.pooling_type = LLAMA_POOLING_TYPE_RANK; } else if (value == "rank") { params.pooling_type = LLAMA_POOLING_TYPE_RANK; }
else { throw std::invalid_argument("invalid value"); } else { throw std::invalid_argument("invalid value"); }
} }
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_DEBUG}).set_env("LLAMA_ARG_POOLING")); ).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_RETRIEVAL, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_POOLING"));
add_opt(common_arg( add_opt(common_arg(
{"--attention"}, "{causal,non-causal}", {"--attention"}, "{causal,non-causal}",
"attention type for embeddings, use model default if unspecified", "attention type for embeddings, use model default if unspecified",
@ -2050,7 +1995,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--image", "--audio"}, "FILE", {"--image", "--audio"}, "FILE",
"path to an image or audio file. use with multimodal models, use comma-separated values for multiple files\n", "path to an image or audio file. use with multimodal models, use comma-separated values for multiple files\n",
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
for (const auto & item : parse_csv_row(value)) { for (const auto & item : string_split<std::string>(value, ',')) {
params.image.emplace_back(item); params.image.emplace_back(item);
} }
} }
@ -2307,12 +2252,37 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
)); ));
add_opt(common_arg( add_opt(common_arg(
{"--override-kv"}, "KEY=TYPE:VALUE,...", {"--override-kv"}, "KEY=TYPE:VALUE,...",
"advanced option to override model metadata by key. to specify multiple overrides, either use comma-separated values.\n" "advanced option to override model metadata by key. to specify multiple overrides, either use comma-separated or repeat this argument.\n"
"types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false,tokenizer.ggml.add_eos_token=bool:false", "types: int, float, bool, str. example: --override-kv tokenizer.ggml.add_bos_token=bool:false,tokenizer.ggml.add_eos_token=bool:false",
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
for (const auto & item : parse_csv_row(value)) { std::vector<std::string> kv_overrides;
if (!string_parse_kv_override(item.c_str(), params.kv_overrides)) {
throw std::runtime_error(string_format("error: Invalid type for KV override: %s\n", item.c_str())); std::string current;
bool escaping = false;
for (const char c : value) {
if (escaping) {
current.push_back(c);
escaping = false;
} else if (c == '\\') {
escaping = true;
} else if (c == ',') {
kv_overrides.push_back(current);
current.clear();
} else {
current.push_back(c);
}
}
if (escaping) {
current.push_back('\\');
}
kv_overrides.push_back(current);
for (const auto & kv_override : kv_overrides) {
if (!string_parse_kv_override(kv_override.c_str(), params.kv_overrides)) {
throw std::runtime_error(string_format("error: Invalid type for KV override: %s\n", kv_override.c_str()));
} }
} }
} }
@ -2329,7 +2299,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--lora"}, "FNAME", {"--lora"}, "FNAME",
"path to LoRA adapter (use comma-separated values to load multiple adapters)", "path to LoRA adapter (use comma-separated values to load multiple adapters)",
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
for (const auto & item : parse_csv_row(value)) { for (const auto & item : string_split<std::string>(value, ',')) {
params.lora_adapters.push_back({ item, 1.0, "", "", nullptr }); params.lora_adapters.push_back({ item, 1.0, "", "", nullptr });
} }
} }
@ -2340,7 +2310,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
"path to LoRA adapter with user defined scaling (format: FNAME:SCALE,...)\n" "path to LoRA adapter with user defined scaling (format: FNAME:SCALE,...)\n"
"note: use comma-separated values", "note: use comma-separated values",
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
for (const auto & item : parse_csv_row(value)) { for (const auto & item : string_split<std::string>(value, ',')) {
auto parts = string_split<std::string>(item, ':'); auto parts = string_split<std::string>(item, ':');
if (parts.size() != 2) { if (parts.size() != 2) {
throw std::invalid_argument("lora-scaled format: FNAME:SCALE"); throw std::invalid_argument("lora-scaled format: FNAME:SCALE");
@ -2354,7 +2324,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--control-vector"}, "FNAME", {"--control-vector"}, "FNAME",
"add a control vector\nnote: use comma-separated values to add multiple control vectors", "add a control vector\nnote: use comma-separated values to add multiple control vectors",
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
for (const auto & item : parse_csv_row(value)) { for (const auto & item : string_split<std::string>(value, ',')) {
params.control_vectors.push_back({ 1.0f, item, }); params.control_vectors.push_back({ 1.0f, item, });
} }
} }
@ -2364,7 +2334,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
"add a control vector with user defined scaling SCALE\n" "add a control vector with user defined scaling SCALE\n"
"note: use comma-separated values (format: FNAME:SCALE,...)", "note: use comma-separated values (format: FNAME:SCALE,...)",
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
for (const auto & item : parse_csv_row(value)) { for (const auto & item : string_split<std::string>(value, ',')) {
auto parts = string_split<std::string>(item, ':'); auto parts = string_split<std::string>(item, ':');
if (parts.size() != 2) { if (parts.size() != 2) {
throw std::invalid_argument("control-vector-scaled format: FNAME:SCALE"); throw std::invalid_argument("control-vector-scaled format: FNAME:SCALE");
@ -2462,7 +2432,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
{"--context-file"}, "FNAME", {"--context-file"}, "FNAME",
"file to load context from (use comma-separated values to specify multiple files)", "file to load context from (use comma-separated values to specify multiple files)",
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
for (const auto & item : parse_csv_row(value)) { for (const auto & item : string_split<std::string>(value, ',')) {
std::ifstream file(item, std::ios::binary); std::ifstream file(item, std::ios::binary);
if (!file) { if (!file) {
throw std::runtime_error(string_format("error: failed to open file '%s'\n", item.c_str())); throw std::runtime_error(string_format("error: failed to open file '%s'\n", item.c_str()));
@ -2609,7 +2579,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, int value) { [](common_params & params, int value) {
params.embd_normalize = value; params.embd_normalize = value;
} }
).set_examples({LLAMA_EXAMPLE_EMBEDDING, LLAMA_EXAMPLE_DEBUG})); ).set_examples({LLAMA_EXAMPLE_EMBEDDING}));
add_opt(common_arg( add_opt(common_arg(
{"--embd-output-format"}, "FORMAT", {"--embd-output-format"}, "FORMAT",
"empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix, \"raw\" = plain whitespace-delimited output (one embedding per line)", "empty = default, \"array\" = [[],[]...], \"json\" = openai style, \"json+\" = same \"json\" + cosine similarity matrix, \"raw\" = plain whitespace-delimited output (one embedding per line)",
@ -2687,7 +2657,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params) { [](common_params & params) {
params.embedding = true; params.embedding = true;
} }
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_DEBUG}).set_env("LLAMA_ARG_EMBEDDINGS")); ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_EMBEDDINGS"));
add_opt(common_arg( add_opt(common_arg(
{"--rerank", "--reranking"}, {"--rerank", "--reranking"},
string_format("enable reranking endpoint on server (default: %s)", "disabled"), string_format("enable reranking endpoint on server (default: %s)", "disabled"),
@ -2698,13 +2668,9 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_RERANKING")); ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_RERANKING"));
add_opt(common_arg( add_opt(common_arg(
{"--api-key"}, "KEY", {"--api-key"}, "KEY",
"API key to use for authentication, multiple keys can be provided as a comma-separated list (default: none)", "API key to use for authentication (default: none)",
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
for (const auto & key : parse_csv_row(value)) { params.api_keys.push_back(value);
if (!key.empty()) {
params.api_keys.push_back(key);
}
}
} }
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_API_KEY")); ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_API_KEY"));
add_opt(common_arg( add_opt(common_arg(
@ -2718,7 +2684,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
std::string key; std::string key;
while (std::getline(key_file, key)) { while (std::getline(key_file, key)) {
if (!key.empty()) { if (!key.empty()) {
params.api_keys.push_back(key); params.api_keys.push_back(key);
} }
} }
key_file.close(); key_file.close();
@ -2740,7 +2706,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SSL_CERT_FILE")); ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_SSL_CERT_FILE"));
add_opt(common_arg( add_opt(common_arg(
{"--chat-template-kwargs"}, "STRING", {"--chat-template-kwargs"}, "STRING",
"sets additional params for the json template parser, must be a valid json object string, e.g. '{\"key1\":\"value1\",\"key2\":\"value2\"}'", string_format("sets additional params for the json template parser"),
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
auto parsed = json::parse(value); auto parsed = json::parse(value);
for (const auto & item : parsed.items()) { for (const auto & item : parsed.items()) {
@ -3378,27 +3344,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
} }
} }
).set_examples({ LLAMA_EXAMPLE_FINETUNE })); ).set_examples({ LLAMA_EXAMPLE_FINETUNE }));
add_opt(common_arg(
{"--save-logits"},
string_format("save final logits to files for verification (default: %s)", params.save_logits ? "true" : "false"),
[](common_params & params) {
params.save_logits = true;
}
).set_examples({LLAMA_EXAMPLE_DEBUG}));
add_opt(common_arg(
{"--logits-output-dir"}, "PATH",
string_format("directory for saving logits output files (default: %s)", params.logits_output_dir.c_str()),
[](common_params & params, const std::string & value) {
params.logits_output_dir = value;
}
).set_examples({LLAMA_EXAMPLE_DEBUG}));
add_opt(common_arg(
{"--tensor-filter"}, "REGEX",
"filter tensor names for debug output (regex pattern, can be specified multiple times)",
[](common_params & params, const std::string & value) {
params.tensor_filter.push_back(value);
}
).set_examples({LLAMA_EXAMPLE_DEBUG}));
// presets // presets
add_opt(common_arg( add_opt(common_arg(

View File

@ -1395,14 +1395,6 @@ static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
builder.consume_reasoning_with_xml_tool_calls(form, "<seed:think>", "</seed:think>"); builder.consume_reasoning_with_xml_tool_calls(form, "<seed:think>", "</seed:think>");
} }
static void common_chat_parse_solar_open(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<|think|>", "<|end|><|begin|>assistant<|content|>");
// TODO: Tool calling
builder.add_content(builder.consume_rest());
}
static void common_chat_parse_content_only(common_chat_msg_parser & builder) { static void common_chat_parse_content_only(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<think>", "</think>"); builder.try_parse_reasoning("<think>", "</think>");
builder.add_content(builder.consume_rest()); builder.add_content(builder.consume_rest());
@ -1487,9 +1479,6 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
case COMMON_CHAT_FORMAT_XIAOMI_MIMO: case COMMON_CHAT_FORMAT_XIAOMI_MIMO:
common_chat_parse_xiaomi_mimo(builder); common_chat_parse_xiaomi_mimo(builder);
break; break;
case COMMON_CHAT_FORMAT_SOLAR_OPEN:
common_chat_parse_solar_open(builder);
break;
default: default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
} }

View File

@ -380,8 +380,8 @@ std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const json & too
const auto & function = tool.at("function"); const auto & function = tool.at("function");
result.push_back({ result.push_back({
/* .name = */ function.at("name"), /* .name = */ function.at("name"),
/* .description = */ function.value("description", ""), /* .description = */ function.at("description"),
/* .parameters = */ function.value("parameters", json::object()).dump(), /* .parameters = */ function.at("parameters").dump(),
}); });
} }
} }
@ -669,7 +669,6 @@ const char * common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: return "Qwen3 Coder"; case COMMON_CHAT_FORMAT_QWEN3_CODER_XML: return "Qwen3 Coder";
case COMMON_CHAT_FORMAT_APRIEL_1_5: return "Apriel 1.5"; case COMMON_CHAT_FORMAT_APRIEL_1_5: return "Apriel 1.5";
case COMMON_CHAT_FORMAT_XIAOMI_MIMO: return "Xiaomi MiMo"; case COMMON_CHAT_FORMAT_XIAOMI_MIMO: return "Xiaomi MiMo";
case COMMON_CHAT_FORMAT_SOLAR_OPEN: return "Solar Open";
case COMMON_CHAT_FORMAT_PEG_SIMPLE: return "peg-simple"; case COMMON_CHAT_FORMAT_PEG_SIMPLE: return "peg-simple";
case COMMON_CHAT_FORMAT_PEG_NATIVE: return "peg-native"; case COMMON_CHAT_FORMAT_PEG_NATIVE: return "peg-native";
case COMMON_CHAT_FORMAT_PEG_CONSTRUCTED: return "peg-constructed"; case COMMON_CHAT_FORMAT_PEG_CONSTRUCTED: return "peg-constructed";
@ -2065,7 +2064,7 @@ static common_chat_params common_chat_params_init_gpt_oss(const common_chat_temp
// Trigger on tool calls that appear in the commentary channel // Trigger on tool calls that appear in the commentary channel
data.grammar_triggers.push_back({ data.grammar_triggers.push_back({
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN,
"<\\|channel\\|>(?:commentary|analysis) to" "<\\|channel\\|>(commentary|analysis) to"
}); });
// Trigger tool calls that appear in the role section, either at the // Trigger tool calls that appear in the role section, either at the
@ -2398,17 +2397,17 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
(inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call)); (inputs.parallel_tool_calls ? "(" + tool_call + ")+" : tool_call));
// Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives) // Trigger on some common known "good bad" outputs (only from the start and with a json that's about a specific argument name to avoid false positives)
data.grammar_triggers.push_back({ data.grammar_triggers.push_back({
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN, COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
// If thinking_forced_open, then we capture the </think> tag in the grammar, // If thinking_forced_open, then we capture the </think> tag in the grammar,
// (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
std::string(data.thinking_forced_open ? "(</think>\\s*)" : "") + ( std::string(data.thinking_forced_open ? "[\\s\\S]*?(</think>\\s*)" : "(?:<think>[\\s\\S]*?</think>\\s*)?") + (
"\\s*(" "\\s*("
"(?:<tool_call>" "(?:<tool_call>"
"|<function" "|<function"
"|(?:```(?:json|xml)?\n\\s*)?(?:<function_call>|<tools>|<xml><json>|<response>)?" "|(?:```(?:json|xml)?\n\\s*)?(?:<function_call>|<tools>|<xml><json>|<response>)?"
"\\s*\\{\\s*\"name\"\\s*:\\s*\"(?:" + string_join(escaped_names, "|") + ")\"" "\\s*\\{\\s*\"name\"\\s*:\\s*\"(?:" + string_join(escaped_names, "|") + ")\""
")" ")"
")" ")[\\s\\S]*"
), ),
}); });
data.preserved_tokens = { data.preserved_tokens = {
@ -2518,27 +2517,6 @@ static common_chat_params common_chat_params_init_granite(const common_chat_temp
return data; return data;
} }
static common_chat_params common_chat_params_init_solar_open(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
// TODO: Reasoning effort
json additional_context = {};
data.prompt = apply(tmpl, inputs, std::nullopt, std::nullopt, additional_context);
data.format = COMMON_CHAT_FORMAT_SOLAR_OPEN;
data.preserved_tokens = {
"<|think|>",
"<|content|>",
"<|begin|>",
"<|end|>",
};
// TODO: Tool calling
return data;
}
static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) { static common_chat_params common_chat_params_init_without_tools(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data; common_chat_params data;
data.prompt = apply(tmpl, inputs); data.prompt = apply(tmpl, inputs);
@ -2802,13 +2780,6 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_magistral(tmpl, params); return common_chat_params_init_magistral(tmpl, params);
} }
// Solar Open
if (src.find("<|tool_response:begin|>") != std::string::npos &&
src.find("<|tool_response:name|>") != std::string::npos &&
src.find("<|tool_response:result|>") != std::string::npos) {
return common_chat_params_init_solar_open(tmpl, params);
}
// Plain handler (no tools) // Plain handler (no tools)
if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) { if (params.tools.is_null() || inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_NONE) {
return common_chat_params_init_without_tools(tmpl, params); return common_chat_params_init_without_tools(tmpl, params);

View File

@ -124,7 +124,6 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_QWEN3_CODER_XML, COMMON_CHAT_FORMAT_QWEN3_CODER_XML,
COMMON_CHAT_FORMAT_APRIEL_1_5, COMMON_CHAT_FORMAT_APRIEL_1_5,
COMMON_CHAT_FORMAT_XIAOMI_MIMO, COMMON_CHAT_FORMAT_XIAOMI_MIMO,
COMMON_CHAT_FORMAT_SOLAR_OPEN,
// These are intended to be parsed by the PEG parser // These are intended to be parsed by the PEG parser
COMMON_CHAT_FORMAT_PEG_SIMPLE, COMMON_CHAT_FORMAT_PEG_SIMPLE,

View File

@ -1086,7 +1086,6 @@ struct common_init_result::impl {
std::vector<llama_adapter_lora_ptr> lora; std::vector<llama_adapter_lora_ptr> lora;
std::vector<common_sampler_ptr> samplers; std::vector<common_sampler_ptr> samplers;
std::vector<llama_sampler_seq_config> samplers_seq_config;
}; };
common_init_result::common_init_result(common_params & params) : common_init_result::common_init_result(common_params & params) :
@ -1163,19 +1162,10 @@ common_init_result::common_init_result(common_params & params) :
// params.sampling.dry_penalty_last_n = llama_n_ctx(lctx); // params.sampling.dry_penalty_last_n = llama_n_ctx(lctx);
//} //}
// init the backend samplers as part of the context creation
pimpl->samplers.resize(cparams.n_seq_max); pimpl->samplers.resize(cparams.n_seq_max);
pimpl->samplers_seq_config.resize(cparams.n_seq_max);
for (int i = 0; i < (int) cparams.n_seq_max; ++i) { for (int i = 0; i < (int) cparams.n_seq_max; ++i) {
pimpl->samplers[i].reset(common_sampler_init(model, params.sampling)); pimpl->samplers[i].reset(common_sampler_init(model, params.sampling));
pimpl->samplers_seq_config[i] = { i, common_sampler_get(pimpl->samplers[i].get()) };
}
// TODO: temporarily gated behind a flag
if (params.sampling.backend_sampling) {
cparams.samplers = pimpl->samplers_seq_config.data();
cparams.n_samplers = pimpl->samplers_seq_config.size();
} }
llama_context * lctx = llama_init_from_model(model, cparams); llama_context * lctx = llama_init_from_model(model, cparams);
@ -1199,12 +1189,6 @@ common_sampler * common_init_result::sampler(llama_seq_id seq_id) {
return pimpl->samplers[seq_id].get(); return pimpl->samplers[seq_id].get();
} }
void common_init_result::reset_samplers() {
for (int i = 0; i < (int) pimpl->samplers.size(); ++i) {
llama_sampler_reset(common_sampler_get(pimpl->samplers[i].get()));
}
}
std::vector<llama_adapter_lora_ptr> & common_init_result::lora() { std::vector<llama_adapter_lora_ptr> & common_init_result::lora() {
return pimpl->lora; return pimpl->lora;
} }
@ -1320,9 +1304,6 @@ common_init_result_ptr common_init_from_params(common_params & params) {
llama_synchronize(lctx); llama_synchronize(lctx);
llama_perf_context_reset(lctx); llama_perf_context_reset(lctx);
llama_set_warmup(lctx, false); llama_set_warmup(lctx, false);
// reset samplers to reset RNG state after warmup to the seeded state
res->reset_samplers();
} }
return res; return res;

View File

@ -80,7 +80,6 @@ int32_t cpu_get_num_math();
// //
enum llama_example { enum llama_example {
LLAMA_EXAMPLE_DEBUG,
LLAMA_EXAMPLE_COMMON, LLAMA_EXAMPLE_COMMON,
LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SPECULATIVE,
LLAMA_EXAMPLE_COMPLETION, LLAMA_EXAMPLE_COMPLETION,
@ -217,8 +216,6 @@ struct common_params_sampling {
std::vector<llama_logit_bias> logit_bias; // logit biases to apply std::vector<llama_logit_bias> logit_bias; // logit biases to apply
std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens std::vector<llama_logit_bias> logit_bias_eog; // pre-calculated logit biases for EOG tokens
bool backend_sampling = false;
bool has_logit_bias() const { bool has_logit_bias() const {
return !logit_bias.empty(); return !logit_bias.empty();
} }
@ -373,11 +370,6 @@ struct common_params {
std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT std::string lookup_cache_dynamic = ""; // path of dynamic ngram cache file for lookup decoding // NOLINT
std::string logits_file = ""; // file for saving *all* logits // NOLINT std::string logits_file = ""; // file for saving *all* logits // NOLINT
// llama-debug specific options
std::string logits_output_dir = "data"; // directory for saving logits output files // NOLINT
bool save_logits = false; // whether to save logits to files // NOLINT
std::vector<std::string> tensor_filter; // filter tensor names for debug output (regex) // NOLINT
std::vector<std::string> in_files; // all input files std::vector<std::string> in_files; // all input files
std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts) std::vector<std::string> antiprompt; // strings upon which more user input is prompted (a.k.a. reverse prompts)
std::vector<llama_model_kv_override> kv_overrides; std::vector<llama_model_kv_override> kv_overrides;
@ -697,9 +689,7 @@ struct common_init_result {
llama_model * model(); llama_model * model();
llama_context * context(); llama_context * context();
common_sampler * sampler(llama_seq_id seq_id); common_sampler * sampler(llama_seq_id seq_id);
void reset_samplers();
std::vector<llama_adapter_lora_ptr> & lora(); std::vector<llama_adapter_lora_ptr> & lora();

View File

@ -106,16 +106,12 @@ static void llama_sampler_llg_free(llama_sampler * smpl) {
} }
static llama_sampler_i llama_sampler_llg_i = { static llama_sampler_i llama_sampler_llg_i = {
/* .name = */ llama_sampler_llg_name, /* .name = */ llama_sampler_llg_name,
/* .accept = */ llama_sampler_llg_accept_impl, /* .accept = */ llama_sampler_llg_accept_impl,
/* .apply = */ llama_sampler_llg_apply, /* .apply = */ llama_sampler_llg_apply,
/* .reset = */ llama_sampler_llg_reset, /* .reset = */ llama_sampler_llg_reset,
/* .clone = */ llama_sampler_llg_clone, /* .clone = */ llama_sampler_llg_clone,
/* .free = */ llama_sampler_llg_free, /* .free = */ llama_sampler_llg_free,
/* .backend_init = */ NULL,
/* .backend_accept = */ NULL,
/* .backend_apply = */ NULL,
/* .backend_set_input = */ NULL,
}; };
static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len, static size_t llama_sampler_llg_tokenize_fn(const void * user_data, const uint8_t * bytes, size_t bytes_len,

View File

@ -27,7 +27,7 @@ common_regex_match common_regex::search(const std::string & input, size_t pos, b
return res; return res;
} }
std::match_results<std::string::const_reverse_iterator> srmatch; std::match_results<std::string::const_reverse_iterator> srmatch;
if (std::regex_search(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial, std::regex_constants::match_continuous)) { if (std::regex_match(input.rbegin(), input.rend() - pos, srmatch, rx_reversed_partial)) {
auto group = srmatch[1].str(); auto group = srmatch[1].str();
if (group.length() != 0) { if (group.length() != 0) {
auto it = srmatch[1].second.base(); auto it = srmatch[1].second.base();
@ -55,18 +55,18 @@ common_regex_match common_regex::search(const std::string & input, size_t pos, b
to see if a string ends with a partial regex match, but but it's not in std::regex yet. to see if a string ends with a partial regex match, but but it's not in std::regex yet.
Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input. Instead, we'll the regex into a partial match regex operating as a full match on the reverse iterators of the input.
- /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:(?:d)?c)?b)?a) - /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:(?:d)?c)?b)?a).*
- /a|b/ -> ^(a|b) - /a|b/ -> (a|b).*
- /a*?/ -> error, could match "" - /a*?/ -> error, could match ""
- /a*b/ -> ^((?:b)?a*+) (final repetitions become eager) - /a*b/ -> ((?:b)?a*+).* (final repetitions become eager)
- /.*?ab/ -> ^((?:b)?a) (omit .*) - /.*?ab/ -> ((?:b)?a).* (merge .*)
- /a.*?b/ -> ^((?:b)?.*?a) (keep reluctant matches) - /a.*?b/ -> ((?:b)?.*?a).* (keep reluctant matches)
- /a(bc)d/ -> ^((?:(?:d)?(?:(?:c)?b))?a) - /a(bc)d/ -> ((?:(?:d)?(?:(?:c)?b))?a).*
- /a(bc|de)/ -> ^((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a) - /a(bc|de)/ -> ((?:(?:(?:e)?d)?|(?:(?:c)?b)?)?a).*
- /ab{2,4}c/ -> ^cbbb?b?a -> ^((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a) - /ab{2,4}c/ -> abbb?b?c -> ((?:(?:(?:(?:(?:c)?b)?b)?b?)?b?)?a).*
The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern. The regex will match a reversed string fully, and the end of the first (And only) capturing group will indicate the reversed start of the original partial pattern
All other groups are turned into non-capturing groups, and reluctant quantifiers are ignored. (i.e. just where the final .* starts in the inverted pattern; all other groups are turned into non-capturing groups, and reluctant quantifiers are ignored)
*/ */
std::string regex_to_reversed_partial_regex(const std::string & pattern) { std::string regex_to_reversed_partial_regex(const std::string & pattern) {
auto it = pattern.begin(); auto it = pattern.begin();
@ -177,7 +177,7 @@ std::string regex_to_reversed_partial_regex(const std::string & pattern) {
} }
} }
// /abcd/ -> ^(dcba|cba|ba|a) -> ^((?:(?:(?:d)?c)?b)?a) // /abcd/ -> (dcba|cba|ba|a).* -> ((?:(?:(?:d)?c)?b)?a).*
// if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group // if n(=4) parts, opening n-1(=3) non-capturing groups after the 1 capturing group
// We'll do the outermost capturing group and final .* in the enclosing function. // We'll do the outermost capturing group and final .* in the enclosing function.
std::vector<std::string> res_alts; std::vector<std::string> res_alts;
@ -200,5 +200,5 @@ std::string regex_to_reversed_partial_regex(const std::string & pattern) {
throw std::runtime_error("Unmatched '(' in pattern"); throw std::runtime_error("Unmatched '(' in pattern");
} }
return "^(" + res + ")"; return "(" + res + ")[\\s\\S]*";
} }

View File

@ -120,34 +120,17 @@ struct common_sampler {
} }
void set_logits(struct llama_context * ctx, int idx) { void set_logits(struct llama_context * ctx, int idx) {
const float * sampled_probs = llama_get_sampled_probs_ith (ctx, idx); const auto * logits = llama_get_logits_ith(ctx, idx);
const float * sampled_logits = llama_get_sampled_logits_ith (ctx, idx);
const llama_token * sampled_ids = llama_get_sampled_candidates_ith(ctx, idx);
const llama_model * model = llama_get_model(ctx); const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_vocab = llama_vocab_n_tokens(vocab); const int n_vocab = llama_vocab_n_tokens(vocab);
if (sampled_probs) { cur.resize(n_vocab);
const uint32_t sampled_probs_count = llama_get_sampled_probs_count_ith(ctx, idx);
cur.resize(sampled_probs_count); for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
for (uint32_t i = 0; i < sampled_probs_count; ++i) { cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], sampled_probs[i]};
}
} else if (sampled_logits) {
const uint32_t sampled_logits_count = llama_get_sampled_logits_count_ith(ctx, idx);
cur.resize(sampled_logits_count);
for (uint32_t i = 0; i < sampled_logits_count; i++) {
cur[i] = llama_token_data{sampled_ids[i], sampled_logits[i], 0.0f};
}
} else {
const auto * logits = llama_get_logits_ith(ctx, idx);
GGML_ASSERT(logits != nullptr);
cur.resize(n_vocab);
for (llama_token token_id = 0; token_id < n_vocab; token_id++) {
cur[token_id] = llama_token_data{token_id, logits[token_id], 0.0f};
}
} }
cur_p = { cur.data(), cur.size(), -1, false }; cur_p = { cur.data(), cur.size(), -1, false };
@ -176,7 +159,7 @@ std::string common_params_sampling::print() const {
return std::string(result); return std::string(result);
} }
struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params) { struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params) {
const llama_vocab * vocab = llama_model_get_vocab(model); const llama_vocab * vocab = llama_model_get_vocab(model);
llama_sampler_chain_params lparams = llama_sampler_chain_default_params(); llama_sampler_chain_params lparams = llama_sampler_chain_default_params();
@ -196,30 +179,24 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
#endif // LLAMA_USE_LLGUIDANCE #endif // LLAMA_USE_LLGUIDANCE
} else { } else {
std::vector<std::string> trigger_patterns; std::vector<std::string> trigger_patterns;
std::vector<std::string> patterns_anywhere;
std::vector<llama_token> trigger_tokens; std::vector<llama_token> trigger_tokens;
for (const auto & trigger : params.grammar_triggers) { for (const auto & trigger : params.grammar_triggers) {
switch (trigger.type) { switch (trigger.type) {
case COMMON_GRAMMAR_TRIGGER_TYPE_WORD: case COMMON_GRAMMAR_TRIGGER_TYPE_WORD:
{ {
const auto & word = trigger.value; const auto & word = trigger.value;
trigger_patterns.push_back(regex_escape(word)); patterns_anywhere.push_back(regex_escape(word));
break; break;
} }
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN: case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN:
{ {
trigger_patterns.push_back(trigger.value); patterns_anywhere.push_back(trigger.value);
break; break;
} }
case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL: case COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL:
{ {
const auto & pattern = trigger.value; trigger_patterns.push_back(trigger.value);
std::string anchored = "^$";
if (!pattern.empty()) {
anchored = (pattern.front() != '^' ? "^" : "")
+ pattern
+ (pattern.back() != '$' ? "$" : "");
}
trigger_patterns.push_back(anchored);
break; break;
} }
case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN: case COMMON_GRAMMAR_TRIGGER_TYPE_TOKEN:
@ -233,6 +210,10 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
} }
} }
if (!patterns_anywhere.empty()) {
trigger_patterns.push_back("^[\\s\\S]*?(" + string_join(patterns_anywhere, "|") + ")[\\s\\S]*");
}
std::vector<const char *> trigger_patterns_c; std::vector<const char *> trigger_patterns_c;
trigger_patterns_c.reserve(trigger_patterns.size()); trigger_patterns_c.reserve(trigger_patterns.size());
for (const auto & regex : trigger_patterns) { for (const auto & regex : trigger_patterns) {
@ -315,12 +296,6 @@ struct common_sampler * common_sampler_init(const struct llama_model * model, st
llama_sampler_chain_add(chain, smpl); llama_sampler_chain_add(chain, smpl);
} }
if (grmr && params.backend_sampling) {
LOG_WRN("%s: backend sampling is not compatible with grammar, disabling\n", __func__);
params.backend_sampling = false;
}
auto * result = new common_sampler { auto * result = new common_sampler {
/* .params = */ params, /* .params = */ params,
/* .grmr = */ grmr, /* .grmr = */ grmr,
@ -430,25 +405,6 @@ llama_token common_sampler_sample(struct common_sampler * gsmpl, struct llama_co
auto & chain = gsmpl->chain; auto & chain = gsmpl->chain;
auto & cur_p = gsmpl->cur_p; // initialized by set_logits auto & cur_p = gsmpl->cur_p; // initialized by set_logits
// Check if a backend sampler has already sampled a token in which case we
// return that token id directly.
{
id = llama_get_sampled_token_ith(ctx, idx);
if (id != LLAMA_TOKEN_NULL) {
LOG_DBG("%s: Backend sampler selected token: '%d'. Will not run any CPU samplers\n", __func__, id);
GGML_ASSERT(!gsmpl->grmr && "using grammar in combination with backend sampling is not supported");
// TODO: simplify
gsmpl->cur.resize(1);
gsmpl->cur[0] = { id, 0.0f, 1.0f };
cur_p = { gsmpl->cur.data(), gsmpl->cur.size(), 0, true };
return id;
}
}
gsmpl->set_logits(ctx, idx); gsmpl->set_logits(ctx, idx);
if (grammar_first) { if (grammar_first) {

View File

@ -36,8 +36,7 @@ struct common_sampler;
// llama_sampler API overloads // llama_sampler API overloads
// note: can mutate params in some cases struct common_sampler * common_sampler_init(const struct llama_model * model, const struct common_params_sampling & params);
struct common_sampler * common_sampler_init(const struct llama_model * model, struct common_params_sampling & params);
void common_sampler_free(struct common_sampler * gsmpl); void common_sampler_free(struct common_sampler * gsmpl);
@ -49,7 +48,6 @@ struct common_sampler * common_sampler_clone (struct common_sampler * gsmpl);
// arguments can be nullptr to skip printing // arguments can be nullptr to skip printing
void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl); void common_perf_print(const struct llama_context * ctx, const struct common_sampler * gsmpl);
// get the underlying llama_sampler_chain
struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl); struct llama_sampler * common_sampler_get(const struct common_sampler * gsmpl);
// extended sampling implementation: // extended sampling implementation:

View File

@ -771,14 +771,9 @@ class TextModel(ModelBase):
self.rope_parameters = self.hparams.get("rope_parameters", self.hparams.get("rope_scaling")) or {} self.rope_parameters = self.hparams.get("rope_parameters", self.hparams.get("rope_scaling")) or {}
rope_theta = self.find_hparam(["rope_theta", "global_rope_theta", "rotary_emb_base"], optional=True)
local_rope_theta = self.find_hparam(["local_rope_theta", "rope_local_theta", "swa_rope_theta", "rope_local_base_freq"], optional=True)
# Ensure "rope_theta" and "rope_type" is mirrored in rope_parameters # Ensure "rope_theta" and "rope_type" is mirrored in rope_parameters
if "full_attention" not in self.rope_parameters and "sliding_attention" not in self.rope_parameters: if "full_attention" not in self.rope_parameters and "sliding_attention" not in self.rope_parameters:
if local_rope_theta is not None: if "rope_theta" not in self.rope_parameters and (rope_theta := self.find_hparam(["rope_theta", "global_rope_theta", "rotary_emb_base"], optional=True)) is not None:
self.rope_parameters["sliding_attention"] = {"rope_theta": local_rope_theta}
if "rope_theta" not in self.rope_parameters and rope_theta is not None:
self.rope_parameters["rope_theta"] = rope_theta self.rope_parameters["rope_theta"] = rope_theta
if "rope_type" not in self.rope_parameters and (rope_type := self.rope_parameters.get("type")) is not None: if "rope_type" not in self.rope_parameters and (rope_type := self.rope_parameters.get("type")) is not None:
self.rope_parameters["rope_type"] = rope_type self.rope_parameters["rope_type"] = rope_type
@ -844,7 +839,6 @@ class TextModel(ModelBase):
self.gguf_writer.add_head_count_kv(n_head_kv) self.gguf_writer.add_head_count_kv(n_head_kv)
logger.info(f"gguf: key-value head count = {n_head_kv}") logger.info(f"gguf: key-value head count = {n_head_kv}")
# TODO: Handle "sliding_attention" similarly when models start implementing it
rope_params = self.rope_parameters.get("full_attention", self.rope_parameters) rope_params = self.rope_parameters.get("full_attention", self.rope_parameters)
if (rope_type := rope_params.get("rope_type")) is not None: if (rope_type := rope_params.get("rope_type")) is not None:
rope_factor = rope_params.get("factor") rope_factor = rope_params.get("factor")
@ -891,9 +885,6 @@ class TextModel(ModelBase):
if (rope_theta := rope_params.get("rope_theta")) is not None: if (rope_theta := rope_params.get("rope_theta")) is not None:
self.gguf_writer.add_rope_freq_base(rope_theta) self.gguf_writer.add_rope_freq_base(rope_theta)
logger.info(f"gguf: rope theta = {rope_theta}") logger.info(f"gguf: rope theta = {rope_theta}")
if (local_rope_theta := self.rope_parameters.get("sliding_attention", {}).get("rope_theta")) is not None:
self.gguf_writer.add_rope_freq_base_swa(local_rope_theta)
logger.info(f"gguf: rope theta swa = {local_rope_theta}")
if (f_rms_eps := self.find_hparam(["rms_norm_eps", "norm_eps"], optional=True)) is not None: if (f_rms_eps := self.find_hparam(["rms_norm_eps", "norm_eps"], optional=True)) is not None:
self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps) self.gguf_writer.add_layer_norm_rms_eps(f_rms_eps)
logger.info(f"gguf: rms norm epsilon = {f_rms_eps}") logger.info(f"gguf: rms norm epsilon = {f_rms_eps}")
@ -1071,9 +1062,6 @@ class TextModel(ModelBase):
if chkhsh == "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273": if chkhsh == "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273":
# ref: https://huggingface.co/alvarobartt/grok-2-tokenizer # ref: https://huggingface.co/alvarobartt/grok-2-tokenizer
res = "grok-2" res = "grok-2"
if chkhsh == "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df":
# ref: https://huggingface.co/aari1995/German_Semantic_V3
res = "jina-v2-de"
if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5": if chkhsh == "0ef9807a4087ebef797fc749390439009c3b9eda9ad1a097abbe738f486c01e5":
# ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B # ref: https://huggingface.co/meta-llama/Meta-Llama-3-8B
res = "llama-bpe" res = "llama-bpe"
@ -1242,12 +1230,6 @@ class TextModel(ModelBase):
if chkhsh == "4a2e2abae11ca2b86d570fc5b44be4d5eb5e72cc8f22dd136a94b37da83ab665": if chkhsh == "4a2e2abae11ca2b86d570fc5b44be4d5eb5e72cc8f22dd136a94b37da83ab665":
# ref: https://huggingface.co/KORMo-Team/KORMo-tokenizer # ref: https://huggingface.co/KORMo-Team/KORMo-tokenizer
res = "kormo" res = "kormo"
if chkhsh == "9d70134b369a70e5735009b6de918f7581b5211f7c074d1f89f753aea8248af1":
# ref: https://huggingface.co/tencent/Youtu-LLM-2B
res = "youtu"
if chkhsh == "16389f0a1f51ee53e562ffd51c371dc508639ab0e4261502071836e50e223e91":
# ref: https://huggingface.co/upstage/Solar-Open-100B
res = "solar-open"
if res is None: if res is None:
logger.warning("\n") logger.warning("\n")
@ -2504,7 +2486,6 @@ class StableLMModel(TextModel):
"VLlama3ForCausalLM", "VLlama3ForCausalLM",
"LlavaForConditionalGeneration", "LlavaForConditionalGeneration",
"VoxtralForConditionalGeneration", "VoxtralForConditionalGeneration",
"IQuestCoderForCausalLM",
"LlamaModel") "LlamaModel")
class LlamaModel(TextModel): class LlamaModel(TextModel):
model_arch = gguf.MODEL_ARCH.LLAMA model_arch = gguf.MODEL_ARCH.LLAMA
@ -3522,7 +3503,7 @@ class QwenModel(TextModel):
self._set_vocab_qwen() self._set_vocab_qwen()
@ModelBase.register("Qwen2Model", "Qwen2ForCausalLM", "Qwen2AudioForConditionalGeneration", "KORMoForCausalLM", "AudioFlamingo3ForConditionalGeneration") @ModelBase.register("Qwen2Model", "Qwen2ForCausalLM", "Qwen2AudioForConditionalGeneration", "KORMoForCausalLM")
class Qwen2Model(TextModel): class Qwen2Model(TextModel):
model_arch = gguf.MODEL_ARCH.QWEN2 model_arch = gguf.MODEL_ARCH.QWEN2
@ -5013,6 +4994,7 @@ class Plamo3Model(TextModel):
if (sliding_window := self.find_hparam(["window_size", "sliding_window"], optional=True)) is not None: if (sliding_window := self.find_hparam(["window_size", "sliding_window"], optional=True)) is not None:
self.gguf_writer.add_sliding_window(sliding_window) self.gguf_writer.add_sliding_window(sliding_window)
self.gguf_writer.add_sliding_window_pattern(self.hparams["sliding_window_pattern"]) self.gguf_writer.add_sliding_window_pattern(self.hparams["sliding_window_pattern"])
self.gguf_writer.add_rope_freq_base_swa(self.rope_parameters.get("sliding_attention", {"rope_theta": self.hparams.get("rope_local_theta")})["rope_theta"])
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]: def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
@ -5302,14 +5284,13 @@ class BertModel(TextModel):
self.gguf_writer.add_token_type_count(self.hparams.get("type_vocab_size", 1)) self.gguf_writer.add_token_type_count(self.hparams.get("type_vocab_size", 1))
# convert to phantom space vocab # convert to phantom space vocab
def phantom(tok, toktype): def phantom(tok):
if toktype == gguf.TokenType.CONTROL: if tok.startswith("[") and tok.endswith("]"):
return tok return tok
if tok.startswith("##"): if tok.startswith("##"):
return tok[2:] return tok[2:]
return "\u2581" + tok return "\u2581" + tok
assert len(tokens) == len(toktypes) tokens = list(map(phantom, tokens))
tokens = list(map(phantom, tokens, toktypes))
# add vocab to gguf # add vocab to gguf
self.gguf_writer.add_tokenizer_model("bert") self.gguf_writer.add_tokenizer_model("bert")
@ -6423,17 +6404,6 @@ class ARwkv7Model(Rwkv7Model):
self.gguf_writer.add_head_count(0) self.gguf_writer.add_head_count(0)
@ModelBase.register("MaincoderForCausalLM")
class MaincoderModel(TextModel):
model_arch = gguf.MODEL_ARCH.MAINCODER
def set_gguf_parameters(self):
super().set_gguf_parameters()
if (head_dim := self.hparams.get("head_dim")) is not None:
self.gguf_writer.add_rope_dimension_count(head_dim)
@ModelBase.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM") @ModelBase.register("MambaForCausalLM", "MambaLMHeadModel", "FalconMambaForCausalLM")
class MambaModel(TextModel): class MambaModel(TextModel):
model_arch = gguf.MODEL_ARCH.MAMBA model_arch = gguf.MODEL_ARCH.MAMBA
@ -7211,8 +7181,6 @@ class DeepseekModel(TextModel):
"DeepseekV2ForCausalLM", "DeepseekV2ForCausalLM",
"DeepseekV3ForCausalLM", "DeepseekV3ForCausalLM",
"KimiVLForConditionalGeneration", "KimiVLForConditionalGeneration",
"YoutuForCausalLM",
"YoutuVLForConditionalGeneration"
) )
class DeepseekV2Model(TextModel): class DeepseekV2Model(TextModel):
model_arch = gguf.MODEL_ARCH.DEEPSEEK2 model_arch = gguf.MODEL_ARCH.DEEPSEEK2
@ -7279,15 +7247,7 @@ class DeepseekV2Model(TextModel):
super().set_gguf_parameters() super().set_gguf_parameters()
hparams = self.hparams hparams = self.hparams
# first_k_dense_replace: number of leading layers using dense FFN instead of MoE self.gguf_writer.add_leading_dense_block_count(hparams["first_k_dense_replace"])
# For non-MoE models (like Youtu), set to n_layer to use dense FFN for all layers
# For MoE models (like DeepSeek-V2), this is the number of leading non-MoE layers
has_moe = hparams.get("n_routed_experts") is not None
first_k_dense_replace = hparams.get("first_k_dense_replace")
if first_k_dense_replace is None:
# Default: if no MoE, all layers are dense; if MoE, none are dense
first_k_dense_replace = hparams["num_hidden_layers"] if not has_moe else 0
self.gguf_writer.add_leading_dense_block_count(first_k_dense_replace)
self.gguf_writer.add_vocab_size(hparams["vocab_size"]) self.gguf_writer.add_vocab_size(hparams["vocab_size"])
if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None: if "q_lora_rank" in hparams and hparams["q_lora_rank"] is not None:
self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"]) self.gguf_writer.add_q_lora_rank(hparams["q_lora_rank"])
@ -7299,24 +7259,11 @@ class DeepseekV2Model(TextModel):
self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"]) self.gguf_writer.add_key_length_mla(hparams["qk_nope_head_dim"] + hparams["qk_rope_head_dim"])
self.gguf_writer.add_value_length_mla(hparams["v_head_dim"]) self.gguf_writer.add_value_length_mla(hparams["v_head_dim"])
# MoE parameters (required by C++ code for DEEPSEEK2 arch) self.gguf_writer.add_expert_feed_forward_length(hparams["moe_intermediate_size"])
# For non-MoE models like Youtu, use intermediate_size as expert_feed_forward_length self.gguf_writer.add_expert_count(hparams["n_routed_experts"])
moe_intermediate_size = self.find_hparam(["moe_intermediate_size", "intermediate_size"], optional=False) self.gguf_writer.add_expert_shared_count(hparams["n_shared_experts"])
self.gguf_writer.add_expert_feed_forward_length(moe_intermediate_size) self.gguf_writer.add_expert_weights_scale(hparams["routed_scaling_factor"])
self.gguf_writer.add_expert_weights_norm(hparams["norm_topk_prob"])
if (n_routed_experts := hparams.get("n_routed_experts")) is not None:
self.gguf_writer.add_expert_count(n_routed_experts)
# expert_shared_count is required by C++ code, default to 0 for non-MoE models
n_shared_experts = hparams.get("n_shared_experts", 0)
self.gguf_writer.add_expert_shared_count(n_shared_experts)
# When not set, C++ code will use scale_w = false to skip the no-op scaling
if (routed_scaling_factor := hparams.get("routed_scaling_factor")) is not None:
self.gguf_writer.add_expert_weights_scale(routed_scaling_factor)
if (norm_topk_prob := hparams.get("norm_topk_prob")) is not None and norm_topk_prob:
self.gguf_writer.add_expert_weights_norm(norm_topk_prob)
self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"]) self.gguf_writer.add_rope_dimension_count(hparams["qk_rope_head_dim"])
@ -7332,17 +7279,10 @@ class DeepseekV2Model(TextModel):
# skip vision tensors and remove "language_model." for Kimi-VL # skip vision tensors and remove "language_model." for Kimi-VL
if "vision_tower" in name or "multi_modal_projector" in name: if "vision_tower" in name or "multi_modal_projector" in name:
return [] return []
if name.startswith("siglip2.") or name.startswith("merger."):
return []
if name.startswith("language_model."): if name.startswith("language_model."):
name = name.replace("language_model.", "") name = name.replace("language_model.", "")
# skip lm_head.weight if tie_word_embeddings is True
if self.hparams.get("tie_word_embeddings", False):
if name == "lm_head.weight" or name == "model.lm_head.weight":
logger.info("Skipping tied output layer 'lm_head.weight' (will use token_embd.weight)")
return []
# rename e_score_correction_bias tensors # rename e_score_correction_bias tensors
if name.endswith("e_score_correction_bias"): if name.endswith("e_score_correction_bias"):
name = name.replace("e_score_correction_bias", "e_score_correction.bias") name = name.replace("e_score_correction_bias", "e_score_correction.bias")
@ -7489,6 +7429,7 @@ class MimoV2Model(TextModel):
self.gguf_writer.add_sliding_window(self.hparams["sliding_window"]) self.gguf_writer.add_sliding_window(self.hparams["sliding_window"])
self.gguf_writer.add_sliding_window_pattern(self.hparams["hybrid_layer_pattern"]) self.gguf_writer.add_sliding_window_pattern(self.hparams["hybrid_layer_pattern"])
self.gguf_writer.add_rope_freq_base_swa(self.hparams["swa_rope_theta"])
self.gguf_writer.add_value_length(self.hparams["v_head_dim"]) self.gguf_writer.add_value_length(self.hparams["v_head_dim"])
self.gguf_writer.add_expert_count(self.hparams["n_routed_experts"]) self.gguf_writer.add_expert_count(self.hparams["n_routed_experts"])
self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"]) self.gguf_writer.add_expert_feed_forward_length(self.hparams["moe_intermediate_size"])
@ -9351,19 +9292,6 @@ class VoxtralWhisperEncoderModel(WhisperEncoderModel):
self.gguf_writer.add_audio_stack_factor(4) # == intermediate_size // hidden_size self.gguf_writer.add_audio_stack_factor(4) # == intermediate_size // hidden_size
@ModelBase.register("AudioFlamingo3ForConditionalGeneration")
class AudioFlamingo3WhisperEncoderModel(WhisperEncoderModel):
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.MUSIC_FLAMINGO)
def tensor_force_quant(self, name, new_name, bid, n_dims):
if ".conv" in name and ".weight" in name:
# Was trained in BF16, being safe, avoiding quantizing to FP16
return gguf.GGMLQuantizationType.F32
return super().tensor_force_quant(name, new_name, bid, n_dims)
@ModelBase.register("FalconH1ForCausalLM") @ModelBase.register("FalconH1ForCausalLM")
class FalconH1Model(Mamba2Model): class FalconH1Model(Mamba2Model):
model_arch = gguf.MODEL_ARCH.FALCON_H1 model_arch = gguf.MODEL_ARCH.FALCON_H1
@ -9956,27 +9884,6 @@ class LFM2Model(TextModel):
return any(p in name for p in ["audio", "codebook", "conformer", "depth_embedding", "depthformer", "depth_linear"]) return any(p in name for p in ["audio", "codebook", "conformer", "depth_embedding", "depthformer", "depth_linear"])
@ModelBase.register("Lfm2Model")
class LFM2ColBertModel(LFM2Model):
model_arch = gguf.MODEL_ARCH.LFM2
dense_tensor_name = "dense_2"
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if not name.startswith(self.dense_tensor_name):
name = "model." + name
return super().modify_tensors(data_torch, name, bid)
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
# dense tensor is stored in a separate safetensors file
from safetensors.torch import load_file
tensors_file = self.dir_model / "1_Dense" / "model.safetensors"
assert tensors_file.is_file()
tensor = load_file(tensors_file)["linear.weight"]
self.gguf_writer.add_embedding_length_out(tensor.shape[0])
yield f"{self.dense_tensor_name}.weight", tensor.clone()
@ModelBase.register("Lfm2MoeForCausalLM") @ModelBase.register("Lfm2MoeForCausalLM")
class LFM2MoeModel(TextModel): class LFM2MoeModel(TextModel):
model_arch = gguf.MODEL_ARCH.LFM2MOE model_arch = gguf.MODEL_ARCH.LFM2MOE
@ -10247,6 +10154,7 @@ class ModernBertModel(BertModel):
self.gguf_writer.add_sliding_window(self.hparams["local_attention"]) self.gguf_writer.add_sliding_window(self.hparams["local_attention"])
if (sliding_window_pattern := self.hparams.get("global_attn_every_n_layers")) is not None: if (sliding_window_pattern := self.hparams.get("global_attn_every_n_layers")) is not None:
self.gguf_writer.add_sliding_window_pattern(sliding_window_pattern) self.gguf_writer.add_sliding_window_pattern(sliding_window_pattern)
self.gguf_writer.add_rope_freq_base_swa(self.rope_parameters.get("sliding_attention", {"rope_theta": self.hparams.get("local_rope_theta")})["rope_theta"])
self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE) self.gguf_writer.add_rope_scaling_type(gguf.RopeScalingType.NONE)
self.gguf_writer.add_vocab_size(self.hparams["vocab_size"]) self.gguf_writer.add_vocab_size(self.hparams["vocab_size"])
@ -10696,79 +10604,6 @@ class JanusProVisionModel(MmprojModel):
return [] return []
@ModelBase.register("YoutuVLForConditionalGeneration")
class YoutuVLVisionModel(MmprojModel):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
assert self.hparams_vision is not None
self.hparams_vision["image_size"] = self.hparams_vision.get("image_size", 560)
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_clip_projector_type(gguf.VisionProjectorType.YOUTUVL)
self.gguf_writer.add_vision_attention_layernorm_eps(self.hparams.get("layer_norm_eps", 1e-6))
# Handle activation function
hidden_act = str(self.hparams.get("hidden_act", "gelu_pytorch_tanh")).lower()
if hidden_act in ("gelu", "gelu_pytorch_tanh", "gelu_fast", "gelu_new", "gelu_accurate"):
self.gguf_writer.add_vision_use_gelu(True)
elif hidden_act == "silu":
self.gguf_writer.add_vision_use_silu(True)
else:
raise ValueError(f"Unsupported activation function for YOUTUVL: {hidden_act}")
self.gguf_writer.add_vision_spatial_merge_size(self.hparams.get("spatial_merge_size", 2))
window_size = self.hparams.get("window_size")
if window_size is not None:
self.gguf_writer.add_vision_window_size(window_size)
# fullatt_block_indexes contains explicit layer indices that use full attention
# e.g., [2, 5, 8, 11] means layers 2, 5, 8, 11 use full attention
# All other layers use window attention
fullatt_block_indexes = self.hparams.get("fullatt_block_indexes")
assert fullatt_block_indexes is not None, "fullatt_block_indexes is required for youtuvl"
# Store the explicit layer indices for YoutuVL (irregular pattern approach)
self.gguf_writer.add_vision_wa_layer_indexes(layers=fullatt_block_indexes)
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
del bid # unused
# Skip language model tensors
skip_prefixes = ('lm_head.', 'model.layers.', 'model.embed_tokens.', 'model.norm.')
if name.startswith(skip_prefixes):
return []
# Try to map the tensor using TensorNameMap (handles vision encoder and projector)
try:
new_name = self.map_tensor_name(name)
return [(new_name, data_torch)]
except ValueError:
# If mapping fails, log warning and skip
logger.warning(f"Cannot map tensor: {name}")
return []
@ModelBase.register("SolarOpenForCausalLM")
class SolarOpenModel(Glm4MoeModel):
model_arch = gguf.MODEL_ARCH.GLM4_MOE
def set_vocab(self):
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained(self.dir_model)
special_vocab = gguf.SpecialVocab(self.dir_model, load_merges=True)
tokens, toktypes, tokpre = self.get_vocab_base()
self.gguf_writer.add_tokenizer_model("gpt2")
self.gguf_writer.add_tokenizer_pre(tokpre)
self.gguf_writer.add_token_list(tokens)
self.gguf_writer.add_token_types(toktypes)
special_vocab._set_special_token("eos", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab._set_special_token("eot", tokenizer.get_added_vocab()["<|endoftext|>"])
special_vocab._set_special_token("unk", tokenizer.get_added_vocab()["<unk>"])
special_vocab._set_special_token("bos", tokenizer.get_added_vocab()["<|startoftext|>"])
special_vocab.add_to_gguf(self.gguf_writer)
###### CONVERSION LOGIC ###### ###### CONVERSION LOGIC ######
@ -10974,8 +10809,8 @@ def parse_args() -> argparse.Namespace:
parser.add_argument( parser.add_argument(
"--sentence-transformers-dense-modules", action="store_true", "--sentence-transformers-dense-modules", action="store_true",
help=("Whether to include sentence-transformers dense modules. " help=("Whether to include sentence-transformers dense modules."
"It can be used for sentence-transformers models, like google/embeddinggemma-300m. " "It can be used for sentence-transformers models, like google/embeddinggemma-300m"
"Default these modules are not included.") "Default these modules are not included.")
) )

View File

@ -145,8 +145,6 @@ models = [
{"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", }, {"name": "granite-docling", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/ibm-granite/granite-docling-258M", },
{"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", }, {"name": "minimax-m2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/MiniMaxAI/MiniMax-M2", },
{"name": "kormo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/KORMo-Team/KORMo-tokenizer", }, {"name": "kormo", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/KORMo-Team/KORMo-tokenizer", },
{"name": "youtu", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/tencent/Youtu-LLM-2B", },
{"name": "solar-open", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/upstage/Solar-Open-100B", },
] ]
# some models are known to be broken upstream, so we will skip them as exceptions # some models are known to be broken upstream, so we will skip them as exceptions
@ -167,8 +165,6 @@ pre_computed_hashes = [
{"name": "kimi-k2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base", "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890"}, {"name": "kimi-k2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/moonshotai/Kimi-K2-Base", "chkhsh": "81212dc7cdb7e0c1074ca62c5aeab0d43c9f52b8a737be7b12a777c953027890"},
{"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3-Embedding-0.6B", "chkhsh": "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c"}, {"name": "qwen2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/Qwen/Qwen3-Embedding-0.6B", "chkhsh": "d4540891389ea895b53b399da6ac824becc30f2fba0e9ddbb98f92e55ca0e97c"},
{"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"}, {"name": "grok-2", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/alvarobartt/grok-2-tokenizer", "chkhsh": "66b8d4e19ab16c3bfd89bce5d785fb7e0155e8648708a1f42077cb9fe002c273"},
# jina-v2-de variants
{"name": "jina-v2-de", "tokt": TOKENIZER_TYPE.BPE, "repo": "https://huggingface.co/aari1995/German_Semantic_V3", "chkhsh": "b3d1dd861f1d4c5c0d2569ce36baf3f90fe8a102db3de50dd71ff860d91be3df"},
] ]

View File

@ -327,7 +327,3 @@ Maximum number of compiled CANN graphs kept in the LRU cache, default is 12. Whe
### GGML_CANN_PREFILL_USE_GRAPH ### GGML_CANN_PREFILL_USE_GRAPH
Enable ACL graph execution during the prefill stage, default is false. This option is only effective when FA is enabled. Enable ACL graph execution during the prefill stage, default is false. This option is only effective when FA is enabled.
### GGML_CANN_OPERATOR_FUSION
Enable operator fusion during computation, default is false. This option fuses compatible operators (e.g., ADD + RMS_NORM) to reduce overhead and improve performance.

View File

@ -218,56 +218,6 @@ cmake .. -G Ninja `
ninja ninja
``` ```
## Linux
The two steps just above also apply to Linux. When building for linux, the commands are mostly the same as those for PowerShell on Windows, but in the second step they do not have the `-DCMAKE_TOOLCHAIN_FILE` parameter, and then in both steps the backticks are replaced with back slashes.
If not installed already, install Git, CMake, Clang, Ninja and Python, then run in the terminal the following:
### I. Setup Environment
1. **Install OpenCL Headers and Library**
```bash
mkdir -p ~/dev/llm
cd ~/dev/llm
git clone https://github.com/KhronosGroup/OpenCL-Headers && cd OpenCL-Headers
mkdir build && cd build
cmake .. -G Ninja \
-DBUILD_TESTING=OFF \
-DOPENCL_HEADERS_BUILD_TESTING=OFF \
-DOPENCL_HEADERS_BUILD_CXX_TESTS=OFF \
-DCMAKE_INSTALL_PREFIX="$HOME/dev/llm/opencl"
cmake --build . --target install
cd ~/dev/llm
git clone https://github.com/KhronosGroup/OpenCL-ICD-Loader && cd OpenCL-ICD-Loader
mkdir build && cd build
cmake .. -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
-DCMAKE_PREFIX_PATH="$HOME/dev/llm/opencl" \
-DCMAKE_INSTALL_PREFIX="$HOME/dev/llm/opencl"
cmake --build . --target install
```
### II. Build llama.cpp
```bash
mkdir -p ~/dev/llm
cd ~/dev/llm
git clone https://github.com/ggml-org/llama.cpp && cd llama.cpp
mkdir build && cd build
cmake .. -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
-DCMAKE_PREFIX_PATH="$HOME/dev/llm/opencl" \
-DBUILD_SHARED_LIBS=OFF \
-DGGML_OPENCL=ON
ninja
```
## Known Issues ## Known Issues
- Flash attention does not always improve performance. - Flash attention does not always improve performance.

View File

@ -22,7 +22,7 @@ Legend:
| ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | ARANGE | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | ARGMAX | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ | ❌ | | ARGSORT | ❌ | ✅ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ❌ | ❌ | ❌ |
| CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | | ❌ | ❌ | | CEIL | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | | ❌ | ❌ |
| CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ | ❌ | | CLAMP | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ | ❌ |
| CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ | | CONCAT | ❌ | ✅ | ✅ | 🟡 | ✅ | 🟡 | ✅ | ✅ | ❌ | ❌ | ❌ |
| CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ | | CONT | ❌ | 🟡 | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | 🟡 | ❌ | ❌ |
@ -32,7 +32,7 @@ Legend:
| CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | CONV_TRANSPOSE_1D | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | | CONV_TRANSPOSE_2D | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | | COS | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
| COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | | COUNT_EQUAL | ❌ | ✅ | ✅ | ✅ | | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ |
| CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | | CPY | ❌ | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ |
| CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | CROSS_ENTROPY_LOSS | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |
| CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | | CROSS_ENTROPY_LOSS_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ❌ |

View File

@ -965,7 +965,6 @@
"Metal","IM2COL","type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,2560],ne_kernel=[3,3,1,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1","support","1","yes","Metal" "Metal","IM2COL","type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,1,2560],ne_kernel=[3,3,1,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1","support","1","yes","Metal"
"Metal","IM2COL","type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,2560],ne_kernel=[3,3,2,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1","support","1","yes","Metal" "Metal","IM2COL","type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[12,12,2,2560],ne_kernel=[3,3,2,2560],s0=1,s1=1,p0=1,p1=1,d0=1,d1=1,is_2D=1","support","1","yes","Metal"
"Metal","IM2COL","type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[5,5,1,32],ne_kernel=[3,4,1,32],s0=1,s1=1,p0=0,p1=0,d0=1,d1=1,is_2D=1","support","1","yes","Metal" "Metal","IM2COL","type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[5,5,1,32],ne_kernel=[3,4,1,32],s0=1,s1=1,p0=0,p1=0,d0=1,d1=1,is_2D=1","support","1","yes","Metal"
"Metal","IM2COL","type_input=f32,type_kernel=f32,dst_type=f32,ne_input=[2,2,1536,729],ne_kernel=[2,2,1536,4096],s0=1,s1=1,p0=0,p1=0,d0=1,d1=1,is_2D=1","support","1","yes","Metal"
"Metal","IM2COL_3D","type_input=f32,type_kernel=f32,dst_type=f32,ne_input=[10,10,10,9],ne_kernel=[3,3,3,1],IC=3,s0=1,s1=1,s2=1,p0=1,p1=1,p2=1,d0=1,d1=1,d2=1,v=0","support","0","no","Metal" "Metal","IM2COL_3D","type_input=f32,type_kernel=f32,dst_type=f32,ne_input=[10,10,10,9],ne_kernel=[3,3,3,1],IC=3,s0=1,s1=1,s2=1,p0=1,p1=1,p2=1,d0=1,d1=1,d2=1,v=0","support","0","no","Metal"
"Metal","IM2COL_3D","type_input=f32,type_kernel=f16,dst_type=f32,ne_input=[10,10,10,9],ne_kernel=[3,3,3,1],IC=3,s0=1,s1=1,s2=1,p0=1,p1=1,p2=1,d0=1,d1=1,d2=1,v=0","support","0","no","Metal" "Metal","IM2COL_3D","type_input=f32,type_kernel=f16,dst_type=f32,ne_input=[10,10,10,9],ne_kernel=[3,3,3,1],IC=3,s0=1,s1=1,s2=1,p0=1,p1=1,p2=1,d0=1,d1=1,d2=1,v=0","support","0","no","Metal"
"Metal","IM2COL_3D","type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[10,10,10,9],ne_kernel=[3,3,3,1],IC=3,s0=1,s1=1,s2=1,p0=1,p1=1,p2=1,d0=1,d1=1,d2=1,v=0","support","0","no","Metal" "Metal","IM2COL_3D","type_input=f32,type_kernel=f16,dst_type=f16,ne_input=[10,10,10,9],ne_kernel=[3,3,3,1],IC=3,s0=1,s1=1,s2=1,p0=1,p1=1,p2=1,d0=1,d1=1,d2=1,v=0","support","0","no","Metal"
@ -4965,9 +4964,8 @@
"Metal","CONV_TRANSPOSE_1D","ne_input=[2,1,1,1],ne_kernel=[3,1,1,1],s0=1,p0=0,d0=1","support","1","yes","Metal" "Metal","CONV_TRANSPOSE_1D","ne_input=[2,1,1,1],ne_kernel=[3,1,1,1],s0=1,p0=0,d0=1","support","1","yes","Metal"
"Metal","CONV_TRANSPOSE_2D","ne_input=[3,2,3,1],ne_kernel=[2,2,1,3],stride=1","support","1","yes","Metal" "Metal","CONV_TRANSPOSE_2D","ne_input=[3,2,3,1],ne_kernel=[2,2,1,3],stride=1","support","1","yes","Metal"
"Metal","CONV_TRANSPOSE_2D","ne_input=[10,10,9,1],ne_kernel=[3,3,1,9],stride=2","support","1","yes","Metal" "Metal","CONV_TRANSPOSE_2D","ne_input=[10,10,9,1],ne_kernel=[3,3,1,9],stride=2","support","1","yes","Metal"
"Metal","CONV_TRANSPOSE_2D","ne_input=[129,63,35,1],ne_kernel=[3,3,48,35],stride=1","support","1","yes","Metal" "Metal","COUNT_EQUAL","type=f32,ne=[4,500,1,1]","support","0","no","Metal"
"Metal","COUNT_EQUAL","type=f32,ne=[4,500,1,1]","support","1","yes","Metal" "Metal","COUNT_EQUAL","type=f32,ne=[4,5000,1,1]","support","0","no","Metal"
"Metal","COUNT_EQUAL","type=f32,ne=[4,5000,1,1]","support","1","yes","Metal"
"Metal","ARGMAX","type=f32,ne=[32,1,1,1]","support","1","yes","Metal" "Metal","ARGMAX","type=f32,ne=[32,1,1,1]","support","1","yes","Metal"
"Metal","ARGMAX","type=f32,ne=[32,513,1,1]","support","1","yes","Metal" "Metal","ARGMAX","type=f32,ne=[32,513,1,1]","support","1","yes","Metal"
"Metal","ARGMAX","type=f32,ne=[100,10,1,1]","support","1","yes","Metal" "Metal","ARGMAX","type=f32,ne=[100,10,1,1]","support","1","yes","Metal"
@ -5717,15 +5715,15 @@
"Metal","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","Metal" "Metal","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","Metal"
"Metal","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001,inplace=1","support","1","yes","Metal" "Metal","RMS_NORM","type=f32,ne=[64,5,4,3],v=0,eps=0.000001,inplace=1","support","1","yes","Metal"
"Metal","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","Metal" "Metal","L2_NORM","type=f32,ne=[64,5,4,3]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[3,1024,1,1],ne_b=[3,1024,1,1]","support","1","yes","Metal" "Metal","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[3,1024,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[6,1024,1,1],ne_b=[3,1024,1,1]","support","1","yes","Metal" "Metal","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[3,1024,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[3,1024,4,1],ne_b=[3,1024,1,1]","support","1","yes","Metal" "Metal","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[3,1024,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[3,1536,1,1],ne_b=[3,1536,1,1]","support","1","yes","Metal" "Metal","SSM_CONV","type=f32,ne_a=[4,1536,1,1],ne_b=[3,1536,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[6,1536,1,1],ne_b=[3,1536,1,1]","support","1","yes","Metal" "Metal","SSM_CONV","type=f32,ne_a=[8,1536,1,1],ne_b=[3,1536,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[3,1536,4,1],ne_b=[3,1536,1,1]","support","1","yes","Metal" "Metal","SSM_CONV","type=f32,ne_a=[4,1536,4,1],ne_b=[3,1536,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[3,2048,1,1],ne_b=[3,2048,1,1]","support","1","yes","Metal" "Metal","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[3,2048,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[6,2048,1,1],ne_b=[3,2048,1,1]","support","1","yes","Metal" "Metal","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[3,2048,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[3,2048,4,1],ne_b=[3,2048,1,1]","support","1","yes","Metal" "Metal","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[3,2048,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[4,1024,1,1]","support","1","yes","Metal" "Metal","SSM_CONV","type=f32,ne_a=[4,1024,1,1],ne_b=[4,1024,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[4,1024,1,1]","support","1","yes","Metal" "Metal","SSM_CONV","type=f32,ne_a=[8,1024,1,1],ne_b=[4,1024,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[4,1024,1,1]","support","1","yes","Metal" "Metal","SSM_CONV","type=f32,ne_a=[4,1024,4,1],ne_b=[4,1024,1,1]","support","1","yes","Metal"
@ -5735,15 +5733,6 @@
"Metal","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[4,2048,1,1]","support","1","yes","Metal" "Metal","SSM_CONV","type=f32,ne_a=[4,2048,1,1],ne_b=[4,2048,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[4,2048,1,1]","support","1","yes","Metal" "Metal","SSM_CONV","type=f32,ne_a=[8,2048,1,1],ne_b=[4,2048,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[4,2048,1,1]","support","1","yes","Metal" "Metal","SSM_CONV","type=f32,ne_a=[4,2048,4,1],ne_b=[4,2048,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[9,1024,1,1],ne_b=[9,1024,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[18,1024,1,1],ne_b=[9,1024,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[9,1024,4,1],ne_b=[9,1024,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[9,1536,1,1],ne_b=[9,1536,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[18,1536,1,1],ne_b=[9,1536,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[9,1536,4,1],ne_b=[9,1536,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[9,2048,1,1],ne_b=[9,2048,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[18,2048,1,1],ne_b=[9,2048,1,1]","support","1","yes","Metal"
"Metal","SSM_CONV","type=f32,ne_a=[9,2048,4,1],ne_b=[9,2048,1,1]","support","1","yes","Metal"
"Metal","SSM_SCAN","type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=4","support","1","yes","Metal" "Metal","SSM_SCAN","type=f32,d_state=16,head_dim=1,n_head=1024,n_group=1,n_seq_tokens=32,n_seqs=4","support","1","yes","Metal"
"Metal","SSM_SCAN","type=f32,d_state=128,head_dim=64,n_head=16,n_group=2,n_seq_tokens=32,n_seqs=4","support","1","yes","Metal" "Metal","SSM_SCAN","type=f32,d_state=128,head_dim=64,n_head=16,n_group=2,n_seq_tokens=32,n_seqs=4","support","1","yes","Metal"
"Metal","SSM_SCAN","type=f32,d_state=256,head_dim=64,n_head=8,n_group=2,n_seq_tokens=32,n_seqs=4","support","1","yes","Metal" "Metal","SSM_SCAN","type=f32,d_state=256,head_dim=64,n_head=8,n_group=2,n_seq_tokens=32,n_seqs=4","support","1","yes","Metal"
@ -8927,8 +8916,6 @@
"Metal","SOFT_MAX","type=f32,ne=[32,2,32,1],mask=1,sinks=0,m_prec=f16,nr23=[1,1],scale=0.100000,max_bias=0.000000,inplace=0","support","1","yes","Metal" "Metal","SOFT_MAX","type=f32,ne=[32,2,32,1],mask=1,sinks=0,m_prec=f16,nr23=[1,1],scale=0.100000,max_bias=0.000000,inplace=0","support","1","yes","Metal"
"Metal","SOFT_MAX","type=f32,ne=[32,2,32,1],mask=1,sinks=1,m_prec=f32,nr23=[1,1],scale=0.100000,max_bias=8.000000,inplace=0","support","1","yes","Metal" "Metal","SOFT_MAX","type=f32,ne=[32,2,32,1],mask=1,sinks=1,m_prec=f32,nr23=[1,1],scale=0.100000,max_bias=8.000000,inplace=0","support","1","yes","Metal"
"Metal","SOFT_MAX","type=f32,ne=[32,2,32,1],mask=1,sinks=1,m_prec=f16,nr23=[1,1],scale=0.100000,max_bias=8.000000,inplace=0","support","1","yes","Metal" "Metal","SOFT_MAX","type=f32,ne=[32,2,32,1],mask=1,sinks=1,m_prec=f16,nr23=[1,1],scale=0.100000,max_bias=8.000000,inplace=0","support","1","yes","Metal"
"Metal","SOFT_MAX","type=f32,ne=[200001,2,3,1],mask=1,sinks=1,m_prec=f32,nr23=[1,1],scale=0.100000,max_bias=8.000000,inplace=0","support","1","yes","Metal"
"Metal","SOFT_MAX","type=f32,ne=[200001,2,3,1],mask=1,sinks=1,m_prec=f16,nr23=[1,1],scale=0.100000,max_bias=8.000000,inplace=0","support","1","yes","Metal"
"Metal","SOFT_MAX_BACK","type=f32,ne=[16,16,1,1],scale=1.000000,max_bias=0.000000","support","0","no","Metal" "Metal","SOFT_MAX_BACK","type=f32,ne=[16,16,1,1],scale=1.000000,max_bias=0.000000","support","0","no","Metal"
"Metal","SOFT_MAX_BACK","type=f32,ne=[15,15,1,1],scale=1.000000,max_bias=0.000000","support","0","no","Metal" "Metal","SOFT_MAX_BACK","type=f32,ne=[15,15,1,1],scale=1.000000,max_bias=0.000000","support","0","no","Metal"
"Metal","SOFT_MAX_BACK","type=f32,ne=[16,16,2,3],scale=1.000000,max_bias=0.000000","support","0","no","Metal" "Metal","SOFT_MAX_BACK","type=f32,ne=[16,16,2,3],scale=1.000000,max_bias=0.000000","support","0","no","Metal"
@ -9555,311 +9542,311 @@
"Metal","ARGSORT","type=f32,ne=[2048,2,1,3],order=1","support","1","yes","Metal" "Metal","ARGSORT","type=f32,ne=[2048,2,1,3],order=1","support","1","yes","Metal"
"Metal","ARGSORT","type=f32,ne=[2049,2,1,3],order=1","support","1","yes","Metal" "Metal","ARGSORT","type=f32,ne=[2049,2,1,3],order=1","support","1","yes","Metal"
"Metal","ARGSORT","type=f32,ne=[2,8,8192,1],order=1","support","1","yes","Metal" "Metal","ARGSORT","type=f32,ne=[2,8,8192,1],order=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1,1,1,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1,1,1,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[12,1,2,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[12,1,2,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2,1,1,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2,1,1,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[13,1,2,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[13,1,2,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2,1,1,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2,1,1,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[13,1,2,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[13,1,2,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4,1,1,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4,1,1,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[15,1,2,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[15,1,2,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4,1,1,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4,1,1,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[15,1,2,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[15,1,2,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4,1,1,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4,1,1,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[15,1,2,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[15,1,2,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8,1,1,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8,1,1,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[19,1,2,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[19,1,2,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8,1,1,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8,1,1,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[19,1,2,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[19,1,2,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8,1,1,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8,1,1,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[19,1,2,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[19,1,2,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8,1,1,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8,1,1,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[19,1,2,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[19,1,2,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16,1,1,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16,1,1,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[27,1,2,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[27,1,2,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16,1,1,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16,1,1,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[27,1,2,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[27,1,2,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16,1,1,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16,1,1,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[27,1,2,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[27,1,2,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16,1,1,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16,1,1,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[27,1,2,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[27,1,2,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16,1,1,1],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16,1,1,1],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[27,1,2,1],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[27,1,2,1],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32,1,1,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32,1,1,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[43,1,2,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[43,1,2,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32,1,1,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32,1,1,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[43,1,2,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[43,1,2,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32,1,1,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32,1,1,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[43,1,2,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[43,1,2,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32,1,1,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32,1,1,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[43,1,2,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[43,1,2,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32,1,1,1],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32,1,1,1],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[43,1,2,1],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[43,1,2,1],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[64,1,1,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[64,1,1,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[75,1,2,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[75,1,2,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[64,1,1,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[64,1,1,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[75,1,2,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[75,1,2,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[64,1,1,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[64,1,1,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[75,1,2,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[75,1,2,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[64,1,1,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[64,1,1,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[75,1,2,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[75,1,2,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[64,1,1,1],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[64,1,1,1],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[75,1,2,1],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[75,1,2,1],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[128,1,1,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[128,1,1,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[139,1,2,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[139,1,2,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[128,1,1,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[128,1,1,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[139,1,2,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[139,1,2,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[128,1,1,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[128,1,1,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[139,1,2,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[139,1,2,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[128,1,1,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[128,1,1,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[139,1,2,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[139,1,2,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[128,1,1,1],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[128,1,1,1],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[139,1,2,1],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[139,1,2,1],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[128,1,1,1],k=100,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[128,1,1,1],k=100","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[139,1,2,1],k=100,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[139,1,2,1],k=100","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[256,1,1,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[256,1,1,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[267,1,2,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[267,1,2,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[256,1,1,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[256,1,1,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[267,1,2,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[267,1,2,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[256,1,1,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[256,1,1,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[267,1,2,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[267,1,2,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[256,1,1,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[256,1,1,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[267,1,2,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[267,1,2,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[256,1,1,1],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[256,1,1,1],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[267,1,2,1],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[267,1,2,1],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[256,1,1,1],k=100,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[256,1,1,1],k=100","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[267,1,2,1],k=100,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[267,1,2,1],k=100","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[512,1,1,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[512,1,1,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[523,1,2,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[523,1,2,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[512,1,1,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[512,1,1,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[523,1,2,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[523,1,2,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[512,1,1,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[512,1,1,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[523,1,2,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[523,1,2,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[512,1,1,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[512,1,1,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[523,1,2,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[523,1,2,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[512,1,1,1],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[512,1,1,1],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[523,1,2,1],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[523,1,2,1],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[512,1,1,1],k=100,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[512,1,1,1],k=100","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[523,1,2,1],k=100,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[523,1,2,1],k=100","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[512,1,1,1],k=500,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[512,1,1,1],k=500","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[523,1,2,1],k=500,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[523,1,2,1],k=500","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=100,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=100","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=100,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=100","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=500,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=500","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=500,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=500","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=1023,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1024,1,1,1],k=1023","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=1023,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1035,1,2,1],k=1023","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=100,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=100","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=100,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=100","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=500,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=500","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=500,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=500","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=1023,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2048,1,1,1],k=1023","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=1023,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2059,1,2,1],k=1023","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=100,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=100","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=100,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=100","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=500,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=500","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=500,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=500","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=1023,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4096,1,1,1],k=1023","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=1023,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[4107,1,2,1],k=1023","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=100,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=100","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=100,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=100","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=500,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=500","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=500,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=500","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=1023,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8192,1,1,1],k=1023","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=1023,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[8203,1,2,1],k=1023","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=100,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=100","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=100,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=100","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=500,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=500","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=500,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=500","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=1023,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=1023","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=1023,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=1023","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=9999,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=9999","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=9999,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16395,1,2,1],k=9999","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=100,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=100","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=100,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=100","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=500,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=500","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=500,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=500","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=1023,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=1023","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=1023,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=1023","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=9999,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32768,1,1,1],k=9999","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=9999,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[32779,1,2,1],k=9999","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=100,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=100","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=100,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=100","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=500,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=500","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=500,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=500","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=1023,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=1023","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=1023,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=1023","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=9999,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65536,1,1,1],k=9999","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=9999,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[65547,1,2,1],k=9999","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=100,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=100","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=100,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=100","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=500,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=500","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=500,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=500","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=1023,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=1023","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=1023,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=1023","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=9999,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131072,1,1,1],k=9999","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=9999,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[131083,1,2,1],k=9999","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=100,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=100","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=100,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=100","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=500,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=500","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=500,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=500","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=1023,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=1023","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=1023,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=1023","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=9999,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262144,1,1,1],k=9999","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=9999,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[262155,1,2,1],k=9999","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=100,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=100","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=100,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=100","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=500,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=500","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=500,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=500","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=1023,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=1023","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=1023,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=1023","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=9999,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524288,1,1,1],k=9999","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=9999,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[524299,1,2,1],k=9999","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16,10,10,10],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16,10,10,10],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[60,10,10,10],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[60,10,10,10],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1023,2,1,3],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1023,2,1,3],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1024,2,1,3],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1024,2,1,3],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1025,2,1,3],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1025,2,1,3],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2047,2,1,3],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2047,2,1,3],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2048,2,1,3],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2048,2,1,3],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2049,2,1,3],k=1,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2049,2,1,3],k=1","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16,10,10,10],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16,10,10,10],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[60,10,10,10],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[60,10,10,10],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1023,2,1,3],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1023,2,1,3],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1024,2,1,3],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1024,2,1,3],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1025,2,1,3],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1025,2,1,3],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2047,2,1,3],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2047,2,1,3],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2048,2,1,3],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2048,2,1,3],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2049,2,1,3],k=2,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2049,2,1,3],k=2","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16,10,10,10],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16,10,10,10],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[60,10,10,10],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[60,10,10,10],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1023,2,1,3],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1023,2,1,3],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1024,2,1,3],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1024,2,1,3],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1025,2,1,3],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1025,2,1,3],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2047,2,1,3],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2047,2,1,3],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2048,2,1,3],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2048,2,1,3],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2049,2,1,3],k=3,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2049,2,1,3],k=3","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16,10,10,10],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16,10,10,10],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[60,10,10,10],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[60,10,10,10],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1023,2,1,3],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1023,2,1,3],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1024,2,1,3],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1024,2,1,3],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1025,2,1,3],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1025,2,1,3],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2047,2,1,3],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2047,2,1,3],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2048,2,1,3],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2048,2,1,3],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2049,2,1,3],k=7,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2049,2,1,3],k=7","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16,10,10,10],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16,10,10,10],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[60,10,10,10],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[60,10,10,10],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1023,2,1,3],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1023,2,1,3],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1024,2,1,3],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1024,2,1,3],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[1025,2,1,3],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[1025,2,1,3],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[16384,1,1,1],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2047,2,1,3],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2047,2,1,3],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2048,2,1,3],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2048,2,1,3],k=15","support","1","yes","Metal"
"Metal","TOP_K","type=f32,ne=[2049,2,1,3],k=15,ties=0","support","1","yes","Metal" "Metal","TOP_K","type=f32,ne=[2049,2,1,3],k=15","support","1","yes","Metal"
"Metal","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=nearest,transpose=0","support","1","yes","Metal" "Metal","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=nearest,transpose=0","support","1","yes","Metal"
"Metal","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=nearest,transpose=1","support","1","yes","Metal" "Metal","UPSCALE","type=f32,ne=[512,512,3,2],scale_factor=2,mode=nearest,transpose=1","support","1","yes","Metal"
"Metal","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=nearest,flags=none","support","1","yes","Metal" "Metal","UPSCALE","type=f32,ne=[2,5,7,11],ne_tgt=[5,7,11,13],mode=nearest,flags=none","support","1","yes","Metal"
@ -9904,9 +9891,8 @@
"Metal","GROUP_NORM","type=f32,ne=[64,64,320,1],num_groups=32,eps=0.000001","support","1","yes","Metal" "Metal","GROUP_NORM","type=f32,ne=[64,64,320,1],num_groups=32,eps=0.000001","support","1","yes","Metal"
"Metal","GROUP_NORM","type=f32,ne=[9,9,1280,1],num_groups=32,eps=0.000001","support","1","yes","Metal" "Metal","GROUP_NORM","type=f32,ne=[9,9,1280,1],num_groups=32,eps=0.000001","support","1","yes","Metal"
"Metal","ACC","type=f32,ne_a=[256,17,1,1],ne_b=[256,16,1,1]","support","1","yes","Metal" "Metal","ACC","type=f32,ne_a=[256,17,1,1],ne_b=[256,16,1,1]","support","1","yes","Metal"
"Metal","PAD","type=f32,ne_a=[512,512,1,1],pad_0=1,pad_1=1,circular=0","support","1","yes","Metal" "Metal","PAD","type=f32,ne_a=[512,512,1,1],pad_0=1,pad_1=1","support","1","yes","Metal"
"Metal","PAD","type=f32,ne_a=[33,17,2,1],pad_0=4,pad_1=3,circular=1","support","0","no","Metal" "Metal","PAD","type=f32,ne_a=[512,512,3,1],lp0=1,rp0=1,lp1=1,rp1=1,lp2=1,rp2=1,lp3=1,rp3=1,v=0","support","0","no","Metal"
"Metal","PAD","type=f32,ne_a=[512,512,3,1],lp0=1,rp0=1,lp1=1,rp1=1,lp2=1,rp2=1,lp3=1,rp3=1,v=0,circular=0","support","0","no","Metal"
"Metal","PAD_REFLECT_1D","type=f32,ne_a=[512,34,2,1],pad_0=10,pad_1=9","support","1","yes","Metal" "Metal","PAD_REFLECT_1D","type=f32,ne_a=[512,34,2,1],pad_0=10,pad_1=9","support","1","yes","Metal"
"Metal","PAD_REFLECT_1D","type=f32,ne_a=[3000,384,4,1],pad_0=10,pad_1=9","support","1","yes","Metal" "Metal","PAD_REFLECT_1D","type=f32,ne_a=[3000,384,4,1],pad_0=10,pad_1=9","support","1","yes","Metal"
"Metal","ROLL","shift0=3,shift1=-2,shift3=1,shift4=-1","support","0","no","Metal" "Metal","ROLL","shift0=3,shift1=-2,shift3=1,shift4=-1","support","0","no","Metal"
@ -9937,41 +9923,17 @@
"Metal","FILL","type=f32,ne=[303,207,11,3],c=2.000000","support","1","yes","Metal" "Metal","FILL","type=f32,ne=[303,207,11,3],c=2.000000","support","1","yes","Metal"
"Metal","FILL","type=f32,ne=[800,600,4,4],c=-152.000000","support","1","yes","Metal" "Metal","FILL","type=f32,ne=[800,600,4,4],c=-152.000000","support","1","yes","Metal"
"Metal","FILL","type=f32,ne=[2048,512,2,2],c=3.500000","support","1","yes","Metal" "Metal","FILL","type=f32,ne=[2048,512,2,2],c=3.500000","support","1","yes","Metal"
"Metal","DIAG","type=f32,ne=[10,1,4,3]","support","0","no","Metal"
"Metal","DIAG","type=f32,ne=[79,1,19,13]","support","0","no","Metal"
"Metal","DIAG","type=f32,ne=[256,1,8,16]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[10,10,4,3],ne_rhs=[3,10,4,3]","support","0","no","Metal" "Metal","SOLVE_TRI","type=f32,ne_lhs=[10,10,4,3],ne_rhs=[3,10,4,3]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[11,11,1,1],ne_rhs=[5,11,1,1]","support","0","no","Metal" "Metal","SOLVE_TRI","type=f32,ne_lhs=[11,11,1,1],ne_rhs=[5,11,1,1]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[17,17,2,4],ne_rhs=[9,17,2,4]","support","0","no","Metal" "Metal","SOLVE_TRI","type=f32,ne_lhs=[17,17,2,4],ne_rhs=[9,17,2,4]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[30,30,7,1],ne_rhs=[8,30,7,1]","support","0","no","Metal" "Metal","SOLVE_TRI","type=f32,ne_lhs=[30,30,7,1],ne_rhs=[8,30,7,1]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[42,42,5,2],ne_rhs=[10,42,5,2]","support","0","no","Metal" "Metal","SOLVE_TRI","type=f32,ne_lhs=[42,42,5,2],ne_rhs=[10,42,5,2]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[64,64,2,2],ne_rhs=[10,64,2,2]","support","0","no","Metal" "Metal","SOLVE_TRI","type=f32,ne_lhs=[64,64,2,2],ne_rhs=[10,64,2,2]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[64,64,2,2],ne_rhs=[64,64,2,2]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[79,79,5,3],ne_rhs=[417,79,5,3]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[128,128,4,2],ne_rhs=[32,128,4,2]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[80,80,2,8],ne_rhs=[80,80,2,8]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[80,80,2,8],ne_rhs=[79,80,2,8]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[80,80,2,8],ne_rhs=[81,80,2,8]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[80,80,8,8],ne_rhs=[80,80,8,8]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[80,80,8,8],ne_rhs=[79,80,8,8]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[80,80,8,8],ne_rhs=[81,80,8,8]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[84,84,4,4],ne_rhs=[32,84,4,4]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[95,95,8,8],ne_rhs=[40,95,8,8]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[100,100,4,4],ne_rhs=[41,100,4,4]","support","0","no","Metal" "Metal","SOLVE_TRI","type=f32,ne_lhs=[100,100,4,4],ne_rhs=[41,100,4,4]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[128,128,4,4],ne_rhs=[31,128,4,4]","support","0","no","Metal" "Metal","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=0","support","1","yes","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[128,128,4,4],ne_rhs=[32,128,4,4]","support","0","no","Metal" "Metal","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=0","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[128,128,3,4],ne_rhs=[32,128,3,4]","support","0","no","Metal" "Metal","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=1","support","1","yes","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[128,128,4,1],ne_rhs=[32,128,4,1]","support","0","no","Metal" "Metal","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=1","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[64,64,4,4],ne_rhs=[200,64,4,4]","support","0","no","Metal"
"Metal","SOLVE_TRI","type=f32,ne_lhs=[64,64,4,4],ne_rhs=[384,64,4,4]","support","0","no","Metal"
"Metal","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=0,circular=0","support","1","yes","Metal"
"Metal","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=0,circular=0","support","0","no","Metal"
"Metal","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=0,circular=1","support","0","no","Metal"
"Metal","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=0,circular=1","support","0","no","Metal"
"Metal","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=1,circular=0","support","1","yes","Metal"
"Metal","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=1,circular=0","support","0","no","Metal"
"Metal","PAD","type=f32,ne_a=[512,512,1,1],lp0=0,rp0=1,lp1=0,rp1=1,lp2=0,rp2=0,lp3=0,rp3=0,v=1,circular=1","support","0","no","Metal"
"Metal","PAD","type=f32,ne_a=[11,22,33,44],lp0=1,rp0=2,lp1=3,rp1=4,lp2=5,rp2=6,lp3=7,rp3=8,v=1,circular=1","support","0","no","Metal"
"Metal","FLASH_ATTN_EXT","hsk=40,hsv=40,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","1","yes","Metal" "Metal","FLASH_ATTN_EXT","hsk=40,hsv=40,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f32,permute=[0,1,2,3]","support","1","yes","Metal"
"Metal","FLASH_ATTN_EXT","hsk=40,hsv=40,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","1","yes","Metal" "Metal","FLASH_ATTN_EXT","hsk=40,hsv=40,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=f16,permute=[0,1,2,3]","support","1","yes","Metal"
"Metal","FLASH_ATTN_EXT","hsk=40,hsv=40,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","1","yes","Metal" "Metal","FLASH_ATTN_EXT","hsk=40,hsv=40,nh=4,nr23=[1,1],kv=113,nb=1,mask=1,sinks=1,max_bias=0.000000,logit_softcap=0.000000,prec=f32,type_KV=bf16,permute=[0,1,2,3]","support","1","yes","Metal"

Can't render this file because it is too large.

File diff suppressed because it is too large Load Diff

View File

@ -15,7 +15,6 @@ llama_add_compile_flags()
if (EMSCRIPTEN) if (EMSCRIPTEN)
else() else()
add_subdirectory(batched) add_subdirectory(batched)
add_subdirectory(debug)
add_subdirectory(embedding) add_subdirectory(embedding)
add_subdirectory(eval-callback) add_subdirectory(eval-callback)
@ -35,6 +34,7 @@ else()
add_subdirectory(gen-docs) add_subdirectory(gen-docs)
add_subdirectory(training) add_subdirectory(training)
add_subdirectory(diffusion) add_subdirectory(diffusion)
add_subdirectory(model-conversion)
if (NOT GGML_BACKEND_DL) if (NOT GGML_BACKEND_DL)
add_subdirectory(convert-llama2c-to-ggml) add_subdirectory(convert-llama2c-to-ggml)
# these examples use the backends directly and cannot be built with dynamic loading # these examples use the backends directly and cannot be built with dynamic loading

View File

@ -68,7 +68,7 @@ int main(int argc, char ** argv) {
auto sparams = llama_sampler_chain_default_params(); auto sparams = llama_sampler_chain_default_params();
sparams.no_perf = false; sparams.no_perf = false;
std::vector<llama_sampler_seq_config> sampler_configs; std::vector<llama_sampler *> samplers;
for (int32_t i = 0; i < n_parallel; ++i) { for (int32_t i = 0; i < n_parallel; ++i) {
llama_sampler * smpl = llama_sampler_chain_init(sparams); llama_sampler * smpl = llama_sampler_chain_init(sparams);
@ -78,13 +78,7 @@ int main(int argc, char ** argv) {
llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp)); llama_sampler_chain_add(smpl, llama_sampler_init_temp (params.sampling.temp));
llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed)); llama_sampler_chain_add(smpl, llama_sampler_init_dist (params.sampling.seed));
sampler_configs.push_back({ i, smpl }); samplers.push_back(smpl);
}
// TODO: temporarily gated behind a flag
if (params.sampling.backend_sampling) {
ctx_params.samplers = sampler_configs.data();
ctx_params.n_samplers = sampler_configs.size();
} }
llama_context * ctx = llama_init_from_model(model, ctx_params); llama_context * ctx = llama_init_from_model(model, ctx_params);
@ -186,7 +180,7 @@ int main(int argc, char ** argv) {
continue; continue;
} }
const llama_token new_token_id = llama_sampler_sample(sampler_configs[i].sampler, ctx, i_batch[i]); const llama_token new_token_id = llama_sampler_sample(samplers[i], ctx, i_batch[i]);
// is it an end of generation? -> mark the stream as finished // is it an end of generation? -> mark the stream as finished
if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_predict) { if (llama_vocab_is_eog(vocab, new_token_id) || n_cur == n_predict) {
@ -242,15 +236,15 @@ int main(int argc, char ** argv) {
__func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f)); __func__, n_decode, (t_main_end - t_main_start) / 1000000.0f, n_decode / ((t_main_end - t_main_start) / 1000000.0f));
LOG("\n"); LOG("\n");
llama_perf_sampler_print(sampler_configs[0].sampler); llama_perf_sampler_print(samplers[0]);
llama_perf_context_print(ctx); llama_perf_context_print(ctx);
fprintf(stderr, "\n"); fprintf(stderr, "\n");
llama_batch_free(batch); llama_batch_free(batch);
for (auto & sampler_config : sampler_configs) { for (auto & sampler_config : samplers) {
llama_sampler_free(sampler_config.sampler); llama_sampler_free(sampler_config);
} }
llama_free(ctx); llama_free(ctx);

View File

@ -1,54 +0,0 @@
# llama.cpp/examples/debug
This is a utility intended to help debug a model by registering a callback that
logs GGML operations and tensor data. It can also store the generated logits or
embeddings as well as the prompt and token ids for comparision with the original
model.
### Usage
```shell
llama-debug \
--hf-repo ggml-org/models \
--hf-file phi-2/ggml-model-q4_0.gguf \
--model phi-2-q4_0.gguf \
--prompt hello \
--save-logits \
--verbose
```
The tensor data is logged as debug and required the --verbose flag. The reason
for this is that while useful for a model with many layers there can be a lot of
output. You can filter the tensor names using the `--tensor-filter` option.
A recommended approach is to first run without `--verbose` and see if the
generated logits/embeddings are close to the original model. If they are not,
then it might be required to inspect tensor by tensor and in that case it is
useful to enable the `--verbose` flag along with `--tensor-filter` to focus on
specific tensors.
### Options
This example supports all standard `llama.cpp` options and also accepts the
following options:
```console
$ llama-debug --help
...
----- example-specific params -----
--save-logits save final logits to files for verification (default: false)
--logits-output-dir PATH directory for saving logits output files (default: data)
--tensor-filter REGEX filter tensor names for debug output (regex pattern, can be specified multiple times)
```
### Output Files
When `--save-logits` is enabled, the following files are created in the output
directory:
* `llamacpp-<model>[-embeddings].bin` - Binary output (logits or embeddings)
* `llamacpp-<model>[-embeddings].txt` - Text output (logits or embeddings, one per line)
* `llamacpp-<model>[-embeddings]-prompt.txt` - Prompt text and token IDs
* `llamacpp-<model>[-embeddings]-tokens.bin` - Binary token IDs for programmatic comparison
These files can be compared against the original model's output to verify the
converted model.

View File

@ -1,421 +0,0 @@
#include "arg.h"
#include "common.h"
#include "log.h"
#include "llama.h"
#include "ggml.h"
#include <cmath>
#include <cstdint>
#include <cstdlib>
#include <string>
#include <vector>
#include <filesystem>
#include <fstream>
#include <regex>
static void print_usage(int, char ** argv) {
const std::string usage_template = R"(
example usage:
Print tensors:
{prog} -m model.gguf -p "Hello my name is" --verbose
The tensors to be printed can be filtered with --tensor-filter option.
Save logits/embeddings:
{prog} -m model.gguf -p "Hello my name is" --save-logits
Add --embedding to save embeddings)" "\n";
// Fix the source code indentation above that is introduced by the raw string literal.
std::string usage = std::regex_replace(usage_template, std::regex("\\n {8}"), "\n");
usage = std::regex_replace(usage, std::regex("\\{prog\\}"), argv[0]);
LOG("%s\n", usage.c_str());
}
static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data);
struct callback_data {
std::vector<uint8_t> data;
std::vector<std::regex> tensor_filters;
callback_data() = default;
callback_data(common_params & params, const std::vector<std::string> & filter_patterns) {
for (const auto & pattern : filter_patterns) {
try {
std::string anchored_pattern = "^" + pattern;
tensor_filters.emplace_back(anchored_pattern, std::regex::optimize);
} catch (const std::regex_error & e) {
throw std::runtime_error("Invalid regex pattern '" + pattern + "': " + e.what());
}
}
params.cb_eval = ggml_debug;
params.cb_eval_user_data = this;
}
};
struct output_data {
float * data_ptr = nullptr;
int data_size = 0;
std::string type_suffix;
std::vector<float> storage;
std::string prompt;
std::vector<llama_token> tokens;
output_data(llama_context * ctx, const llama_model * model, const common_params & params) {
const llama_vocab * vocab = llama_model_get_vocab(model);
const bool add_bos = llama_vocab_get_add_bos(vocab);
tokens = common_tokenize(ctx, params.prompt, add_bos);
prompt = params.prompt;
if (params.embedding) {
const int n_embd = llama_model_n_embd_out(model);
const bool pooling_enabled = llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE;
const int n_embd_count = pooling_enabled ? 1 : tokens.size();
const int n_embeddings = n_embd * n_embd_count;
float * embeddings;
if (pooling_enabled) {
embeddings = llama_get_embeddings_seq(ctx, 0);
storage.resize(n_embeddings);
common_embd_normalize(embeddings, storage.data(), n_embeddings, params.embd_normalize);
embeddings = storage.data();
} else {
embeddings = llama_get_embeddings(ctx);
}
data_ptr = embeddings;
data_size = n_embeddings;
type_suffix = "-embeddings";
} else {
const float * logits = llama_get_logits_ith(ctx, tokens.size() - 1);
const int n_logits = llama_vocab_n_tokens(vocab);
data_ptr = const_cast<float*>(logits);
data_size = n_logits;
type_suffix = "";
}
}
};
static std::string ggml_ne_string(const ggml_tensor * t) {
std::string str;
for (int i = 0; i < GGML_MAX_DIMS; ++i) {
str += std::to_string(t->ne[i]);
if (i + 1 < GGML_MAX_DIMS) {
str += ", ";
}
}
return str;
}
static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
union {
float f;
uint32_t i;
} u;
u.i = (uint32_t)h.bits << 16;
return u.f;
}
static float ggml_get_float_value(const uint8_t * data, ggml_type type,
const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) {
size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
switch (type) {
case GGML_TYPE_F16:
return ggml_fp16_to_fp32(*(const ggml_fp16_t *) &data[i]);
case GGML_TYPE_F32:
return *(const float *) &data[i];
case GGML_TYPE_I64:
return (float) *(const int64_t *) &data[i];
case GGML_TYPE_I32:
return (float) *(const int32_t *) &data[i];
case GGML_TYPE_I16:
return (float) *(const int16_t *) &data[i];
case GGML_TYPE_I8:
return (float) *(const int8_t *) &data[i];
case GGML_TYPE_BF16:
return ggml_compute_bf16_to_fp32(*(const ggml_bf16_t *) &data[i]);
default:
GGML_ABORT("fatal error");
}
}
static void ggml_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) {
GGML_ASSERT(n > 0);
float sum = 0;
float sum_sq = 0.0;
for (int64_t i3 = 0; i3 < ne[3]; i3++) {
for (int64_t i2 = 0; i2 < ne[2]; i2++) {
for (int64_t i1 = 0; i1 < ne[1]; i1++) {
for (int64_t i0 = 0; i0 < ne[0]; i0++) {
const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3);
sum += v;
sum_sq += v * v;
}
}
}
}
for (int64_t i3 = 0; i3 < ne[3]; i3++) {
LOG_DBG(" [\n");
for (int64_t i2 = 0; i2 < ne[2]; i2++) {
if (i2 == n && ne[2] > 2*n) {
LOG_DBG(" ..., \n");
i2 = ne[2] - n;
}
LOG_DBG(" [\n");
for (int64_t i1 = 0; i1 < ne[1]; i1++) {
if (i1 == n && ne[1] > 2*n) {
LOG_DBG(" ..., \n");
i1 = ne[1] - n;
}
LOG_DBG(" [");
for (int64_t i0 = 0; i0 < ne[0]; i0++) {
if (i0 == n && ne[0] > 2*n) {
LOG_DBG("..., ");
i0 = ne[0] - n;
}
const float v = ggml_get_float_value(data, type, nb, i0, i1, i2, i3);
LOG_DBG("%12.4f", v);
if (i0 < ne[0] - 1) {
LOG_DBG(", ");
}
}
LOG_DBG("],\n");
}
LOG_DBG(" ],\n");
}
LOG_DBG(" ]\n");
LOG_DBG(" sum = %f\n", sum);
LOG_DBG(" sum_sq = %f\n", sum_sq);
}
if (std::isnan(sum)) {
LOG_ERR("encountered NaN - aborting\n");
exit(0);
}
}
/**
* GGML operations callback during the graph execution.
*
* @param t current tensor
* @param ask when ask is true, the scheduler wants to know if we are interested in data from this tensor
* if we return true, a follow-up call will be made with ask=false in which we can do the actual collection.
* see ggml_backend_sched_eval_callback
* @param user_data user data to pass at each call back
* @return true to receive data or continue the graph, false otherwise
*/
static bool ggml_debug(struct ggml_tensor * t, bool ask, void * user_data) {
auto * cb_data = (callback_data *) user_data;
const struct ggml_tensor * src0 = t->src[0];
const struct ggml_tensor * src1 = t->src[1];
if (ask) {
return true; // Always retrieve data
}
bool matches_filter = cb_data->tensor_filters.empty();
if (!matches_filter) {
for (const auto & filter : cb_data->tensor_filters) {
if (std::regex_search(t->name, filter)) {
matches_filter = true;
break;
}
}
}
char src1_str[128] = {0};
if (src1) {
snprintf(src1_str, sizeof(src1_str), "%s{%s}", src1->name, ggml_ne_string(src1).c_str());
}
if (matches_filter) {
LOG_DBG("%s: %24s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__,
t->name,
ggml_type_name(t->type),
ggml_op_desc(t),
src0->name,
ggml_ne_string(src0).c_str(),
src1 ? src1_str : "",
ggml_ne_string(t).c_str());
}
const bool is_host = ggml_backend_buffer_is_host(t->buffer);
if (!is_host) {
auto n_bytes = ggml_nbytes(t);
cb_data->data.resize(n_bytes);
ggml_backend_tensor_get(t, cb_data->data.data(), 0, n_bytes);
}
if (!ggml_is_quantized(t->type) && matches_filter) {
uint8_t * data = is_host ? (uint8_t *) t->data : cb_data->data.data();
ggml_print_tensor(data, t->type, t->ne, t->nb, 3);
}
return true;
}
static void save_output_data(const output_data & output, const std::string & model_name, const std::string & output_dir) {
std::filesystem::create_directory(output_dir);
auto base_path = std::filesystem::path{output_dir} / ("llamacpp-" + model_name + output.type_suffix);
// Save logits/embeddings to binary file.
{
std::filesystem::path filepath{base_path.string() + ".bin"};
std::ofstream file{filepath, std::ios::binary};
if (!file) {
throw std::runtime_error("failed to open binary output file: " + filepath.string());
}
file.write(reinterpret_cast<const char*>(output.data_ptr), output.data_size * sizeof(float));
LOG("Data saved to %s\n", filepath.c_str());
}
// Save logits/embeddings to text file.
{
std::filesystem::path filepath{base_path.string() + ".txt"};
std::ofstream file{filepath};
if (!file) {
throw std::runtime_error("failed to open text output file: " + filepath.string());
}
for (int i = 0; i < output.data_size; i++) {
file << i << ": " << output.data_ptr[i] << '\n';
}
LOG("Data saved to %s\n", filepath.c_str());
}
// Save prompt and tokens to text file.
{
std::filesystem::path filepath{base_path.string() + "-prompt.txt"};
std::ofstream file{filepath};
if (!file) {
throw std::runtime_error("failed to open prompt output file: " + filepath.string());
}
file << "prompt: " << output.prompt << '\n';
file << "n_tokens: " << output.tokens.size() << '\n';
file << "token ids: ";
for (size_t i = 0; i < output.tokens.size(); i++) {
file << output.tokens[i];
if (i + 1 < output.tokens.size()) {
file << ", ";
}
}
file << '\n';
LOG("Prompt saved to %s\n", filepath.c_str());
}
// Save token ids to binary file.
{
std::filesystem::path filepath{base_path.string() + "-tokens.bin"};
std::ofstream file{filepath, std::ios::binary};
if (!file) {
throw std::runtime_error("failed to open tokens binary file: " + filepath.string());
}
file.write(reinterpret_cast<const char*>(output.tokens.data()), output.tokens.size() * sizeof(llama_token));
LOG("Tokens saved to %s\n", filepath.c_str());
}
}
static void print_tokenized_prompt(llama_context * ctx, const std::vector<llama_token> & tokens, const std::string & prompt) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
LOG("Model add_bos: %s\n", llama_vocab_get_add_bos(vocab) ? "true" : "false");
LOG("Input prompt: \"%s\"\n", prompt.c_str());
LOG("Token ids (%zu):\n", tokens.size());
for (auto id : tokens) {
std::string piece(128, '\0');
int n = llama_token_to_piece(vocab, id, piece.data(), piece.size(), 0, true);
if (n < 0) {
LOG_ERR("failed to convert token %d to piece\n", id);
continue;
}
piece.resize(n);
LOG("%s(%d) ", piece.c_str(), id);
}
LOG("\n");
}
static bool run(llama_context * ctx, const common_params & params) {
const llama_model * model = llama_get_model(ctx);
const llama_vocab * vocab = llama_model_get_vocab(model);
const bool add_bos = llama_vocab_get_add_bos(vocab);
std::vector<llama_token> tokens = common_tokenize(ctx, params.prompt, add_bos);
if (tokens.empty()) {
LOG_ERR("%s : there are not input tokens to process - (try to provide a prompt with '-p')\n", __func__);
return false;
}
if (llama_decode(ctx, llama_batch_get_one(tokens.data(), tokens.size()))) {
LOG_ERR("%s : failed to eval\n", __func__);
return false;
}
print_tokenized_prompt(ctx, tokens, params.prompt);
if (params.save_logits) {
output_data output {ctx, model, params};
std::filesystem::path model_path{params.model.path};
std::string model_name{model_path.stem().string()};
save_output_data(output, model_name, params.logits_output_dir);
}
return true;
}
int main(int argc, char ** argv) {
common_params params;
if (!common_params_parse(argc, argv, params, LLAMA_EXAMPLE_DEBUG, print_usage)) {
return 1;
}
common_init();
llama_backend_init();
llama_numa_init(params.numa);
callback_data cb_data(params, params.tensor_filter);
auto llama_init = common_init_from_params(params);
auto * model = llama_init->model();
auto * ctx = llama_init->context();
if (model == nullptr || ctx == nullptr) {
LOG_ERR("%s : failed to init\n", __func__);
return 1;
}
{
LOG_INF("\n");
LOG_INF("%s\n", common_params_get_system_info(params).c_str());
LOG_INF("\n");
}
if (!run(ctx, params)) {
return 1;
}
LOG("\n");
llama_perf_context_print(ctx);
llama_backend_free();
return 0;
}

View File

@ -33,7 +33,7 @@ static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & toke
} }
} }
static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd_out, int embd_norm) { static void batch_decode(llama_context * ctx, llama_batch & batch, float * output, int n_seq, int n_embd, int embd_norm) {
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
// clear previous kv_cache values (irrelevant for embeddings) // clear previous kv_cache values (irrelevant for embeddings)
@ -65,8 +65,8 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");
} }
float * out = output + embd_pos * n_embd_out; float * out = output + embd_pos * n_embd;
common_embd_normalize(embd, out, n_embd_out, embd_norm); common_embd_normalize(embd, out, n_embd, embd_norm);
} }
} }
@ -252,8 +252,8 @@ int main(int argc, char ** argv) {
} }
// allocate output // allocate output
const int n_embd_out = llama_model_n_embd_out(model); const int n_embd = llama_model_n_embd(model);
std::vector<float> embeddings(n_embd_count * n_embd_out, 0); std::vector<float> embeddings(n_embd_count * n_embd, 0);
float * emb = embeddings.data(); float * emb = embeddings.data();
// break into batches // break into batches
@ -267,8 +267,8 @@ int main(int argc, char ** argv) {
// encode if at capacity // encode if at capacity
if (batch.n_tokens + n_toks > n_batch || s >= n_seq_max) { if (batch.n_tokens + n_toks > n_batch || s >= n_seq_max) {
float * out = emb + e * n_embd_out; float * out = emb + e * n_embd;
batch_decode(ctx, batch, out, s, n_embd_out, params.embd_normalize); batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s; e += pooling_type == LLAMA_POOLING_TYPE_NONE ? batch.n_tokens : s;
s = 0; s = 0;
common_batch_clear(batch); common_batch_clear(batch);
@ -280,8 +280,8 @@ int main(int argc, char ** argv) {
} }
// final batch // final batch
float * out = emb + e * n_embd_out; float * out = emb + e * n_embd;
batch_decode(ctx, batch, out, s, n_embd_out, params.embd_normalize); batch_decode(ctx, batch, out, s, n_embd, params.embd_normalize);
if (params.embd_out.empty()) { if (params.embd_out.empty()) {
LOG("\n"); LOG("\n");
@ -289,19 +289,19 @@ int main(int argc, char ** argv) {
if (pooling_type == LLAMA_POOLING_TYPE_NONE) { if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
for (int j = 0; j < n_embd_count; j++) { for (int j = 0; j < n_embd_count; j++) {
LOG("embedding %d: ", j); LOG("embedding %d: ", j);
for (int i = 0; i < std::min(3, n_embd_out); i++) { for (int i = 0; i < std::min(3, n_embd); i++) {
if (params.embd_normalize == 0) { if (params.embd_normalize == 0) {
LOG("%6.0f ", emb[j * n_embd_out + i]); LOG("%6.0f ", emb[j * n_embd + i]);
} else { } else {
LOG("%9.6f ", emb[j * n_embd_out + i]); LOG("%9.6f ", emb[j * n_embd + i]);
} }
} }
LOG(" ... "); LOG(" ... ");
for (int i = n_embd_out - 3; i < n_embd_out; i++) { for (int i = n_embd - 3; i < n_embd; i++) {
if (params.embd_normalize == 0) { if (params.embd_normalize == 0) {
LOG("%6.0f ", emb[j * n_embd_out + i]); LOG("%6.0f ", emb[j * n_embd + i]);
} else { } else {
LOG("%9.6f ", emb[j * n_embd_out + i]); LOG("%9.6f ", emb[j * n_embd + i]);
} }
} }
LOG("\n"); LOG("\n");
@ -320,9 +320,9 @@ int main(int argc, char ** argv) {
for (uint32_t i = 0; i < n_cls_out; i++) { for (uint32_t i = 0; i < n_cls_out; i++) {
// NOTE: if you change this log - update the tests in ci/run.sh // NOTE: if you change this log - update the tests in ci/run.sh
if (n_cls_out == 1) { if (n_cls_out == 1) {
LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd_out]); LOG("rerank score %d: %8.3f\n", j, emb[j * n_embd]);
} else { } else {
LOG("rerank score %d: %8.3f [%s]\n", j, emb[j * n_embd_out + i], cls_out_labels[i].c_str()); LOG("rerank score %d: %8.3f [%s]\n", j, emb[j * n_embd + i], cls_out_labels[i].c_str());
} }
} }
} }
@ -330,11 +330,11 @@ int main(int argc, char ** argv) {
// print the first part of the embeddings or for a single prompt, the full embedding // print the first part of the embeddings or for a single prompt, the full embedding
for (int j = 0; j < n_prompts; j++) { for (int j = 0; j < n_prompts; j++) {
LOG("embedding %d: ", j); LOG("embedding %d: ", j);
for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd_out) : n_embd_out); i++) { for (int i = 0; i < (n_prompts > 1 ? std::min(16, n_embd) : n_embd); i++) {
if (params.embd_normalize == 0) { if (params.embd_normalize == 0) {
LOG("%6.0f ", emb[j * n_embd_out + i]); LOG("%6.0f ", emb[j * n_embd + i]);
} else { } else {
LOG("%9.6f ", emb[j * n_embd_out + i]); LOG("%9.6f ", emb[j * n_embd + i]);
} }
} }
LOG("\n"); LOG("\n");
@ -350,7 +350,7 @@ int main(int argc, char ** argv) {
LOG("\n"); LOG("\n");
for (int i = 0; i < n_prompts; i++) { for (int i = 0; i < n_prompts; i++) {
for (int j = 0; j < n_prompts; j++) { for (int j = 0; j < n_prompts; j++) {
float sim = common_embd_similarity_cos(emb + i * n_embd_out, emb + j * n_embd_out, n_embd_out); float sim = common_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
LOG("%6.2f ", sim); LOG("%6.2f ", sim);
} }
LOG("%1.10s", prompts[i].c_str()); LOG("%1.10s", prompts[i].c_str());
@ -368,9 +368,9 @@ int main(int argc, char ** argv) {
if (notArray) LOG(" {\n \"object\": \"embedding\",\n \"index\": %d,\n \"embedding\": ",j); if (notArray) LOG(" {\n \"object\": \"embedding\",\n \"index\": %d,\n \"embedding\": ",j);
LOG("["); LOG("[");
for (int i = 0;;) { // at least one iteration (n_embd > 0) for (int i = 0;;) { // at least one iteration (n_embd > 0)
LOG(params.embd_normalize == 0 ? "%1.0f" : "%1.7f", emb[j * n_embd_out + i]); LOG(params.embd_normalize == 0 ? "%1.0f" : "%1.7f", emb[j * n_embd + i]);
i++; i++;
if (i < n_embd_out) LOG(","); else break; if (i < n_embd) LOG(","); else break;
} }
LOG(notArray ? "]\n }" : "]"); LOG(notArray ? "]\n }" : "]");
j++; j++;
@ -383,7 +383,7 @@ int main(int argc, char ** argv) {
for (int i = 0;;) { // at least two iteration (n_embd_count > 1) for (int i = 0;;) { // at least two iteration (n_embd_count > 1)
LOG(" ["); LOG(" [");
for (int j = 0;;) { // at least two iteration (n_embd_count > 1) for (int j = 0;;) { // at least two iteration (n_embd_count > 1)
float sim = common_embd_similarity_cos(emb + i * n_embd_out, emb + j * n_embd_out, n_embd_out); float sim = common_embd_similarity_cos(emb + i * n_embd, emb + j * n_embd, n_embd);
LOG("%6.2f", sim); LOG("%6.2f", sim);
j++; j++;
if (j < n_embd_count) LOG(", "); else break; if (j < n_embd_count) LOG(", "); else break;
@ -397,7 +397,7 @@ int main(int argc, char ** argv) {
if (notArray) LOG("\n}\n"); if (notArray) LOG("\n}\n");
} else if (params.embd_out == "raw") { } else if (params.embd_out == "raw") {
print_raw_embeddings(emb, n_embd_count, n_embd_out, model, pooling_type, params.embd_normalize); print_raw_embeddings(emb, n_embd_count, n_embd, model, pooling_type, params.embd_normalize);
} }
LOG("\n"); LOG("\n");

View File

@ -1,5 +1,5 @@
set(TARGET llama-debug) set(TARGET llama-logits)
add_executable(${TARGET} debug.cpp) add_executable(${TARGET} logits.cpp)
install(TARGETS ${TARGET} RUNTIME) install(TARGETS ${TARGET} RUNTIME)
target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT}) target_link_libraries(${TARGET} PRIVATE common llama ${CMAKE_THREAD_LIBS_INIT})
target_compile_features(${TARGET} PRIVATE cxx_std_17) target_compile_features(${TARGET} PRIVATE cxx_std_17)

View File

@ -0,0 +1,268 @@
#include "llama.h"
#include "common.h"
#include <cstdio>
#include <cstring>
#include <string>
#include <vector>
#include <ctype.h>
#include <filesystem>
static void print_usage(int, char ** argv) {
printf("\nexample usage:\n");
printf("\n %s -m model.gguf [-ngl n_gpu_layers] -embd-mode [-pooling] [-embd-norm <norm>] [prompt]\n", argv[0]);
printf("\n");
printf(" -embd-norm: normalization type for pooled embeddings (default: 2)\n");
printf(" -1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm\n");
printf("\n");
}
int main(int argc, char ** argv) {
std::string model_path;
std::string prompt = "Hello, my name is";
int ngl = 0;
bool embedding_mode = false;
bool pooling_enabled = false;
int32_t embd_norm = 2; // (-1=none, 0=max absolute int16, 1=taxicab, 2=Euclidean/L2, >2=p-norm)
{
int i = 1;
for (; i < argc; i++) {
if (strcmp(argv[i], "-m") == 0) {
if (i + 1 < argc) {
model_path = argv[++i];
} else {
print_usage(argc, argv);
return 1;
}
} else if (strcmp(argv[i], "-ngl") == 0) {
if (i + 1 < argc) {
try {
ngl = std::stoi(argv[++i]);
} catch (...) {
print_usage(argc, argv);
return 1;
}
} else {
print_usage(argc, argv);
return 1;
}
} else if (strcmp(argv[i], "-embd-mode") == 0) {
embedding_mode = true;
} else if (strcmp(argv[i], "-pooling") == 0) {
pooling_enabled = true;
} else if (strcmp(argv[i], "-embd-norm") == 0) {
if (i + 1 < argc) {
try {
embd_norm = std::stoi(argv[++i]);
} catch (...) {
print_usage(argc, argv);
return 1;
}
} else {
print_usage(argc, argv);
return 1;
}
} else {
// prompt starts here
break;
}
}
if (model_path.empty()) {
print_usage(argc, argv);
return 1;
}
if (i < argc) {
prompt = argv[i++];
for (; i < argc; i++) {
prompt += " ";
prompt += argv[i];
}
}
}
ggml_backend_load_all();
llama_model_params model_params = llama_model_default_params();
model_params.n_gpu_layers = ngl;
llama_model * model = llama_model_load_from_file(model_path.c_str(), model_params);
if (model == NULL) {
fprintf(stderr , "%s: error: unable to load model\n" , __func__);
return 1;
}
// Extract basename from model_path
const char * basename = strrchr(model_path.c_str(), '/');
basename = (basename == NULL) ? model_path.c_str() : basename + 1;
char model_name[256];
strncpy(model_name, basename, 255);
model_name[255] = '\0';
char * dot = strrchr(model_name, '.');
if (dot != NULL && strcmp(dot, ".gguf") == 0) {
*dot = '\0';
}
printf("Model name: %s\n", model_name);
const llama_vocab * vocab = llama_model_get_vocab(model);
const int n_prompt = -llama_tokenize(vocab, prompt.c_str(), prompt.size(), NULL, 0, true, true);
std::vector<llama_token> prompt_tokens(n_prompt);
if (llama_tokenize(vocab, prompt.c_str(), prompt.size(), prompt_tokens.data(), prompt_tokens.size(), true, true) < 0) {
fprintf(stderr, "%s: error: failed to tokenize the prompt\n", __func__);
return 1;
}
llama_context_params ctx_params = llama_context_default_params();
ctx_params.n_ctx = n_prompt;
ctx_params.n_batch = n_prompt;
ctx_params.no_perf = false;
if (embedding_mode) {
ctx_params.embeddings = true;
ctx_params.pooling_type = pooling_enabled ? LLAMA_POOLING_TYPE_MEAN : LLAMA_POOLING_TYPE_NONE;
ctx_params.n_ubatch = ctx_params.n_batch;
}
llama_context * ctx = llama_init_from_model(model, ctx_params);
if (ctx == NULL) {
fprintf(stderr , "%s: error: failed to create the llama_context\n" , __func__);
return 1;
}
printf("Input prompt: \"%s\"\n", prompt.c_str());
printf("Tokenized prompt (%d tokens): ", n_prompt);
for (auto id : prompt_tokens) {
char buf[128];
int n = llama_token_to_piece(vocab, id, buf, sizeof(buf), 0, true);
if (n < 0) {
fprintf(stderr, "%s: error: failed to convert token to piece\n", __func__);
return 1;
}
std::string s(buf, n);
printf("%s (%d)", s.c_str(), id);
}
printf("\n");
llama_batch batch = llama_batch_get_one(prompt_tokens.data(), prompt_tokens.size());
if (llama_decode(ctx, batch)) {
fprintf(stderr, "%s : failed to eval\n", __func__);
return 1;
}
float * data_ptr;
int data_size;
const char * type;
std::vector<float> embd_out;
if (embedding_mode) {
const int n_embd = llama_model_n_embd(model);
const int n_embd_count = pooling_enabled ? 1 : batch.n_tokens;
const int n_embeddings = n_embd * n_embd_count;
float * embeddings;
type = "-embeddings";
if (llama_pooling_type(ctx) != LLAMA_POOLING_TYPE_NONE) {
embeddings = llama_get_embeddings_seq(ctx, 0);
embd_out.resize(n_embeddings);
printf("Normalizing embeddings using norm: %d\n", embd_norm);
common_embd_normalize(embeddings, embd_out.data(), n_embeddings, embd_norm);
embeddings = embd_out.data();
} else {
embeddings = llama_get_embeddings(ctx);
}
printf("Embedding dimension: %d\n", n_embd);
printf("\n");
// Print embeddings in the specified format
for (int j = 0; j < n_embd_count; j++) {
printf("embedding %d: ", j);
// Print first 3 values
for (int i = 0; i < 3 && i < n_embd; i++) {
printf("%9.6f ", embeddings[j * n_embd + i]);
}
printf(" ... ");
// Print last 3 values
for (int i = n_embd - 3; i < n_embd; i++) {
if (i >= 0) {
printf("%9.6f ", embeddings[j * n_embd + i]);
}
}
printf("\n");
}
printf("\n");
printf("Embeddings size: %d\n", n_embeddings);
data_ptr = embeddings;
data_size = n_embeddings;
} else {
float * logits = llama_get_logits_ith(ctx, batch.n_tokens - 1);
const int n_logits = llama_vocab_n_tokens(vocab);
type = "";
printf("Vocab size: %d\n", n_logits);
data_ptr = logits;
data_size = n_logits;
}
std::filesystem::create_directory("data");
// Save data to binary file
char bin_filename[512];
snprintf(bin_filename, sizeof(bin_filename), "data/llamacpp-%s%s.bin", model_name, type);
printf("Saving data to %s\n", bin_filename);
FILE * f = fopen(bin_filename, "wb");
if (f == NULL) {
fprintf(stderr, "%s: error: failed to open binary output file\n", __func__);
return 1;
}
fwrite(data_ptr, sizeof(float), data_size, f);
fclose(f);
// Also save as text for debugging
char txt_filename[512];
snprintf(txt_filename, sizeof(txt_filename), "data/llamacpp-%s%s.txt", model_name, type);
f = fopen(txt_filename, "w");
if (f == NULL) {
fprintf(stderr, "%s: error: failed to open text output file\n", __func__);
return 1;
}
for (int i = 0; i < data_size; i++) {
fprintf(f, "%d: %.6f\n", i, data_ptr[i]);
}
fclose(f);
if (!embedding_mode) {
printf("First 10 logits: ");
for (int i = 0; i < 10 && i < data_size; i++) {
printf("%.6f ", data_ptr[i]);
}
printf("\n");
printf("Last 10 logits: ");
for (int i = data_size - 10; i < data_size; i++) {
if (i >= 0) printf("%.6f ", data_ptr[i]);
}
printf("\n\n");
}
printf("Data saved to %s\n", bin_filename);
printf("Data saved to %s\n", txt_filename);
llama_free(ctx);
llama_model_free(model);
return 0;
}

View File

@ -6,7 +6,7 @@ from pathlib import Path
# Add utils directory to path for direct script execution # Add utils directory to path for direct script execution
sys.path.insert(0, str(Path(__file__).parent.parent / "utils")) sys.path.insert(0, str(Path(__file__).parent.parent / "utils"))
from common import get_model_name_from_env_path, compare_tokens # type: ignore[import-not-found] from common import get_model_name_from_env_path # type: ignore[import-not-found]
def quick_logits_check(pytorch_file, llamacpp_file): def quick_logits_check(pytorch_file, llamacpp_file):
"""Lightweight sanity check before NMSE""" """Lightweight sanity check before NMSE"""
@ -58,13 +58,6 @@ def main():
print("Checked all required files were found. Proceeding...\n") print("Checked all required files were found. Proceeding...\n")
# Verify tokens as they are a prerequisite for logits comparison.
print("🔍 Token Comparison Check")
print("=" * 40)
if not compare_tokens(f"pytorch-{model_name}", f"llamacpp-{llamacpp_model_name}"):
print("\n❌ Token mismatch detected")
sys.exit(1)
print()
print("🔍 GGML Model Validation for model ", model_name) print("🔍 GGML Model Validation for model ", model_name)
print("=" * 40) print("=" * 40)

View File

@ -67,7 +67,7 @@ with torch.no_grad():
last_hidden_states = outputs.hidden_states[-1] last_hidden_states = outputs.hidden_states[-1]
# Get embeddings for all tokens # Get embeddings for all tokens
token_embeddings = last_hidden_states[0].float().cpu().numpy() # Remove batch dimension token_embeddings = last_hidden_states[0].cpu().numpy() # Remove batch dimension
print(f"Hidden states shape: {last_hidden_states.shape}") print(f"Hidden states shape: {last_hidden_states.shape}")
print(f"Token embeddings shape: {token_embeddings.shape}") print(f"Token embeddings shape: {token_embeddings.shape}")

View File

@ -13,6 +13,6 @@ if [ -z "$CONVERTED_MODEL" ]; then
exit 1 exit 1
fi fi
cmake --build ../../build --target llama-debug -j8 cmake --build ../../build --target llama-logits -j8
../../build/bin/llama-debug -m $CONVERTED_MODEL --embedding -p "Hello world today" --save-logits ../../build/bin/llama-logits -m $CONVERTED_MODEL -embd-mode "Hello world today"

View File

@ -21,6 +21,6 @@ fi
echo $CONVERTED_MODEL echo $CONVERTED_MODEL
echo $MODEL_TESTING_PROMPT echo $MODEL_TESTING_PROMPT
cmake --build ../../build --target llama-debug -j8 cmake --build ../../build --target llama-logits -j8
../../build/bin/llama-debug -m "$CONVERTED_MODEL" -p "$MODEL_TESTING_PROMPT" --save-logits ../../build/bin/llama-logits -m "$CONVERTED_MODEL" "$MODEL_TESTING_PROMPT"

View File

@ -7,11 +7,12 @@ import importlib
import torch import torch
import numpy as np import numpy as np
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, AutoConfig from transformers import AutoTokenizer, AutoModelForCausalLM, AutoModelForImageTextToText, AutoConfig
# Add parent directory to path for imports # Add parent directory to path for imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..')) sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from utils.common import debug_hook, save_output_data from utils.common import debug_hook
def parse_arguments(): def parse_arguments():
parser = argparse.ArgumentParser(description="Process model with specified path") parser = argparse.ArgumentParser(description="Process model with specified path")
@ -125,7 +126,6 @@ def main():
device = next(model.parameters()).device device = next(model.parameters()).device
prompt = get_prompt(args) prompt = get_prompt(args)
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device) input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
token_ids = input_ids[0].cpu().tolist()
print(f"Input tokens: {input_ids}") print(f"Input tokens: {input_ids}")
print(f"Input text: {repr(prompt)}") print(f"Input text: {repr(prompt)}")
@ -151,6 +151,19 @@ def main():
print(f"Last token logits shape: {last_logits.shape}") print(f"Last token logits shape: {last_logits.shape}")
print(f"Vocab size: {len(last_logits)}") print(f"Vocab size: {len(last_logits)}")
data_dir = Path("data")
data_dir.mkdir(exist_ok=True)
bin_filename = data_dir / f"pytorch-{model_name}.bin"
txt_filename = data_dir / f"pytorch-{model_name}.txt"
# Save to file for comparison
last_logits.astype(np.float32).tofile(bin_filename)
# Also save as text file for easy inspection
with open(txt_filename, "w") as f:
for i, logit in enumerate(last_logits):
f.write(f"{i}: {logit:.6f}\n")
# Print some sample logits for quick verification # Print some sample logits for quick verification
print(f"First 10 logits: {last_logits[:10]}") print(f"First 10 logits: {last_logits[:10]}")
print(f"Last 10 logits: {last_logits[-10:]}") print(f"Last 10 logits: {last_logits[-10:]}")
@ -162,7 +175,8 @@ def main():
token = tokenizer.decode([idx]) token = tokenizer.decode([idx])
print(f" Token {idx} ({repr(token)}): {last_logits[idx]:.6f}") print(f" Token {idx} ({repr(token)}): {last_logits[idx]:.6f}")
save_output_data(last_logits, token_ids, prompt, model_name) print(f"Saved bin logits to: {bin_filename}")
print(f"Saved txt logist to: {txt_filename}")
if __name__ == "__main__": if __name__ == "__main__":
main() main()

View File

@ -50,9 +50,10 @@ fi
echo $CONVERTED_MODEL echo $CONVERTED_MODEL
cmake --build ../../build --target llama-debug -j8 cmake --build ../../build --target llama-logits -j8
# TODO: update logits.cpp to accept a --file/-f option for the prompt
if [ -n "$USE_POOLING" ]; then if [ -n "$USE_POOLING" ]; then
../../build/bin/llama-debug -m "$CONVERTED_MODEL" --embedding --pooling mean -p "$PROMPT" --save-logits ../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode -pooling "$PROMPT"
else else
../../build/bin/llama-debug -m "$CONVERTED_MODEL" --embedding --pooling none -p "$PROMPT" --save-logits ../../build/bin/llama-logits -m "$CONVERTED_MODEL" -embd-mode "$PROMPT"
fi fi

View File

@ -3,15 +3,13 @@
import argparse import argparse
import os import os
import sys import sys
import numpy as np
import importlib import importlib
from pathlib import Path
from transformers import AutoTokenizer, AutoConfig, AutoModel from transformers import AutoTokenizer, AutoConfig, AutoModel
import torch import torch
# Add parent directory to path for imports
sys.path.insert(0, os.path.join(os.path.dirname(__file__), '..'))
from utils.common import save_output_data
def parse_arguments(): def parse_arguments():
parser = argparse.ArgumentParser(description='Run original embedding model') parser = argparse.ArgumentParser(description='Run original embedding model')
@ -171,7 +169,6 @@ def main():
return_tensors="pt" return_tensors="pt"
) )
tokens = encoded['input_ids'][0] tokens = encoded['input_ids'][0]
token_ids = tokens.cpu().tolist()
token_strings = tokenizer.convert_ids_to_tokens(tokens) token_strings = tokenizer.convert_ids_to_tokens(tokens)
for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)): for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)):
print(f"{token_id:6d} -> '{token_str}'") print(f"{token_id:6d} -> '{token_str}'")
@ -188,7 +185,6 @@ def main():
) )
tokens = encoded['input_ids'][0] tokens = encoded['input_ids'][0]
token_ids = tokens.cpu().tolist()
token_strings = tokenizer.convert_ids_to_tokens(tokens) token_strings = tokenizer.convert_ids_to_tokens(tokens)
for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)): for i, (token_id, token_str) in enumerate(zip(tokens, token_strings)):
print(f"{token_id:6d} -> '{token_str}'") print(f"{token_id:6d} -> '{token_str}'")
@ -232,11 +228,24 @@ def main():
print() print()
data_dir = Path("data")
data_dir.mkdir(exist_ok=True)
bin_filename = data_dir / f"pytorch-{model_name}-embeddings.bin"
txt_filename = data_dir / f"pytorch-{model_name}-embeddings.txt"
flattened_embeddings = all_embeddings.flatten() flattened_embeddings = all_embeddings.flatten()
flattened_embeddings.astype(np.float32).tofile(bin_filename)
with open(txt_filename, "w") as f:
idx = 0
for j in range(n_embd_count):
for value in all_embeddings[j]:
f.write(f"{idx}: {value:.6f}\n")
idx += 1
print(f"Total values: {len(flattened_embeddings)} ({n_embd_count} embeddings × {n_embd} dimensions)") print(f"Total values: {len(flattened_embeddings)} ({n_embd_count} embeddings × {n_embd} dimensions)")
print("") print("")
print(f"Saved bin embeddings to: {bin_filename}")
save_output_data(flattened_embeddings, token_ids, prompt_text, model_name, type_suffix="-embeddings") print(f"Saved txt embeddings to: {txt_filename}")
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -3,8 +3,6 @@
import os import os
import sys import sys
import torch import torch
import numpy as np
from pathlib import Path
def get_model_name_from_env_path(env_path_name): def get_model_name_from_env_path(env_path_name):
@ -150,96 +148,3 @@ def setup_rope_debug(model_module_path: str, function_name: str = "apply_rotary_
# Patch it # Patch it
setattr(module, function_name, debug_rope) setattr(module, function_name, debug_rope)
print(f"RoPE debug patching applied to {model_module_path}.{function_name}") print(f"RoPE debug patching applied to {model_module_path}.{function_name}")
def save_output_data(data, tokens, prompt, model_name, type_suffix="", output_dir="data"):
"""
Save output data (logits/embeddings), tokens, and prompt to files.
Args:
data: numpy array of floats (logits or embeddings)
tokens: list or array of token IDs
prompt: string containing the input prompt
model_name: name of the model
type_suffix: optional suffix like "-embeddings" (default: "")
output_dir: directory to save files (default: "data")
Creates the following files in output_dir:
- pytorch-{model_name}{type_suffix}.bin
- pytorch-{model_name}{type_suffix}.txt
- pytorch-{model_name}{type_suffix}-prompt.txt
- pytorch-{model_name}{type_suffix}-tokens.bin
"""
data_dir = Path(output_dir)
data_dir.mkdir(exist_ok=True)
base_path = data_dir / f"pytorch-{model_name}{type_suffix}"
# Convert and flatten logits/embeddings
data = data.cpu().numpy() if isinstance(data, torch.Tensor) else np.asarray(data)
data = data.flatten() if data.ndim > 1 else data
# Save logits/embedding files
data.astype(np.float32).tofile(f"{base_path}.bin")
print(f"Data saved to {base_path}.bin")
with open(f"{base_path}.txt", "w") as f:
f.writelines(f"{i}: {value:.6f}\n" for i, value in enumerate(data))
print(f"Data saved to {base_path}.txt")
# Convert and flatten tokens
tokens = tokens.cpu().numpy() if isinstance(tokens, torch.Tensor) else np.asarray(tokens)
tokens = tokens.flatten() if tokens.ndim > 1 else tokens
# Save token binary file
tokens.astype(np.int32).tofile(f"{base_path}-tokens.bin")
print(f"Tokens saved to {base_path}-tokens.bin")
# Save prompt file
with open(f"{base_path}-prompt.txt", "w") as f:
f.write(f"prompt: {prompt}\n")
f.write(f"n_tokens: {len(tokens)}\n")
f.write(f"token ids: {', '.join(str(int(tid)) for tid in tokens)}\n")
print(f"Prompt saved to {base_path}-prompt.txt")
def compare_tokens(original, converted, type_suffix="", output_dir="data"):
data_dir = Path(output_dir)
# Read tokens from both models
tokens1_file = data_dir / f"{original}{type_suffix}-tokens.bin"
tokens2_file = data_dir / f"{converted}{type_suffix}-tokens.bin"
if not tokens1_file.exists():
print(f"Error: Token file not found: {tokens1_file}")
return False
if not tokens2_file.exists():
print(f"Error: Token file not found: {tokens2_file}")
return False
tokens1 = np.fromfile(tokens1_file, dtype=np.int32)
tokens2 = np.fromfile(tokens2_file, dtype=np.int32)
print(f"\nComparing tokens between:")
print(f" Original : {original} ({len(tokens1)} tokens)")
print(f" Converted: {converted} ({len(tokens2)} tokens)")
if len(tokens1) != len(tokens2):
print(f"\n❌ Token count mismatch: {len(tokens1)} vs {len(tokens2)}")
return False
if np.array_equal(tokens1, tokens2):
print(f"\n✅ All {len(tokens1)} tokens match!")
return True
mismatches = np.where(tokens1 != tokens2)[0]
print(f"\n❌ Found {len(mismatches)} mismatched tokens:")
num_to_show = min(len(mismatches), 10)
for idx in mismatches[:num_to_show]:
print(f" Position {idx}: {tokens1[idx]} vs {tokens2[idx]}")
if len(mismatches) > num_to_show:
print(f" ... and {len(mismatches) - num_to_show} more mismatches")
return False

View File

@ -1,76 +0,0 @@
#!/usr/bin/env python3
import argparse
import sys
from common import compare_tokens # type: ignore
def parse_arguments():
parser = argparse.ArgumentParser(
description='Compare tokens between two models',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
%(prog)s pytorch-gemma-3-270m-it llamacpp-gemma-3-270m-it-bf16
"""
)
parser.add_argument(
'original',
help='Original model name'
)
parser.add_argument(
'converted',
help='Converted model name'
)
parser.add_argument(
'-s', '--suffix',
default='',
help='Type suffix (e.g., "-embeddings")'
)
parser.add_argument(
'-d', '--data-dir',
default='data',
help='Directory containing token files (default: data)'
)
parser.add_argument(
'-v', '--verbose',
action='store_true',
help='Print prompts from both models'
)
return parser.parse_args()
def main():
args = parse_arguments()
if args.verbose:
from pathlib import Path
data_dir = Path(args.data_dir)
prompt1_file = data_dir / f"{args.original}{args.suffix}-prompt.txt"
prompt2_file = data_dir / f"{args.converted}{args.suffix}-prompt.txt"
if prompt1_file.exists():
print(f"\nOriginal model prompt ({args.original}):")
print(f" {prompt1_file.read_text().strip()}")
if prompt2_file.exists():
print(f"\nConverted model prompt ({args.converted}):")
print(f" {prompt2_file.read_text().strip()}")
print()
result = compare_tokens(
args.original,
args.converted,
type_suffix=args.suffix,
output_dir=args.data_dir
)
# Enable the script to be used in shell scripts so that they can check
# the exit code for success/failure.
sys.exit(0 if result else 1)
if __name__ == "__main__":
main()

View File

@ -4,10 +4,8 @@ import numpy as np
import argparse import argparse
import os import os
import importlib import importlib
from pathlib import Path
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, AutoModel from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, AutoModel
from common import compare_tokens # type: ignore[import-not-found]
unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME') unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME')
@ -159,25 +157,9 @@ def main():
else: else:
prompt = args.prompt prompt = args.prompt
python_emb_path = Path(args.python_embeddings)
cpp_emb_path = Path(args.cpp_embeddings)
# Extract base names (e.g., "pytorch-model-name-embeddings.bin" -> "pytorch-model-name")
python_model_name = python_emb_path.stem.replace("-embeddings", "")
cpp_model_name = cpp_emb_path.stem.replace("-embeddings", "")
print("Semantic Similarity Test Between Python and llama.cpp Embedding Models") print("Semantic Similarity Test Between Python and llama.cpp Embedding Models")
print("=" * 70) print("=" * 70)
# First verify tokens match before comparing embeddings
print("\n🔍 Token Comparison Check")
print("=" * 70)
data_dir = python_emb_path.parent
if not compare_tokens(python_model_name, cpp_model_name, type_suffix="-embeddings", output_dir=str(data_dir)):
print("\n❌ Token mismatch detected")
exit(1)
print()
# Single prompt detailed comparison # Single prompt detailed comparison
print(f"\nTesting with prompt: '{prompt}'") print(f"\nTesting with prompt: '{prompt}'")

View File

@ -217,8 +217,8 @@ int main(int argc, char ** argv) {
struct llama_batch batch = llama_batch_init(n_batch, 0, 1); struct llama_batch batch = llama_batch_init(n_batch, 0, 1);
// allocate output // allocate output
const int n_embd_out = llama_model_n_embd_out(model); const int n_embd = llama_model_n_embd(model);
std::vector<float> embeddings(n_chunks * n_embd_out, 0); std::vector<float> embeddings(n_chunks * n_embd, 0);
float * emb = embeddings.data(); float * emb = embeddings.data();
// break into batches // break into batches
@ -232,8 +232,8 @@ int main(int argc, char ** argv) {
// encode if at capacity // encode if at capacity
if (batch.n_tokens + n_toks > n_batch || s >= llama_n_seq_max(ctx)) { if (batch.n_tokens + n_toks > n_batch || s >= llama_n_seq_max(ctx)) {
float * out = emb + p * n_embd_out; float * out = emb + p * n_embd;
batch_process(ctx, batch, out, s, n_embd_out); batch_process(ctx, batch, out, s, n_embd);
common_batch_clear(batch); common_batch_clear(batch);
p += s; p += s;
s = 0; s = 0;
@ -245,12 +245,12 @@ int main(int argc, char ** argv) {
} }
// final batch // final batch
float * out = emb + p * n_embd_out; float * out = emb + p * n_embd;
batch_process(ctx, batch, out, s, n_embd_out); batch_process(ctx, batch, out, s, n_embd);
// save embeddings to chunks // save embeddings to chunks
for (int i = 0; i < n_chunks; i++) { for (int i = 0; i < n_chunks; i++) {
chunks[i].embedding = std::vector<float>(emb + i * n_embd_out, emb + (i + 1) * n_embd_out); chunks[i].embedding = std::vector<float>(emb + i * n_embd, emb + (i + 1) * n_embd);
// clear tokens as they are no longer needed // clear tokens as they are no longer needed
chunks[i].tokens.clear(); chunks[i].tokens.clear();
} }
@ -266,8 +266,8 @@ int main(int argc, char ** argv) {
batch_add_seq(query_batch, query_tokens, 0); batch_add_seq(query_batch, query_tokens, 0);
std::vector<float> query_emb(n_embd_out, 0); std::vector<float> query_emb(n_embd, 0);
batch_process(ctx, query_batch, query_emb.data(), 1, n_embd_out); batch_process(ctx, query_batch, query_emb.data(), 1, n_embd);
common_batch_clear(query_batch); common_batch_clear(query_batch);
@ -275,7 +275,7 @@ int main(int argc, char ** argv) {
{ {
std::vector<std::pair<int, float>> similarities; std::vector<std::pair<int, float>> similarities;
for (int i = 0; i < n_chunks; i++) { for (int i = 0; i < n_chunks; i++) {
float sim = common_embd_similarity_cos(chunks[i].embedding.data(), query_emb.data(), n_embd_out); float sim = common_embd_similarity_cos(chunks[i].embedding.data(), query_emb.data(), n_embd);
similarities.push_back(std::make_pair(i, sim)); similarities.push_back(std::make_pair(i, sim));
} }

View File

@ -4,7 +4,7 @@ project("ggml" C CXX ASM)
### GGML Version ### GGML Version
set(GGML_VERSION_MAJOR 0) set(GGML_VERSION_MAJOR 0)
set(GGML_VERSION_MINOR 9) set(GGML_VERSION_MINOR 9)
set(GGML_VERSION_PATCH 5) set(GGML_VERSION_PATCH 4)
set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}") set(GGML_VERSION_BASE "${GGML_VERSION_MAJOR}.${GGML_VERSION_MINOR}.${GGML_VERSION_PATCH}")
find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH) find_program(GIT_EXE NAMES git git.exe NO_CMAKE_FIND_ROOT_PATH)

View File

@ -358,7 +358,7 @@ extern "C" {
typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data); typedef bool (*ggml_backend_eval_callback)(int node_index, struct ggml_tensor * t1, struct ggml_tensor * t2, void * user_data);
// Compare the output of two backends // Compare the output of two backends
GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor const * const * test_nodes, size_t num_test_nodes); GGML_API bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node);
// Tensor initialization // Tensor initialization
GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr); GGML_API enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr);

View File

@ -2053,7 +2053,7 @@ void ggml_backend_graph_copy_free(struct ggml_backend_graph_copy copy) {
ggml_free(copy.ctx_unallocated); ggml_free(copy.ctx_unallocated);
} }
bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor const * const * test_nodes, size_t num_test_nodes) { bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t backend2, struct ggml_cgraph * graph, ggml_backend_eval_callback callback, void * user_data, struct ggml_tensor * test_node) {
struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph); struct ggml_backend_graph_copy copy = ggml_backend_graph_copy(backend2, graph);
if (copy.buffer == NULL) { if (copy.buffer == NULL) {
return false; return false;
@ -2064,22 +2064,22 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t
assert(g1->n_nodes == g2->n_nodes); assert(g1->n_nodes == g2->n_nodes);
if (num_test_nodes != 0) { if (test_node != nullptr) {
GGML_ASSERT(test_nodes); // Compute the whole graph and only test the output for a specific tensor
// Compute the whole graph and only test the output for specific tensors
ggml_backend_graph_compute(backend1, g1); ggml_backend_graph_compute(backend1, g1);
ggml_backend_graph_compute(backend2, g2); ggml_backend_graph_compute(backend2, g2);
bool verified = false; int test_node_idx = -1;
for (int i = 0; i < g1->n_nodes; i++) { for (int i = 0; i < g1->n_nodes; i++) {
for (size_t j = 0; j < num_test_nodes; ++j) { struct ggml_tensor * t1 = g1->nodes[i];
if (g1->nodes[i] == test_nodes[j]) { if (t1 == test_node) {
callback(i, g1->nodes[i], g2->nodes[i], user_data); test_node_idx = i;
verified = true; break;
}
} }
} }
GGML_ASSERT(verified); GGML_ASSERT(test_node_idx != -1);
callback(test_node_idx, g1->nodes[test_node_idx], g2->nodes[test_node_idx], user_data);
} else { } else {
for (int i = 0; i < g1->n_nodes; i++) { for (int i = 0; i < g1->n_nodes; i++) {
struct ggml_tensor * t1 = g1->nodes[i]; struct ggml_tensor * t1 = g1->nodes[i];

View File

@ -26,7 +26,6 @@
#include "ggml.h" #include "ggml.h"
#include <aclnnop/aclnn_add.h> #include <aclnnop/aclnn_add.h>
#include <aclnnop/aclnn_add_rms_norm.h>
#include <aclnnop/aclnn_addcdiv.h> #include <aclnnop/aclnn_addcdiv.h>
#include <aclnnop/aclnn_argmax.h> #include <aclnnop/aclnn_argmax.h>
#include <aclnnop/aclnn_avgpool2d.h> #include <aclnnop/aclnn_avgpool2d.h>
@ -1963,7 +1962,7 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context & ctx, ggml_tensor *
acl_tensor_ptr acl_weight_tensor; acl_tensor_ptr acl_weight_tensor;
// Only check env once. // Only check env once.
static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on")); static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
if (weight_to_nz && is_matmul_weight(weight)) { if (weight_to_nz && is_matmul_weight(weight)) {
acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ); acl_weight_tensor = ggml_cann_create_tensor(weight, transpose_ne, transpose_nb, n_dims, ACL_FORMAT_FRACTAL_NZ);
} else { } else {
@ -3806,57 +3805,3 @@ void ggml_cann_ssm_conv(ggml_backend_cann_context & ctx, ggml_tensor * dst) {
cubeMathType); cubeMathType);
} }
void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx,
ggml_tensor * add_node,
ggml_tensor * rms_norm_node) {
// Get the two input tensors for ADD operation
ggml_tensor * x1 = add_node->src[0];
ggml_tensor * x2 = add_node->src[1];
// Create ACL tensors for the two ADD inputs
acl_tensor_ptr acl_x1 = ggml_cann_create_tensor(x1);
acl_tensor_ptr acl_x2 = ggml_cann_create_tensor(x2);
// Get epsilon parameter from rms_norm_tensor
float eps;
memcpy(&eps, rms_norm_node->op_params, sizeof(float));
// Build gamma tensor (RMS normalization scaling factor)
// Gamma should match the normalized dimensions (last dimension of x1)
size_t acl_gamma_nb[GGML_MAX_DIMS];
acl_gamma_nb[0] = ggml_type_size(rms_norm_node->type);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
acl_gamma_nb[i] = acl_gamma_nb[i - 1] * x1->ne[i - 1];
}
acl_tensor_ptr acl_gamma =
get_cache_acl_tensor(ctx, &ctx.rms_norm_one_tensor_cache.cache, ctx.rms_norm_one_tensor_cache.size, x1->ne,
acl_gamma_nb, rms_norm_node->type,
1, // dims - only the last dimension
1.0f // value
);
// Build rstdOut tensor (output for normalized standard deviation)
// Shape should be the dimensions that are NOT normalized
int64_t acl_rstd_ne[] = { 1, x1->ne[1], x1->ne[2], x1->ne[3] };
size_t acl_rstd_nb[GGML_MAX_DIMS - 1];
acl_rstd_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS - 1; i++) {
acl_rstd_nb[i] = acl_rstd_nb[i - 1] * acl_rstd_ne[i - 1];
}
acl_tensor_ptr acl_rstd =
get_cache_acl_tensor(ctx, &ctx.rms_norm_zero_tensor_cache.cache, ctx.rms_norm_zero_tensor_cache.size,
acl_rstd_ne, acl_rstd_nb, GGML_TYPE_F32, GGML_MAX_DIMS,
0.0f // value
);
acl_tensor_ptr acl_xout = ggml_cann_create_tensor(add_node);
// Create yOut tensor (final output after RMS normalization)
acl_tensor_ptr acl_yout = ggml_cann_create_tensor(rms_norm_node);
// Call fused ADD + RMS_NORM operator
GGML_CANN_CALL_ACLNN_OP(ctx, AddRmsNorm, acl_x1.get(), acl_x2.get(), acl_gamma.get(),
eps, // double type
acl_yout.get(), acl_rstd.get(), acl_xout.get());
}

View File

@ -935,20 +935,6 @@ template <typename... Args> void register_acl_resources(std::vector<any_acl_reso
*/ */
void ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst); void ggml_cann_mul_mat_id(ggml_backend_cann_context & ctx, ggml_tensor * dst);
/**
* @brief Performs fused ADD + RMS_NORM operation using the CANN backend.
*
* This function fuses the ADD and RMS_NORM operations into a single kernel call
* for better performance. It first adds two input tensors (x1 + x2), then applies
* RMS normalization to the result.
*
* @param ctx The context for the CANN backend operations.
* @param dst The ADD operation node, contains the two input tensors to be added.
* @param rms_norm_tensor The RMS_NORM operation node, contains the gamma weights
* and epsilon parameter.
*/
void ggml_cann_op_add_rms_norm_fused(ggml_backend_cann_context & ctx, ggml_tensor * add_node, ggml_tensor * rms_norm_node);
/** /**
* @brief Check whether a tensor is a weight tensor for matrix multiplication. * @brief Check whether a tensor is a weight tensor for matrix multiplication.
* *

View File

@ -103,7 +103,7 @@ const ggml_cann_device_info & ggml_cann_info();
void ggml_cann_set_device(int32_t device); void ggml_cann_set_device(int32_t device);
int32_t ggml_cann_get_device(); int32_t ggml_cann_get_device();
std::optional<std::string> get_env_as_lowercase(const std::string & name); std::optional<std::string> get_env(const std::string & name);
bool parse_bool(const std::string & value); bool parse_bool(const std::string & value);
int parse_integer(const std::string & value); int parse_integer(const std::string & value);

View File

@ -105,10 +105,10 @@ int32_t ggml_cann_get_device() {
} }
/** /**
* @brief Get the value of the specified environment variable (name) as lowercase. * @brief Get the value of the specified environment variable (name).
* if not empty, return a std::string object * if not empty, return a std::string object
*/ */
std::optional<std::string> get_env_as_lowercase(const std::string & name) { std::optional<std::string> get_env(const std::string & name) {
const char * val = std::getenv(name.c_str()); const char * val = std::getenv(name.c_str());
if (!val) { if (!val) {
return std::nullopt; return std::nullopt;
@ -122,7 +122,7 @@ std::optional<std::string> get_env_as_lowercase(const std::string & name) {
* @brief Verify whether the environment variable is a valid value. * @brief Verify whether the environment variable is a valid value.
*/ */
bool parse_bool(const std::string & value) { bool parse_bool(const std::string & value) {
static const std::unordered_set<std::string> valid_values = { "on", "1", "yes", "y", "enable", "true" }; std::unordered_set<std::string> valid_values = { "on", "1", "yes", "y", "enable", "true" };
return valid_values.find(value) != valid_values.end(); return valid_values.find(value) != valid_values.end();
} }
@ -259,7 +259,7 @@ struct ggml_cann_pool_buf_prio : public ggml_cann_pool {
* @param device The device ID to associate with this buffer pool. * @param device The device ID to associate with this buffer pool.
*/ */
explicit ggml_cann_pool_buf_prio(int device) : device(device) { explicit ggml_cann_pool_buf_prio(int device) : device(device) {
disable_clean = parse_bool(get_env_as_lowercase("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or("")); disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
} }
/** /**
@ -452,7 +452,7 @@ struct ggml_cann_pool_buf : public ggml_cann_pool {
* @param device The device ID to associate with this buffer pool. * @param device The device ID to associate with this buffer pool.
*/ */
explicit ggml_cann_pool_buf(int device) : device(device) { explicit ggml_cann_pool_buf(int device) : device(device) {
disable_clean = parse_bool(get_env_as_lowercase("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or("")); disable_clean = parse_bool(get_env("GGML_CANN_DISABLE_BUF_POOL_CLEAN").value_or(""));
} }
/** /**
@ -764,7 +764,7 @@ struct ggml_cann_pool_vmm : public ggml_cann_pool {
* @return A unique pointer to the created CANN pool. * @return A unique pointer to the created CANN pool.
*/ */
std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(int device) { std::unique_ptr<ggml_cann_pool> ggml_backend_cann_context::new_pool_for_device(int device) {
std::string mem_pool_type = get_env_as_lowercase("GGML_CANN_MEM_POOL").value_or(""); std::string mem_pool_type = get_env("GGML_CANN_MEM_POOL").value_or("");
if (mem_pool_type == "prio") { if (mem_pool_type == "prio") {
GGML_LOG_INFO("%s: device %d use buffer pool with priority queue\n", __func__, device); GGML_LOG_INFO("%s: device %d use buffer pool with priority queue\n", __func__, device);
@ -1217,7 +1217,7 @@ static void ggml_backend_cann_buffer_set_tensor(ggml_backend_buffer_t buffer,
// Why aclrtSynchronizeDevice? // Why aclrtSynchronizeDevice?
// Only check env once. // Only check env once.
static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on")); static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
if (!need_transform(tensor->type)) { if (!need_transform(tensor->type)) {
ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE)); ACL_CHECK(aclrtMemcpy((char *) tensor->data + offset, size, data, size, ACL_MEMCPY_HOST_TO_DEVICE));
if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) { if (weight_to_nz && is_matmul_weight((const ggml_tensor *) tensor)) {
@ -1442,7 +1442,7 @@ static size_t ggml_backend_cann_buffer_type_get_alloc_size(ggml_backend_buffer_t
int64_t ne0 = tensor->ne[0]; int64_t ne0 = tensor->ne[0];
// Only check env once. // Only check env once.
static bool weight_to_nz = parse_bool(get_env_as_lowercase("GGML_CANN_WEIGHT_NZ").value_or("on")); static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
// last line must bigger than 32, because every single op deal at // last line must bigger than 32, because every single op deal at
// least 32 bytes. // least 32 bytes.
@ -1888,7 +1888,6 @@ static bool ggml_cann_compute_forward(ggml_backend_cann_context & ctx, struct gg
break; break;
case GGML_OP_OUT_PROD: case GGML_OP_OUT_PROD:
ggml_cann_out_prod(ctx, dst); ggml_cann_out_prod(ctx, dst);
break;
case GGML_OP_SSM_CONV: case GGML_OP_SSM_CONV:
ggml_cann_ssm_conv(ctx, dst); ggml_cann_ssm_conv(ctx, dst);
break; break;
@ -2078,40 +2077,6 @@ static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream())); ACL_CHECK(aclrtSynchronizeStream(cann_ctx->stream()));
} }
/**
* @brief Check if CANN backend can fuse the specified operation sequence
*
* This function determines whether an operation sequence starting from the specified node
* can be fused into an optimized operation in the CANN backend. Operation fusion can reduce
* memory access overhead and improve computational efficiency.
*
* @param cgraph Pointer to the computation graph
* @param node_idx Index of the starting node in the computation graph
* @param ops Sequence of operation types to check for fusion
* @return true if the operations can be fused
* @return false if the operations cannot be fused
*/
static bool ggml_cann_can_fuse(const struct ggml_cgraph * cgraph,
int node_idx,
std::initializer_list<enum ggml_op> ops) {
if (!ggml_can_fuse(cgraph, node_idx, ops)) {
return false;
}
// CANN backend supports fusing ADD + RMS_NORM operations
if ((ops.size() == 2) && ops.begin()[0] == GGML_OP_ADD && ops.begin()[1] == GGML_OP_RMS_NORM) {
ggml_tensor * add_node = cgraph->nodes[node_idx];
// TODO: support broadcast for ADD + RMS_NORM
if (add_node->src[0]->ne[0] != add_node->src[1]->ne[0] || add_node->src[0]->ne[1] != add_node->src[1]->ne[1] ||
add_node->src[0]->ne[2] != add_node->src[1]->ne[2] || add_node->src[0]->ne[3] != add_node->src[1]->ne[3]) {
return false;
}
return true;
}
return false;
}
/** /**
* @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API. * @brief Evaluate the computation graph and optionally capture or execute it using CANN graph API.
* *
@ -2136,18 +2101,9 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
#endif // USE_ACL_GRAPH #endif // USE_ACL_GRAPH
// Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph. // Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.
// With the use of CANN graphs, the execution will be performed by the graph launch. // With the use of CANN graphs, the execution will be performed by the graph launch.
static bool opt_fusion = parse_bool(get_env_as_lowercase("GGML_CANN_OPERATOR_FUSION").value_or(""));
if (!use_cann_graph || cann_graph_capture_required) { if (!use_cann_graph || cann_graph_capture_required) {
for (int i = 0; i < cgraph->n_nodes; i++) { for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i]; ggml_tensor * node = cgraph->nodes[i];
if (opt_fusion) {
if (ggml_cann_can_fuse(cgraph, i, { GGML_OP_ADD, GGML_OP_RMS_NORM })) {
ggml_cann_op_add_rms_norm_fused(*cann_ctx, node, cgraph->nodes[i + 1]);
i++;
continue;
}
}
if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || if (ggml_is_empty(node) || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE ||
node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) { node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE || node->op == GGML_OP_NONE) {
@ -2201,7 +2157,7 @@ static enum ggml_status ggml_backend_cann_graph_compute(ggml_backend_t backend,
#ifdef USE_ACL_GRAPH #ifdef USE_ACL_GRAPH
bool use_cann_graph = true; bool use_cann_graph = true;
static bool prefill_use_graph = parse_bool(get_env_as_lowercase("GGML_CANN_PREFILL_USE_GRAPH").value_or("")); static bool prefill_use_graph = parse_bool(get_env("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
if (!prefill_use_graph) { if (!prefill_use_graph) {
// Do not use acl_graph for prefill. // Do not use acl_graph for prefill.
for (int i = 0; i < cgraph->n_nodes; i++) { for (int i = 0; i < cgraph->n_nodes; i++) {

View File

@ -54,20 +54,6 @@ if (CUDAToolkit_FOUND)
enable_language(CUDA) enable_language(CUDA)
# TODO: Remove once CCCL 3.2 has been released and bundled with CUDA Toolkit
if (GGML_CUDA_CUB_3DOT2)
include(FetchContent)
FetchContent_Declare(
CCCL
GIT_REPOSITORY https://github.com/nvidia/cccl.git
GIT_TAG v3.2.0-rc2
GIT_SHALLOW TRUE
)
FetchContent_MakeAvailable(CCCL)
endif()
# Replace any plain 12X CUDA architectures with their "architecture-specific" equivalents 12Xa. # Replace any plain 12X CUDA architectures with their "architecture-specific" equivalents 12Xa.
# 12X is forwards-compatible, 12Xa is not. # 12X is forwards-compatible, 12Xa is not.
# Notably the Blackwell FP4 tensor core instructions are not forwards compatible and therefore need 12Xa. # Notably the Blackwell FP4 tensor core instructions are not forwards compatible and therefore need 12Xa.
@ -157,9 +143,6 @@ if (CUDAToolkit_FOUND)
# As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library # As of 12.3.1 CUDA Toolkit for Windows does not offer a static cublas library
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas) target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas)
else () else ()
if (GGML_CUDA_CUB_3DOT2)
target_link_libraries(ggml-cuda PRIVATE CCCL::CCCL)
endif()
if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "10.1") if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL "10.1")
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static) target_link_libraries(ggml-cuda PRIVATE CUDA::cudart_static CUDA::cublas_static CUDA::cublasLt_static)
else() else()
@ -167,9 +150,6 @@ if (CUDAToolkit_FOUND)
endif() endif()
endif() endif()
else() else()
if (GGML_CUDA_CUB_3DOT2)
target_link_libraries(ggml-cuda PRIVATE CCCL::CCCL)
endif()
target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas) target_link_libraries(ggml-cuda PRIVATE CUDA::cudart CUDA::cublas)
endif() endif()
@ -238,10 +218,6 @@ if (CUDAToolkit_FOUND)
if (NOT MSVC) if (NOT MSVC)
list(APPEND CUDA_CXX_FLAGS -Wno-pedantic) list(APPEND CUDA_CXX_FLAGS -Wno-pedantic)
else()
# CCCL 3.2 onwards will require a cpp-standard-compliant preprocessor for MSVC
# https://github.com/NVIDIA/cccl/pull/6827
list(APPEND CUDA_CXX_FLAGS /Zc:preprocessor)
endif() endif()
list(JOIN CUDA_CXX_FLAGS " " CUDA_CXX_FLAGS_JOINED) # pass host compiler flags as a single argument list(JOIN CUDA_CXX_FLAGS " " CUDA_CXX_FLAGS_JOINED) # pass host compiler flags as a single argument

View File

@ -22,13 +22,13 @@ static __global__ void init_offsets(int * offsets, const int ncols, const int nr
} }
#ifdef GGML_CUDA_USE_CUB #ifdef GGML_CUDA_USE_CUB
void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool, static void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
const float * x, const float * x,
int * dst, int * dst,
const int ncols, const int ncols,
const int nrows, const int nrows,
ggml_sort_order order, ggml_sort_order order,
cudaStream_t stream) { cudaStream_t stream) {
ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols * nrows); ggml_cuda_pool_alloc<int> temp_indices_alloc(pool, ncols * nrows);
ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows); ggml_cuda_pool_alloc<float> temp_keys_alloc(pool, ncols * nrows);
ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1); ggml_cuda_pool_alloc<int> offsets_alloc(pool, nrows + 1);
@ -49,49 +49,28 @@ void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
size_t temp_storage_bytes = 0; size_t temp_storage_bytes = 0;
if (order == GGML_SORT_ORDER_ASC) { if (order == GGML_SORT_ORDER_ASC) {
if (nrows == 1) { DeviceSegmentedRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
DeviceRadixSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) temp_indices, dst, // values (indices)
temp_indices, dst, // values (indices) ncols * nrows, nrows, // num items, num segments
ncols, 0, sizeof(float) * 8, stream); d_offsets, d_offsets + 1, 0, sizeof(float) * 8, // all bits
} else { stream);
DeviceSegmentedSort::SortPairs(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place)
temp_indices, dst, // values (indices)
ncols * nrows, nrows, // num items, num segments
d_offsets, d_offsets + 1, stream);
}
} else { } else {
if (nrows == 1) { DeviceSegmentedRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
DeviceRadixSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, 0,
temp_indices, dst, // values (indices) sizeof(float) * 8, stream);
ncols, 0, sizeof(float) * 8, stream);
} else {
DeviceSegmentedSort::SortPairsDescending(nullptr, temp_storage_bytes, temp_keys, temp_keys, temp_indices,
dst, ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
}
} }
ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes); ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
void * d_temp_storage = temp_storage_alloc.get(); void * d_temp_storage = temp_storage_alloc.get();
if (order == GGML_SORT_ORDER_ASC) { if (order == GGML_SORT_ORDER_ASC) {
if (nrows == 1) { DeviceSegmentedRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
DeviceRadixSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) ncols * nrows, nrows, d_offsets, d_offsets + 1, 0, sizeof(float) * 8,
temp_indices, dst, // values (indices) stream);
ncols, 0, sizeof(float) * 8, stream);
} else {
DeviceSegmentedSort::SortPairs(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, temp_indices, dst,
ncols * nrows, nrows, d_offsets, d_offsets + 1, stream);
}
} else { } else {
if (nrows == 1) { DeviceSegmentedRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
DeviceRadixSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys, // keys (in-place) temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
temp_indices, dst, // values (indices) 0, sizeof(float) * 8, stream);
ncols, 0, sizeof(float) * 8, stream);
} else {
DeviceSegmentedSort::SortPairsDescending(d_temp_storage, temp_storage_bytes, temp_keys, temp_keys,
temp_indices, dst, ncols * nrows, nrows, d_offsets, d_offsets + 1,
stream);
}
} }
} }
#endif // GGML_CUDA_USE_CUB #endif // GGML_CUDA_USE_CUB
@ -162,12 +141,12 @@ static int next_power_of_2(int x) {
return n; return n;
} }
void argsort_f32_i32_cuda_bitonic(const float * x, static void argsort_f32_i32_cuda_bitonic(const float * x,
int * dst, int * dst,
const int ncols, const int ncols,
const int nrows, const int nrows,
ggml_sort_order order, ggml_sort_order order,
cudaStream_t stream) { cudaStream_t stream) {
// bitonic sort requires ncols to be power of 2 // bitonic sort requires ncols to be power of 2
const int ncols_pad = next_power_of_2(ncols); const int ncols_pad = next_power_of_2(ncols);

View File

@ -1,19 +1,3 @@
#include "common.cuh" #include "common.cuh"
void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
#ifdef GGML_CUDA_USE_CUB
void argsort_f32_i32_cuda_cub(ggml_cuda_pool & pool,
const float * x,
int * dst,
const int ncols,
const int nrows,
ggml_sort_order order,
cudaStream_t stream);
#endif // GGML_CUDA_USE_CUB
void argsort_f32_i32_cuda_bitonic(const float * x,
int * dst,
const int ncols,
const int nrows,
ggml_sort_order order,
cudaStream_t stream);

View File

@ -950,16 +950,15 @@ struct ggml_cuda_device_info {
int device_count; int device_count;
struct cuda_device_info { struct cuda_device_info {
int cc; // compute capability int cc; // compute capability
int nsm; // number of streaming multiprocessors int nsm; // number of streaming multiprocessors
size_t smpb; // max. shared memory per block size_t smpb; // max. shared memory per block
size_t smpbo; // max. shared memory per block (with opt-in) size_t smpbo; // max. shared memory per block (with opt-in)
bool integrated; // Device is integrated as opposed to discrete bool integrated; // Device is integrated as opposed to discrete
bool vmm; // virtual memory support bool vmm; // virtual memory support
size_t vmm_granularity; // granularity of virtual memory size_t vmm_granularity; // granularity of virtual memory
size_t total_vram; size_t total_vram;
int warp_size; // Number of threads in a dispatch int warp_size; // Number of threads in a dispatch
bool supports_cooperative_launch; // whether cooperative launch is supported
}; };
cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {}; cuda_device_info devices[GGML_CUDA_MAX_DEVICES] = {};
@ -1036,7 +1035,7 @@ struct ggml_tensor_extra_gpu {
#define USE_CUDA_GRAPH #define USE_CUDA_GRAPH
#endif #endif
struct ggml_cuda_graph_node_properties { struct ggml_graph_node_properties {
void * node_address; void * node_address;
ggml_op node_op; ggml_op node_op;
int64_t ne[GGML_MAX_DIMS]; int64_t ne[GGML_MAX_DIMS];
@ -1059,27 +1058,12 @@ struct ggml_cuda_graph {
cudaGraphExec_t instance = nullptr; cudaGraphExec_t instance = nullptr;
size_t num_nodes = 0; size_t num_nodes = 0;
std::vector<cudaGraphNode_t> nodes; std::vector<cudaGraphNode_t> nodes;
std::vector<cudaKernelNodeParams> params;
bool disable_due_to_gpu_arch = false; bool disable_due_to_gpu_arch = false;
bool disable_due_to_too_many_updates = false; bool disable_due_to_too_many_updates = false;
bool disable_due_to_failed_graph_capture = false;
int number_consecutive_updates = 0; int number_consecutive_updates = 0;
std::vector<ggml_cuda_graph_node_properties> props; std::vector<ggml_graph_node_properties> ggml_graph_properties;
void record_update(bool use_graph, bool update_required) {
if (use_graph && update_required) {
number_consecutive_updates++;
} else {
number_consecutive_updates = 0;
}
if (number_consecutive_updates >= 4) {
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
disable_due_to_too_many_updates = true;
}
}
bool is_enabled() const {
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
return !(disable_due_to_gpu_arch || disable_cuda_graphs_due_to_env || disable_due_to_too_many_updates);
}
#endif #endif
}; };

View File

@ -12,11 +12,11 @@ const int CUDA_CPY_BLOCK_NM = 8; // block size of 3rd dimension if available
const int CUDA_CPY_BLOCK_ROWS = 8; // block dimension for marching through rows const int CUDA_CPY_BLOCK_ROWS = 8; // block dimension for marching through rows
template <cpy_kernel_t cpy_1> template <cpy_kernel_t cpy_1>
static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne, static __global__ void cpy_scalar(const char * cx, char * cdst, const int ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int64_t nb12, const int64_t nb13) { const int nb12, const int nb13) {
const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= ne) { if (i >= ne) {
return; return;
@ -40,10 +40,10 @@ static __global__ void cpy_scalar(const char * cx, char * cdst, const int64_t ne
} }
template <typename T> template <typename T>
static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int64_t ne, static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const int ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int64_t nb12, const int64_t nb13) { const int nb12, const int nb13) {
const T* src = reinterpret_cast<const T*>(cx); const T* src = reinterpret_cast<const T*>(cx);
T* dst = reinterpret_cast<T*>(cdst); T* dst = reinterpret_cast<T*>(cdst);
@ -117,60 +117,60 @@ static __device__ void cpy_blck_q_f32(const char * cxi, char * cdsti) {
} }
template <cpy_kernel_t cpy_blck, int qk> template <cpy_kernel_t cpy_blck, int qk>
static __global__ void cpy_f32_q(const char * cx, char * cdst, const int64_t ne, static __global__ void cpy_f32_q(const char * cx, char * cdst, const int ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int64_t nb12, const int64_t nb13) { const int nb12, const int nb13) {
const int64_t i = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*qk; const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
if (i >= ne) { if (i >= ne) {
return; return;
} }
const int64_t i03 = i/(ne00 * ne01 * ne02); const int i03 = i/(ne00 * ne01 * ne02);
const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
const int64_t x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; const int x_offset = i00*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
const int64_t i13 = i/(ne10 * ne11 * ne12); const int i13 = i/(ne10 * ne11 * ne12);
const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
const int64_t dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13; const int dst_offset = (i10/qk)*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
cpy_blck(cx + x_offset, cdst + dst_offset); cpy_blck(cx + x_offset, cdst + dst_offset);
} }
template <cpy_kernel_t cpy_blck, int qk> template <cpy_kernel_t cpy_blck, int qk>
static __global__ void cpy_q_f32(const char * cx, char * cdst, const int64_t ne, static __global__ void cpy_q_f32(const char * cx, char * cdst, const int ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11,
const int64_t nb12, const int64_t nb13) { const int nb12, const int nb13) {
const int64_t i = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*qk; const int i = (blockDim.x*blockIdx.x + threadIdx.x)*qk;
if (i >= ne) { if (i >= ne) {
return; return;
} }
const int64_t i03 = i/(ne00 * ne01 * ne02); const int i03 = i/(ne00 * ne01 * ne02);
const int64_t i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01); const int i02 = (i - i03*ne00*ne01*ne02 )/ (ne00*ne01);
const int64_t i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00; const int i01 = (i - i03*ne00*ne01*ne02 - i02*ne01*ne00) / ne00;
const int64_t i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00; const int i00 = i - i03*ne00*ne01*ne02 - i02*ne01*ne00 - i01*ne00;
const int64_t x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03; const int x_offset = (i00/qk)*nb00 + i01*nb01 + i02*nb02 + i03 * nb03;
const int64_t i13 = i/(ne10 * ne11 * ne12); const int i13 = i/(ne10 * ne11 * ne12);
const int64_t i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11); const int i12 = (i - i13*ne10*ne11*ne12) / (ne10*ne11);
const int64_t i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10; const int i11 = (i - i13*ne10*ne11*ne12 - i12*ne10*ne11) / ne10;
const int64_t i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10; const int i10 = i - i13*ne10*ne11*ne12 - i12*ne10*ne11 - i11*ne10;
const int64_t dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13; const int dst_offset = i10*nb10 + i11*nb11 + i12*nb12 + i13*nb13;
cpy_blck(cx + x_offset, cdst + dst_offset); cpy_blck(cx + x_offset, cdst + dst_offset);
} }
template<typename src_t, typename dst_t> template<typename src_t, typename dst_t>
static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const int64_t ne) { static __global__ void cpy_scalar_contiguous(const char * cx, char * cdst, const int64_t ne) {
const int64_t i = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; const int64_t i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= ne) { if (i >= ne) {
return; return;
@ -188,20 +188,19 @@ static void ggml_cpy_scalar_contiguous_cuda(
cudaStream_t stream) { cudaStream_t stream) {
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_scalar_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>> cpy_scalar_contiguous<src_t, dst_t><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne); (cx, cdst, ne);
} }
template<typename src_t, typename dst_t, bool transposed = false> template<typename src_t, typename dst_t, bool transposed = false>
static void ggml_cpy_scalar_cuda( static void ggml_cpy_scalar_cuda(
const char * cx, char * cdst, const int64_t ne, const char * cx, char * cdst, const int ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
if (transposed) { if (transposed) {
GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed GGML_ASSERT(ne == ne00*ne01*ne02); // ne[3] is 1 assumed
int64_t ne00n, ne01n, ne02n; int ne00n, ne01n, ne02n;
if (nb00 <= nb02) { // most likely safe to handle nb00 = nb02 case here if (nb00 <= nb02) { // most likely safe to handle nb00 = nb02 case here
ne00n = ne00; ne00n = ne00;
ne01n = ne01; ne01n = ne01;
@ -212,159 +211,143 @@ static void ggml_cpy_scalar_cuda(
ne02n = 1; ne02n = 1;
} }
int64_t grid_x = (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D; dim3 dimGrid( (ne01n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
int64_t grid_y = (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D; (ne00n + CUDA_CPY_TILE_DIM_2D - 1) / CUDA_CPY_TILE_DIM_2D,
int64_t grid_z = (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM; (ne/(ne01n*ne00n) + CUDA_CPY_BLOCK_NM - 1) / CUDA_CPY_BLOCK_NM);
GGML_ASSERT(grid_x < UINT_MAX);
GGML_ASSERT(grid_y < USHRT_MAX);
GGML_ASSERT(grid_z < USHRT_MAX);
dim3 dimGrid(grid_x, grid_y, grid_z);
dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1); dim3 dimBlock(CUDA_CPY_TILE_DIM_2D, CUDA_CPY_BLOCK_ROWS, 1);
cpy_scalar_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>> cpy_scalar_transpose<dst_t><<<dimGrid, dimBlock, 0, stream>>>
(cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); (cx, cdst, ne, ne00n, ne01n, ne02n, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} else { } else {
const int64_t num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_scalar<cpy_1_scalar<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>> cpy_scalar<cpy_1_scalar<src_t, dst_t>><<<num_blocks, CUDA_CPY_BLOCK_SIZE, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
} }
static void ggml_cpy_f32_q8_0_cuda( static void ggml_cpy_f32_q8_0_cuda(
const char * cx, char * cdst, const int64_t ne, const char * cx, char * cdst, const int ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK8_0 == 0); GGML_ASSERT(ne % QK8_0 == 0);
const int64_t num_blocks = ne / QK8_0; const int num_blocks = ne / QK8_0;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>> cpy_f32_q<cpy_blck_f32_q8_0, QK8_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_q8_0_f32_cuda( static void ggml_cpy_q8_0_f32_cuda(
const char * cx, char * cdst, const int64_t ne, const char * cx, char * cdst, const int ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
const int64_t num_blocks = ne; const int num_blocks = ne;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>> cpy_q_f32<cpy_blck_q8_0_f32, QK8_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_f32_q4_0_cuda( static void ggml_cpy_f32_q4_0_cuda(
const char * cx, char * cdst, const int64_t ne, const char * cx, char * cdst, const int ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK4_0 == 0); GGML_ASSERT(ne % QK4_0 == 0);
const int64_t num_blocks = ne / QK4_0; const int num_blocks = ne / QK4_0;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>> cpy_f32_q<cpy_blck_f32_q4_0, QK4_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_q4_0_f32_cuda( static void ggml_cpy_q4_0_f32_cuda(
const char * cx, char * cdst, const int64_t ne, const char * cx, char * cdst, const int ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int ne00, const int ne01, const int ne02,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int nb00, const int nb01, const int nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int nb03, const int ne10, const int ne11, const int ne12,
const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, const int nb10, const int nb11, const int nb12, const int nb13,
cudaStream_t stream) { cudaStream_t stream) {
const int64_t num_blocks = ne; const int num_blocks = ne;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>( cpy_q_f32<cpy_blck_q_f32<dequantize_q4_0, QK4_0>, QK4_0><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13); ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_f32_q4_1_cuda( static void ggml_cpy_f32_q4_1_cuda(
const char * cx, char * cdst, const int64_t ne, const char * cx, char * cdst, const int ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK4_1 == 0); GGML_ASSERT(ne % QK4_1 == 0);
const int64_t num_blocks = ne / QK4_1; const int num_blocks = ne / QK4_1;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>> cpy_f32_q<cpy_blck_f32_q4_1, QK4_1><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_q4_1_f32_cuda( static void ggml_cpy_q4_1_f32_cuda(
const char * cx, char * cdst, const int64_t ne, const char * cx, char * cdst, const int ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int ne00, const int ne01, const int ne02,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int nb00, const int nb01, const int nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int nb03, const int ne10, const int ne11, const int ne12,
const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, const int nb10, const int nb11, const int nb12, const int nb13,
cudaStream_t stream) { cudaStream_t stream) {
const int64_t num_blocks = ne; const int num_blocks = ne;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>( cpy_q_f32<cpy_blck_q_f32<dequantize_q4_1, QK4_1>, QK4_1><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13); ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_f32_q5_0_cuda( static void ggml_cpy_f32_q5_0_cuda(
const char * cx, char * cdst, const int64_t ne, const char * cx, char * cdst, const int ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK5_0 == 0); GGML_ASSERT(ne % QK5_0 == 0);
const int64_t num_blocks = ne / QK5_0; const int num_blocks = ne / QK5_0;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>> cpy_f32_q<cpy_blck_f32_q5_0, QK5_0><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_q5_0_f32_cuda( static void ggml_cpy_q5_0_f32_cuda(
const char * cx, char * cdst, const int64_t ne, const char * cx, char * cdst, const int ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int ne00, const int ne01, const int ne02,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int nb00, const int nb01, const int nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int nb03, const int ne10, const int ne11, const int ne12,
const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, const int nb10, const int nb11, const int nb12, const int nb13,
cudaStream_t stream) { cudaStream_t stream) {
const int64_t num_blocks = ne; const int num_blocks = ne;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>( cpy_q_f32<cpy_blck_q_f32<dequantize_q5_0, QK5_0>, QK5_0><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13); ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_f32_q5_1_cuda( static void ggml_cpy_f32_q5_1_cuda(
const char * cx, char * cdst, const int64_t ne, const char * cx, char * cdst, const int ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK5_1 == 0); GGML_ASSERT(ne % QK5_1 == 0);
const int64_t num_blocks = ne / QK5_1; const int num_blocks = ne / QK5_1;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>> cpy_f32_q<cpy_blck_f32_q5_1, QK5_1><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_q5_1_f32_cuda( static void ggml_cpy_q5_1_f32_cuda(
const char * cx, char * cdst, const int64_t ne, const char * cx, char * cdst, const int ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int ne00, const int ne01, const int ne02,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int nb00, const int nb01, const int nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int nb03, const int ne10, const int ne11, const int ne12,
const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, const int nb10, const int nb11, const int nb12, const int nb13,
cudaStream_t stream) { cudaStream_t stream) {
const int64_t num_blocks = ne; const int num_blocks = ne;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>( cpy_q_f32<cpy_blck_q_f32<dequantize_q5_1, QK5_1>, QK5_1><<<num_blocks, 1, 0, stream>>>(
cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03,
ne10, ne11, ne12, nb10, nb11, nb12, nb13); ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
static void ggml_cpy_f32_iq4_nl_cuda( static void ggml_cpy_f32_iq4_nl_cuda(
const char * cx, char * cdst, const int64_t ne, const char * cx, char * cdst, const int ne,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t nb00, const int64_t nb01, const int64_t nb02, const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02,
const int64_t nb03, const int64_t ne10, const int64_t ne11, const int64_t ne12, const int64_t nb10, const int64_t nb11, const int64_t nb12, const int64_t nb13, cudaStream_t stream) { const int nb03, const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, cudaStream_t stream) {
GGML_ASSERT(ne % QK4_NL == 0); GGML_ASSERT(ne % QK4_NL == 0);
const int64_t num_blocks = ne / QK4_NL; const int num_blocks = ne / QK4_NL;
GGML_ASSERT(num_blocks < UINT_MAX);
cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>> cpy_f32_q<cpy_blck_f32_iq4_nl, QK4_NL><<<num_blocks, 1, 0, stream>>>
(cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13); (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13);
} }
@ -373,6 +356,9 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
const int64_t ne = ggml_nelements(src0); const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne == ggml_nelements(src1)); GGML_ASSERT(ne == ggml_nelements(src1));
GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
const int64_t ne00 = src0->ne[0]; const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1]; const int64_t ne01 = src0->ne[1];
const int64_t ne02 = src0->ne[2]; const int64_t ne02 = src0->ne[2];

View File

@ -5,7 +5,7 @@
#include "ggml.h" #include "ggml.h"
#ifdef GGML_CUDA_USE_CUB #ifdef GGML_CUDA_USE_CUB
# include <cub/cub.cuh> # include <cub/block/block_scan.cuh>
#endif // GGML_CUDA_USE_CUB #endif // GGML_CUDA_USE_CUB
template<typename T, int BLOCK_SIZE> template<typename T, int BLOCK_SIZE>
@ -185,34 +185,9 @@ static __global__ void cumsum_kernel(
} }
} }
#ifdef GGML_CUDA_USE_CUB
template <typename T>
static void cumsum_cub(ggml_cuda_pool & pool,
const T * src,
T * dst,
int64_t ne,
cudaStream_t stream) {
size_t tmp_size = 0;
// Query how much temp storage CUDA UnBound (CUB) needs
cub::DeviceScan::InclusiveSum(nullptr, // d_temp_storage (null = just query size)
tmp_size, // reference to size (will be set by CUB)
src, // input pointer
dst, // output pointer
ne, // number of elements
stream // CUDA stream to use
);
ggml_cuda_pool_alloc<uint8_t> tmp_alloc(pool, tmp_size);
// Perform the inclusive scan
cub::DeviceScan::InclusiveSum((void *) tmp_alloc.get(), tmp_size, src, dst, ne, stream);
}
#endif // GGML_CUDA_USE_CUB
template<typename T> template<typename T>
static void cumsum_cuda( static void cumsum_cuda(
[[maybe_unused]] ggml_backend_cuda_context & ctx, const T * src, T * dst, const T * src, T * dst,
const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03, const int64_t ne00, const int64_t ne01, const int64_t ne02, const int64_t ne03,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03, const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3, const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
@ -226,15 +201,6 @@ static void cumsum_cuda(
if (is_contiguous) { if (is_contiguous) {
use_cub = true; use_cub = true;
const int64_t nrows = ne01 * ne02 * ne03;
// TODO: Compare with DeviceSegmentedScan::InclusiveSegmentedSum for nrows > 1 once InclusiveSegmentedSum is released
// Heuristics were determined as part of https://github.com/ggml-org/llama.cpp/pull/17004
if (((nrows == 1) && (ne00 > 1024)) || (ne00 / nrows > 4096)) {
for (int i=0; i<nrows; i++) {
cumsum_cub(ctx.pool(), src + i * ne00, dst + i * ne00, ne00, stream);
}
return;
}
} }
#endif // GGML_CUDA_USE_CUB #endif // GGML_CUDA_USE_CUB
dim3 grid_dims(ne01, ne02, ne03); dim3 grid_dims(ne01, ne02, ne03);
@ -273,7 +239,7 @@ void ggml_cuda_op_cumsum(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
{ {
cumsum_cuda( cumsum_cuda(
ctx, (const float *)src0->data, (float *)dst->data, (const float *)src0->data, (float *)dst->data,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3], src0->nb[0], src0->nb[1], src0->nb[2], src0->nb[3],
dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3], dst->nb[0], dst->nb[1], dst->nb[2], dst->nb[3],

View File

@ -11,12 +11,10 @@
#define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs. #define SOFTMAX_FTZ_THRESHOLD -20.0f // Softmax exp. of values smaller than this are flushed to zero to avoid NaNs.
// log(2) = 0.6931, by adding this to the KQ maximum used for the softmax the numerical range representable // log(2) = 0.6931, by adding this to the KQ maximum used for the softmax the numerical range representable
// by the VKQ accumulators is effectively being shifted up by a factor of 2. // by the VKQ accumulators is effectively being shifted up by a factor of 8.
// This reduces issues with numerical overflow but also causes larger values to be flushed to zero. // This reduces issues with numerical overflow but also causes larger values to be flushed to zero.
// However, as the output from FlashAttention will usually be used as an input for a matrix multiplication this should be negligible. // However, as the output from FlashAttention will usually be used as an input for a matrix multiplication this should be negligible.
// Still, the value range should be shifted as much as necessary but as little as possible. #define FATTN_KQ_MAX_OFFSET 0.6931f
// The macro on the following line shifts it by a factor of 2**3=8, as was needed to fix https://github.com/ggml-org/llama.cpp/issues/18606 .
#define FATTN_KQ_MAX_OFFSET (3.0f*0.6931f)
typedef void (* fattn_kernel_t)( typedef void (* fattn_kernel_t)(
const char * __restrict__ Q, const char * __restrict__ Q,
@ -920,9 +918,7 @@ void launch_fattn(
blocks_num.y = 1; blocks_num.y = 1;
blocks_num.z = 1; blocks_num.z = 1;
if (ntiles_total % blocks_num.x != 0) { // Fixup is only needed if the SMs work on fractional tiles. dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float));
dst_tmp_meta.alloc((size_t(blocks_num.x) * ncols * (2 + DV/2)));
}
} else { } else {
const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size. const int ntiles_KQ = (K->ne[1] + nbatch_fa - 1) / nbatch_fa; // Max. number of parallel blocks limited by tensor size.

View File

@ -531,7 +531,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) { for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::I) {
#pragma unroll #pragma unroll
for (int l = 0; l < T_C_KQ::ne; ++l) { for (int l = 0; l < T_C_KQ::ne; ++l) {
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::I + T_C_KQ::get_i(l) < k_VKQ_sup) { if (!oob_check || k0 + T_C_KQ::get_i(l) < k_VKQ_sup) {
KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET); KQ_max_new[l % 2] = fmaxf(KQ_max_new[l % 2], KQ_C[k0/(np*T_C_KQ::I)].x[l] + FATTN_KQ_MAX_OFFSET);
} }
} }
@ -583,7 +583,7 @@ static __device__ __forceinline__ void flash_attn_ext_f16_iter(
for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) { for (int k0 = 0; k0 < nbatch_fa; k0 += np*T_C_KQ::J) {
#pragma unroll #pragma unroll
for (int l = 0; l < T_C_KQ::ne; ++l) { for (int l = 0; l < T_C_KQ::ne; ++l) {
if (!oob_check || k0 + (threadIdx.y % np)*T_C_KQ::J + T_C_KQ::get_j(l) < k_VKQ_sup) { if (!oob_check || k0 + T_C_KQ::get_j(l) < k_VKQ_sup) {
// Turing + Volta: // Turing + Volta:
KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET); KQ_max_new[(l/2) % 2] = fmaxf(KQ_max_new[(l/2) % 2], KQ_C[(k0/(np*T_C_KQ::J))].x[l] + FATTN_KQ_MAX_OFFSET);
} }

View File

@ -19,7 +19,6 @@
#include "ggml-cuda/count-equal.cuh" #include "ggml-cuda/count-equal.cuh"
#include "ggml-cuda/cpy.cuh" #include "ggml-cuda/cpy.cuh"
#include "ggml-cuda/cross-entropy-loss.cuh" #include "ggml-cuda/cross-entropy-loss.cuh"
#include "ggml-cuda/cumsum.cuh"
#include "ggml-cuda/diagmask.cuh" #include "ggml-cuda/diagmask.cuh"
#include "ggml-cuda/diag.cuh" #include "ggml-cuda/diag.cuh"
#include "ggml-cuda/fattn.cuh" #include "ggml-cuda/fattn.cuh"
@ -45,7 +44,6 @@
#include "ggml-cuda/ssm-scan.cuh" #include "ggml-cuda/ssm-scan.cuh"
#include "ggml-cuda/sum.cuh" #include "ggml-cuda/sum.cuh"
#include "ggml-cuda/sumrows.cuh" #include "ggml-cuda/sumrows.cuh"
#include "ggml-cuda/top-k.cuh"
#include "ggml-cuda/mean.cuh" #include "ggml-cuda/mean.cuh"
#include "ggml-cuda/tsembd.cuh" #include "ggml-cuda/tsembd.cuh"
#include "ggml-cuda/topk-moe.cuh" #include "ggml-cuda/topk-moe.cuh"
@ -203,6 +201,16 @@ static ggml_cuda_device_info ggml_cuda_init() {
GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES); GGML_ASSERT(info.device_count <= GGML_CUDA_MAX_DEVICES);
int64_t total_vram = 0; int64_t total_vram = 0;
#ifdef GGML_CUDA_FORCE_MMQ
GGML_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: yes\n", __func__);
#else
GGML_LOG_INFO("%s: GGML_CUDA_FORCE_MMQ: no\n", __func__);
#endif // GGML_CUDA_FORCE_MMQ
#ifdef GGML_CUDA_FORCE_CUBLAS
GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: yes\n", __func__);
#else
GGML_LOG_INFO("%s: GGML_CUDA_FORCE_CUBLAS: no\n", __func__);
#endif // GGML_CUDA_FORCE_CUBLAS
GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count); GGML_LOG_INFO("%s: found %d " GGML_CUDA_NAME " devices:\n", __func__, info.device_count);
std::vector<std::pair<int, std::string>> turing_devices_without_mma; std::vector<std::pair<int, std::string>> turing_devices_without_mma;
@ -233,14 +241,6 @@ static ggml_cuda_device_info ggml_cuda_init() {
info.devices[id].nsm = prop.multiProcessorCount; info.devices[id].nsm = prop.multiProcessorCount;
info.devices[id].smpb = prop.sharedMemPerBlock; info.devices[id].smpb = prop.sharedMemPerBlock;
info.devices[id].warp_size = prop.warpSize; info.devices[id].warp_size = prop.warpSize;
#ifndef GGML_USE_MUSA
int supports_coop_launch = 0;
CUDA_CHECK(cudaDeviceGetAttribute(&supports_coop_launch, cudaDevAttrCooperativeLaunch, id));
info.devices[id].supports_cooperative_launch = !!supports_coop_launch;
#else
info.devices[id].supports_cooperative_launch = false;
#endif // !(GGML_USE_MUSA)
#if defined(GGML_USE_HIP) #if defined(GGML_USE_HIP)
info.devices[id].smpbo = prop.sharedMemPerBlock; info.devices[id].smpbo = prop.sharedMemPerBlock;
@ -2687,9 +2687,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_SUM: case GGML_OP_SUM:
ggml_cuda_op_sum(ctx, dst); ggml_cuda_op_sum(ctx, dst);
break; break;
case GGML_OP_CUMSUM:
ggml_cuda_op_cumsum(ctx, dst);
break;
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
ggml_cuda_op_sum_rows(ctx, dst); ggml_cuda_op_sum_rows(ctx, dst);
break; break;
@ -2702,9 +2699,6 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_SSM_SCAN: case GGML_OP_SSM_SCAN:
ggml_cuda_op_ssm_scan(ctx, dst); ggml_cuda_op_ssm_scan(ctx, dst);
break; break;
case GGML_OP_TOP_K:
ggml_cuda_op_top_k(ctx, dst);
break;
case GGML_OP_ARGSORT: case GGML_OP_ARGSORT:
ggml_cuda_op_argsort(ctx, dst); ggml_cuda_op_argsort(ctx, dst);
break; break;
@ -2714,6 +2708,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_CROSS_ENTROPY_LOSS: case GGML_OP_CROSS_ENTROPY_LOSS:
ggml_cuda_cross_entropy_loss(ctx, dst); ggml_cuda_cross_entropy_loss(ctx, dst);
break; break;
case GGML_OP_CUMSUM:
ggml_cuda_op_cumsum(ctx, dst);
break;
case GGML_OP_TRI: case GGML_OP_TRI:
ggml_cuda_op_tri(ctx, dst); ggml_cuda_op_tri(ctx, dst);
break; break;
@ -2853,9 +2850,9 @@ static void ggml_backend_cuda_synchronize(ggml_backend_t backend) {
} }
#ifdef USE_CUDA_GRAPH #ifdef USE_CUDA_GRAPH
static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) { static bool check_node_graph_compatibility(ggml_cgraph * cgraph,
bool use_cuda_graph) {
bool use_cuda_graph = true;
// Loop over nodes in GGML graph to obtain info needed for CUDA graph // Loop over nodes in GGML graph to obtain info needed for CUDA graph
const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected"; const std::string gemma3n_per_layer_proj_src0_name = "inp_per_layer_selected";
@ -2915,41 +2912,41 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
return use_cuda_graph; return use_cuda_graph;
} }
static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) { static void set_ggml_graph_node_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
props->node_address = node->data; graph_node_properties->node_address = node->data;
props->node_op = node->op; graph_node_properties->node_op = node->op;
for (int i = 0; i < GGML_MAX_DIMS; i++) { for (int i = 0; i < GGML_MAX_DIMS; i++) {
props->ne[i] = node->ne[i]; graph_node_properties->ne[i] = node->ne[i];
props->nb[i] = node->nb[i]; graph_node_properties->nb[i] = node->nb[i];
} }
for (int i = 0; i < GGML_MAX_SRC; i++) { for (int i = 0; i < GGML_MAX_SRC; i++) {
props->src_address[i] = node->src[i] ? node->src[i]->data : nullptr; graph_node_properties->src_address[i] = node->src[i] ? node->src[i]->data : nullptr;
} }
memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS); memcpy(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS);
} }
static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) { static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_graph_node_properties * graph_node_properties) {
if (node->data != props->node_address && if (node->data != graph_node_properties->node_address &&
node->op != GGML_OP_VIEW) { node->op != GGML_OP_VIEW) {
return false; return false;
} }
if (node->op != props->node_op) { if (node->op != graph_node_properties->node_op) {
return false; return false;
} }
for (int i = 0; i < GGML_MAX_DIMS; i++) { for (int i = 0; i < GGML_MAX_DIMS; i++) {
if (node->ne[i] != props->ne[i]) { if (node->ne[i] != graph_node_properties->ne[i]) {
return false; return false;
} }
if (node->nb[i] != props->nb[i]) { if (node->nb[i] != graph_node_properties->nb[i]) {
return false; return false;
} }
} }
for (int i = 0; i < GGML_MAX_SRC; i++) { for (int i = 0; i < GGML_MAX_SRC; i++) {
if (node->src[i] && if (node->src[i] &&
node->src[i]->data != props->src_address[i] && node->src[i]->data != graph_node_properties->src_address[i] &&
node->op != GGML_OP_VIEW node->op != GGML_OP_VIEW
) { ) {
return false; return false;
@ -2957,55 +2954,44 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_
} }
if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) && if ((node->op == GGML_OP_SCALE || node->op == GGML_OP_GLU) &&
memcmp(props->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) { memcmp(graph_node_properties->op_params, node->op_params, GGML_MAX_OP_PARAMS) != 0) {
return false; return false;
} }
return true; return true;
} }
static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) { static bool is_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
bool res = false; bool cuda_graph_update_required = false;
if (cuda_ctx->cuda_graph->instance == nullptr) { if (cuda_ctx->cuda_graph->instance == nullptr) {
res = true; cuda_graph_update_required = true;
} }
// Check if the graph size has changed // Check if the graph size has changed
if (cuda_ctx->cuda_graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) { if (cuda_ctx->cuda_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) {
res = true; cuda_graph_update_required = true;
cuda_ctx->cuda_graph->props.resize(cgraph->n_nodes + cgraph->n_leafs); cuda_ctx->cuda_graph->ggml_graph_properties.resize(cgraph->n_nodes);
} }
// Loop over nodes in GGML graph to determine if CUDA graph update is required // Loop over nodes in GGML graph to determine if CUDA graph update is required
// and store properties to allow this comparison for the next token // and store properties to allow this comparison for the next token
for (int i = 0; i < cgraph->n_nodes; i++) { for (int i = 0; i < cgraph->n_nodes; i++) {
bool props_match = true; bool has_matching_properties = true;
if (!res) { if (!cuda_graph_update_required) {
props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &cuda_ctx->cuda_graph->props[i]); has_matching_properties = ggml_graph_node_has_matching_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
} }
if (!props_match) { if (!has_matching_properties) {
res = true; cuda_graph_update_required = true;
} }
ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[i], cgraph->nodes[i]); set_ggml_graph_node_properties(cgraph->nodes[i], &cuda_ctx->cuda_graph->ggml_graph_properties[i]);
} }
for (int i = 0; i < cgraph->n_leafs; i++) { return cuda_graph_update_required;
bool props_match= true;
if (!res) {
props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &cuda_ctx->cuda_graph->props[cgraph->n_nodes + i]);
}
if (!props_match) {
res = true;
}
ggml_cuda_graph_node_set_properties(&cuda_ctx->cuda_graph->props[cgraph->n_nodes + i], cgraph->leafs[i]);
}
return res;
} }
static void ggml_cuda_graph_update_executable(ggml_backend_cuda_context * cuda_ctx) { static void update_cuda_graph_executable(ggml_backend_cuda_context * cuda_ctx) {
#if CUDART_VERSION >= 12000 #if CUDART_VERSION >= 12000
cudaGraphExecUpdateResultInfo result_info; cudaGraphExecUpdateResultInfo result_info;
@ -3236,11 +3222,10 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
return false; return false;
} }
static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, const bool use_cuda_graph, const bool cuda_graph_update_required) { static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph,
bool graph_evaluated_or_captured = false; bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) {
// flag used to determine whether it is an integrated_gpu // flag used to determine whether it is an integrated_gpu
const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated; const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated;
ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context(); ggml_cuda_stream_context & stream_ctx = cuda_ctx->stream_context();
bool is_concurrent_event_active = false; bool is_concurrent_event_active = false;
@ -3278,7 +3263,6 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
should_launch_concurrent_events = should_launch_concurrent_events && event.is_valid(); should_launch_concurrent_events = should_launch_concurrent_events && event.is_valid();
} }
} }
if (should_launch_concurrent_events) { if (should_launch_concurrent_events) {
// Restore original node order within each concurrent region to enable fusion within streams // Restore original node order within each concurrent region to enable fusion within streams
@ -3330,8 +3314,6 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
cgraph->nodes[start_pos + i] = const_cast<ggml_tensor *>(event.original_order[i]); cgraph->nodes[start_pos + i] = const_cast<ggml_tensor *>(event.original_order[i]);
} }
} }
} else {
stream_ctx.concurrent_events.clear();
} }
for (int i = 0; i < cgraph->n_nodes; i++) { for (int i = 0; i < cgraph->n_nodes; i++) {
@ -3710,7 +3692,7 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0)); CUDA_CHECK(cudaGraphInstantiate(&cuda_ctx->cuda_graph->instance, cuda_ctx->cuda_graph->graph, NULL, NULL, 0));
} }
if (cuda_graph_update_required) { // Update graph executable if (cuda_graph_update_required) { // Update graph executable
ggml_cuda_graph_update_executable(cuda_ctx); update_cuda_graph_executable(cuda_ctx);
} }
// Launch graph // Launch graph
CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream())); CUDA_CHECK(cudaGraphLaunch(cuda_ctx->cuda_graph->instance, cuda_ctx->stream()));
@ -3720,45 +3702,60 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
} }
} }
static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx) { static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context;
ggml_cuda_set_device(cuda_ctx->device);
#ifdef USE_CUDA_GRAPH #ifdef USE_CUDA_GRAPH
static const bool disable_cuda_graphs_due_to_env = (getenv("GGML_CUDA_DISABLE_GRAPHS") != nullptr);
// Objects required for CUDA Graph
if (cuda_ctx->cuda_graph == nullptr) { if (cuda_ctx->cuda_graph == nullptr) {
cuda_ctx->cuda_graph.reset(new ggml_cuda_graph()); cuda_ctx->cuda_graph.reset(new ggml_cuda_graph());
} }
bool use_cuda_graph = true;
bool cuda_graph_update_required = false;
if (cuda_ctx->cuda_graph->graph == nullptr) { if (cuda_ctx->cuda_graph->graph == nullptr) {
if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) { if (ggml_cuda_info().devices[cuda_ctx->device].cc < GGML_CUDA_CC_AMPERE) {
cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true; cuda_ctx->cuda_graph->disable_due_to_gpu_arch = true;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__); GGML_LOG_DEBUG("%s: disabling CUDA graphs due to GPU architecture\n", __func__);
#endif
} }
} }
return cuda_ctx->cuda_graph->is_enabled(); // Disable CUDA graphs in presence of env var, old GPU, use-case which is changing too rapidly,
#else // or previous graph capture failure.
return false; // Also disable for multi-gpu for now. TO DO investigate
#endif // USE_CUDA_GRAPH if (disable_cuda_graphs_due_to_env
} || cuda_ctx->cuda_graph->disable_due_to_gpu_arch
|| cuda_ctx->cuda_graph->disable_due_to_too_many_updates
static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { || cuda_ctx->cuda_graph->disable_due_to_failed_graph_capture) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; use_cuda_graph = false;
}
ggml_cuda_set_device(cuda_ctx->device);
if (use_cuda_graph) {
bool use_cuda_graph = false; cuda_graph_update_required = is_cuda_graph_update_required(cuda_ctx, cgraph);
bool cuda_graph_update_required = false;
use_cuda_graph = check_node_graph_compatibility(cgraph, use_cuda_graph);
#ifdef USE_CUDA_GRAPH
use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx); // Disable CUDA graphs (from the next token) if the use-case is demanding too many consecutive graph updates.
if (use_cuda_graph && cuda_graph_update_required) {
if (cuda_ctx->cuda_graph->is_enabled()) { cuda_ctx->cuda_graph->number_consecutive_updates++;
cuda_graph_update_required = ggml_cuda_graph_update_required(cuda_ctx, cgraph); } else {
use_cuda_graph = ggml_cuda_graph_check_compability(cgraph); cuda_ctx->cuda_graph->number_consecutive_updates = 0;
}
cuda_ctx->cuda_graph->record_update(use_cuda_graph, cuda_graph_update_required);
if (cuda_ctx->cuda_graph->number_consecutive_updates >= 4) {
cuda_ctx->cuda_graph->disable_due_to_too_many_updates = true;
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: disabling CUDA graphs due to too many consecutive updates\n", __func__);
#endif
}
} }
#endif // USE_CUDA_GRAPH
if (use_cuda_graph && cuda_graph_update_required) { if (use_cuda_graph && cuda_graph_update_required) {
// Start CUDA graph capture // Start CUDA graph capture
@ -3770,7 +3767,14 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend,
CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed)); CUDA_CHECK(cudaStreamBeginCapture(cuda_ctx->stream(), cudaStreamCaptureModeRelaxed));
} }
ggml_cuda_graph_evaluate_and_capture(cuda_ctx, cgraph, use_cuda_graph, cuda_graph_update_required); #else
bool use_cuda_graph = false;
bool cuda_graph_update_required = false;
#endif // USE_CUDA_GRAPH
bool graph_evaluated_or_captured = false;
evaluate_and_capture_cuda_graph(cuda_ctx, cgraph, graph_evaluated_or_captured, use_cuda_graph, cuda_graph_update_required);
return GGML_STATUS_SUCCESS; return GGML_STATUS_SUCCESS;
} }
@ -3803,10 +3807,8 @@ static void ggml_backend_cuda_event_wait(ggml_backend_t backend, ggml_backend_ev
static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) { static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
const bool use_cuda_graph = ggml_cuda_graph_set_enabled(cuda_ctx);
static bool enable_graph_optimization = [] { static bool enable_graph_optimization = [] {
const char * env = getenv("GGML_CUDA_GRAPH_OPT"); const char * env = getenv("GGML_CUDA_GRAPH_OPT");
return env != nullptr && atoi(env) == 1; return env != nullptr && atoi(env) == 1;
}(); }();
@ -3814,13 +3816,12 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
return; return;
} }
GGML_ASSERT(ggml_backend_cuda_get_device_count() == 1 && "compute graph optimization is only supported on single GPU in the CUDA backend");
GGML_LOG_DEBUG("Optimizing CUDA graph %p with %d nodes\n", cgraph->nodes, cgraph->n_nodes);
ggml_cuda_stream_context & stream_context = cuda_ctx->stream_context(); ggml_cuda_stream_context & stream_context = cuda_ctx->stream_context();
stream_context.reset(); stream_context.reset();
if (!use_cuda_graph || ggml_backend_cuda_get_device_count() != 1) {
return;
}
// number of out-degrees for a particular node // number of out-degrees for a particular node
std::unordered_map<const ggml_tensor *, int> fan_out; std::unordered_map<const ggml_tensor *, int> fan_out;
// reverse mapping of node to index in the cgraph // reverse mapping of node to index in the cgraph
@ -3881,12 +3882,6 @@ static void ggml_backend_cuda_graph_optimize(ggml_backend_t backend, ggml_cgraph
if (count >= min_fan_out && count <= max_fan_out) { if (count >= min_fan_out && count <= max_fan_out) {
const int root_node_idx = node_indices[root_node]; const int root_node_idx = node_indices[root_node];
// only optimize for attn_norm
// TODO: make this more generic
if (!strstr(root_node->name, "attn_norm")) {
continue;
}
bool is_part_of_event = false; bool is_part_of_event = false;
for (const auto & [start, end] : concurrent_node_ranges) { for (const auto & [start, end] : concurrent_node_ranges) {
if (root_node_idx >= start && root_node_idx <= end) { if (root_node_idx >= start && root_node_idx <= end) {
@ -4615,7 +4610,6 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return true; return true;
case GGML_OP_SUM: case GGML_OP_SUM:
return ggml_is_contiguous_rows(op->src[0]); return ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_TOP_K:
case GGML_OP_ARGSORT: case GGML_OP_ARGSORT:
#ifndef GGML_CUDA_USE_CUB #ifndef GGML_CUDA_USE_CUB
return op->src[0]->ne[0] <= 1024; return op->src[0]->ne[0] <= 1024;

View File

@ -34,11 +34,13 @@ void ggml_cuda_op_mean(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
// CUDA_GRAPHS_DISABLED // CUDA_GRAPHS_DISABLED
((ncols > 65536) && ((ncols > 65536) &&
((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) || ((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
ctx.cuda_graph->is_enabled())) || ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates ||
ctx.cuda_graph->disable_due_to_failed_graph_capture)) ||
// CUDA_GRAPHS ENABLED // CUDA_GRAPHS ENABLED
((ncols > 32768) && ((ncols > 32768) &&
!((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) || !((ctx.cuda_graph->instance == nullptr) && (iscapturing == cudaStreamCaptureStatusNone) ||
ctx.cuda_graph->is_enabled()))) { ctx.cuda_graph->disable_due_to_gpu_arch || ctx.cuda_graph->disable_due_to_too_many_updates ||
ctx.cuda_graph->disable_due_to_failed_graph_capture))) {
#else #else
(ncols > 65536)) { (ncols > 65536)) {
#endif // USE_CUDA_GRAPH #endif // USE_CUDA_GRAPH

View File

@ -333,28 +333,6 @@ bool ggml_cuda_should_use_mmq(enum ggml_type type, int cc, int64_t ne11, int64_t
} }
if (amd_wmma_available(cc)) { if (amd_wmma_available(cc)) {
// RDNA 4 is consistently worse on rocblas
// https://github.com/ggml-org/llama.cpp/pull/18537#issuecomment-3706422301
if (GGML_CUDA_CC_IS_RDNA3(cc)) {
// High expert counts almost always better on MMQ
// due to a large amount of graph splits
// https://github.com/ggml-org/llama.cpp/pull/18202
if (n_experts >= 64) {
return true;
}
switch (type) {
// These quants are really bad on MMQ
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q6_K:
// These quants are usually worse but not always
case GGML_TYPE_IQ2_XS:
case GGML_TYPE_IQ2_S:
return ne11 <= 128;
default:
return true;
}
}
return true; return true;
} }

View File

@ -1,14 +1,6 @@
#include "common.cuh" #include "common.cuh"
#include "ggml.h" #include "ggml.h"
#include "softmax.cuh" #include "softmax.cuh"
#ifdef GGML_USE_HIP
#include <hip/hip_cooperative_groups.h>
#else
#include <cooperative_groups.h>
#include <cooperative_groups/reduce.h>
#endif // GGML_USE_HIP
#include <cstdint> #include <cstdint>
#include <utility> #include <utility>
@ -168,156 +160,6 @@ static __global__ void soft_max_f32(
dst[col] = vals[col] * inv_sum; dst[col] = vals[col] * inv_sum;
} }
} }
// TODO: This is a common pattern used across kernels that could be moved to common.cuh + templated
static __device__ float two_stage_warp_reduce_max(float val) {
val = warp_reduce_max(val);
if (blockDim.x > WARP_SIZE) {
assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
__shared__ float local_vals[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
local_vals[warp_id] = val;
}
__syncthreads();
val = -INFINITY;
if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) {
val = local_vals[lane_id];
}
return warp_reduce_max(val);
} else {
return val;
}
}
static __device__ float two_stage_warp_reduce_sum(float val) {
val = warp_reduce_sum(val);
if (blockDim.x > WARP_SIZE) {
assert((blockDim.x <= 1024) && (blockDim.x % WARP_SIZE) == 0);
__shared__ float local_vals[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
if (lane_id == 0) {
local_vals[warp_id] = val;
}
__syncthreads();
val = 0.0f;
if (lane_id < (static_cast<int>(blockDim.x) / WARP_SIZE)) {
val = local_vals[lane_id];
}
return warp_reduce_sum(val);
} else {
return val;
}
}
// TODO: Template to allow keeping ncols in registers if they fit
static __device__ void soft_max_f32_parallelize_cols_single_row(const float * __restrict__ x,
float * __restrict__ dst,
float * __restrict__ tmp_maxs,
float * __restrict__ tmp_sums,
const soft_max_params p) {
namespace cg = cooperative_groups;
const cg::grid_group g = cg::this_grid();
const int tid = threadIdx.x;
const int col_start = blockIdx.x * blockDim.x + tid;
const int n_elem_per_thread = 4;
float local_vals[n_elem_per_thread] = { -INFINITY, -INFINITY, -INFINITY, -INFINITY };
float local_max = -INFINITY;
const int step_size = gridDim.x * blockDim.x;
// Compute thread-local max
for (int col = col_start; col < p.ncols;) {
#pragma unroll
for (int i = 0; i < n_elem_per_thread; i++) {
const int idx = col + i * step_size;
local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY;
}
#pragma unroll
for (int i = 0; i < n_elem_per_thread; i++) {
local_max = fmaxf(local_max, local_vals[i]);
}
col += step_size * n_elem_per_thread;
}
// Compute CTA-level max
local_max = two_stage_warp_reduce_max(local_max);
// Store CTA-level max to GMEM
if (tid == 0) {
tmp_maxs[blockIdx.x] = local_max;
}
g.sync();
// Compute compute global max from CTA-level maxs
assert(gridDim.x < blockDim.x); // currently we only support this case
if (tid < gridDim.x) {
local_max = tmp_maxs[tid];
} else {
local_max = -INFINITY;
}
local_max = two_stage_warp_reduce_max(local_max);
// Compute softmax dividends, accumulate divisor
float tmp_expf = 0.0f;
for (int col = col_start; col < p.ncols;) {
#pragma unroll
for (int i = 0; i < n_elem_per_thread; i++) {
const int idx = col + i * step_size;
local_vals[i] = idx < p.ncols ? x[idx] : -INFINITY;
}
#pragma unroll
for (int i = 0; i < n_elem_per_thread; i++) {
const int idx = col + i * step_size;
if (idx < p.ncols) {
const float tmp = expf(local_vals[i] - local_max);
tmp_expf += tmp;
dst[idx] = tmp;
}
}
col += step_size * n_elem_per_thread;
}
// Reduce divisor within CTA
tmp_expf = two_stage_warp_reduce_sum(tmp_expf);
// Store CTA-level sum to GMEM
if (tid == 0) {
tmp_sums[blockIdx.x] = tmp_expf;
}
g.sync();
// Compute global sum from CTA-level sums
if (tid < gridDim.x) {
tmp_expf = tmp_sums[tid];
} else {
tmp_expf = 0.0f;
}
tmp_expf = two_stage_warp_reduce_sum(tmp_expf);
// Divide dividend by global sum + store data
for (int col = col_start; col < p.ncols;) {
#pragma unroll
for (int i = 0; i < n_elem_per_thread; i++) {
const int idx = col + i * step_size;
local_vals[i] = idx < p.ncols ? dst[idx] : -INFINITY;
}
#pragma unroll
for (int i = 0; i < n_elem_per_thread; i++) {
const int idx = col + i * step_size;
if (idx < p.ncols) {
dst[idx] = local_vals[i] / tmp_expf;
}
}
col += step_size * n_elem_per_thread;
}
}
#ifdef __clang__ #ifdef __clang__
#pragma clang diagnostic pop #pragma clang diagnostic pop
#endif // __clang__ #endif // __clang__
@ -374,31 +216,9 @@ static void launch_soft_max_kernels(const float * x, const T * mask, const float
soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, sinks, dst, p); soft_max_f32<true, 0, 0><<<block_nums, block_dims, nbytes_shared, stream>>>(x, mask, sinks, dst, p);
} }
__launch_bounds__(8*WARP_SIZE, 1) static __global__ void soft_max_f32_parallelize_cols(const float * __restrict__ x,
float * __restrict__ dst,
float * __restrict__ tmp_maxs,
float * __restrict__ tmp_sums,
const soft_max_params p)
// We loop over all instead of parallelizing across gridDim.y as cooperative groups
// currently only support synchronizing the complete grid if not launched as a cluster group
// (which requires CC > 9.0)
// https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#grid-synchronization
// https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/device-callable-apis.html#class-cluster-group
{
for (int rowx = 0; rowx < p.ne01 * p.ne02 * p.ne03; rowx++) {
soft_max_f32_parallelize_cols_single_row(x + int64_t(rowx) * p.ncols, dst + int64_t(rowx) * p.ncols, tmp_maxs,
tmp_sums, p);
}
}
template <typename T> template<typename T>
static void soft_max_f32_cuda(const float * x, static void soft_max_f32_cuda(const float * x, const T * mask, const float * sinks, float * dst, const soft_max_params & params, cudaStream_t stream) {
const T * mask,
const float * sinks,
float * dst,
const soft_max_params & params,
cudaStream_t stream,
[[maybe_unused]] ggml_backend_cuda_context & ctx) {
int nth = WARP_SIZE; int nth = WARP_SIZE;
const int64_t ncols_x = params.ncols; const int64_t ncols_x = params.ncols;
@ -416,25 +236,8 @@ static void soft_max_f32_cuda(const float * x,
if (nbytes_shared <= smpbo) { if (nbytes_shared <= smpbo) {
launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared); launch_soft_max_kernels<32, 64, 128, 256, 512, 1024, 2048, 4096>(x, mask, sinks, dst, params, stream, block_dims, block_nums, nbytes_shared);
} else { } else {
// Parallelize across SMs for top-p/dist-sampling const size_t nbytes_shared_low = WARP_SIZE*sizeof(float);
// The heuristic for parallelizing rows across SMs vs parallelizing single row & looping over all rows was done on the basis of a B6000 GPU and soft_max_f32<false, 0, 0><<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, sinks, dst, params);
// Can be adapted further for lower-SM-count GPUs, though keeping data in registers should be implemented first as that is the optimal solution.
if (ggml_cuda_info().devices[id].supports_cooperative_launch &&
ncols_x / (params.ne01 * params.ne02 * params.ne03) > 8192 && mask == nullptr && sinks == nullptr &&
params.scale == 1.0f && params.max_bias == 0.0f) {
ggml_cuda_pool_alloc<float> tmp_maxs_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float));
ggml_cuda_pool_alloc<float> tmp_sums_alloc(ctx.pool(), ggml_cuda_info().devices[id].nsm * sizeof(float));
void * kernel_args[] = { (void *) &x, (void *) &dst, (void *) &tmp_maxs_alloc.ptr,
(void *) &tmp_sums_alloc.ptr, (void *) const_cast<soft_max_params *>(&params) };
CUDA_CHECK(cudaLaunchCooperativeKernel((void *) soft_max_f32_parallelize_cols,
dim3(ggml_cuda_info().devices[id].nsm, 1, 1),
dim3(WARP_SIZE * 8, 1, 1), kernel_args, 0, stream));
} else {
const size_t nbytes_shared_low = WARP_SIZE * sizeof(float);
soft_max_f32<false, 0, 0>
<<<block_nums, block_dims, nbytes_shared_low, stream>>>(x, mask, sinks, dst, params);
}
} }
} }
@ -512,9 +315,9 @@ void ggml_cuda_op_soft_max(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
params.m1 = m1; params.m1 = m1;
if (use_f16) { if (use_f16) {
soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream, ctx); soft_max_f32_cuda(src0_d, (const half *) src1_d, (const float *) src2_d, dst_d, params, stream);
} else { } else {
soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream, ctx); soft_max_f32_cuda(src0_d, (const float *) src1_d, (const float *) src2_d, dst_d, params, stream);
} }
} }

View File

@ -114,7 +114,7 @@ __global__ void __launch_bounds__(splitD, 1)
#endif // __clang__ #endif // __clang__
// assumes as many threads as d_state // assumes as many threads as d_state
template <int c_factor, int d_state> template <int splitH, int d_state>
__global__ void __launch_bounds__(d_state, 1) __global__ void __launch_bounds__(d_state, 1)
ssm_scan_f32_group( ssm_scan_f32_group(
const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2, const float * __restrict__ src0, const float * __restrict__ src1, const float * __restrict__ src2,
@ -125,25 +125,20 @@ __global__ void __launch_bounds__(d_state, 1)
const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3, const int src4_nb2, const int src4_nb3, const int src5_nb2, const int src5_nb3,
const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) { const int64_t s_off, const int64_t n_head, const int64_t d_head, const int64_t n_group, const int64_t n_tok) {
const int warp = threadIdx.x / WARP_SIZE; const int head_idx = (blockIdx.x * splitH) / d_head;
const int lane = threadIdx.x % WARP_SIZE; const int head_off = ((blockIdx.x * splitH) % d_head) * sizeof(float);
const int warp_idx = blockIdx.x * c_factor + warp; const int seq_idx = blockIdx.y;
const int head_idx = warp_idx / d_head;
const int head_off = (warp_idx % d_head) * sizeof(float);
const int seq_idx = blockIdx.y;
const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float); const int group_off = (head_idx / (n_head / n_group)) * d_state * sizeof(float);
// TODO: refactor strides to be in elements/floats instead of bytes to be cleaner and consistent with the rest of the codebase const float * s0_block = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
const float * s0_warp = (const float *) ((const char *) src0 + src6[seq_idx] * src0_nb3 + head_idx * src0_nb2 + head_off * d_state); const float * x_block = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + blockIdx.x * splitH * sizeof(float));
const float * x_warp = (const float *) ((const char *) src1 + (seq_idx * src1_nb3) + (warp_idx * sizeof(float))); const float * dt_block = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float));
const float * dt_warp = (const float *) ((const char *) src2 + (seq_idx * src2_nb2) + head_idx * sizeof(float)); const float * A_block = (const float *) ((const char *) src3 + head_idx * src3_nb1);
const float * A_warp = (const float *) ((const char *) src3 + head_idx * src3_nb1); const float * B_block = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off));
const float * B_warp = (const float *) ((const char *) src4 + (seq_idx * src4_nb3) + (group_off)); const float * C_block = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off));
const float * C_warp = (const float *) ((const char *) src5 + (seq_idx * src5_nb3) + (group_off)); float * y_block = dst + (seq_idx * n_tok * n_head * d_head) + blockIdx.x * splitH;
float * y_warp = dst + (seq_idx * n_tok * n_head * d_head) + warp_idx; float * s_block = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
float * s_warp = (float *) ((char *) dst + s_off + seq_idx * src0_nb3 + head_idx * src0_nb2 + head_off * d_state);
// strides across n_seq_tokens // strides across n_seq_tokens
const int stride_x = src1_nb2 / sizeof(float); const int stride_x = src1_nb2 / sizeof(float);
@ -152,42 +147,80 @@ __global__ void __launch_bounds__(d_state, 1)
const int stride_C = src5_nb2 / sizeof(float); const int stride_C = src5_nb2 / sizeof(float);
const int stride_y = n_head * d_head; const int stride_y = n_head * d_head;
float state[c_factor]; float state[splitH];
float state_sum = 0.0f; // for the parallel accumulation
__shared__ float stateC[splitH * d_state];
#pragma unroll #pragma unroll
for (int j = 0; j < c_factor; j++) { for (int j = 0; j < splitH; j++) {
state[j] = s0_warp[WARP_SIZE * j + lane]; state[j] = s0_block[j * d_state + threadIdx.x];
} }
for (int64_t i = 0; i < n_tok; i++) { for (int64_t i = 0; i < n_tok; i++) {
// NOTE: dt_soft_plus, dA and x_dt have the same value for a warp here. // TODO: only calculate dA and dt_soft_plus once per head instead of every splitH head elements
// Recalculation is intentional; sharing via shuffles/smem proved slower due to sync overhead. // TODO: only calculate B and C once per head group
const float dt_soft_plus = (dt_warp[i * stride_dt] <= 20.0f ? log1pf(expf(dt_warp[i * stride_dt])) : dt_warp[i * stride_dt]); // NOTE: dt_soft_plus, dA and x_dt have the same value across threads here.
float dt_soft_plus = dt_block[i * stride_dt];
if (dt_soft_plus <= 20.0f) {
dt_soft_plus = log1pf(expf(dt_soft_plus));
}
const float dA = expf(dt_soft_plus * A_block[0]);
const float B = B_block[i * stride_B + threadIdx.x];
const float C = C_block[i * stride_C + threadIdx.x];
state_sum = 0.0f; // across d_head
const float dA = expf(dt_soft_plus * A_warp[0]);
const float x_dt = x_warp[i * stride_x] * dt_soft_plus;
#pragma unroll #pragma unroll
for (int j = 0; j < c_factor; j++) { for (int j = 0; j < splitH; j++) {
const float B_val = B_warp[i * stride_B + WARP_SIZE * j + lane]; const float x_dt = x_block[i * stride_x + j] * dt_soft_plus;
const float C_val = C_warp[i * stride_C + WARP_SIZE * j + lane];
state[j] = (state[j] * dA) + (B_val * x_dt); state[j] = (state[j] * dA) + (B * x_dt);
state_sum += state[j] * C_val;
stateC[j * d_state + threadIdx.x] = state[j] * C;
} }
// parallel accumulation for output __syncthreads();
state_sum = warp_reduce_sum(state_sum);
if (lane == 0) { // parallel accumulation for stateC
y_warp[i * stride_y] = state_sum; // TODO: simplify
{
static_assert((d_state & -d_state) == d_state, "the state size has to be a power of 2");
static_assert((splitH & -splitH) == splitH, "splitH has to be a power of 2");
// reduce until w matches the warp size
// TODO: does this work even when the physical warp size is 64?
#pragma unroll
for (int w = d_state; w > WARP_SIZE; w >>= 1) {
// (assuming there are d_state threads)
#pragma unroll
for (int j = 0; j < ((w >> 1) * splitH + d_state - 1) / d_state; j++) {
// TODO: check for bank conflicts
const int k = (threadIdx.x % (w >> 1)) + (d_state * (threadIdx.x / (w >> 1))) + j * d_state * (d_state / (w >> 1));
stateC[k] += stateC[k + (w >> 1)];
}
__syncthreads();
}
static_assert(splitH >= d_state / WARP_SIZE);
#pragma unroll
for (int j = 0; j < splitH / (d_state / WARP_SIZE); j++) {
float y = stateC[(threadIdx.x % WARP_SIZE) + d_state * (threadIdx.x / WARP_SIZE) + j * d_state * (d_state / WARP_SIZE)];
y = warp_reduce_sum(y);
// store the above accumulations
if (threadIdx.x % WARP_SIZE == 0) {
const int k = threadIdx.x / WARP_SIZE + j * (d_state / WARP_SIZE);
y_block[i * stride_y + k] = y;
}
}
} }
} }
// write back the state // write back the state
#pragma unroll #pragma unroll
for (int j = 0; j < c_factor; j++) { for (int j = 0; j < splitH; j++) {
s_warp[WARP_SIZE * j + lane] = state[j]; s_block[j * d_state + threadIdx.x] = state[j];
} }
} }
@ -198,24 +231,27 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim, const int src5_nb3, const int64_t s_off, const int64_t d_state, const int64_t head_dim,
const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq, const int64_t n_head, const int64_t n_group, const int64_t n_tok, const int64_t n_seq,
cudaStream_t stream) { cudaStream_t stream) {
const int threads = 128;
// NOTE: if you change conditions here, be sure to update the corresponding supports_op condition! // NOTE: if you change conditions here, be sure to update the corresponding supports_op condition!
if (src3_nb1 == sizeof(float)) { if (src3_nb1 == sizeof(float)) {
// Mamba-2 // Mamba-2
if (d_state == 128) { if (d_state == 128) {
constexpr int threads = 128; GGML_ASSERT(d_state % threads == 0);
constexpr int num_warps = threads/WARP_SIZE; // NOTE: can be any power of two between 4 and 64
const int splitH = 16;
const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1); GGML_ASSERT(head_dim % splitH == 0);
ssm_scan_f32_group<128/WARP_SIZE, 128><<<blocks, threads, 0, stream>>>( const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
ssm_scan_f32_group<16, 128><<<blocks, threads, 0, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst, src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok); src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
} else if (d_state == 256) { // Falcon-H1 } else if (d_state == 256) { // Falcon-H1
constexpr int threads = 256; const int threads = 256;
constexpr int num_warps = threads/WARP_SIZE; // NOTE: can be any power of two between 8 and 64
const int splitH = 16;
const dim3 blocks((n_head * head_dim + (num_warps - 1)) / num_warps, n_seq, 1); GGML_ASSERT(head_dim % splitH == 0);
ssm_scan_f32_group<256/WARP_SIZE, 256><<<blocks, threads, 0, stream>>>( const dim3 blocks((n_head * head_dim + (splitH - 1)) / splitH, n_seq, 1);
ssm_scan_f32_group<16, 256><<<blocks, threads, 0, stream>>>(
src0, src1, src2, src3, src4, src5, src6, dst, src0, src1, src2, src3, src4, src5, src6, dst,
src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1, src0_nb2, src0_nb3, src1_nb2, src1_nb3, src2_nb1, src2_nb2, src3_nb1,
src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok); src4_nb2, src4_nb3, src5_nb2, src5_nb3, s_off, n_head, head_dim, n_group, n_tok);
@ -224,7 +260,6 @@ static void ssm_scan_f32_cuda(const float * src0, const float * src1, const floa
} }
} else { } else {
// Mamba-1 // Mamba-1
constexpr int threads = 128;
GGML_ASSERT(n_head % threads == 0); GGML_ASSERT(n_head % threads == 0);
GGML_ASSERT(head_dim == 1); GGML_ASSERT(head_dim == 1);
GGML_ASSERT(n_group == 1); GGML_ASSERT(n_group == 1);

View File

@ -1,96 +0,0 @@
#include "argsort.cuh"
#include "top-k.cuh"
#ifdef GGML_CUDA_USE_CUB
# include <cub/cub.cuh>
# if (CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2)
# include <cuda/iterator>
# define CUB_TOP_K_AVAILABLE
using namespace cub;
# endif // CCCL_MAJOR_VERSION >= 3 && CCCL_MINOR_VERSION >= 2
#endif // GGML_CUDA_USE_CUB
#ifdef CUB_TOP_K_AVAILABLE
static void top_k_cub(ggml_cuda_pool & pool,
const float * src,
int * dst,
const int ncols,
const int k,
cudaStream_t stream) {
auto requirements = cuda::execution::require(cuda::execution::determinism::not_guaranteed,
cuda::execution::output_ordering::unsorted);
auto stream_env = cuda::stream_ref{ stream };
auto env = cuda::std::execution::env{ stream_env, requirements };
auto indexes_in = cuda::make_counting_iterator(0);
size_t temp_storage_bytes = 0;
DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst, ncols, k,
env);
ggml_cuda_pool_alloc<uint8_t> temp_storage_alloc(pool, temp_storage_bytes);
void * d_temp_storage = temp_storage_alloc.get();
DeviceTopK::MaxPairs(d_temp_storage, temp_storage_bytes, src, cuda::discard_iterator(), indexes_in, dst,
ncols, k, env);
}
#elif defined(GGML_CUDA_USE_CUB) // CUB_TOP_K_AVAILABLE
static int next_power_of_2(int x) {
int n = 1;
while (n < x) {
n *= 2;
}
return n;
}
#endif // CUB_TOP_K_AVAILABLE
void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const float * src0_d = (const float *) src0->data;
int * dst_d = (int *) dst->data;
cudaStream_t stream = ctx.stream();
// are these asserts truly necessary?
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_I32);
GGML_ASSERT(ggml_is_contiguous(src0));
const int64_t ncols = src0->ne[0];
const int64_t nrows = ggml_nrows(src0);
const int64_t k = dst->ne[0];
ggml_cuda_pool & pool = ctx.pool();
#ifdef CUB_TOP_K_AVAILABLE
// TODO: Switch to `DeviceSegmentedTopK` for multi-row TopK once implemented
// https://github.com/NVIDIA/cccl/issues/6391
// TODO: investigate if there exists a point where parallelized argsort is faster than sequential top-k
for (int i = 0; i < nrows; i++) {
top_k_cub(pool, src0_d + i * ncols, dst_d + i * k, ncols, k, stream);
}
#elif defined(GGML_CUDA_USE_CUB) // CUB_TOP_K_AVAILABLE
// Fall back to argsort + copy
const int ncols_pad = next_power_of_2(ncols);
const size_t shared_mem = ncols_pad * sizeof(int);
const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb;
ggml_cuda_pool_alloc<int> temp_dst_alloc(pool, ncols * nrows);
int * tmp_dst = temp_dst_alloc.get();
if (shared_mem > max_shared_mem || ncols > 1024) {
argsort_f32_i32_cuda_cub(pool, src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
} else {
argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
}
CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows,
cudaMemcpyDeviceToDevice, stream));
#else // GGML_CUDA_USE_CUB
ggml_cuda_pool_alloc<int> temp_dst_alloc(pool, ncols * nrows);
int * tmp_dst = temp_dst_alloc.get();
argsort_f32_i32_cuda_bitonic(src0_d, tmp_dst, ncols, nrows, GGML_SORT_ORDER_DESC, stream);
CUDA_CHECK(cudaMemcpy2DAsync(dst_d, k * sizeof(int), tmp_dst, ncols * sizeof(int), k * sizeof(int), nrows,
cudaMemcpyDeviceToDevice, stream));
#endif
}

View File

@ -1,3 +0,0 @@
#include "common.cuh"
void ggml_cuda_op_top_k(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -45,11 +45,9 @@
#define cublasSgemm hipblasSgemm #define cublasSgemm hipblasSgemm
#define cublasStatus_t hipblasStatus_t #define cublasStatus_t hipblasStatus_t
#define cublasOperation_t hipblasOperation_t #define cublasOperation_t hipblasOperation_t
#define cudaDevAttrCooperativeLaunch hipDeviceAttributeCooperativeLaunch
#define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer #define cudaDeviceCanAccessPeer hipDeviceCanAccessPeer
#define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess
#define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess
#define cudaDeviceGetAttribute hipDeviceGetAttribute
#define cudaDeviceProp hipDeviceProp_t #define cudaDeviceProp hipDeviceProp_t
#define cudaDeviceSynchronize hipDeviceSynchronize #define cudaDeviceSynchronize hipDeviceSynchronize
#define cudaError_t hipError_t #define cudaError_t hipError_t
@ -72,7 +70,6 @@
#define cudaHostRegisterPortable hipHostRegisterPortable #define cudaHostRegisterPortable hipHostRegisterPortable
#define cudaHostRegisterReadOnly hipHostRegisterReadOnly #define cudaHostRegisterReadOnly hipHostRegisterReadOnly
#define cudaHostUnregister hipHostUnregister #define cudaHostUnregister hipHostUnregister
#define cudaLaunchCooperativeKernel hipLaunchCooperativeKernel
#define cudaLaunchHostFunc hipLaunchHostFunc #define cudaLaunchHostFunc hipLaunchHostFunc
#define cudaMalloc hipMalloc #define cudaMalloc hipMalloc
#define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault) #define cudaMallocHost(ptr, size) hipHostMalloc(ptr, size, hipHostMallocDefault)

View File

@ -61,7 +61,6 @@
#define cudaHostRegisterPortable musaHostRegisterPortable #define cudaHostRegisterPortable musaHostRegisterPortable
#define cudaHostRegisterReadOnly musaHostRegisterReadOnly #define cudaHostRegisterReadOnly musaHostRegisterReadOnly
#define cudaHostUnregister musaHostUnregister #define cudaHostUnregister musaHostUnregister
#define cudaLaunchCooperativeKernel musaLaunchCooperativeKernel
#define cudaLaunchHostFunc musaLaunchHostFunc #define cudaLaunchHostFunc musaLaunchHostFunc
#define cudaMalloc musaMalloc #define cudaMalloc musaMalloc
#define cudaMallocHost musaMallocHost #define cudaMallocHost musaMallocHost

View File

@ -1773,37 +1773,6 @@ static bool hex_supported_dims2(const struct ggml_tensor * x, const struct ggml_
return true; return true;
} }
static bool ggml_hexagon_supported_flash_attn_ext(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
const struct ggml_tensor * src0 = op->src[0];
const struct ggml_tensor * src1 = op->src[1];
const struct ggml_tensor * src2 = op->src[2];
const struct ggml_tensor * src3 = op->src[3];
const struct ggml_tensor * src4 = op->src[4];
const struct ggml_tensor * dst = op;
// Check for F16 support only as requested
if ((src0->type != GGML_TYPE_F16 && src0->type != GGML_TYPE_F32) || src1->type != GGML_TYPE_F16 || src2->type != GGML_TYPE_F16) {
return false;
}
if (src3 && src3->type != GGML_TYPE_F16) { // mask
return false;
}
if (src4 && src4->type != GGML_TYPE_F32) { // sinks
return false;
}
// For now we support F32 or F16 output as htp backend often converts output on the fly if needed,
// but the op implementation writes to F16 or F32.
// Let's assume dst can be F32 or F16.
if (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) {
return false;
}
return opt_experimental;
}
static bool hex_supported_src0_type(ggml_type t) { static bool hex_supported_src0_type(ggml_type t) {
return t == GGML_TYPE_F32; return t == GGML_TYPE_F32;
} }
@ -1846,11 +1815,12 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s
const struct ggml_tensor * src0 = dst->src[0]; const struct ggml_tensor * src0 = dst->src[0];
const struct ggml_tensor * src1 = dst->src[1]; const struct ggml_tensor * src1 = dst->src[1];
if (dst->type != GGML_TYPE_F32) { if (src1->type != GGML_TYPE_F32 || dst->type != GGML_TYPE_F32) {
return false; return false;
} }
if (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) { // TODO: add support for non-cont tensors
if (!ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
return false; return false;
} }
@ -1866,6 +1836,7 @@ static bool ggml_hexagon_supported_mul_mat(const struct ggml_hexagon_session * s
return false; // typically the lm-head which would be too large for VTCM return false; // typically the lm-head which would be too large for VTCM
} }
// if ((src0->ne[2] != src1->ne[2] || src0->ne[3] != src1->ne[3])) return false;
if ((src1->ne[2] != 1 || src1->ne[3] != 1)) { if ((src1->ne[2] != 1 || src1->ne[3] != 1)) {
return false; return false;
} }
@ -1914,10 +1885,21 @@ static bool ggml_hexagon_supported_mul_mat_id(const struct ggml_hexagon_session
} }
break; break;
case GGML_TYPE_F16:
if (!opt_experimental) {
return false;
}
break;
default: default:
return false; return false;
} }
// TODO: add support for non-cont tensors
if (!ggml_is_contiguous(src1) || !ggml_is_contiguous(dst)) {
return false;
}
return true; return true;
} }
@ -2078,46 +2060,6 @@ static bool ggml_hexagon_supported_softmax(const struct ggml_hexagon_session * s
return true; return true;
} }
static bool ggml_hexagon_supported_set_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
const struct ggml_tensor * src0 = op->src[0]; // values
const struct ggml_tensor * src1 = op->src[1]; // indices
const struct ggml_tensor * dst = op;
if (src0->type != GGML_TYPE_F32) {
return false;
}
if (src1->type != GGML_TYPE_I32 && src1->type != GGML_TYPE_I64) {
return false;
}
if (dst->type != GGML_TYPE_F16) {
return false;
}
return true;
}
static bool ggml_hexagon_supported_get_rows(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
const struct ggml_tensor * src0 = op->src[0]; // values
const struct ggml_tensor * src1 = op->src[1]; // indices
const struct ggml_tensor * dst = op;
if (src0->type != GGML_TYPE_F32) {
return false;
}
if (src1->type != GGML_TYPE_I32 && src1->type != GGML_TYPE_I64) {
return false;
}
if (dst->type != GGML_TYPE_F32) {
return false;
}
return true;
}
static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) { static bool ggml_hexagon_supported_rope(const struct ggml_hexagon_session * sess, const struct ggml_tensor * op) {
const int32_t * op_params = &op->op_params[0]; const int32_t * op_params = &op->op_params[0];
@ -2212,11 +2154,6 @@ static size_t htp_req_buff_init(htp_tensor *h, dspqueue_buffer * d, const ggml_t
d->offset = (uint8_t *) t->data - buf->base; d->offset = (uint8_t *) t->data - buf->base;
d->size = ggml_nbytes(t); d->size = ggml_nbytes(t);
if (!d->size) {
// Some requests contain srcs where ggml_nbytes() returns 0 but the rest of the op is non-empty
d->size = 64;
}
switch (type) { switch (type) {
case DSPQBUF_TYPE_DSP_WRITE_CPU_READ: case DSPQBUF_TYPE_DSP_WRITE_CPU_READ:
// Flush CPU // Flush CPU
@ -2302,17 +2239,6 @@ static inline size_t init_binary_req(htp_general_req * req, dspqueue_buffer * bu
return n_bufs; return n_bufs;
} }
static inline size_t init_get_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
req->op = HTP_OP_GET_ROWS;
size_t n_bufs = 0;
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
return n_bufs;
}
template <bool _is_src0_constant> template <bool _is_src0_constant>
static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
switch (t->op) { switch (t->op) {
@ -2340,17 +2266,6 @@ static inline size_t init_binary_id_req(htp_general_req * req, dspqueue_buffer *
return n_bufs; return n_bufs;
} }
static inline size_t init_set_rows_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
req->op = HTP_OP_SET_ROWS;
size_t n_bufs = 0;
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
return n_bufs;
}
static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) { static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
memcpy(&req->op_params, &t->op_params, sizeof(t->op_params)); memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
@ -2362,11 +2277,6 @@ static inline size_t init_unary_req(htp_general_req * req, dspqueue_buffer * buf
supported = true; supported = true;
break; break;
case GGML_OP_SCALE:
req->op = HTP_OP_SCALE;
supported = true;
break;
case GGML_OP_UNARY: case GGML_OP_UNARY:
if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) { if (ggml_get_unary_op(t) == GGML_UNARY_OP_SILU) {
req->op = HTP_OP_UNARY_SILU; req->op = HTP_OP_UNARY_SILU;
@ -2421,21 +2331,6 @@ static inline size_t init_rope_req(htp_general_req * req, dspqueue_buffer * bufs
return n_bufs; return n_bufs;
} }
static inline size_t init_flash_attn_ext_req(htp_general_req * req, dspqueue_buffer * bufs, const ggml_tensor * t) {
memcpy(&req->op_params, &t->op_params, sizeof(t->op_params));
req->op = HTP_OP_FLASH_ATTN_EXT;
size_t n_bufs = 0;
n_bufs += htp_req_buff_init(&req->src0, &bufs[n_bufs], t->src[0], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
n_bufs += htp_req_buff_init(&req->src1, &bufs[n_bufs], t->src[1], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
n_bufs += htp_req_buff_init(&req->src2, &bufs[n_bufs], t->src[2], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
n_bufs += htp_req_buff_init(&req->src3, &bufs[n_bufs], t->src[3], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
n_bufs += htp_req_buff_init(&req->src4, &bufs[n_bufs], t->src[4], DSPQBUF_TYPE_CPU_WRITE_DSP_READ);
n_bufs += htp_req_buff_init(&req->dst, &bufs[n_bufs], t, DSPQBUF_TYPE_DSP_WRITE_CPU_READ);
return n_bufs;
}
static const char * ggml_backend_hexagon_name(ggml_backend_t backend) { static const char * ggml_backend_hexagon_name(ggml_backend_t backend) {
auto sess = static_cast<ggml_hexagon_session *>(backend->context); auto sess = static_cast<ggml_hexagon_session *>(backend->context);
return sess->name.c_str(); return sess->name.c_str();
@ -2522,7 +2417,6 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
ggml_hexagon_dispatch_op<init_binary_id_req<false>>(sess, node, flags); ggml_hexagon_dispatch_op<init_binary_id_req<false>>(sess, node, flags);
break; break;
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
case GGML_OP_SCALE:
ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags); ggml_hexagon_dispatch_op<init_unary_req>(sess, node, flags);
break; break;
case GGML_OP_UNARY: case GGML_OP_UNARY:
@ -2545,18 +2439,6 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
ggml_hexagon_dispatch_op<init_rope_req>(sess, node, flags); ggml_hexagon_dispatch_op<init_rope_req>(sess, node, flags);
break; break;
case GGML_OP_FLASH_ATTN_EXT:
ggml_hexagon_dispatch_op<init_flash_attn_ext_req>(sess, node, flags);
break;
case GGML_OP_SET_ROWS:
ggml_hexagon_dispatch_op<init_set_rows_req>(sess, node, flags);
break;
case GGML_OP_GET_ROWS:
ggml_hexagon_dispatch_op<init_get_rows_req>(sess, node, flags);
break;
default: default:
GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node)); GGML_ABORT("\nggml-hex: graph-compute %s is not supported\n", ggml_op_desc(node));
} }
@ -2896,7 +2778,6 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
break; break;
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
case GGML_OP_SCALE:
supp = ggml_hexagon_supported_unary(sess, op); supp = ggml_hexagon_supported_unary(sess, op);
break; break;
@ -2924,18 +2805,6 @@ static bool ggml_backend_hexagon_device_supports_op(ggml_backend_dev_t dev, cons
supp = ggml_hexagon_supported_rope(sess, op); supp = ggml_hexagon_supported_rope(sess, op);
break; break;
case GGML_OP_FLASH_ATTN_EXT:
supp = ggml_hexagon_supported_flash_attn_ext(sess, op);
break;
case GGML_OP_SET_ROWS:
supp = ggml_hexagon_supported_set_rows(sess, op);
break;
case GGML_OP_GET_ROWS:
supp = ggml_hexagon_supported_get_rows(sess, op);
break;
default: default:
break; break;
} }

View File

@ -28,9 +28,6 @@ add_library(${HTP_LIB} SHARED
softmax-ops.c softmax-ops.c
act-ops.c act-ops.c
rope-ops.c rope-ops.c
flash-attn-ops.c
set-rows-ops.c
get-rows-ops.c
) )
target_compile_definitions(${HTP_LIB} PRIVATE target_compile_definitions(${HTP_LIB} PRIVATE

View File

@ -85,16 +85,13 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
struct htp_spad * dst_spad, struct htp_spad * dst_spad,
uint32_t nth, uint32_t nth,
uint32_t ith, uint32_t ith,
uint32_t src0_nrows_per_thread, uint32_t src0_nrows_per_thread) {
dma_queue * dma_queue) {
htp_act_preamble3; htp_act_preamble3;
size_t src0_row_size = nb01; size_t src0_row_size = nb01;
size_t src1_row_size = nb11; size_t src1_row_size = nb11;
size_t dst_row_size = nb1; size_t dst_row_size = nb1;
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
const uint32_t src0_start_row = src0_nrows_per_thread * ith; const uint32_t src0_start_row = src0_nrows_per_thread * ith;
@ -108,6 +105,12 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
uint64_t t1, t2; uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count(); t1 = HAP_perf_get_qtimer_count();
int is_aligned = 1;
if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
is_aligned = 0;
FARF(HIGH, "swiglu-f32: unaligned addresses in elementwise op, possibly slower execution\n");
}
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
uint8_t * restrict data_dst = (uint8_t *) dst->data; uint8_t * restrict data_dst = (uint8_t *) dst->data;
@ -124,81 +127,37 @@ static void glu_swiglu_fp32_per_thread(const struct htp_tensor * src0,
data_src1 += swapped ? 0 : nc_in_bytes; data_src1 += swapped ? 0 : nc_in_bytes;
} }
const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN); uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size);
const size_t src1_row_size_aligned = htp_round_up(src1_row_size, VLEN); uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_row_size);
const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN); uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size);
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); const bool opt_path = ((1 == is_aligned) && !(nb01 & (VLEN - 1)));
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread); for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size));
const float * restrict src1 = (float *) (data_src1 + (ir * src1_row_size));
float * restrict dst = (float *) (data_dst + (ir * dst_row_size));
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 if (ir + 1 < src0_end_row) {
size_t src0_spad_half_size = src0_spad->size_per_thread / 2; htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size);
size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
if (BLOCK == 0) {
FARF(ERROR,
"swiglu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
src0_spad->size_per_thread, src0_row_size_aligned);
return;
}
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
// Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
dma_queue_push_vtcm_to_ddr(dma_queue,
dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
dst_row_size, dst_row_size_aligned, 0);
dma_queue_push_ddr_to_vtcm(dma_queue,
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)),
src0_row_size_aligned, src0_row_size, block_size);
dma_queue_push_ddr_to_vtcm(dma_queue,
dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)),
src1_row_size_aligned, src1_row_size, block_size);
}
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
float * src1_spad = (float *) dma_queue_pop(dma_queue).dst;
for (uint32_t ib = 0; ib < block_size; ib++) {
const float * src0_spad_ptr = src0_spad + ib * (src0_row_size_aligned / sizeof(float));
const float * src1_spad_ptr = src1_spad + ib * (src1_row_size_aligned / sizeof(float));
float * dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
//swiglu(x) = x1 * sigmoid(x0)
hvx_fast_sigmoid_f32((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, nc);
hvx_mul_mul_f32_opt((const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr,
(const uint8_t *) src1_spad_ptr, (uint8_t *) dst_spad_ptr, nc);
} }
dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size, if (opt_path) {
dst_row_size_aligned, block_size); hvx_fast_sigmoid_f32((const uint8_t *) src0, (uint8_t *) src0_spad_data, nc);
hvx_mul_mul_f32_opt((const uint8_t *) src0, (const uint8_t *) src0_spad_data, (const uint8_t *) src1,
(uint8_t *) dst, nc);
} else {
hvx_exp_f32((const uint8_t *) src0, src0_spad_data, nc, true);
hvx_add_scalar_f32(src0_spad_data, 1.0, src1_spad_data, nc);
hvx_inverse_f32(src1_spad_data, src0_spad_data, nc);
// prefetch N+2 loop iteration if any hvx_mul_f32((const uint8_t *) src0, src0_spad_data, dst_spad_data, nc);
const uint32_t pref_block = (ir + BLOCK * 2); hvx_mul_f32(dst_spad_data, (const uint8_t *) src1, (uint8_t *) dst, nc);
if (pref_block < src0_end_row) {
const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),
src0_row_size_aligned, src0_row_size, pref_block_size);
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)),
src1_row_size_aligned, src1_row_size, pref_block_size);
} }
} }
dma_queue_flush(dma_queue);
t2 = HAP_perf_get_qtimer_count(); t2 = HAP_perf_get_qtimer_count();
FARF(HIGH, "swiglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, FARF(HIGH, "swiglu-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path,
ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3, ne00, ne01, ne02, ne03, src0_start_row, src0_end_row, ne10, ne11, ne12, ne13, ne0, ne1, ne2, ne3,
(unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
} }
@ -212,16 +171,15 @@ static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
struct htp_spad * dst_spad, struct htp_spad * dst_spad,
uint32_t nth, uint32_t nth,
uint32_t ith, uint32_t ith,
uint32_t src0_nrows_per_thread, uint32_t src0_nrows_per_thread) {
dma_queue * dma_queue) {
htp_act_preamble3; htp_act_preamble3;
uint64_t t1, t2; uint64_t t1, t2;
t1 = HAP_perf_get_qtimer_count(); t1 = HAP_perf_get_qtimer_count();
size_t src0_row_size = nb01; const size_t src0_row_size = nb01;
size_t src1_row_size = nb11; const size_t src1_row_size = nb11;
size_t dst_row_size = nb1; const size_t dst_row_size = nb1;
const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows const uint32_t src0_nrows = ne01 * ne02 * ne03; // src0 rows
@ -233,110 +191,66 @@ static void glu_swiglu_oai_fp32_per_thread(const struct htp_tensor * src0,
return; return;
} }
if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
FARF(HIGH, "act-f32: unaligned addresses in activations op, possibly slower execution\n");
}
const uint8_t * restrict data_src0 = (const uint8_t *) src0->data; const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
const uint8_t * restrict data_src1 = (const uint8_t *) src1->data; const uint8_t * restrict data_src1 = (const uint8_t *) src1->data;
uint8_t * restrict data_dst = (uint8_t *) dst->data; uint8_t * restrict data_dst = (uint8_t *) dst->data;
const bool src1_valid = src1->ne[0]; bool src1_valid = src1->ne[0];
const int nc = (src1_valid) ? ne00 : ne00 / 2;
if (!src1_valid) { if (!src1_valid) {
const int32_t swapped = op_params[1]; data_src1 = data_src0;
data_src1 = data_src0;
src1_row_size = src0_row_size;
const size_t nc_in_bytes = nc * SIZEOF_FP32;
data_src0 += swapped ? nc_in_bytes : 0;
data_src1 += swapped ? 0 : nc_in_bytes;
} }
const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN); uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size);
const size_t src1_row_size_aligned = htp_round_up(src1_row_size, VLEN); uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_row_size);
const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN); uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size);
uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); const int32_t swapped = op_params[1];
uint8_t * restrict src1_spad_data = src1_spad->data + (ith * src1_spad->size_per_thread); const float alpha = ((const float *) (op_params))[2];
uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); const float limit = ((const float *) (op_params))[3];
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 const int nc = (src1_valid) ? ne00 : ne00 / 2;
size_t src0_spad_half_size = src0_spad->size_per_thread / 2;
size_t src1_spad_half_size = src1_spad->size_per_thread / 2;
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
if (BLOCK == 0) { const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size));
FARF(ERROR, const float * restrict src1 = (float *) (data_src1 + (ir * src1_row_size));
"swiglu-oai-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least " float * restrict dst = (float *) (data_dst + (ir * dst_row_size));
"%zu\n",
src0_spad->size_per_thread, src0_row_size_aligned);
return;
}
const float alpha = ((const float *) (op_params))[2];
const float limit = ((const float *) (op_params))[3];
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 if (ir + 1 < src0_end_row) {
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size);
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
// Dummy DMA transation for sequencing (interleaving dst,src,dst,...)
dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
dst_row_size, dst_row_size_aligned, 0);
dma_queue_push_ddr_to_vtcm(
dma_queue,
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)),
src0_row_size_aligned, src0_row_size, block_size);
dma_queue_push_ddr_to_vtcm(
dma_queue,
dma_make_ptr(src1_spad_data + (spad_idx * src1_spad_half_size), data_src1 + (ir * src1_row_size)),
src1_row_size_aligned, src1_row_size, block_size);
}
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
float * dst_spad = (float *) dma_queue_pop(dma_queue).src;
float * src0_spad = (float *) dma_queue_pop(dma_queue).dst;
float * src1_spad = (float *) dma_queue_pop(dma_queue).dst;
for (uint32_t ib = 0; ib < block_size; ib++) {
const float * src0_spad_ptr = src0_spad + ib * (src0_row_size_aligned / sizeof(float));
const float * src1_spad_ptr = src1_spad + ib * (src1_row_size_aligned / sizeof(float));
float * dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
// x (src0_spad_data) = std::min(src0_p[k], limit);
hvx_min_scalar_f32((const uint8_t *) src0_spad_ptr, limit, (uint8_t *) src0_spad_ptr, nc);
// y1 (src1_spad_data) = std::clamp(src1_p[k], -limit, limit);
hvx_clamp_scalar_f32((const uint8_t *) src1_spad_ptr, -limit, limit, (uint8_t *) src1_spad_ptr, nc);
// y (src1_spad_data) = y1 + 1.f
hvx_add_scalar_f32((const uint8_t *) src1_spad_ptr, 1.0, (uint8_t *) src1_spad_ptr, nc);
// x1 (dst_spad_data) = alpha * (x)
hvx_mul_scalar_f32((const uint8_t *) src0_spad_ptr, alpha, (uint8_t *) dst_spad_ptr, nc);
// x2 (dst_spad_data) = sigmoid(x1) = 1/(1+exp(-x1))
hvx_fast_sigmoid_f32((const uint8_t *) dst_spad_ptr, (uint8_t *) dst_spad_ptr, nc);
// out = x * sigmoid(alpha * x) * (y + 1.f)
hvx_mul_mul_f32_opt((const uint8_t *) src0_spad_ptr, (const uint8_t *) dst_spad_ptr,
(const uint8_t *) src1_spad_ptr, (uint8_t *) dst_spad_ptr, nc);
} }
dma_queue_push_vtcm_to_ddr(dma_queue, dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), dst_row_size, if (!src1) {
dst_row_size_aligned, block_size); src0 += swapped ? nc : 0;
src1 += swapped ? 0 : nc;
// prefetch N+2 loop iteration if any
const uint32_t pref_block = (ir + BLOCK * 2);
if (pref_block < src0_end_row) {
const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),
src0_row_size_aligned, src0_row_size, pref_block_size);
dma_queue_push_ddr_to_vtcm(dma_queue, dma_make_ptr(src1_spad, data_src1 + (pref_block * src1_row_size)),
src1_row_size_aligned, src1_row_size, pref_block_size);
} }
}
dma_queue_flush(dma_queue); // x (src0_spad_data) = std::min(src0_p[k], limit);
hvx_min_scalar_f32((const uint8_t *) src0, limit, src0_spad_data, nc);
// y1 (src1_spad_data) = std::clamp(src1_p[k], -limit, limit);
hvx_clamp_scalar_f32((const uint8_t *) src1, -limit, limit, src1_spad_data, nc);
// y (src1_spad_data) = y1 + 1.f
hvx_add_scalar_f32(src1_spad_data, 1.0, src1_spad_data, nc);
// x1 (dst_spad_data) = alpha * (x)
hvx_mul_scalar_f32(src0_spad_data, alpha, dst_spad_data, nc);
// x2 (dst_spad_data) = expf(-x1)
hvx_exp_f32(dst_spad_data, dst_spad_data, nc, true);
// x3 (dst_spad_data) = x2 + 1.f
hvx_add_scalar_f32(dst_spad_data, 1.0, dst_spad_data, nc);
// x4 (dst_spad_data) = 1 / x3
hvx_inverse_f32(dst_spad_data, dst_spad_data, nc);
// out_glu(dst_spad_data) = x * x4
hvx_mul_f32(src0_spad_data, dst_spad_data, dst_spad_data, nc);
// out = out_glu * (y + 1.f);
hvx_mul_f32(dst_spad_data, src1_spad_data, (uint8_t *) dst, nc);
}
t2 = HAP_perf_get_qtimer_count(); t2 = HAP_perf_get_qtimer_count();
FARF(HIGH, "swiglu-oai-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, src0->ne[0], FARF(HIGH, "swiglu-f32 %d/%d: %ux%ux%ux%u (%u:%u) x %ux%ux%ux%u -> %ux%ux%ux%u usec %u\n", ith, nth, src0->ne[0],
src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2], src0->ne[1], src0->ne[2], src0->ne[3], src0_start_row, src0_end_row, src1->ne[0], src1->ne[1], src1->ne[2],
src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); src1->ne[3], dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
} }
@ -457,8 +371,7 @@ static void unary_silu_fp32_per_thread(const struct htp_tensor * src0,
struct htp_spad * dst_spad, struct htp_spad * dst_spad,
uint32_t nth, uint32_t nth,
uint32_t ith, uint32_t ith,
uint32_t src0_nrows_per_thread, uint32_t src0_nrows_per_thread) {
dma_queue * dma_queue) {
htp_act_preamble2; htp_act_preamble2;
uint64_t t1, t2; uint64_t t1, t2;
@ -466,8 +379,6 @@ static void unary_silu_fp32_per_thread(const struct htp_tensor * src0,
const size_t src0_row_size = nb01; const size_t src0_row_size = nb01;
const size_t dst_row_size = nb1; const size_t dst_row_size = nb1;
const size_t src0_row_size_aligned = htp_round_up(src0_row_size, VLEN);
const size_t dst_row_size_aligned = htp_round_up(dst_row_size, VLEN);
const uint32_t src0_nrows = ne01 * ne02 * ne03; const uint32_t src0_nrows = ne01 * ne02 * ne03;
@ -479,91 +390,64 @@ static void unary_silu_fp32_per_thread(const struct htp_tensor * src0,
return; return;
} }
const uint8_t * data_src0 = (const uint8_t *) src0->data; int is_aligned = 1;
uint8_t * data_dst = (uint8_t *) dst->data; int opt_path = 0;
if (!htp_is_aligned((void *) src0->data, VLEN) || !htp_is_aligned((void *) dst->data, VLEN)) {
uint8_t * src0_spad_data = src0_spad->data + (ith * src0_spad->size_per_thread); is_aligned = 0;
uint8_t * dst_spad_data = dst_spad->data + (ith * dst_spad->size_per_thread); FARF(HIGH, "silu-f32: unaligned addresses in elementwise op, possibly slower execution\n");
}
// While given src0_spad->size_per_thread, divide it to two ping-pong buffer for src0 if ((1 == is_aligned) && !(nb01 & (VLEN - 1))) {
size_t src0_spad_half_size = src0_spad->size_per_thread / 2; opt_path = 1;
size_t dst_spad_half_size = dst_spad->size_per_thread / 2;
const int BLOCK = src0_spad_half_size / src0_row_size_aligned; // How many rows can we process in one block
if (BLOCK == 0) {
FARF(ERROR, "silu-f32 : current VTCM reservation %zu is too small for even 1 row per thread, needed at least %zu\n",
src0_spad->size_per_thread, src0_row_size_aligned);
return;
} }
// See discussion: https://github.com/ggml-org/llama.cpp/pull/18151#issuecomment-3678235379 const uint8_t * restrict data_src0 = (const uint8_t *) src0->data;
for (uint32_t ir = src0_start_row, spad_idx = 0; ir < src0_end_row && spad_idx < 2; ir += BLOCK, spad_idx++) { uint8_t * restrict data_dst = (uint8_t *) dst->data;
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir);
// Dummy DMA transation for sequencing (interleaving dst,src,dst,...) uint8_t * restrict src0_spad_data = src0_spad->data + (ith * src0_row_size);
dma_queue_push_vtcm_to_ddr(dma_queue, uint8_t * restrict dst_spad_data = dst_spad->data + (ith * dst_row_size);
dma_make_ptr(data_dst, dst_spad_data + (spad_idx * dst_spad_half_size)),
dst_row_size, dst_row_size_aligned, 0);
dma_queue_push_ddr_to_vtcm(dma_queue, for (uint32_t ir = src0_start_row; ir < src0_end_row; ir++) {
dma_make_ptr(src0_spad_data + (spad_idx * src0_spad_half_size), data_src0 + (ir * src0_row_size)), const float * restrict src0 = (float *) (data_src0 + (ir * src0_row_size));
src0_row_size_aligned, src0_row_size, block_size); float * restrict dst = (float *) (data_dst + (ir * dst_row_size));
}
for (uint32_t ir = src0_start_row; ir < src0_end_row; ir += BLOCK) { if (ir + 1 < src0_end_row) {
const uint32_t block_size = MIN(BLOCK, src0_end_row - ir); htp_l2fetch(src0 + src0_row_size, 1, src0_row_size, src0_row_size);
float* dst_spad = (float *) dma_queue_pop(dma_queue).src;
float* src0_spad = (float *) dma_queue_pop(dma_queue).dst;
for (uint32_t ib = 0; ib < block_size; ib++) {
const float* src0_spad_ptr = src0_spad + ib * (src0_row_size_aligned / sizeof(float));
float* dst_spad_ptr = dst_spad + ib * (dst_row_size_aligned / sizeof(float));
// silu = x * sigmoid(x)
hvx_fast_sigmoid_f32((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, ne0);
hvx_mul_f32_opt((const uint8_t *) src0_spad_ptr, (uint8_t *) dst_spad_ptr, (uint8_t *) dst_spad_ptr, ne0);
} }
dma_queue_push_vtcm_to_ddr(dma_queue, if (1 == opt_path) {
dma_make_ptr(data_dst + (ir * dst_row_size), dst_spad), hvx_fast_sigmoid_f32((const uint8_t *) src0, (uint8_t *) src0_spad_data, ne0);
dst_row_size, dst_row_size_aligned, block_size); hvx_mul_f32_opt((const uint8_t *) src0, src0_spad_data, (uint8_t *) dst, ne0);
} else {
hvx_exp_f32((const uint8_t *) src0, src0_spad_data, ne0, true);
hvx_add_scalar_f32(src0_spad_data, 1.0, dst_spad_data, ne0);
hvx_inverse_f32(dst_spad_data, src0_spad_data, ne0);
// prefetch N+2 loop iteration if any hvx_mul_f32((const uint8_t *) src0, src0_spad_data, (uint8_t *) dst, ne0);
const uint32_t pref_block = (ir + BLOCK * 2);
if (pref_block < src0_end_row) {
const uint32_t pref_block_size = MIN(BLOCK, src0_end_row - pref_block);
dma_queue_push_ddr_to_vtcm(dma_queue,
dma_make_ptr(src0_spad, data_src0 + (pref_block * src0_row_size)),
src0_row_size_aligned, src0_row_size, pref_block_size);
} }
} }
dma_queue_flush(dma_queue);
t2 = HAP_perf_get_qtimer_count(); t2 = HAP_perf_get_qtimer_count();
FARF(HIGH, "silu-f32 %d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, ne00, ne01, ne02, FARF(HIGH, "silu-f32 %d/%d/%d: %ux%ux%ux%u (%u:%u) -> %ux%ux%ux%u usec %u\n", ith, nth, opt_path, ne00, ne01, ne02,
ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1)); ne03, src0_start_row, src0_end_row, ne0, ne1, ne2, ne3, (unsigned) HAP_perf_qtimer_count_to_us(t2 - t1));
} }
static void unary_silu_fp32(unsigned int n, unsigned int i, void * data) { static void unary_silu_fp32(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = (struct htp_ops_context *) data; struct htp_ops_context * octx = (struct htp_ops_context *) data;
unary_silu_fp32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i, unary_silu_fp32_per_thread(&octx->src0, &octx->dst, octx->op_params, &octx->src0_spad, &octx->dst_spad, n, i,
octx->src0_nrows_per_thread, octx->ctx->dma[i]); octx->src0_nrows_per_thread);
} }
static void glu_swiglu_fp32(unsigned int n, unsigned int i, void * data) { static void glu_swiglu_fp32(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = (struct htp_ops_context *) data; struct htp_ops_context * octx = (struct htp_ops_context *) data;
glu_swiglu_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad, glu_swiglu_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread);
} }
static void glu_swiglu_oai_fp32(unsigned int n, unsigned int i, void * data) { static void glu_swiglu_oai_fp32(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = (struct htp_ops_context *) data; struct htp_ops_context * octx = (struct htp_ops_context *) data;
glu_swiglu_oai_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad, glu_swiglu_oai_fp32_per_thread(&octx->src0, &octx->src1, &octx->dst, octx->op_params, &octx->src0_spad,
&octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread, octx->ctx->dma[i]); &octx->src1_spad, &octx->dst_spad, n, i, octx->src0_nrows_per_thread);
} }
static int execute_op_activations_fp32(struct htp_ops_context * octx) { static int execute_op_activations_fp32(struct htp_ops_context * octx) {

View File

@ -1,566 +0,0 @@
#pragma clang diagnostic ignored "-Wunused-variable"
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
#ifdef HTP_DEBUG
# define FARF_HIGH 1
#endif
#include <HAP_farf.h>
#include <HAP_mem.h>
#include <HAP_perf.h>
#include <hexagon_protos.h>
#include <hexagon_types.h>
#include <math.h>
#include <string.h>
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "htp-ctx.h"
#include "htp-dma.h"
#include "htp-msg.h"
#include "htp-ops.h"
#include "hvx-utils.h"
#include "ops-utils.h"
// Dot product of FP32 and FP16 vectors, accumulating to float
static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) {
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
const HVX_Vector zero = Q6_V_vsplat_R(0);
HVX_Vector rsum = Q6_V_vsplat_R(0);
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
// Load y (fp32) and convert into fp16
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
// Load x (fp16)
HVX_Vector x_hf = vx[i];
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
if (nloe) {
// Load y (fp32) and convert into fp16
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
// Load x (fp16)
HVX_Vector x_hf = vx[i];
// Zero-out unused elements
// Note that we need to clear both x and y because they may contain NANs
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
x_hf = Q6_V_vand_QV(bmask, x_hf);
y_hf = Q6_V_vand_QV(bmask, y_hf);
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s));
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
hvx_vec_store_u(r, 4, rsum);
}
// Dot product of two F16 vectors, accumulating to float
static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict x, const void * restrict y, unsigned int n, float s) {
const HVX_Vector * restrict vx = (const HVX_Vector * restrict) x; // fp16
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
const HVX_Vector zero = Q6_V_vsplat_R(0);
HVX_Vector rsum = Q6_V_vsplat_R(0);
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
HVX_Vector y_hf = vy[i];
HVX_Vector x_hf = vx[i];
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
if (nloe) {
HVX_Vector y_hf = vy[i];
// Load x (fp16) and zero-out unused elements
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector x_hf = Q6_V_vand_QV(bmask, vx[i]);
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
}
rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_fp32(s));
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_qf32_reduce_sum(rsum));
hvx_vec_store_u(r, 4, rsum);
}
// MAD: y (F32) += x (F16) * v (float)
static inline void hvx_mad_f32_f16_aa(float * restrict y, const void * restrict x, int n, float s) {
const HVX_Vector * restrict ptr_x = (const HVX_Vector *) x;
HVX_Vector * restrict ptr_y = (HVX_Vector *) y;
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
HVX_Vector S = hvx_vec_splat_fp16(s);
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; ++i) {
// Multiply x * s -> pair of F32 vectors
HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
ptr_y[i*2] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_lo_W(xs_p), ptr_y[i*2]));
ptr_y[i*2+1] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_V_hi_W(xs_p), ptr_y[i*2+1]));
}
if (nloe) {
HVX_VectorPair xs_p = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(ptr_x[i]), S);
HVX_Vector xs = Q6_V_lo_W(xs_p);
i = 2 * i; // index for ptr_y
if (nloe >= 32) {
ptr_y[i] = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
nloe -= 32; ++i; xs = Q6_V_hi_W(xs_p);
}
if (nloe) {
HVX_Vector xy = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(xs, ptr_y[i]));
hvx_vec_store_u(&ptr_y[i], nloe * 4, xy);
}
}
}
#define FLASH_ATTN_BLOCK_SIZE 128
static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, int nth) {
const struct htp_tensor * q = &octx->src0;
const struct htp_tensor * k = &octx->src1;
const struct htp_tensor * v = &octx->src2;
const struct htp_tensor * mask = (octx->src3.data) ? &octx->src3 : NULL;
const struct htp_tensor * sinks = (octx->src4.data) ? &octx->src4 : NULL;
struct htp_tensor * dst = &octx->dst;
const uint32_t neq0 = q->ne[0];
const uint32_t neq1 = q->ne[1];
const uint32_t neq2 = q->ne[2];
const uint32_t neq3 = q->ne[3];
const uint32_t nek0 = k->ne[0];
const uint32_t nek1 = k->ne[1];
const uint32_t nek2 = k->ne[2];
const uint32_t nek3 = k->ne[3];
const uint32_t nev0 = v->ne[0];
const uint32_t nev1 = v->ne[1];
const uint32_t nev2 = v->ne[2];
const uint32_t nev3 = v->ne[3];
const uint32_t nbq1 = q->nb[1];
const uint32_t nbq2 = q->nb[2];
const uint32_t nbq3 = q->nb[3];
const uint32_t nbk1 = k->nb[1];
const uint32_t nbk2 = k->nb[2];
const uint32_t nbk3 = k->nb[3];
const uint32_t nbv1 = v->nb[1];
const uint32_t nbv2 = v->nb[2];
const uint32_t nbv3 = v->nb[3];
const uint32_t ne1 = dst->ne[1];
const uint32_t ne2 = dst->ne[2];
const uint32_t ne3 = dst->ne[3];
const uint32_t nb1 = dst->nb[1];
const uint32_t nb2 = dst->nb[2];
const uint32_t nb3 = dst->nb[3];
float scale = 1.0f;
float max_bias = 0.0f;
float logit_softcap = 0.0f;
memcpy(&scale, (float *) octx->op_params + 0, sizeof(float));
memcpy(&max_bias, (float *) octx->op_params + 1, sizeof(float));
memcpy(&logit_softcap, (float *) octx->op_params + 2, sizeof(float));
if (logit_softcap != 0) {
scale /= logit_softcap;
}
// total rows in q
const uint32_t nr = neq1*neq2*neq3;
const uint32_t dr = (nr + nth - 1) / nth;
const uint32_t ir0 = dr * ith;
const uint32_t ir1 = MIN(ir0 + dr, nr);
if (ir0 >= ir1) return;
dma_queue * dma = octx->ctx->dma[ith];
const uint32_t DK = nek0;
const uint32_t DV = nev0;
const size_t size_q_row = DK * ((q->type == HTP_TYPE_F32) ? 4 : 2);
const size_t size_q_row_padded = htp_round_up(size_q_row, 128);
const size_t size_k_row = DK * sizeof(__fp16);
const size_t size_v_row = DV * sizeof(__fp16);
const size_t size_m_row = FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16); // Treat block as one row for mask
const size_t size_k_row_padded = htp_round_up(size_k_row, 128);
const size_t size_v_row_padded = htp_round_up(size_v_row, 128);
const size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
const size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
const size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
// Scratchpad buffers for Q, K, V, Mask, and VKQ32 accumulator
uint8_t * spad_q = octx->src0_spad.data + octx->src0_spad.size_per_thread * ith;
uint8_t * spad_k = octx->src1_spad.data + octx->src1_spad.size_per_thread * ith;
uint8_t * spad_v = octx->src2_spad.data + octx->src2_spad.size_per_thread * ith;
uint8_t * spad_m = octx->src3_spad.data + octx->src3_spad.size_per_thread * ith;
uint8_t * spad_a = octx->dst_spad.data + octx->dst_spad.size_per_thread * ith;
const uint32_t n_head = neq2;
const uint32_t n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
for (uint32_t ir = ir0; ir < ir1; ++ir) {
const uint32_t iq3 = fastdiv(ir, &octx->src0_div21);
const uint32_t iq2 = fastdiv(ir - iq3*neq2*neq1, &octx->src0_div1);
const uint32_t iq1 = (ir - iq3*neq2*neq1 - iq2 * neq1);
const uint32_t ik3 = fastdiv(iq3, &octx->broadcast_rk3);
const uint32_t ik2 = fastdiv(iq2, &octx->broadcast_rk2);
const uint32_t iv3 = fastdiv(iq3, &octx->broadcast_rv3);
const uint32_t iv2 = fastdiv(iq2, &octx->broadcast_rv2);
// Fetch Q row
const uint8_t * q_row_ptr = (const uint8_t *) q->data + (iq1*nbq1 + iq2*nbq2 + iq3*nbq3);
dma_queue_push(dma, dma_make_ptr(spad_q, q_row_ptr), size_q_row_padded, nbq1, size_q_row, 1);
const uint32_t h = iq2; // head index
const float slope = (max_bias > 0.0f) ? (h < n_head_log2 ? powf(m0, h + 1) : powf(m1, 2*(h - n_head_log2) + 1)) : 1.0f;
float S = 0.0f; // sum
float M = -INFINITY; // maximum KQ value
// Clear accumulator
float * VKQ32 = (float *) spad_a;
memset(VKQ32, 0, DV * sizeof(float));
const __fp16 * mp_base = NULL;
if (mask) {
const uint32_t im2 = fastmodulo(iq2, mask->ne[2], &octx->src3_div2);
const uint32_t im3 = fastmodulo(iq3, mask->ne[3], &octx->src3_div3);
mp_base = (const __fp16 *) ((const uint8_t *) mask->data + iq1*mask->nb[1] + im2*mask->nb[2] + im3*mask->nb[3]);
}
const uint32_t n_blocks = (nek1 + FLASH_ATTN_BLOCK_SIZE - 1) / FLASH_ATTN_BLOCK_SIZE;
// Prefetch first two blocks
for (uint32_t ib = 0; ib < MIN(n_blocks, 2); ++ib) {
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
// K
const uint8_t * k_src = (const uint8_t *) k->data + (ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
uint8_t * k_dst = spad_k + (ib % 2) * size_k_block;
dma_queue_push(dma, dma_make_ptr(k_dst, k_src), size_k_row_padded, nbk1, size_k_row, current_block_size);
// V
const uint8_t * v_src = (const uint8_t *) v->data + (ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
uint8_t * v_dst = spad_v + (ib % 2) * size_v_block;
dma_queue_push(dma, dma_make_ptr(v_dst, v_src), size_v_row_padded, nbv1, size_v_row, current_block_size);
// Mask
if (mask) {
const uint8_t * m_src = (const uint8_t *) (mp_base + ic_start);
uint8_t * m_dst = spad_m + (ib % 2) * size_m_block;
// Mask is 1D contiguous for this row
dma_queue_push(dma, dma_make_ptr(m_dst, m_src), current_block_size * 2, current_block_size * 2, current_block_size * 2, 1);
}
}
const uint8_t * q_ptr_vtcm = dma_queue_pop(dma).dst;
for (uint32_t ib = 0; ib < n_blocks; ++ib) {
const uint32_t ic_start = ib * FLASH_ATTN_BLOCK_SIZE;
const uint32_t current_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - ic_start);
// Wait for DMA
uint8_t * k_base = dma_queue_pop(dma).dst; // K
uint8_t * v_base = dma_queue_pop(dma).dst; // V
__fp16 * m_base = mask ? dma_queue_pop(dma).dst : NULL; // M
// Inner loop processing the block from VTCM
uint32_t ic = 0;
// Process in blocks of 32 (VLEN_FP32)
for (; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32) {
// 1. Compute scores
float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
for (int j = 0; j < VLEN_FP32; ++j) {
const uint32_t cur_ic = ic + j;
const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
if (q->type == HTP_TYPE_F32) {
hvx_dot_f32_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
} else {
hvx_dot_f16_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale);
}
}
HVX_Vector scores = *(HVX_Vector *) scores_arr;
// 2. Softcap
if (logit_softcap != 0.0f) {
scores = hvx_vec_tanh_fp32(scores);
scores = Q6_Vqf32_vmpy_VsfVsf(scores, hvx_vec_splat_fp32(logit_softcap));
scores = Q6_Vsf_equals_Vqf32(scores);
}
// 3. Mask
if (mask) {
const __fp16 * mp = m_base + ic;
HVX_Vector m_vals_fp16 = *(const HVX_UVector *) mp;
HVX_Vector one_fp16 = Q6_Vh_vsplat_R(0x3c00);
HVX_VectorPair m_vals_fp32_pair = Q6_Wqf32_vmpy_VhfVhf(Q6_Vh_vshuff_Vh(m_vals_fp16), one_fp16);
HVX_Vector m_vals_fp32 = Q6_Vsf_equals_Vqf32(Q6_V_lo_W(m_vals_fp32_pair));
HVX_Vector slope_vec = hvx_vec_splat_fp32(slope);
HVX_Vector add_val = Q6_Vqf32_vmpy_VsfVsf(m_vals_fp32, slope_vec);
scores = Q6_Vqf32_vadd_VsfVsf(scores, Q6_Vsf_equals_Vqf32(add_val));
scores = Q6_Vsf_equals_Vqf32(scores);
}
// 4. Online Softmax Update
HVX_Vector v_max = hvx_vec_reduce_max_fp32(scores);
float m_block = hvx_vec_get_fp32(v_max);
float M_old = M;
float M_new = (m_block > M) ? m_block : M;
M = M_new;
float ms = expf(M_old - M_new);
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
S = S * ms;
HVX_Vector M_new_vec = hvx_vec_splat_fp32(M_new);
HVX_Vector scores_shifted = Q6_Vqf32_vsub_VsfVsf(scores, M_new_vec);
HVX_Vector P = hvx_vec_exp_fp32(Q6_Vsf_equals_Vqf32(scores_shifted));
HVX_Vector p_sum_vec = hvx_vec_fp32_reduce_sum(P);
float p_sum = hvx_vec_get_fp32(p_sum_vec);
S += p_sum;
// 5. Accumulate V
float __attribute__((aligned(VLEN))) p_arr[VLEN_FP32];
*(HVX_Vector*)p_arr = P;
for (int j = 0; j < VLEN_FP32; ++j) {
const uint32_t cur_ic = ic + j;
const uint8_t * v_ptr = v_base + cur_ic * size_v_row_padded;
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, p_arr[j]);
}
}
// Leftover
for (; ic < current_block_size; ++ic) {
float s_val;
const uint8_t * k_ptr = k_base + ic * size_k_row_padded;
if (q->type == HTP_TYPE_F32) {
hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
} else {
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
}
if (logit_softcap != 0.0f) {
s_val = logit_softcap * tanhf(s_val);
}
if (mask) {
const float m_val = m_base[ic];
s_val += slope * m_val;
}
const float Mold = M;
float ms = 1.0f;
float vs = 1.0f;
if (s_val > M) {
M = s_val;
ms = expf(Mold - M);
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
} else {
vs = expf(s_val - M);
}
const uint8_t * v_ptr = v_base + ic * size_v_row_padded;
hvx_mad_f32_f16_aa(VKQ32, v_ptr, DV, vs);
S = S * ms + vs;
}
// Issue DMA for next+1 block (if exists)
if (ib + 2 < n_blocks) {
const uint32_t next_ib = ib + 2;
const uint32_t next_ic_start = next_ib * FLASH_ATTN_BLOCK_SIZE;
const uint32_t next_block_size = MIN(FLASH_ATTN_BLOCK_SIZE, nek1 - next_ic_start);
// K
const uint8_t * k_src = (const uint8_t *) k->data + (next_ic_start*nbk1 + ik2*nbk2 + ik3*nbk3);
dma_queue_push(dma, dma_make_ptr(k_base, k_src), size_k_row_padded, nbk1, size_k_row, next_block_size);
// V
const uint8_t * v_src = (const uint8_t *) v->data + (next_ic_start*nbv1 + iv2*nbv2 + iv3*nbv3);
dma_queue_push(dma, dma_make_ptr(v_base, v_src), size_v_row_padded, nbv1, size_v_row, next_block_size);
// Mask
if (mask) {
const uint8_t * m_src = (const uint8_t *) (mp_base + next_ic_start);
dma_queue_push(dma, dma_make_ptr(m_base, m_src), next_block_size * 2, next_block_size * 2, next_block_size * 2, 1);
}
}
}
// sinks
if (sinks) {
const float s = ((float *)((char *) sinks->data))[h];
float ms = 1.0f;
float vs = 1.0f;
if (s > M) {
ms = expf(M - s);
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, ms);
} else {
vs = expf(s - M);
}
S = S * ms + vs;
}
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
hvx_scale_f32_aa((uint8_t *) VKQ32, (const uint8_t *) VKQ32, DV, S_inv);
// Store result
// dst indices
const int i1 = iq1;
const int i2 = iq2;
const int i3 = iq3;
// dst is permuted
uint8_t * dst_ptr = (uint8_t *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1) * nb1;
if (dst->type == HTP_TYPE_F32) {
hvx_copy_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
} else if (dst->type == HTP_TYPE_F16) {
hvx_copy_fp16_fp32_ua(dst_ptr, (uint8_t *) VKQ32, DV);
}
}
}
static void htp_flash_attn_ext_job(unsigned int n, unsigned int i, void * data) {
struct htp_ops_context * octx = data;
flash_attn_ext_f16_thread(octx, i, n);
}
int op_flash_attn_ext(struct htp_ops_context * octx) {
const struct htp_tensor * q = &octx->src0;
const struct htp_tensor * k = &octx->src1;
const struct htp_tensor * v = &octx->src2;
const struct htp_tensor * mask = (octx->src3.type != HTP_TYPE_COUNT) ? &octx->src3 : NULL;
struct htp_tensor * dst = &octx->dst;
// Check support
if ((q->type != HTP_TYPE_F16 && q->type != HTP_TYPE_F32) ||
k->type != HTP_TYPE_F16 ||
v->type != HTP_TYPE_F16) {
return HTP_STATUS_NO_SUPPORT;
}
octx->src0_div21 = init_fastdiv_values(q->ne[2] * q->ne[1]);
octx->src0_div1 = init_fastdiv_values(q->ne[1]);
octx->broadcast_rk2 = init_fastdiv_values(q->ne[2]/k->ne[2]);
octx->broadcast_rk3 = init_fastdiv_values(q->ne[3]/k->ne[3]);
octx->broadcast_rv2 = init_fastdiv_values(q->ne[2]/v->ne[2]);
octx->broadcast_rv3 = init_fastdiv_values(q->ne[3]/v->ne[3]);
if (mask) {
octx->src3_div2 = init_fastdiv_values(mask->ne[2]);
octx->src3_div3 = init_fastdiv_values(mask->ne[3]);
}
size_t size_q_row_padded = htp_round_up(q->ne[0] * (q->type == HTP_TYPE_F32 ? 4 : 2), 128);
size_t size_k_row_padded = htp_round_up(k->ne[0] * sizeof(__fp16), 128);
size_t size_v_row_padded = htp_round_up(v->ne[0] * sizeof(__fp16), 128);
size_t size_q_block = size_q_row_padded * 1; // single row for now
size_t size_k_block = size_k_row_padded * FLASH_ATTN_BLOCK_SIZE;
size_t size_v_block = size_v_row_padded * FLASH_ATTN_BLOCK_SIZE;
size_t size_m_block = htp_round_up(FLASH_ATTN_BLOCK_SIZE * sizeof(__fp16), 128);
size_t size_vkq_acc = htp_round_up(v->ne[0] * sizeof(float), 128); // VKQ32
octx->src0_spad.size_per_thread = size_q_block * 1;
octx->src1_spad.size_per_thread = size_k_block * 2;
octx->src2_spad.size_per_thread = size_v_block * 2;
octx->src3_spad.size_per_thread = mask ? size_m_block * 2 : 0;
octx->dst_spad.size_per_thread = size_vkq_acc;
octx->src0_spad.size = octx->src0_spad.size_per_thread * octx->n_threads;
octx->src1_spad.size = octx->src1_spad.size_per_thread * octx->n_threads;
octx->src2_spad.size = octx->src2_spad.size_per_thread * octx->n_threads;
octx->src3_spad.size = octx->src3_spad.size_per_thread * octx->n_threads;
octx->dst_spad.size = octx->dst_spad.size_per_thread * octx->n_threads;
size_t total_spad = octx->src0_spad.size + octx->src1_spad.size + octx->src2_spad.size + octx->src3_spad.size + octx->dst_spad.size;
if (octx->ctx->vtcm_size < total_spad) {
return HTP_STATUS_VTCM_TOO_SMALL;
}
octx->src0_spad.data = octx->ctx->vtcm_base;
octx->src1_spad.data = octx->src0_spad.data + octx->src0_spad.size;
octx->src2_spad.data = octx->src1_spad.data + octx->src1_spad.size;
octx->src3_spad.data = octx->src2_spad.data + octx->src2_spad.size;
octx->dst_spad.data = octx->src3_spad.data + octx->src3_spad.size;
if (!(octx->flags & HTP_OPFLAGS_SKIP_COMPUTE)) {
worker_pool_run_func(octx->ctx->worker_pool, htp_flash_attn_ext_job, octx, octx->n_threads);
}
return HTP_STATUS_OK;
}

View File

@ -1,112 +0,0 @@
#pragma clang diagnostic ignored "-Wunused-variable"
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
#ifdef HTP_DEBUG
# define FARF_HIGH 1
#endif
#include <HAP_farf.h>
#include <HAP_mem.h>
#include <HAP_perf.h>
#include <hexagon_protos.h>
#include <hexagon_types.h>
#include <math.h>
#include <string.h>
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "htp-ctx.h"
#include "htp-msg.h"
#include "htp-ops.h"
#include "hvx-utils.h"
#include "ops-utils.h"
#define get_rows_preamble \
const uint32_t ne00 = octx->src0.ne[0]; \
const uint32_t ne01 = octx->src0.ne[1]; \
const uint32_t ne02 = octx->src0.ne[2]; \
const uint32_t ne03 = octx->src0.ne[3]; \
\
const uint32_t ne10 = octx->src1.ne[0]; \
const uint32_t ne11 = octx->src1.ne[1]; \
const uint32_t ne12 = octx->src1.ne[2]; \
\
const uint32_t nb01 = octx->src0.nb[1]; \
const uint32_t nb02 = octx->src0.nb[2]; \
const uint32_t nb03 = octx->src0.nb[3]; \
\
const uint32_t nb10 = octx->src1.nb[0]; \
const uint32_t nb11 = octx->src1.nb[1]; \
const uint32_t nb12 = octx->src1.nb[2]; \
\
const uint32_t nb1 = octx->dst.nb[1]; \
const uint32_t nb2 = octx->dst.nb[2]; \
const uint32_t nb3 = octx->dst.nb[3]; \
\
const uint32_t nr = ne10 * ne11 * ne12;
static int get_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
get_rows_preamble;
// parallelize by src1 elements (which correspond to dst rows)
const uint32_t dr = octx->src1_nrows_per_thread;
const uint32_t ir0 = dr * ith;
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
for (uint32_t i = ir0; i < ir1; ++i) {
const uint32_t i12 = fastdiv(i, &octx->get_rows_div_ne10_ne11);
const uint32_t rem = i - i12 * ne11 * ne10;
const uint32_t i11 = fastdiv(rem, &octx->get_rows_div_ne10);
const uint32_t i10 = rem - i11 * ne10;
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
uint32_t i01 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
if (i01 >= ne01) {
// invalid index, skip for now to avoid crash
continue;
}
const uintptr_t src0_ptr = octx->src0.data + i01*nb01 + i11*nb02 + i12*nb03;
const uintptr_t dst_ptr = octx->dst.data + i10*nb1 + i11*nb2 + i12*nb3;
hvx_copy_fp32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
}
return HTP_STATUS_OK;
}
static void get_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
get_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
}
int op_get_rows(struct htp_ops_context * octx) {
get_rows_preamble;
if (octx->src0.type != HTP_TYPE_F32) {
return HTP_STATUS_NO_SUPPORT;
}
if (octx->dst.type != HTP_TYPE_F32) {
return HTP_STATUS_NO_SUPPORT;
}
if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) {
return HTP_STATUS_NO_SUPPORT;
}
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
return HTP_STATUS_OK;
}
octx->get_rows_div_ne10 = init_fastdiv_values(octx->src1.ne[0]);
octx->get_rows_div_ne10_ne11 = init_fastdiv_values(octx->src1.ne[0] * octx->src1.ne[1]);
const uint32_t n_jobs = MIN(nr, octx->n_threads);
octx->src1_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
worker_pool_run_func(octx->ctx->worker_pool, get_rows_work_f32_f32, octx, n_jobs);
return HTP_STATUS_OK;
}

View File

@ -11,6 +11,11 @@
#define HTP_MAX_NTHREADS 10 #define HTP_MAX_NTHREADS 10
// FIXME: move these into matmul-ops
#define HTP_SPAD_SRC0_NROWS 16
#define HTP_SPAD_SRC1_NROWS 16
#define HTP_SPAD_DST_NROWS 2
// Main context for htp DSP backend // Main context for htp DSP backend
struct htp_context { struct htp_context {
dspqueue_t queue; dspqueue_t queue;

View File

@ -36,8 +36,6 @@ enum htp_data_type {
HTP_TYPE_F16 = 1, HTP_TYPE_F16 = 1,
HTP_TYPE_Q4_0 = 2, HTP_TYPE_Q4_0 = 2,
HTP_TYPE_Q8_0 = 8, HTP_TYPE_Q8_0 = 8,
HTP_TYPE_I32 = 26,
HTP_TYPE_I64 = 27,
HTP_TYPE_MXFP4 = 39, HTP_TYPE_MXFP4 = 39,
HTP_TYPE_COUNT HTP_TYPE_COUNT
}; };
@ -59,10 +57,6 @@ enum htp_op {
HTP_OP_SOFTMAX = 11, HTP_OP_SOFTMAX = 11,
HTP_OP_ADD_ID = 12, HTP_OP_ADD_ID = 12,
HTP_OP_ROPE = 13, HTP_OP_ROPE = 13,
HTP_OP_FLASH_ATTN_EXT = 14,
HTP_OP_SET_ROWS = 15,
HTP_OP_SCALE = 16,
HTP_OP_GET_ROWS = 17,
INVALID INVALID
}; };
@ -143,8 +137,6 @@ struct htp_general_req {
struct htp_tensor src0; // Input0 tensor struct htp_tensor src0; // Input0 tensor
struct htp_tensor src1; // Input1 tensor struct htp_tensor src1; // Input1 tensor
struct htp_tensor src2; // Input2 tensor struct htp_tensor src2; // Input2 tensor
struct htp_tensor src3; // Input3 tensor
struct htp_tensor src4; // Input4 tensor
struct htp_tensor dst; // Output tensor struct htp_tensor dst; // Output tensor
// should be multiple of 64 bytes (cacheline) // should be multiple of 64 bytes (cacheline)
@ -160,6 +152,6 @@ struct htp_general_rsp {
}; };
#define HTP_MAX_MESSAGE_SIZE sizeof(struct htp_general_req) #define HTP_MAX_MESSAGE_SIZE sizeof(struct htp_general_req)
#define HTP_MAX_PACKET_BUFFERS 8 #define HTP_MAX_PACKET_BUFFERS 4
#endif /* HTP_MSG_H */ #endif /* HTP_MSG_H */

View File

@ -13,7 +13,6 @@
struct htp_spad { struct htp_spad {
uint8_t * data; uint8_t * data;
size_t stride;
size_t size; size_t size;
size_t size_per_thread; size_t size_per_thread;
}; };
@ -27,14 +26,11 @@ struct htp_ops_context {
struct htp_tensor src0; struct htp_tensor src0;
struct htp_tensor src1; struct htp_tensor src1;
struct htp_tensor src2; struct htp_tensor src2;
struct htp_tensor src3;
struct htp_tensor src4;
struct htp_tensor dst; struct htp_tensor dst;
struct htp_spad src0_spad; struct htp_spad src0_spad;
struct htp_spad src1_spad; struct htp_spad src1_spad;
struct htp_spad src2_spad; struct htp_spad src2_spad;
struct htp_spad src3_spad;
struct htp_spad dst_spad; struct htp_spad dst_spad;
worker_pool_context_t * wpool; // worker pool worker_pool_context_t * wpool; // worker pool
@ -53,27 +49,6 @@ struct htp_ops_context {
struct fastdiv_values src1_div3; // fastdiv values for ne3 struct fastdiv_values src1_div3; // fastdiv values for ne3
struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1 struct fastdiv_values src1_div21; // fastdiv values for ne2 * ne1
struct fastdiv_values src3_div1; // fastdiv values for ne1
struct fastdiv_values src3_div2; // fastdiv values for ne2
struct fastdiv_values src3_div3; // fastdiv values for ne3
struct fastdiv_values src3_div21; // fastdiv values for ne2 * ne1
struct fastdiv_values broadcast_rk2;
struct fastdiv_values broadcast_rk3;
struct fastdiv_values broadcast_rv2;
struct fastdiv_values broadcast_rv3;
struct fastdiv_values mm_div_ne12_ne1; // fastdiv values for ne12 * ne1
struct fastdiv_values mm_div_ne1; // fastdiv values for ne1
struct fastdiv_values mm_div_r2; // fastdiv values for ne12 / ne02
struct fastdiv_values mm_div_r3; // fastdiv values for ne13 / ne03
struct fastdiv_values set_rows_div_ne12; // fastdiv values for ne12
struct fastdiv_values set_rows_div_ne11; // fastdiv values for ne11
struct fastdiv_values get_rows_div_ne10; // fastdiv values for ne10
struct fastdiv_values get_rows_div_ne10_ne11; // fastdiv values for ne10 * ne11
uint32_t flags; uint32_t flags;
}; };
@ -85,8 +60,5 @@ int op_activations(struct htp_ops_context * octx);
int op_softmax(struct htp_ops_context * octx); int op_softmax(struct htp_ops_context * octx);
int op_add_id(struct htp_ops_context * octx); int op_add_id(struct htp_ops_context * octx);
int op_rope(struct htp_ops_context * octx); int op_rope(struct htp_ops_context * octx);
int op_flash_attn_ext(struct htp_ops_context * octx);
int op_set_rows(struct htp_ops_context * octx);
int op_get_rows(struct htp_ops_context * octx);
#endif /* HTP_OPS_H */ #endif /* HTP_OPS_H */

View File

@ -848,6 +848,55 @@ float hvx_self_sum_f32(const uint8_t * restrict src, const int num_elems) {
return hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(v)); return hvx_vec_get_fp32(Q6_Vsf_equals_Vqf32(v));
} }
void hvx_scale_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, const float scale) {
int left_over = num_elems & (VLEN_FP32 - 1);
int num_elems_whole = num_elems - left_over;
int unaligned_addr = 0;
int unaligned_loop = 0;
if ((0 == htp_is_aligned((void *) src, VLEN)) || (0 == htp_is_aligned((void *) dst, VLEN))) {
FARF(HIGH, "hvx_scale_f32: unaligned address in hvx op, possibly slower execution\n");
unaligned_addr = 1;
}
if ((1 == unaligned_addr) && (num_elems_whole != 0)) {
unaligned_loop = 1;
FARF(HIGH, "hvx_scale_f32: unaligned loop in hvx op, possibly slower execution\n");
}
HVX_Vector scale_vec = hvx_vec_splat_fp32(scale);
if (0 == unaligned_loop) {
HVX_Vector * vec_in1 = (HVX_Vector *) src;
HVX_Vector * vec_out = (HVX_Vector *) dst;
#pragma unroll(4)
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(*vec_in1++, scale_vec);
*vec_out++ = Q6_Vsf_equals_Vqf32(v);
}
} else {
#pragma unroll(4)
for (int i = 0; i < num_elems_whole; i += VLEN_FP32) {
HVX_Vector in = *(HVX_UVector *) (src + i * SIZEOF_FP32);
HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, scale_vec);
*(HVX_UVector *) (dst + i * SIZEOF_FP32) = Q6_Vsf_equals_Vqf32(out);
}
}
if (left_over > 0) {
const float * srcf = (const float *) src + num_elems_whole;
float * dstf = (float *) dst + num_elems_whole;
HVX_Vector in = *(HVX_UVector *) srcf;
HVX_Vector out = Q6_Vqf32_vmpy_VsfVsf(in, scale_vec);
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, Q6_Vsf_equals_Vqf32(out));
}
}
float hvx_self_max_f32(const uint8_t * restrict src, const int num_elems) { float hvx_self_max_f32(const uint8_t * restrict src, const int num_elems) {
int left_over = num_elems & (VLEN_FP32 - 1); int left_over = num_elems & (VLEN_FP32 - 1);
int num_elems_whole = num_elems - left_over; int num_elems_whole = num_elems - left_over;
@ -1016,5 +1065,3 @@ void hvx_clamp_scalar_f32(const uint8_t * restrict src,
hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, in_vec); hvx_vec_store_u((void *) dstf, left_over * SIZEOF_FP32, in_vec);
} }
} }

View File

@ -41,24 +41,15 @@ static inline HVX_Vector Q6_Vsf_equals_Vw(HVX_Vector const in)
} }
#endif #endif
static inline HVX_Vector hvx_vec_splat_fp32(float v) { static inline HVX_Vector hvx_vec_splat_fp32(float i) {
union { union {
float f; float f;
uint32_t i; int32_t i;
} fp32 = { .f = v }; } fp32 = { .f = i };
return Q6_V_vsplat_R(fp32.i); return Q6_V_vsplat_R(fp32.i);
} }
static inline HVX_Vector hvx_vec_splat_fp16(float v) {
union {
__fp16 f;
uint16_t i;
} fp16 = { .f = v };
return Q6_Vh_vsplat_R(fp16.i);
}
static inline void hvx_vec_store_u(void * addr, uint32_t n, HVX_Vector v) { static inline void hvx_vec_store_u(void * addr, uint32_t n, HVX_Vector v) {
// Rotate as needed. // Rotate as needed.
v = Q6_V_vlalign_VVR(v, v, (size_t) addr); v = Q6_V_vlalign_VVR(v, v, (size_t) addr);
@ -251,120 +242,6 @@ static inline void hvx_copy_fp32_au(uint8_t * restrict dst, const uint8_t * rest
} }
} }
// copy n fp32 elements : source is unaligned, destination unaligned
static inline void hvx_copy_fp32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
HVX_UVector * restrict vdst = (HVX_UVector *) dst;
HVX_UVector * restrict vsrc = (HVX_UVector *) src;
assert((unsigned long) dst % 128 == 0);
uint32_t nvec = n / 32;
uint32_t nloe = n % 32;
uint32_t i = 0;
#pragma unroll(4)
for (; i < nvec; i++) {
HVX_Vector v = vsrc[i];
vdst[i] = v;
}
if (nloe) {
HVX_Vector v = vsrc[i];
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(float), v);
}
}
// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is unaligned
static inline void hvx_copy_fp16_fp32_uu(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
HVX_UVector * restrict vdst = (HVX_UVector *) dst; // fp16
HVX_UVector * restrict vsrc = (HVX_UVector *) src; // fp32
const HVX_Vector zero = Q6_V_vsplat_R(0);
uint32_t nvec = n / 64;
uint32_t nloe = n % 64;
uint32_t i = 0;
#pragma unroll(4)
for (; i < nvec; i++) {
// Load y (fp32) and convert into fp16
HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
vdst[i] = Q6_Vh_vdeal_Vh(s_hf);
}
if (nloe) {
// Load y (fp32) and convert into fp16
HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf));
}
}
// copy/convert n fp32 elements into n fp16 elements : source is aligned, destination is unaligned
static inline void hvx_copy_fp16_fp32_ua(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
HVX_UVector * restrict vdst = (HVX_UVector *) dst; // fp16
HVX_Vector * restrict vsrc = (HVX_Vector *) src; // fp32
const HVX_Vector zero = Q6_V_vsplat_R(0);
uint32_t nvec = n / 64;
uint32_t nloe = n % 64;
uint32_t i = 0;
#pragma unroll(4)
for (; i < nvec; i++) {
// Load y (fp32) and convert into fp16
HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
vdst[i] = Q6_Vh_vdeal_Vh(s_hf);
}
if (nloe) {
// Load y (fp32) and convert into fp16
HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf));
}
}
// copy/convert n fp32 elements into n fp16 elements : source is unaligned, destination is aligned
static inline void hvx_copy_fp16_fp32_au(uint8_t * restrict dst, const uint8_t * restrict src, uint32_t n) {
HVX_Vector * restrict vdst = (HVX_Vector *) dst; // fp16
HVX_UVector * restrict vsrc = (HVX_UVector *) src; // fp32
const HVX_Vector zero = Q6_V_vsplat_R(0);
uint32_t nvec = n / 64;
uint32_t nloe = n % 64;
uint32_t i = 0;
#pragma unroll(4)
for (; i < nvec; i++) {
// Load y (fp32) and convert into fp16
HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
vdst[i] = Q6_Vh_vdeal_Vh(s_hf);
}
if (nloe) {
// Load y (fp32) and convert into fp16
HVX_Vector s0_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+0], zero); // 32 elements
HVX_Vector s1_qf = Q6_Vqf32_vsub_VsfVsf(vsrc[i*2+1], zero); // 32 elements
HVX_Vector s_hf = Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(s1_qf, s0_qf));
hvx_vec_store_u((void *) &vdst[i], nloe * sizeof(__fp16), Q6_Vh_vdeal_Vh(s_hf));
}
}
// bcast 1 fp32 element from source to n fp32 elements in destination : destination is aligned // bcast 1 fp32 element from source to n fp32 elements in destination : destination is aligned
static inline void hvx_bcast_fp32_a(uint8_t * restrict dst, float elem, uint32_t n) { static inline void hvx_bcast_fp32_a(uint8_t * restrict dst, float elem, uint32_t n) {
HVX_Vector * restrict vdst = (HVX_Vector *) dst; HVX_Vector * restrict vdst = (HVX_Vector *) dst;
@ -396,6 +273,8 @@ static __attribute__((always_inline)) int32_t is_in_one_chunk(void * addr, uint3
return right_off <= chunk_size; return right_off <= chunk_size;
} }
static void hvx_vec_dump_fp16_n(char * pref, HVX_Vector v, uint32_t n) { static void hvx_vec_dump_fp16_n(char * pref, HVX_Vector v, uint32_t n) {
HVX_VectorAlias u = { .v = v }; HVX_VectorAlias u = { .v = v };
@ -652,13 +531,13 @@ static inline HVX_Vector hvx_vec_abs_fp32(HVX_Vector v) {
} }
static inline HVX_Vector hvx_vec_neg_fp32(HVX_Vector v) { static inline HVX_Vector hvx_vec_neg_fp32(HVX_Vector v) {
#if __HVX_ARCH__ > 75 #if __HTP_ARCH__ > 75
return Q6_Vsf_vfneg_Vsf(v); return Q6_Vsf_vfneg_Vsf(v);
#else #else
// neg by setting the fp32 sign bit // neg by setting the fp32 sign bit
HVX_Vector mask = Q6_V_vsplat_R(0x80000000); HVX_Vector mask = Q6_V_vsplat_R(0x80000000);
return Q6_V_vxor_VV(v, mask); return Q6_V_vxor_VV(v, mask);
#endif // __HVX_ARCH__ > 75 #endif // __HTP_ARCH__ > 75
} }
// ==================================================== // ====================================================
@ -1097,24 +976,6 @@ static inline HVX_Vector hvx_vec_fast_sigmoid_fp32_guard(HVX_Vector v,
return Q6_V_vmux_QVV(pred_min, out, Q6_V_vzero()); return Q6_V_vmux_QVV(pred_min, out, Q6_V_vzero());
} }
static inline HVX_Vector hvx_vec_tanh_fp32(HVX_Vector x) {
// tanh(x) = 2 * sigmoid(2x) - 1
HVX_Vector two = hvx_vec_splat_fp32(2.0f);
HVX_Vector one = hvx_vec_splat_fp32(1.0f);
HVX_Vector x2 = Q6_Vqf32_vmpy_VsfVsf(x, two);
static const float kMinExp = -87.f; // 0
static const float kMaxExp = 87.f; // 1
HVX_Vector max_exp = hvx_vec_splat_fp32(kMaxExp);
HVX_Vector min_exp = hvx_vec_splat_fp32(kMinExp);
HVX_Vector sig2x = hvx_vec_fast_sigmoid_fp32_guard(Q6_Vsf_equals_Vqf32(x2), one, max_exp, min_exp);
HVX_Vector res = Q6_Vqf32_vmpy_VsfVsf(sig2x, two);
res = Q6_Vqf32_vsub_Vqf32Vsf(res, one);
return Q6_Vsf_equals_Vqf32(res);
}
static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) { static inline void hvx_fast_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems) {
int step_of_1 = num_elems >> 5; int step_of_1 = num_elems >> 5;
int remaining = num_elems - step_of_1 * VLEN_FP32; int remaining = num_elems - step_of_1 * VLEN_FP32;
@ -1195,115 +1056,6 @@ static inline void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restr
} }
} }
static inline void hvx_scale_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
int nvec = n / VLEN_FP32;
int nloe = n % VLEN_FP32;
HVX_Vector vs = hvx_vec_splat_fp32(scale);
HVX_Vector * vsrc = (HVX_Vector *) src;
HVX_Vector * vdst = (HVX_Vector *) dst;
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; ++i) {
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
vdst[i] = Q6_Vsf_equals_Vqf32(v);
}
if (nloe) {
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v));
}
}
static inline void hvx_scale_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
int nvec = n / VLEN_FP32;
int nloe = n % VLEN_FP32;
HVX_Vector vs = hvx_vec_splat_fp32(scale);
HVX_UVector * vsrc = (HVX_UVector *) src;
HVX_UVector * vdst = (HVX_UVector *) dst;
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; ++i) {
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
vdst[i] = Q6_Vsf_equals_Vqf32(v);
}
if (nloe) {
HVX_Vector v = Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs);
hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v));
}
}
static inline void hvx_scale_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale) {
if (htp_is_aligned((void *) src, VLEN) && htp_is_aligned((void *) dst, VLEN)) {
hvx_scale_f32_aa(dst, src, n, scale);
} else {
hvx_scale_f32_uu(dst, src, n, scale);
}
}
static inline void hvx_scale_offset_f32_aa(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
int nvec = n / VLEN_FP32;
int nloe = n % VLEN_FP32;
HVX_Vector vs = hvx_vec_splat_fp32(scale);
HVX_Vector vo = hvx_vec_splat_fp32(offset);
HVX_Vector * vsrc = (HVX_Vector *) src;
HVX_Vector * vdst = (HVX_Vector *) dst;
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; ++i) {
HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo);
vdst[i] = Q6_Vsf_equals_Vqf32(v);
}
if (nloe) {
HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo);
hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v));
}
}
static inline void hvx_scale_offset_f32_uu(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
int nvec = n / VLEN_FP32;
int nloe = n % VLEN_FP32;
HVX_Vector vs = hvx_vec_splat_fp32(scale);
HVX_Vector vo = hvx_vec_splat_fp32(offset);
HVX_UVector * vsrc = (HVX_UVector *) src;
HVX_UVector * vdst = (HVX_UVector *) dst;
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; ++i) {
HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo);
vdst[i] = Q6_Vsf_equals_Vqf32(v);
}
if (nloe) {
HVX_Vector v = Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vmpy_VsfVsf(vsrc[i], vs), vo);
hvx_vec_store_u((void *) &vdst[i], nloe * 4, Q6_Vsf_equals_Vqf32(v));
}
}
static inline void hvx_scale_offset_f32(uint8_t * restrict dst, const uint8_t * restrict src, const int n, const float scale, const float offset) {
if (htp_is_aligned((void *) src, VLEN) && htp_is_aligned((void *) dst, VLEN)) {
hvx_scale_offset_f32_aa(dst, src, n, scale, offset);
} else {
hvx_scale_offset_f32_uu(dst, src, n, scale, offset);
}
}
float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems); float hvx_sum_of_squares_f32(const uint8_t * restrict src, const int num_elems);
void hvx_mul_f32(const uint8_t * restrict src0, void hvx_mul_f32(const uint8_t * restrict src0,
@ -1338,6 +1090,7 @@ void hvx_sub_f32_opt(const uint8_t * restrict src0,
uint8_t * restrict dst, uint8_t * restrict dst,
const int num_elems); const int num_elems);
void hvx_sub_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems); void hvx_sub_scalar_f32(const uint8_t * restrict src, const float val, uint8_t * restrict dst, const int num_elems);
void hvx_scale_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, const float scale);
void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems); void hvx_inverse_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems);
void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems); void hvx_sigmoid_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems);
void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate); void hvx_exp_f32(const uint8_t * restrict src, uint8_t * restrict dst, const int num_elems, bool negate);

View File

@ -443,45 +443,6 @@ static void proc_matmul_req(struct htp_context * ctx,
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
} }
static void proc_get_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
struct dspqueue_buffer rsp_bufs[1];
// We had written to the output buffer, we'd also need to flush it
rsp_bufs[0].fd = bufs[2].fd;
rsp_bufs[0].ptr = bufs[2].ptr;
rsp_bufs[0].offset = bufs[2].offset;
rsp_bufs[0].size = bufs[2].size;
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
// Setup Op context
struct htp_ops_context octx = { 0 };
octx.ctx = ctx;
octx.src0 = req->src0;
octx.src1 = req->src1;
octx.dst = req->dst;
octx.flags = req->flags;
octx.op = req->op;
// Update data pointers
octx.src0.data = (uint32_t) bufs[0].ptr;
octx.src1.data = (uint32_t) bufs[1].ptr;
octx.dst.data = (uint32_t) bufs[2].ptr;
octx.n_threads = ctx->n_threads;
struct profile_data prof;
profile_start(&prof);
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
rsp_status = op_get_rows(&octx);
vtcm_release(ctx);
}
profile_stop(&prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
static void proc_matmul_id_req(struct htp_context * ctx, static void proc_matmul_id_req(struct htp_context * ctx,
struct htp_general_req * req, struct htp_general_req * req,
struct dspqueue_buffer * bufs, struct dspqueue_buffer * bufs,
@ -707,7 +668,7 @@ static void proc_rope_req(struct htp_context * ctx,
uint32_t n_bufs) { uint32_t n_bufs) {
struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS]; struct dspqueue_buffer rsp_bufs[HTP_MAX_PACKET_BUFFERS];
int write_idx = n_bufs - 1; int write_idx = (n_bufs == 4) ? 3 : 2;
// We had written to the output buffer, we'd also need to flush it // We had written to the output buffer, we'd also need to flush it
rsp_bufs[0].fd = bufs[write_idx].fd; rsp_bufs[0].fd = bufs[write_idx].fd;
@ -755,102 +716,6 @@ static void proc_rope_req(struct htp_context * ctx,
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof); send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
} }
static void proc_set_rows_req(struct htp_context * ctx, struct htp_general_req * req, struct dspqueue_buffer * bufs) {
struct dspqueue_buffer rsp_bufs[1];
// We had written to the output buffer, we'd also need to flush it
rsp_bufs[0].fd = bufs[2].fd;
rsp_bufs[0].ptr = bufs[2].ptr;
rsp_bufs[0].offset = bufs[2].offset;
rsp_bufs[0].size = bufs[2].size;
rsp_bufs[0].flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
// Setup Op context
struct htp_ops_context octx = { 0 };
octx.ctx = ctx;
octx.src0 = req->src0;
octx.src1 = req->src1;
octx.dst = req->dst;
octx.flags = req->flags;
octx.op = req->op;
// Update data pointers
octx.src0.data = (uint32_t) bufs[0].ptr;
octx.src1.data = (uint32_t) bufs[1].ptr;
octx.dst.data = (uint32_t) bufs[2].ptr;
octx.n_threads = ctx->n_threads;
struct profile_data prof;
profile_start(&prof);
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
rsp_status = op_set_rows(&octx);
vtcm_release(ctx);
}
profile_stop(&prof);
send_htp_rsp(ctx, req->op, rsp_status, rsp_bufs, 1, &prof);
}
static void proc_flash_attn_ext_req(struct htp_context * ctx,
struct htp_general_req * req,
struct dspqueue_buffer * bufs,
uint32_t n_bufs) {
// Setup Op context
struct htp_ops_context octx;
memset(&octx, 0, sizeof(octx));
octx.ctx = ctx;
octx.n_threads = ctx->n_threads;
octx.src0 = req->src0;
octx.src1 = req->src1;
octx.src2 = req->src2;
octx.src3 = req->src3;
octx.src4 = req->src4;
octx.dst = req->dst;
octx.flags = req->flags;
octx.op = req->op;
memcpy(octx.op_params, req->op_params, sizeof(octx.op_params));
// Update data pointers
octx.src0.data = (uint32_t) bufs[0].ptr;
octx.src1.data = (uint32_t) bufs[1].ptr;
octx.src2.data = (uint32_t) bufs[2].ptr;
int last_buf = 3;
if (octx.src3.ne[0]) {
octx.src3.data = (uint32_t) bufs[last_buf++].ptr; // mask is valid
}
if (octx.src4.ne[0]) {
octx.src4.data = (uint32_t) bufs[last_buf++].ptr; // sinks is valid
}
octx.dst.data = (uint32_t) bufs[last_buf].ptr;
struct profile_data prof;
profile_start(&prof);
uint32_t rsp_status = HTP_STATUS_INTERNAL_ERR;
if (vtcm_acquire(ctx) == AEE_SUCCESS) {
rsp_status = op_flash_attn_ext(&octx);
vtcm_release(ctx);
}
profile_stop(&prof);
struct dspqueue_buffer rsp_buf = bufs[last_buf];
rsp_buf.flags = (DSPQUEUE_BUFFER_FLAG_FLUSH_SENDER | // Flush HTP
DSPQUEUE_BUFFER_FLAG_INVALIDATE_RECIPIENT); // Invalidate CPU
send_htp_rsp(ctx, req->op, rsp_status, &bufs[last_buf], 1, &prof);
}
static void htp_packet_callback(dspqueue_t queue, int error, void * context) { static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
struct htp_context * ctx = (struct htp_context *) context; struct htp_context * ctx = (struct htp_context *) context;
@ -925,7 +790,6 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
break; break;
case HTP_OP_RMS_NORM: case HTP_OP_RMS_NORM:
case HTP_OP_SCALE:
if (n_bufs != 2) { if (n_bufs != 2) {
FARF(ERROR, "Bad unary-req buffer list"); FARF(ERROR, "Bad unary-req buffer list");
continue; continue;
@ -969,30 +833,6 @@ static void htp_packet_callback(dspqueue_t queue, int error, void * context) {
proc_rope_req(ctx, &req, bufs, n_bufs); proc_rope_req(ctx, &req, bufs, n_bufs);
break; break;
case HTP_OP_FLASH_ATTN_EXT:
if (!(n_bufs >= 4 && n_bufs <= 6)) {
FARF(ERROR, "Bad flash-attn-ext-req buffer list");
continue;
}
proc_flash_attn_ext_req(ctx, &req, bufs, n_bufs);
break;
case HTP_OP_SET_ROWS:
if (n_bufs != 3) {
FARF(ERROR, "Bad set-rows-req buffer list");
continue;
}
proc_set_rows_req(ctx, &req, bufs);
break;
case HTP_OP_GET_ROWS:
if (n_bufs != 3) {
FARF(ERROR, "Bad get-rows-req buffer list");
continue;
}
proc_get_rows_req(ctx, &req, bufs);
break;
default: default:
FARF(ERROR, "Unknown Op %u", req.op); FARF(ERROR, "Unknown Op %u", req.op);
break; break;

File diff suppressed because it is too large Load Diff

View File

@ -1,168 +0,0 @@
#pragma clang diagnostic ignored "-Wunused-variable"
#pragma clang diagnostic ignored "-Wunused-function"
#pragma clang diagnostic ignored "-Wunused-but-set-variable"
#ifdef HTP_DEBUG
# define FARF_HIGH 1
#endif
#include <HAP_farf.h>
#include <HAP_mem.h>
#include <HAP_perf.h>
#include <hexagon_protos.h>
#include <hexagon_types.h>
#include <math.h>
#include <string.h>
#define GGML_COMMON_DECL_C
#include "ggml-common.h"
#include "htp-ctx.h"
#include "htp-msg.h"
#include "htp-ops.h"
#include "hvx-utils.h"
#include "ops-utils.h"
#define set_rows_preamble \
const uint32_t ne00 = octx->src0.ne[0]; \
const uint32_t ne01 = octx->src0.ne[1]; \
const uint32_t ne02 = octx->src0.ne[2]; \
const uint32_t ne03 = octx->src0.ne[3]; \
\
const uint32_t ne10 = octx->src1.ne[0]; \
const uint32_t ne11 = octx->src1.ne[1]; \
const uint32_t ne12 = octx->src1.ne[2]; \
\
const uint32_t nb01 = octx->src0.nb[1]; \
const uint32_t nb02 = octx->src0.nb[2]; \
const uint32_t nb03 = octx->src0.nb[3]; \
\
const uint32_t nb10 = octx->src1.nb[0]; \
const uint32_t nb11 = octx->src1.nb[1]; \
const uint32_t nb12 = octx->src1.nb[2]; \
\
const uint32_t nb1 = octx->dst.nb[1]; \
const uint32_t nb2 = octx->dst.nb[2]; \
const uint32_t nb3 = octx->dst.nb[3]; \
\
const uint32_t ne1 = octx->dst.ne[1]; \
\
const uint32_t nr = ne01;
static int set_rows_thread_f32_f32(struct htp_ops_context * octx, const int nth, const int ith) {
set_rows_preamble;
// parallelize by rows of src0
const uint32_t dr = octx->src0_nrows_per_thread;
const uint32_t ir0 = dr * ith;
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
for (uint32_t i03 = 0; i03 < ne03; ++i03) {
for (uint32_t i02 = 0; i02 < ne02; ++i02) {
for (uint32_t i = ir0; i < ir1; ++i) {
const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
const uint32_t i10 = i;
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
if (i1 >= ne1) {
// ignore invalid indices
continue;
}
const uintptr_t src0_ptr = octx->src0.data + i*nb01 + i02*nb02 + i03*nb03;
const uintptr_t dst_ptr = octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3;
// copy row
hvx_copy_fp32_uu((uint8_t *)dst_ptr, (const uint8_t *)src0_ptr, ne00);
}
}
}
return HTP_STATUS_OK;
}
static int set_rows_thread_f16_f32(struct htp_ops_context * octx, const int nth, const int ith) {
set_rows_preamble;
// parallelize by rows of src0
const uint32_t dr = octx->src0_nrows_per_thread;
const uint32_t ir0 = dr * ith;
const uint32_t ir1 = (ir0 + dr < nr) ? (ir0 + dr) : nr;
const bool is_i32 = (octx->src1.type == HTP_TYPE_I32);
for (uint32_t i03 = 0; i03 < ne03; ++i03) {
for (uint32_t i02 = 0; i02 < ne02; ++i02) {
for (uint32_t i = ir0; i < ir1; ++i) {
const uint32_t i12 = fastmodulo(i03, ne12, &octx->set_rows_div_ne12);
const uint32_t i11 = fastmodulo(i02, ne11, &octx->set_rows_div_ne11);
const uint32_t i10 = i;
const uintptr_t src1_addr = octx->src1.data + i10*nb10 + i11*nb11 + i12*nb12;
uint32_t i1 = is_i32 ? *(int32_t *)src1_addr : *(int64_t *)src1_addr;
if (i1 >= ne1) {
// ignore invalid indices
continue;
}
const uint8_t* src0_ptr = (const uint8_t *) octx->src0.data + i*nb01 + i02*nb02 + i03*nb03;
uint8_t* dst_ptr = (uint8_t *) octx->dst.data + i1*nb1 + i02*nb2 + i03*nb3;
hvx_copy_fp16_fp32_uu(dst_ptr, src0_ptr, ne00);
}
}
}
return HTP_STATUS_OK;
}
static void set_rows_work_f16_f32(unsigned int n, unsigned int i, void *data) {
set_rows_thread_f16_f32((struct htp_ops_context *) data, n, i);
}
static void set_rows_work_f32_f32(unsigned int n, unsigned int i, void *data) {
set_rows_thread_f32_f32((struct htp_ops_context *) data, n, i);
}
int op_set_rows(struct htp_ops_context * octx) {
set_rows_preamble;
if (octx->src0.type != HTP_TYPE_F32) {
return HTP_STATUS_NO_SUPPORT;
}
if (octx->dst.type != HTP_TYPE_F32 && octx->dst.type != HTP_TYPE_F16) {
return HTP_STATUS_NO_SUPPORT;
}
if (octx->src1.type != HTP_TYPE_I32 && octx->src1.type != HTP_TYPE_I64) {
return HTP_STATUS_NO_SUPPORT;
}
if (octx->flags & HTP_OPFLAGS_SKIP_COMPUTE) {
return HTP_STATUS_OK;
}
octx->set_rows_div_ne12 = init_fastdiv_values(ne12);
octx->set_rows_div_ne11 = init_fastdiv_values(ne11);
const uint32_t n_jobs = MIN(nr, octx->n_threads);
octx->src0_nrows_per_thread = (nr + n_jobs - 1) / n_jobs;
switch(octx->dst.type) {
case HTP_TYPE_F32:
worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f32_f32, octx, n_jobs);
break;
case HTP_TYPE_F16:
worker_pool_run_func(octx->ctx->worker_pool, set_rows_work_f16_f32, octx, n_jobs);
break;
default:
return HTP_STATUS_NO_SUPPORT;
}
return HTP_STATUS_OK;
}

View File

@ -238,7 +238,7 @@ static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ct
hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale, hvx_fast_softmax_prep_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale,
(const uint8_t *) mp_f32, slope); (const uint8_t *) mp_f32, slope);
} else { } else {
hvx_scale_f32((uint8_t *) wp0, (const uint8_t *) sp, ne00, softmax_ctx->scale); hvx_scale_f32((const uint8_t *) sp, (uint8_t *) wp0, ne00, softmax_ctx->scale);
if (mp_f32) { if (mp_f32) {
if (softmax_ctx->use_f16) { if (softmax_ctx->use_f16) {
for (int i = 0; i < ne00; ++i) { for (int i = 0; i < ne00; ++i) {
@ -258,7 +258,7 @@ static void softmax_htp_f32(int nth, int ith, struct softmax_th_ctx * softmax_ct
float max = hvx_self_max_f32((const uint8_t *) wp0, ne00); float max = hvx_self_max_f32((const uint8_t *) wp0, ne00);
float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max); float sum = hvx_softmax_f32((const uint8_t *) wp0, (uint8_t *) wp2, (uint8_t *) wp1, ne00, max);
sum = sum > 0.0 ? (1.0 / sum) : 1; sum = sum > 0.0 ? (1.0 / sum) : 1;
hvx_scale_f32((uint8_t *) dp, (const uint8_t *) wp2, ne00, sum); hvx_scale_f32((const uint8_t *) wp2, (uint8_t *) dp, ne00, sum);
} }
} }
} }

View File

@ -83,31 +83,6 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
} }
} }
static void scale_htp_f32(const float * restrict src,
float * restrict dst,
uint8_t * restrict spad,
const uint32_t num_rows,
const uint32_t row_elems,
const size_t row_size,
int32_t * op_params,
int opt_path) {
float scale = 0.f;
float bias = 0.f;
memcpy(&scale, &op_params[0], sizeof(float));
memcpy(&bias, &op_params[1], sizeof(float));
for (uint32_t ir = 0; ir < num_rows; ir++) {
const float * restrict src_local = src + (ir * row_elems);
float * restrict dst_local = dst + (ir * row_elems);
if (ir + 1 < num_rows) {
htp_l2fetch(src_local + row_elems, 1, row_size, row_size);
}
hvx_scale_offset_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale, bias);
}
}
static void rms_norm_htp_f32(const float * restrict src, static void rms_norm_htp_f32(const float * restrict src,
float * restrict dst, float * restrict dst,
uint8_t * restrict spad, uint8_t * restrict spad,
@ -135,7 +110,7 @@ static void rms_norm_htp_f32(const float * restrict src,
const float mean = sum / row_elems; const float mean = sum / row_elems;
const float scale = 1.0f / sqrtf(mean + epsilon); const float scale = 1.0f / sqrtf(mean + epsilon);
hvx_scale_f32((uint8_t *) dst_local, (const uint8_t *) src_local, row_elems, scale); hvx_scale_f32((const uint8_t *) src_local, (uint8_t *) dst_local, row_elems, scale);
} }
} }
} }
@ -187,9 +162,6 @@ static void unary_job_f32_per_thread(const struct htp_tensor * src,
case HTP_OP_RMS_NORM: case HTP_OP_RMS_NORM:
rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path); rms_norm_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
break; break;
case HTP_OP_SCALE:
scale_htp_f32(src_th, dst_th, spad_th, src0_end_row - src0_start_row, ne0, nb1, op_params, opt_path);
break;
default: default:
break; break;
@ -223,10 +195,6 @@ static int execute_op_unary_f32(struct htp_ops_context * octx) {
unary_op_func = unary_job_dispatcher_f32; unary_op_func = unary_job_dispatcher_f32;
op_type = "rmsnorm-f32"; op_type = "rmsnorm-f32";
break; break;
case HTP_OP_SCALE:
unary_op_func = unary_job_dispatcher_f32;
op_type = "scale-f32";
break;
default: default:
FARF(ERROR, "Unsupported unary Op %u\n", octx->op); FARF(ERROR, "Unsupported unary Op %u\n", octx->op);

View File

@ -1684,60 +1684,3 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd(ggm
return res; return res;
} }
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset(ggml_metal_library_t lib, const ggml_tensor * op) {
GGML_ASSERT(op->type == GGML_TYPE_I64);
char base[256];
char name[256];
snprintf(base, 256, "kernel_memset_%s", ggml_type_name(op->type));
snprintf(name, 256, "%s", base);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
}
return res;
}
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal(ggml_metal_library_t lib, const ggml_tensor * op) {
assert(op->op == GGML_OP_COUNT_EQUAL);
GGML_TENSOR_LOCALS(int64_t, ne0, op->src[0], ne);
GGML_ASSERT(op->src[0]->type == op->src[1]->type);
GGML_ASSERT(op->src[0]->type == GGML_TYPE_I32);
GGML_ASSERT(op->type == GGML_TYPE_I64);
// note: the kernel only supports i32 output due to metal atomic add only supporting atomic_int
GGML_ASSERT(ggml_nelements(op->src[0]) < (1LL << 31));
char base[256];
char name[256];
int nsg = 1;
while (32*nsg < ne00 && nsg < 32) {
nsg *= 2;
}
snprintf(base, 256, "kernel_count_equal_%s", ggml_type_name(op->src[0]->type));
snprintf(name, 256, "%s_nsg=%d", base, nsg);
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
if (!res.pipeline) {
ggml_metal_cv_t cv = ggml_metal_cv_init();
ggml_metal_cv_set_int16(cv, nsg, FC_COUNT_EQUAL + 0);
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
ggml_metal_cv_free(cv);
}
res.smem = 32 * sizeof(int32_t);
res.nsg = nsg;
return res;
}

View File

@ -147,8 +147,6 @@ struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_arange
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op); struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal (ggml_metal_library_t lib, const struct ggml_tensor * op);
struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad( struct ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_flash_attn_ext_pad(
ggml_metal_library_t lib, ggml_metal_library_t lib,

View File

@ -1023,11 +1023,6 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]); return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_L2_NORM: case GGML_OP_L2_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0])); return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
case GGML_OP_COUNT_EQUAL:
return has_simdgroup_reduction &&
op->src[0]->type == GGML_TYPE_I32 &&
op->src[1]->type == GGML_TYPE_I32 &&
op->type == GGML_TYPE_I64;
case GGML_OP_ARGMAX: case GGML_OP_ARGMAX:
return has_simdgroup_reduction; return has_simdgroup_reduction;
case GGML_OP_NORM: case GGML_OP_NORM:

View File

@ -78,7 +78,6 @@
#define FC_MUL_MM 700 #define FC_MUL_MM 700
#define FC_ROPE 800 #define FC_ROPE 800
#define FC_SSM_CONV 900 #define FC_SSM_CONV 900
#define FC_COUNT_EQUAL 1000
// op-specific constants // op-specific constants
#define OP_FLASH_ATTN_EXT_NQPTG 8 #define OP_FLASH_ATTN_EXT_NQPTG 8
@ -895,25 +894,6 @@ typedef struct {
float step; float step;
} ggml_metal_kargs_arange; } ggml_metal_kargs_arange;
typedef struct {
int64_t val;
} ggml_metal_kargs_memset;
typedef struct {
int32_t ne00;
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb00;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
uint64_t nb10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
} ggml_metal_kargs_count_equal;
typedef struct { typedef struct {
int32_t k0; int32_t k0;
int32_t k1; int32_t k1;

View File

@ -448,11 +448,7 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
{ {
n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx); n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);
} break; } break;
case GGML_OP_COUNT_EQUAL: default:
{
n_fuse = ggml_metal_op_count_equal(ctx, idx);
} break;
default:
{ {
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op)); GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
@ -2181,11 +2177,7 @@ size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) {
const bool has_mask = op->src[3] != nullptr; const bool has_mask = op->src[3] != nullptr;
// note: the non-vec kernel requires more extra memory, so always reserve for it if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
GGML_ASSERT(OP_FLASH_ATTN_EXT_NCPSG >= OP_FLASH_ATTN_EXT_VEC_NCPSG);
//if (ggml_metal_op_flash_attn_ext_use_vec(op)) {
if (false) {
// note: always reserve the padding space to avoid graph reallocations // note: always reserve the padding space to avoid graph reallocations
//const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0; //const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0;
const bool has_kvpad = true; const bool has_kvpad = true;
@ -4098,64 +4090,3 @@ int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
return 1; return 1;
} }
int ggml_metal_op_count_equal(ggml_metal_op_t ctx, int idx) {
ggml_tensor * op = ctx->node(idx);
ggml_metal_library_t lib = ctx->lib;
ggml_metal_encoder_t enc = ctx->enc;
GGML_TENSOR_LOCALS(int32_t, ne0, op->src[0], ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb);
{
ggml_metal_kargs_memset args = { /*.val =*/ 0 };
auto pipeline = ggml_metal_library_get_pipeline_memset(lib, op);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 1);
ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1);
}
ggml_metal_op_concurrency_reset(ctx);
{
ggml_metal_kargs_count_equal args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.ne03 =*/ ne03,
/*.nb00 =*/ nb00,
/*.nb01 =*/ nb01,
/*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.nb10 =*/ nb10,
/*.nb11 =*/ nb11,
/*.nb12 =*/ nb12,
/*.nb13 =*/ nb13,
};
auto pipeline = ggml_metal_library_get_pipeline_count_equal(lib, op);
const size_t smem = pipeline.smem;
const int nth = 32*pipeline.nsg;
GGML_ASSERT(nth <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2);
ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3);
ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0);
ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, 1, 1);
}
return 1;
}

View File

@ -87,7 +87,6 @@ int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx); int ggml_metal_op_tri (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx); int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx);
int ggml_metal_op_count_equal (ggml_metal_op_t ctx, int idx);
#ifdef __cplusplus #ifdef __cplusplus
} }

View File

@ -1790,7 +1790,6 @@ kernel void kernel_op_sum_f32(
return; return;
} }
// TODO: become function constant
const uint nsg = (ntg.x + 31) / 32; const uint nsg = (ntg.x + 31) / 32;
float sumf = 0; float sumf = 0;
@ -9558,6 +9557,9 @@ template [[host_name("kernel_mul_mm_iq4_xs_f32")]] kernel mul_mm_t kernel_mul_m
template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>; template [[host_name("kernel_mul_mm_f32_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>; template [[host_name("kernel_mul_mm_f16_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mm_bf16_f16")]] kernel mul_mm_t kernel_mul_mm<bfloat, bfloat4x4, simdgroup_bfloat8x8, half, half2x4, simdgroup_half8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, half, half2x4>;
#endif
template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>; template [[host_name("kernel_mul_mm_q4_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>; template [[host_name("kernel_mul_mm_q4_1_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>; template [[host_name("kernel_mul_mm_q5_0_f16")]] kernel mul_mm_t kernel_mul_mm<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
@ -9613,6 +9615,9 @@ template [[host_name("kernel_mul_mm_id_iq4_xs_f32")]] kernel mul_mm_id kernel_m
template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>; template [[host_name("kernel_mul_mm_id_f32_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, float4x4, 1, dequantize_f32, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>; template [[host_name("kernel_mul_mm_id_f16_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, half4x4, 1, dequantize_f16, half, half4x4, half, half2x4>;
#if defined(GGML_METAL_HAS_BF16)
template [[host_name("kernel_mul_mm_id_bf16_f16")]] kernel mul_mm_id kernel_mul_mm_id<bfloat, bfloat4x4, simdgroup_bfloat8x8, half, half2x4, simdgroup_half8x8, bfloat4x4, 1, dequantize_bf16, bfloat, bfloat4x4, half, half2x4>;
#endif
template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>; template [[host_name("kernel_mul_mm_id_q4_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_0, 2, dequantize_q4_0, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>; template [[host_name("kernel_mul_mm_id_q4_1_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q4_1, 2, dequantize_q4_1, float, float4x4, half, half2x4>;
template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>; template [[host_name("kernel_mul_mm_id_q5_0_f16")]] kernel mul_mm_id kernel_mul_mm_id<half, half4x4, simdgroup_half8x8, half, half2x4, simdgroup_half8x8, block_q5_0, 2, dequantize_q5_0, float, float4x4, half, half2x4>;
@ -9915,75 +9920,3 @@ kernel void kernel_opt_step_sgd_f32(
x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid]; x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
} }
template<typename T>
kernel void kernel_memset(
constant ggml_metal_kargs_fill & args,
device T * dst,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = args.val;
}
typedef decltype(kernel_memset<int64_t>) kernel_memset_t;
template [[host_name("kernel_memset_i64")]] kernel kernel_memset_t kernel_memset<int64_t>;
constant short FC_count_equal_nsg [[function_constant(FC_COUNT_EQUAL + 0)]];
template<typename T>
kernel void kernel_count_equal(
constant ggml_metal_kargs_count_equal & args,
device const char * src0,
device const char * src1,
device atomic_int * dst,
threadgroup int32_t * shmem_i32 [[threadgroup(0)]],
uint3 tgpig[[threadgroup_position_in_grid]],
ushort3 tpitg[[thread_position_in_threadgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]],
ushort3 ntg[[threads_per_threadgroup]]) {
const short NSG = FC_count_equal_nsg;
const int i3 = tgpig.z;
const int i2 = tgpig.y;
const int i1 = tgpig.x;
if (i3 >= args.ne03 || i2 >= args.ne02 || i1 >= args.ne01) {
return;
}
int sum = 0;
device const char * base0 = src0 + i1*args.nb01 + i2*args.nb02 + i3*args.nb03;
device const char * base1 = src1 + i1*args.nb11 + i2*args.nb12 + i3*args.nb13;
for (int64_t i0 = tpitg.x; i0 < args.ne00; i0 += ntg.x) {
const T v0 = *(device const T *)(base0 + i0*args.nb00);
const T v1 = *(device const T *)(base1 + i0*args.nb10);
sum += (v0 == v1);
}
sum = simd_sum(sum);
if (tiisg == 0) {
shmem_i32[sgitg] = sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (sgitg == 0) {
float v = 0.0f;
if (tpitg.x < NSG) {
v = shmem_i32[tpitg.x];
}
float total = simd_sum(v);
if (tpitg.x == 0) {
atomic_fetch_add_explicit(dst, (int32_t) total, memory_order_relaxed);
}
}
}
typedef decltype(kernel_count_equal<int32_t>) kernel_count_equal_t;
template [[host_name("kernel_count_equal_i32")]] kernel kernel_count_equal_t kernel_count_equal<int32_t>;

View File

@ -1517,12 +1517,10 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false); struct ggml_cgraph * graph = ggml_new_graph_custom(ctx, n_nodes, false);
graph->n_nodes = n_nodes; graph->n_nodes = n_nodes;
std::unordered_map<uint64_t, const rpc_tensor*> tensor_ptrs; std::unordered_map<uint64_t, const rpc_tensor*> tensor_ptrs;
tensor_ptrs.reserve(n_tensors);
for (uint32_t i = 0; i < n_tensors; i++) { for (uint32_t i = 0; i < n_tensors; i++) {
tensor_ptrs.emplace(tensors[i].id, &tensors[i]); tensor_ptrs[tensors[i].id] = &tensors[i];
} }
std::unordered_map<uint64_t, ggml_tensor*> tensor_map; std::unordered_map<uint64_t, ggml_tensor*> tensor_map;
tensor_map.reserve(n_nodes);
for (uint32_t i = 0; i < n_nodes; i++) { for (uint32_t i = 0; i < n_nodes; i++) {
int64_t id; int64_t id;
memcpy(&id, &nodes[i], sizeof(id)); memcpy(&id, &nodes[i], sizeof(id));

View File

@ -231,4 +231,3 @@ if (GGML_SYCL_DEVICE_ARCH)
target_compile_options(ggml-sycl PRIVATE -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH}) target_compile_options(ggml-sycl PRIVATE -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH})
target_link_options(ggml-sycl PRIVATE -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH}) target_link_options(ggml-sycl PRIVATE -Xsycl-target-backend --offload-arch=${GGML_SYCL_DEVICE_ARCH})
endif() endif()

View File

@ -434,15 +434,8 @@ static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax_norm{ GGM
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE, GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV, GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
GGML_OP_RESHAPE }; GGML_OP_RESHAPE };
static constexpr std::initializer_list<ggml_op> topk_moe_sigmoid_norm_bias{ GGML_OP_UNARY, GGML_OP_RESHAPE, GGML_OP_ADD,
GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS,
GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP,
GGML_OP_DIV, GGML_OP_RESHAPE };
static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT, static constexpr std::initializer_list<ggml_op> topk_moe_early_softmax { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS }; GGML_OP_VIEW, GGML_OP_GET_ROWS };
static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGML_OP_ARGSORT, GGML_OP_VIEW, static constexpr std::initializer_list<ggml_op> topk_moe_late_softmax { GGML_OP_ARGSORT, GGML_OP_VIEW,
GGML_OP_GET_ROWS, GGML_OP_RESHAPE, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE }; GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
@ -471,32 +464,6 @@ static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softma
{ 9, 0, 8 }, // reshape->src[0] == div { 9, 0, 8 }, // reshape->src[0] == div
}; };
//node #436 ( UNARY): ffn_moe_probs-10 ( 256K) [Vulka ] use=2: ffn_moe_logits-10 ( 256K) [Vulka ]
//node #437 ( RESHAPE): ffn_moe_probs-10 (re ( 256K) [Vulka ] use=1: ffn_moe_probs-10 ( 256K) [Vulka ]
//node #438 ( ADD): ffn_moe_probs_biased ( 256K) [Vulka ] use=1: ffn_moe_probs-10 ( 256K) [Vulka ] blk.10.exp_probs_b.b ( 0K) [Vulka ]
//node #439 ( ARGSORT): ffn_moe_argsort-10 ( 256K) [Vulka ] use=1: ffn_moe_probs_biased ( 256K) [Vulka ]
//node #440 ( VIEW): ffn_moe_topk-10 ( 255K) [Vulka ] use=3: ffn_moe_argsort-10 ( 256K) [Vulka ]
//node #441 ( GET_ROWS): ffn_moe_weights-10 ( 12K) [Vulka ] use=1: ffn_moe_probs-10 (re ( 256K) [Vulka ] ffn_moe_topk-10 ( 255K) [Vulka ]
//node #442 ( RESHAPE): ffn_moe_weights-10 ( ( 12K) [Vulka ] use=2: ffn_moe_weights-10 ( 12K) [Vulka ]
//node #443 ( SUM_ROWS): ffn_moe_weights_sum- ( 2K) [Vulka ] use=1: ffn_moe_weights-10 ( ( 12K) [Vulka ]
//node #444 ( CLAMP): ffn_moe_weights_sum_ ( 2K) [Vulka ] use=1: ffn_moe_weights_sum- ( 2K) [Vulka ]
//node #445 ( DIV): ffn_moe_weights_norm ( 12K) [Vulka ] use=1: ffn_moe_weights-10 ( ( 12K) [Vulka ] ffn_moe_weights_sum_ ( 2K) [Vulka ]
//node #446 ( RESHAPE): ffn_moe_weights_norm ( 12K) [Vulka ] use=1: ffn_moe_weights_norm ( 12K) [Vulka ]
static constexpr std::initializer_list<std::array<int, 3>> topk_moe_sigmoid_norm_bias_edges {
{ 1, 0, 0 }, // reshape->src[0] == sigmoid
{ 2, 0, 0 }, // add->src[0] == sigmoid
{ 3, 0, 2 }, // argsort->src[0] == add
{ 4, 0, 3 }, // view->src[0] == argsort
{ 5, 0, 1 }, // get_rows->src[0] == reshape
{ 5, 1, 4 }, // get_rows->src[1] == view
{ 6, 0, 5 }, // reshape->src[0] == get_rows
{ 7, 0, 6 }, // sum_rows->src[0] == reshape
{ 8, 0, 7 }, // clamp->src[0] == sum_rows
{ 9, 0, 6 }, // div->src[0] == reshape
{ 9, 1, 8 }, // div->src[1] == clamp
{10, 0, 9 }, // reshape->src[0] == div
};
// same as early_softmax_norm but ending after the get_rows // same as early_softmax_norm but ending after the get_rows
static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_edges { static constexpr std::initializer_list<std::array<int, 3>> topk_moe_early_softmax_edges {
{ 1, 0, 0 }, // reshape->src[0] == softmax { 1, 0, 0 }, // reshape->src[0] == softmax
@ -524,10 +491,16 @@ enum topk_moe_mode {
TOPK_MOE_EARLY_SOFTMAX, TOPK_MOE_EARLY_SOFTMAX,
TOPK_MOE_EARLY_SOFTMAX_NORM, TOPK_MOE_EARLY_SOFTMAX_NORM,
TOPK_MOE_LATE_SOFTMAX, TOPK_MOE_LATE_SOFTMAX,
TOPK_MOE_SIGMOID_NORM_BIAS,
TOPK_MOE_COUNT, TOPK_MOE_COUNT,
}; };
static topk_moe_mode ggml_vk_num_additional_ops_to_topk_moe_mode(uint32_t num) {
topk_moe_mode mode = num == topk_moe_early_softmax_norm.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX_NORM :
num == topk_moe_early_softmax.size() - 1 ? TOPK_MOE_EARLY_SOFTMAX :
TOPK_MOE_LATE_SOFTMAX;
return mode;
}
static constexpr std::initializer_list<std::array<int, 3>> rope_view_set_rows_edges { static constexpr std::initializer_list<std::array<int, 3>> rope_view_set_rows_edges {
{ 1, 0, 0 }, // view->src[0] == rope { 1, 0, 0 }, // view->src[0] == rope
{ 2, 0, 1 }, // set_rows->src[0] == view { 2, 0, 1 }, // set_rows->src[0] == view
@ -550,8 +523,6 @@ struct vk_device_struct {
uint64_t max_memory_allocation_size; uint64_t max_memory_allocation_size;
uint64_t max_buffer_size; uint64_t max_buffer_size;
uint64_t suballocation_block_size; uint64_t suballocation_block_size;
uint64_t min_imported_host_pointer_alignment;
bool external_memory_host {};
bool fp16; bool fp16;
bool bf16; bool bf16;
bool pipeline_robustness; bool pipeline_robustness;
@ -767,9 +738,6 @@ struct vk_device_struct {
vk_pipeline pipeline_topk_f32[num_topk_pipelines]; vk_pipeline pipeline_topk_f32[num_topk_pipelines];
vk_pipeline pipeline_sum_rows_f32; vk_pipeline pipeline_sum_rows_f32;
vk_pipeline pipeline_cumsum_f32; vk_pipeline pipeline_cumsum_f32;
vk_pipeline pipeline_cumsum_small_f32;
vk_pipeline pipeline_cumsum_multipass1_f32;
vk_pipeline pipeline_cumsum_multipass2_f32;
vk_pipeline pipeline_argmax_f32; vk_pipeline pipeline_argmax_f32;
vk_pipeline pipeline_count_equal_i32; vk_pipeline pipeline_count_equal_i32;
std::map<vk_solve_tri_pipeline_state, vk_pipeline> pipeline_solve_tri_f32; std::map<vk_solve_tri_pipeline_state, vk_pipeline> pipeline_solve_tri_f32;
@ -798,7 +766,7 @@ struct vk_device_struct {
vk_pipeline pipeline_count_experts; vk_pipeline pipeline_count_experts;
// [2] is for whether to take n_experts from spec constant (0) or push constant (1) // [2] is for whether to take n_experts from spec constant (0) or push constant (1)
vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][2]; vk_pipeline pipeline_topk_moe[num_topk_moe_pipelines][TOPK_MOE_COUNT][2];
std::vector<vk_pipeline_ref> all_pipelines; std::vector<vk_pipeline_ref> all_pipelines;
@ -1213,11 +1181,6 @@ struct vk_op_topk_moe_push_constants {
uint32_t n_expert_used; uint32_t n_expert_used;
float clamp_min; float clamp_min;
float clamp_max; float clamp_max;
uint32_t gating_func;
uint32_t has_bias;
uint32_t with_norm;
float output_scale;
float output_bias;
}; };
struct vk_op_add_id_push_constants { struct vk_op_add_id_push_constants {
@ -1808,8 +1771,6 @@ struct ggml_backend_vk_context {
// Bit 'i' means nodes[start_of_fusion + i] writes to memory. // Bit 'i' means nodes[start_of_fusion + i] writes to memory.
// If there's no fusion, bit 0 is still set. // If there's no fusion, bit 0 is still set.
int fused_ops_write_mask {}; int fused_ops_write_mask {};
topk_moe_mode fused_topk_moe_mode {};
bool fused_topk_moe_scale {};
// for GGML_VK_PERF_LOGGER // for GGML_VK_PERF_LOGGER
std::unique_ptr<vk_perf_logger> perf_logger; std::unique_ptr<vk_perf_logger> perf_logger;
@ -2412,8 +2373,7 @@ static std::vector<uint32_t> ggml_vk_find_memory_properties(const vk::PhysicalDe
return indices; return indices;
} }
static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list<vk::MemoryPropertyFlags> & req_flags_list, static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list<vk::MemoryPropertyFlags> & req_flags_list) {
void *import_ptr = nullptr) {
VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags_list.begin()[0]) << ", " << to_string(req_flags_list.begin()[req_flags_list.size()-1]) << ")"); VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags_list.begin()[0]) << ", " << to_string(req_flags_list.begin()[req_flags_list.size()-1]) << ")");
if (size > device->max_buffer_size) { if (size > device->max_buffer_size) {
throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device buffer size limit"); throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device buffer size limit");
@ -2442,12 +2402,6 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
nullptr, nullptr,
}; };
vk::ExternalMemoryBufferCreateInfo external_memory_bci;
if (import_ptr) {
external_memory_bci.handleTypes = vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT;
buffer_create_info.setPNext(&external_memory_bci);
}
buf->buffer = device->device.createBuffer(buffer_create_info); buf->buffer = device->device.createBuffer(buffer_create_info);
vk::MemoryRequirements mem_req = device->device.getBufferMemoryRequirements(buf->buffer); vk::MemoryRequirements mem_req = device->device.getBufferMemoryRequirements(buf->buffer);
@ -2462,80 +2416,35 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
mem_flags_info.setPNext(&mem_priority_info); mem_flags_info.setPNext(&mem_priority_info);
} }
if (import_ptr) { for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
vk::MemoryHostPointerPropertiesEXT host_pointer_props; const auto & req_flags = *it;
try {
host_pointer_props = device->device.getMemoryHostPointerPropertiesEXT(vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT, import_ptr); const std::vector<uint32_t> memory_type_indices = ggml_vk_find_memory_properties(&mem_props, &mem_req, req_flags);
} catch (vk::SystemError& e) {
GGML_LOG_WARN("ggml_vulkan: Failed getMemoryHostPointerPropertiesEXT (%s)\n", e.what()); if (memory_type_indices.empty()) {
device->device.destroyBuffer(buf->buffer); continue;
return {};
} }
vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties(); buf->memory_property_flags = req_flags;
uint32_t memory_type_idx; bool done = false;
vk::MemoryPropertyFlags property_flags = *req_flags_list.begin();
for (memory_type_idx = 0; memory_type_idx < 32; ++memory_type_idx) {
if (!(host_pointer_props.memoryTypeBits & (1u << memory_type_idx))) {
continue;
}
if (!(mem_req.memoryTypeBits & (1u << memory_type_idx))) {
continue;
}
vk::MemoryType memory_type = mem_props.memoryTypes[memory_type_idx]; for (auto mtype_it = memory_type_indices.begin(); mtype_it != memory_type_indices.end(); mtype_it++) {
// check for visible+coherent+cached. Other flags (e.g. devicelocal) are allowed try {
if ((memory_type.propertyFlags & property_flags) == property_flags) { buf->device_memory = device->device.allocateMemory({ mem_req.size, *mtype_it, &mem_flags_info });
property_flags = memory_type.propertyFlags; done = true;
break; break;
} } catch (const vk::SystemError& e) {
} // loop and retry
if (memory_type_idx == 32) { // during last attempt throw the exception
GGML_LOG_WARN("ggml_vulkan: Memory type for host allocation not found\n"); if (it + 1 == req_flags_list.end() && mtype_it + 1 == memory_type_indices.end()) {
device->device.destroyBuffer(buf->buffer); device->device.destroyBuffer(buf->buffer);
return {}; throw e;
}
buf->memory_property_flags = mem_props.memoryTypes[memory_type_idx].propertyFlags;
try {
vk::ImportMemoryHostPointerInfoEXT import_info;
import_info.handleType = vk::ExternalMemoryHandleTypeFlagBits::eHostAllocationEXT;
import_info.pHostPointer = import_ptr;
import_info.setPNext(&mem_flags_info);
buf->device_memory = device->device.allocateMemory({ size, memory_type_idx, &import_info });
} catch (const vk::SystemError& e) {
}
} else {
for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
const auto & req_flags = *it;
const std::vector<uint32_t> memory_type_indices = ggml_vk_find_memory_properties(&mem_props, &mem_req, req_flags);
if (memory_type_indices.empty()) {
continue;
}
buf->memory_property_flags = req_flags;
bool done = false;
for (auto mtype_it = memory_type_indices.begin(); mtype_it != memory_type_indices.end(); mtype_it++) {
try {
buf->device_memory = device->device.allocateMemory({ mem_req.size, *mtype_it, &mem_flags_info });
done = true;
break;
} catch (const vk::SystemError& e) {
// loop and retry
// during last attempt throw the exception
if (it + 1 == req_flags_list.end() && mtype_it + 1 == memory_type_indices.end()) {
device->device.destroyBuffer(buf->buffer);
throw e;
}
} }
} }
}
if (done) { if (done) {
break; break;
}
} }
} }
@ -2546,12 +2455,8 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
buf->ptr = nullptr; buf->ptr = nullptr;
if (import_ptr) { if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
buf->ptr = import_ptr; buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE);
} else {
if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) {
buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE);
}
} }
device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0); device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0);
@ -2763,7 +2668,7 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
switch (src0_type) { switch (src0_type) {
case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ1_M:
lut_size = 2*2048 + 4*2048; lut_size = 2*2048;
break; break;
case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XXS:
lut_size = 8*256; lut_size = 8*256;
@ -2956,50 +2861,44 @@ static void ggml_vk_load_shaders(vk_device& device) {
const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1; const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1;
const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1; const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1;
const uint32_t s_warptile_wm = device->subgroup_size == 8 ? 8 : 32; l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 }; l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
s_warptile = { subgroup_size_32, 32, 32, 16, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 };
m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
s_warptile_mmq = { subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 };
// Integer MMQ has a smaller shared memory profile, but heavier register use // Integer MMQ has a smaller shared memory profile, but heavier register use
l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 }; l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 };
m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 }; m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 };
s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, 2, 1, 1, subgroup_size_8 }; s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 };
// K-quants use even more registers, mitigate by setting WMITER to 1 // K-quants use even more registers, mitigate by setting WMITER to 1
l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 }; l_warptile_mmq_int_k = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 1, 4, 4, 1, subgroup_size_8 };
m_warptile_mmq_int_k = { 128, 64, 64, 32, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 }; m_warptile_mmq_int_k = { 128, 64, 64, 32, subgroup_size_8, 32, 1, 2, 2, 1, subgroup_size_8 };
s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 1, 2, 1, 1, subgroup_size_8 }; s_warptile_mmq_int_k = { subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, subgroup_size_8 };
l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 }; l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 };
m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 }; m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 };
s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 }; s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 };
l_warptile_mmqid = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_8 }; l_warptile_mmqid = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_8 };
m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 }; m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 };
s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 }; s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 };
l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 }; l_warptile_mmqid_int = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, 4, 4, 1, mul_mat_subgroup_size_8 };
m_warptile_mmqid_int = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 }; m_warptile_mmqid_int = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, 2, 2, 1, mul_mat_subgroup_size_8 };
s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 2, 2, 1, 1, mul_mat_subgroup_size_8 }; s_warptile_mmqid_int = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, mul_mat_subgroup_size_8 };
l_warptile_mmqid_int_k = { 128, 128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 }; l_warptile_mmqid_int_k = { 128, 128, 128, 32, mul_mat_subgroup_size_16 * 2, 64, 1, 4, 4, 1, mul_mat_subgroup_size_16 };
m_warptile_mmqid_int_k = { 128, 64, 64, 32, mul_mat_subgroup_size_16, 32, 1, 2, 2, 1, mul_mat_subgroup_size_16 }; m_warptile_mmqid_int_k = { 128, 64, 64, 32, mul_mat_subgroup_size_16, 32, 1, 2, 2, 1, mul_mat_subgroup_size_16 };
s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 32, s_warptile_wm, 32, 1, 2, 1, 1, mul_mat_subgroup_size_16 }; s_warptile_mmqid_int_k = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 1, 2, 1, 1, mul_mat_subgroup_size_16 };
// chip specific tuning // chip specific tuning
if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) { if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) {
m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
m_warptile_mmqid = m_warptile_mmqid_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; m_warptile_mmqid = m_warptile_mmqid_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 };
} else if (device->vendor_id == VK_VENDOR_ID_INTEL && device->coopmat_support && device->architecture == INTEL_XE2) {
// Xe2/Xe3 with coopmat enabled - warptile performance tuning
l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
l_warptile_mmq = { 512, 128, 128, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 };
} }
l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 };
@ -3682,11 +3581,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
m_wg_denoms = { 64, 64, 1 }; m_wg_denoms = { 64, 64, 1 };
s_wg_denoms = { 32, 32, 1 }; s_wg_denoms = { 32, 32, 1 };
if (device->vendor_id == VK_VENDOR_ID_INTEL && device->architecture == INTEL_XE2) {
// Xe2/Xe3 - bf16 warptile performance tuning
l_warptile = { 512, 128, 128, 16, subgroup_size_8, 32, 2, 4, 4, 1, subgroup_size_8 };
}
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0);
CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0); CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, mul_mat_id_param_count, _id, 0);
} }
@ -3699,7 +3593,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
uint32_t rm_kq = 2; uint32_t rm_kq = 2;
uint32_t rm_stdq_int = 1; uint32_t rm_stdq_int = 1;
uint32_t rm_kq_int = 1; uint32_t rm_kq_int = 1;
auto const &rm_iq_int = [](uint32_t i) { return i == 0 ? 8u : 4u; };
if (device->vendor_id == VK_VENDOR_ID_AMD) { if (device->vendor_id == VK_VENDOR_ID_AMD) {
if (device->architecture == AMD_GCN) { if (device->architecture == AMD_GCN) {
rm_stdq = 2; rm_stdq = 2;
@ -3803,10 +3696,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_q8_1_f32", arr_dmmv_q4_k_q8_1_f32_len[reduc], arr_dmmv_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_q8_1_f32", arr_dmmv_q4_k_q8_1_f32_len[reduc], arr_dmmv_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_q8_1_f32", arr_dmmv_q5_k_q8_1_f32_len[reduc], arr_dmmv_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_q8_1_f32", arr_dmmv_q5_k_q8_1_f32_len[reduc], arr_dmmv_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_q8_1_f32", arr_dmmv_q6_k_q8_1_f32_len[reduc], arr_dmmv_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_q8_1_f32", arr_dmmv_q6_k_q8_1_f32_len[reduc], arr_dmmv_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_q8_1_f32", arr_dmmv_iq1_s_q8_1_f32_len[reduc], arr_dmmv_iq1_s_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(i), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(i), i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_q8_1_f32", arr_dmmv_iq1_m_q8_1_f32_len[reduc], arr_dmmv_iq1_m_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(i), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(i), i+1}, 1, true, use_subgroups, subgroup_size_int);
} }
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT #endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
} }
@ -3853,9 +3742,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", arr_dmmv_id_q4_k_q8_1_f32_len[reduc], arr_dmmv_id_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", arr_dmmv_id_q4_k_q8_1_f32_len[reduc], arr_dmmv_id_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", arr_dmmv_id_q5_k_q8_1_f32_len[reduc], arr_dmmv_id_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", arr_dmmv_id_q5_k_q8_1_f32_len[reduc], arr_dmmv_id_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", arr_dmmv_id_q6_k_q8_1_f32_len[reduc], arr_dmmv_id_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", arr_dmmv_id_q6_k_q8_1_f32_len[reduc], arr_dmmv_id_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_q8_1_f32", arr_dmmv_id_iq1_s_q8_1_f32_len[reduc], arr_dmmv_id_iq1_s_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_q8_1_f32", arr_dmmv_id_iq1_m_q8_1_f32_len[reduc], arr_dmmv_id_iq1_m_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_iq_int(0), 1, 1}, {wg_size_subgroup_int, 1*rm_iq_int(0)}, 1, true, use_subgroups, subgroup_size_int);
} }
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT #endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
} }
@ -3863,7 +3749,6 @@ static void ggml_vk_load_shaders(vk_device& device) {
#if !defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) #if !defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
GGML_UNUSED(rm_stdq_int); GGML_UNUSED(rm_stdq_int);
GGML_UNUSED(rm_kq_int); GGML_UNUSED(rm_kq_int);
GGML_UNUSED(rm_iq_int);
#endif #endif
// dequant shaders // dequant shaders
@ -4250,11 +4135,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
const uint32_t cumsum_elem_per_thread = (device->vendor_id == VK_VENDOR_ID_AMD || device->vendor_id == VK_VENDOR_ID_INTEL) ? 2 : 4; ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size }, 1, true, true, device->subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_cumsum_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 256, device->subgroup_size, cumsum_elem_per_thread }, 1, true, true, device->subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_cumsum_small_f32, "cumsum_f32", cumsum_f32_len, cumsum_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { 128, device->subgroup_size, 1 }, 1, true, true, device->subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_cumsum_multipass1_f32, "cumsum_multipass1_f32", cumsum_multipass1_f32_len, cumsum_multipass1_f32_data, "main", 3, sizeof(vk_op_sum_rows_push_constants), {256, 1, 1}, { 256, device->subgroup_size }, 1, true, true, device->subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_cumsum_multipass2_f32, "cumsum_multipass2_f32", cumsum_multipass2_f32_len, cumsum_multipass2_f32_data, "main", 3, sizeof(vk_op_sum_rows_push_constants), {256, 1, 1}, { 256, device->subgroup_size }, 1, true, true, device->subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
@ -4410,7 +4291,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
for (uint32_t use_push = 0; use_push < 2; ++use_push) { for (uint32_t use_push = 0; use_push < 2; ++use_push) {
for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) { for (uint32_t i = 0; i < num_topk_moe_pipelines; ++i) {
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][use_push], "topk_moe_f32_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 4, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, use_push}, 1, true, true, device->subgroup_size); ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX][use_push], "topk_moe_f32_early_softmax_"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 0, use_push}, 1, true, true, device->subgroup_size);
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_EARLY_SOFTMAX_NORM][use_push], "topk_moe_f32_early_softmax_norm"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 1, 0, use_push}, 1, true, true, device->subgroup_size);
ggml_vk_create_pipeline2(device, device->pipeline_topk_moe[i][TOPK_MOE_LATE_SOFTMAX][use_push], "topk_moe_f32_late_softmax"+std::to_string(i), topk_moe_f32_len, topk_moe_f32_data, "main", 3, sizeof(vk_op_topk_moe_push_constants), {1, 1, 1}, {device->subgroup_size, 1u<<i, 0, 1, use_push}, 1, true, true, device->subgroup_size);
} }
} }
@ -4514,8 +4397,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
} else if (strcmp("VK_EXT_memory_priority", properties.extensionName) == 0 && } else if (strcmp("VK_EXT_memory_priority", properties.extensionName) == 0 &&
getenv("GGML_VK_ENABLE_MEMORY_PRIORITY")) { getenv("GGML_VK_ENABLE_MEMORY_PRIORITY")) {
device->memory_priority = true; device->memory_priority = true;
} else if (strcmp("VK_EXT_external_memory_host", properties.extensionName) == 0) {
device->external_memory_host = true;
} }
} }
@ -4530,7 +4411,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
vk::PhysicalDeviceVulkan12Properties vk12_props; vk::PhysicalDeviceVulkan12Properties vk12_props;
vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props;
vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props; vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props;
vk::PhysicalDeviceExternalMemoryHostPropertiesEXT external_memory_host_props;
props2.pNext = &props3; props2.pNext = &props3;
props3.pNext = &subgroup_props; props3.pNext = &subgroup_props;
@ -4570,22 +4450,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props; last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props;
} }
if (device->external_memory_host) {
last_struct->pNext = (VkBaseOutStructure *)&external_memory_host_props;
last_struct = (VkBaseOutStructure *)&external_memory_host_props;
}
device->physical_device.getProperties2(&props2); device->physical_device.getProperties2(&props2);
device->properties = props2.properties; device->properties = props2.properties;
device->vendor_id = device->properties.vendorID; device->vendor_id = device->properties.vendorID;
device->driver_id = driver_props.driverID; device->driver_id = driver_props.driverID;
if (device->driver_id == vk::DriverId::eMoltenvk) {
// Disable external_memory_host until https://github.com/KhronosGroup/MoltenVK/pull/2622
// is available in the Vulkan SDK.
device->external_memory_host = false;
}
// Implementing the async backend interfaces seems broken on older Intel HW, // Implementing the async backend interfaces seems broken on older Intel HW,
// see https://github.com/ggml-org/llama.cpp/issues/17302. // see https://github.com/ggml-org/llama.cpp/issues/17302.
device->support_async = (device->vendor_id != VK_VENDOR_ID_INTEL || device->support_async = (device->vendor_id != VK_VENDOR_ID_INTEL ||
@ -4667,8 +4536,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated; device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated;
device->min_imported_host_pointer_alignment = external_memory_host_props.minImportedHostPointerAlignment;
device->max_workgroup_size_log2 = uint32_t(log2f(float(device->properties.limits.maxComputeWorkGroupInvocations))); device->max_workgroup_size_log2 = uint32_t(log2f(float(device->properties.limits.maxComputeWorkGroupInvocations)));
std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties(); std::vector<vk::QueueFamilyProperties> queue_family_props = device->physical_device.getQueueFamilyProperties();
@ -4800,10 +4667,6 @@ static vk_device ggml_vk_get_device(size_t idx) {
device_extensions.push_back("VK_KHR_pipeline_executable_properties"); device_extensions.push_back("VK_KHR_pipeline_executable_properties");
} }
if (device->external_memory_host) {
device_extensions.push_back("VK_EXT_external_memory_host");
}
vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2); vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2);
device->pipeline_executable_properties_support = pipeline_executable_properties_support; device->pipeline_executable_properties_support = pipeline_executable_properties_support;
@ -5070,23 +4933,11 @@ static vk_device ggml_vk_get_device(size_t idx) {
switch (device->vendor_id) { switch (device->vendor_id) {
#ifndef GGML_VULKAN_RUN_TESTS #ifndef GGML_VULKAN_RUN_TESTS
case VK_VENDOR_ID_AMD: case VK_VENDOR_ID_AMD:
device->mul_mat_l[i] = false;
device->mul_mat_m[i] = true;
device->mul_mat_s[i] = true;
device->mul_mat_id_l[i] = false;
device->mul_mat_id_m[i] = true;
device->mul_mat_id_s[i] = true;
break;
case VK_VENDOR_ID_INTEL: case VK_VENDOR_ID_INTEL:
if (!device->coopmat_support || device->architecture != INTEL_XE2) { device->mul_mat_l[i] = false;
device->mul_mat_l[i] = false;
device->mul_mat_id_l[i] = false;
} else {
device->mul_mat_l[i] = true; // if coopmat & XE2+, allow large matmul warptile config for Intel
device->mul_mat_id_l[i] = true;
}
device->mul_mat_m[i] = true; device->mul_mat_m[i] = true;
device->mul_mat_s[i] = true; device->mul_mat_s[i] = true;
device->mul_mat_id_l[i] = false;
device->mul_mat_id_m[i] = true; device->mul_mat_id_m[i] = true;
device->mul_mat_id_s[i] = true; device->mul_mat_id_s[i] = true;
break; break;
@ -5733,8 +5584,6 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
case GGML_TYPE_Q4_K: case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K: case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
break; break;
default: default:
return nullptr; return nullptr;
@ -5891,8 +5740,6 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
case GGML_TYPE_Q4_K: case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K: case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K:
case GGML_TYPE_IQ1_S:
case GGML_TYPE_IQ1_M:
break; break;
default: default:
return nullptr; return nullptr;
@ -6874,12 +6721,7 @@ static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& sub
vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1); vk_pipeline pipeline = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
const uint32_t num_blocks = CEIL_DIV(ne, pipeline->wg_denoms[0]); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array<uint32_t, 1>{ne}, { ne, 1, 1 });
// clamp the number of elements to the max workgroup count. The shader will iterate over the total number of blocks.
const uint64_t max_elements = std::min<uint64_t>(uint64_t{ctx->device->properties.limits.maxComputeWorkGroupCount[0]} * pipeline->wg_denoms[0], std::numeric_limits<uint32_t>::max());
const uint32_t elements = std::min(ne, static_cast<uint32_t>(max_elements));
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array<uint32_t, 2>{ ne, num_blocks }, { elements, 1, 1 });
ggml_vk_sync_buffers(ctx, subctx); ggml_vk_sync_buffers(ctx, subctx);
} }
@ -7163,7 +7005,7 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
// Quantization overhead is not worth it for small k // Quantization overhead is not worth it for small k
switch (device->vendor_id) { switch (device->vendor_id) {
case VK_VENDOR_ID_NVIDIA: case VK_VENDOR_ID_NVIDIA:
if (src0_type == GGML_TYPE_Q2_K || src0_type == GGML_TYPE_IQ1_S || src0_type == GGML_TYPE_IQ1_M) { if (src0_type == GGML_TYPE_Q2_K) {
return true; return true;
} }
@ -8842,9 +8684,10 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
if (ctx->num_additional_fused_ops) { if (ctx->num_additional_fused_ops) {
uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0])));
GGML_ASSERT(idx < num_topk_moe_pipelines); GGML_ASSERT(idx < num_topk_moe_pipelines);
topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
// use n_experts from push constant if it's not equal to the power of two spec constant // use n_experts from push constant if it's not equal to the power of two spec constant
bool use_push = dst->ne[0] != (1u << idx); bool use_push = dst->ne[0] != (1u << idx);
return ctx->device->pipeline_topk_moe[idx][use_push]; return ctx->device->pipeline_topk_moe[idx][mode][use_push];
} }
if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) {
@ -8917,11 +8760,7 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return nullptr; return nullptr;
case GGML_OP_CUMSUM: case GGML_OP_CUMSUM:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
if (src0->ne[0] <= 512) { return ctx->device->pipeline_cumsum_f32;
return ctx->device->pipeline_cumsum_small_f32;
} else {
return ctx->device->pipeline_cumsum_f32;
}
} }
return nullptr; return nullptr;
case GGML_OP_SOLVE_TRI: case GGML_OP_SOLVE_TRI:
@ -10507,16 +10346,14 @@ static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& sub
} }
static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) { static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx) {
topk_moe_mode mode = ctx->fused_topk_moe_mode; topk_moe_mode mode = ggml_vk_num_additional_ops_to_topk_moe_mode(ctx->num_additional_fused_ops);
ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0]; ggml_tensor * logits = cgraph->nodes[node_idx + 0]->src[0];
ggml_tensor * bias = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 2]->src[1] : logits; ggml_tensor * weights = (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) ? cgraph->nodes[node_idx + 9] :
ggml_tensor * weights = cgraph->nodes[node_idx + ctx->num_additional_fused_ops]; (mode == TOPK_MOE_EARLY_SOFTMAX) ? cgraph->nodes[node_idx + 4] :
ggml_tensor * ids = (mode == TOPK_MOE_SIGMOID_NORM_BIAS) ? cgraph->nodes[node_idx + 4] : cgraph->nodes[node_idx + 5];
(mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : ggml_tensor * ids = (mode == TOPK_MOE_LATE_SOFTMAX) ? cgraph->nodes[node_idx + 1] : cgraph->nodes[node_idx + 3];
cgraph->nodes[node_idx + 3];
GGML_ASSERT(logits->type == GGML_TYPE_F32); GGML_ASSERT(logits->type == GGML_TYPE_F32);
GGML_ASSERT(bias->type == GGML_TYPE_F32);
GGML_ASSERT(weights->type == GGML_TYPE_F32); GGML_ASSERT(weights->type == GGML_TYPE_F32);
GGML_ASSERT(ids->type == GGML_TYPE_I32); GGML_ASSERT(ids->type == GGML_TYPE_I32);
@ -10531,7 +10368,6 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
vk_subbuffer logits_buf = ggml_vk_tensor_subbuffer(ctx, logits); vk_subbuffer logits_buf = ggml_vk_tensor_subbuffer(ctx, logits);
vk_subbuffer bias_buf = ggml_vk_tensor_subbuffer(ctx, bias);
vk_subbuffer weights_buf = ggml_vk_tensor_subbuffer(ctx, weights); vk_subbuffer weights_buf = ggml_vk_tensor_subbuffer(ctx, weights);
vk_subbuffer ids_buf = ggml_vk_tensor_subbuffer(ctx, ids); vk_subbuffer ids_buf = ggml_vk_tensor_subbuffer(ctx, ids);
@ -10539,45 +10375,18 @@ static void ggml_vk_topk_moe(ggml_backend_vk_context * ctx, vk_context& subctx,
pc.n_rows = n_rows; pc.n_rows = n_rows;
pc.n_experts_push = n_experts; pc.n_experts_push = n_experts;
pc.n_expert_used = n_expert_used; pc.n_expert_used = n_expert_used;
pc.clamp_min = -std::numeric_limits<float>::infinity();
pc.clamp_max = std::numeric_limits<float>::infinity();
if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) { if (mode == TOPK_MOE_EARLY_SOFTMAX_NORM) {
ggml_tensor * clamp = cgraph->nodes[node_idx + 7]; ggml_tensor * clamp = cgraph->nodes[node_idx + 7];
GGML_ASSERT(clamp->op == GGML_OP_CLAMP);
pc.clamp_min = ggml_get_op_params_f32(clamp, 0); pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
pc.clamp_max = ggml_get_op_params_f32(clamp, 1); pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
} }
if (mode == TOPK_MOE_SIGMOID_NORM_BIAS) {
ggml_tensor * clamp = cgraph->nodes[node_idx + 8];
GGML_ASSERT(clamp->op == GGML_OP_CLAMP);
pc.clamp_min = ggml_get_op_params_f32(clamp, 0);
pc.clamp_max = ggml_get_op_params_f32(clamp, 1);
}
#define GATING_FUNC_SOFTMAX 0
#define GATING_FUNC_SIGMOID 1
#define GATING_FUNC_SOFTMAX_WEIGHT 2
pc.gating_func = mode == TOPK_MOE_SIGMOID_NORM_BIAS ? GATING_FUNC_SIGMOID :
mode == TOPK_MOE_LATE_SOFTMAX ? GATING_FUNC_SOFTMAX_WEIGHT :
GATING_FUNC_SOFTMAX;
pc.has_bias = mode == TOPK_MOE_SIGMOID_NORM_BIAS;
pc.with_norm = mode == TOPK_MOE_EARLY_SOFTMAX_NORM || mode == TOPK_MOE_SIGMOID_NORM_BIAS;
if (ctx->fused_topk_moe_scale) {
GGML_ASSERT(weights->op == GGML_OP_SCALE);
pc.output_scale = ggml_get_op_params_f32(weights, 0);
pc.output_bias = ggml_get_op_params_f32(weights, 1);
} else {
pc.output_scale = 1.0f;
pc.output_bias = 0.0f;
}
GGML_ASSERT(n_expert_used <= n_experts); GGML_ASSERT(n_expert_used <= n_experts);
const uint32_t rows_per_block = 4; const uint32_t rows_per_block = 4;
std::array<uint32_t, 3> elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 }; std::array<uint32_t, 3> elements = { CEIL_DIV(n_rows, rows_per_block), 1, 1 };
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, bias_buf, weights_buf, ids_buf}, pc, elements); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, {logits_buf, weights_buf, ids_buf}, pc, elements);
} }
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_cgraph * cgraph, int node_idx, bool backprop) { static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_cgraph * cgraph, int node_idx, bool backprop) {
@ -10825,50 +10634,8 @@ static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, cons
} }
static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { static void ggml_vk_cumsum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
vk_op_sum_rows_push_constants pc = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]); vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]);
// Use the single pass shader when the rows are small or there are enough rows to fill the GPU. ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, p);
// For fewer, larger rows, use the multipass shader to spread each row across SMs.
if (dst->ne[0] <= 4096 || ggml_nrows(dst) >= ctx->device->shader_core_count) {
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_CUMSUM, pc);
return;
}
// First pass computes partial sums within a block, and stores the last partial
// to the temp buffer. Second pass sums the block partials from the temp buffer
// and adds that to the result of the first pass.
vk_pipeline pipeline1 = ctx->device->pipeline_cumsum_multipass1_f32;
vk_pipeline pipeline2 = ctx->device->pipeline_cumsum_multipass2_f32;
GGML_ASSERT(pipeline1 != nullptr && pipeline2 != nullptr);
ggml_pipeline_request_descriptor_sets(ctx, pipeline1, 1);
ggml_pipeline_request_descriptor_sets(ctx, pipeline2, 1);
std::array<uint32_t, 3> elements;
elements[0] = dst->ne[0];
elements[1] = (uint32_t)ggml_nrows(dst);
elements[2] = 1;
size_t temp_size = sizeof(float) * elements[0] * ggml_nrows(dst);
if (ctx->prealloc_size_split_k < temp_size) {
ctx->prealloc_size_split_k = temp_size;
ggml_vk_preallocate_buffers(ctx, subctx);
}
vk_subbuffer src_buf = ggml_vk_tensor_subbuffer(ctx, src0);
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst);
vk_subbuffer temp_buf = ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0);
if (ctx->prealloc_split_k_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline1, {src_buf, dst_buf, temp_buf}, pc, elements);
ggml_vk_sync_buffers(ctx, subctx);
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline2, {src_buf, dst_buf, temp_buf}, pc, elements);
ctx->prealloc_split_k_need_sync = true;
} }
static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) { static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
@ -12361,11 +12128,6 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break; break;
case GGML_OP_UNARY: case GGML_OP_UNARY:
if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
break;
}
switch (ggml_get_unary_op(node)) { switch (ggml_get_unary_op(node)) {
case GGML_UNARY_OP_EXP: case GGML_UNARY_OP_EXP:
case GGML_UNARY_OP_SILU: case GGML_UNARY_OP_SILU:
@ -12413,7 +12175,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break; break;
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) { if (ctx->num_additional_fused_ops) {
ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx); ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
} else { } else {
ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node); ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node);
@ -12433,7 +12195,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
break; break;
case GGML_OP_ARGSORT: case GGML_OP_ARGSORT:
if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) { if (ctx->num_additional_fused_ops) {
ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx); ggml_vk_topk_moe(ctx, compute_ctx, cgraph, node_idx);
} else { } else {
ggml_vk_argsort(ctx, compute_ctx, src0, node); ggml_vk_argsort(ctx, compute_ctx, src0, node);
@ -13286,24 +13048,6 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
get_rows = cgraph->nodes[node_idx + 4]; get_rows = cgraph->nodes[node_idx + 4];
argsort = cgraph->nodes[node_idx + 2]; argsort = cgraph->nodes[node_idx + 2];
break; break;
case TOPK_MOE_SIGMOID_NORM_BIAS:
softmax = cgraph->nodes[node_idx + 0]; // really sigmoid
weights = cgraph->nodes[node_idx + 10];
get_rows = cgraph->nodes[node_idx + 5];
argsort = cgraph->nodes[node_idx + 3];
if (ggml_get_unary_op(softmax) != GGML_UNARY_OP_SIGMOID) {
return false;
}
// bias is expected to be 1D
if (ggml_nrows(cgraph->nodes[node_idx + 2]->src[1]) != 1 ||
!ggml_is_contiguous(cgraph->nodes[node_idx + 2]->src[1])) {
return false;
}
// sigmoid fusion seems to generate infinities on moltenvk
if (ctx->device->driver_id == vk::DriverId::eMoltenvk) {
return false;
}
break;
case TOPK_MOE_EARLY_SOFTMAX: case TOPK_MOE_EARLY_SOFTMAX:
softmax = cgraph->nodes[node_idx + 0]; softmax = cgraph->nodes[node_idx + 0];
weights = cgraph->nodes[node_idx + 4]; weights = cgraph->nodes[node_idx + 4];
@ -13327,28 +13071,26 @@ static bool ggml_vk_can_fuse_topk_moe(ggml_backend_vk_context * ctx, const struc
probs = probs->src[0]; probs = probs->src[0];
ggml_tensor * selection_probs = argsort->src[0]; ggml_tensor * selection_probs = argsort->src[0];
if (probs != selection_probs && mode != TOPK_MOE_SIGMOID_NORM_BIAS) { if (probs != selection_probs) {
return false; return false;
} }
const float * op_params = (const float *)softmax->op_params;
float scale = op_params[0];
float max_bias = op_params[1];
if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) { if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
return false; return false;
} }
if (softmax->op == GGML_OP_SOFT_MAX) { if (scale != 1.0f || max_bias != 0.0f) {
const float * op_params = (const float *)softmax->op_params; return false;
}
float scale = op_params[0]; // don't fuse when masks or sinks are present
float max_bias = op_params[1]; if (softmax->src[1] || softmax->src[2]) {
return false;
if (scale != 1.0f || max_bias != 0.0f) {
return false;
}
// don't fuse when masks or sinks are present
if (softmax->src[1] || softmax->src[2]) {
return false;
}
} }
const int n_expert = softmax->ne[0]; const int n_expert = softmax->ne[0];
@ -13621,8 +13363,6 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
total_mul_mat_bytes += bytes; total_mul_mat_bytes += bytes;
} }
ctx->fused_topk_moe_mode = TOPK_MOE_COUNT;
ctx->fused_topk_moe_scale = false;
const char *fusion_string {}; const char *fusion_string {};
if (!ctx->device->disable_fusion) { if (!ctx->device->disable_fusion) {
uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i); uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i);
@ -13668,23 +13408,13 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1; ctx->num_additional_fused_ops = topk_moe_early_softmax_norm.size() - 1;
// view of argsort writes to memory // view of argsort writes to memory
ctx->fused_ops_write_mask |= 1 << 3; ctx->fused_ops_write_mask |= 1 << 3;
ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX_NORM;
fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM"; fusion_string = "TOPK_MOE_EARLY_SOFTMAX_NORM";
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_sigmoid_norm_bias, { i + 4, i + 10 }) &&
ggml_check_edges(cgraph, i, topk_moe_sigmoid_norm_bias_edges) &&
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_SIGMOID_NORM_BIAS)) {
ctx->num_additional_fused_ops = topk_moe_sigmoid_norm_bias.size() - 1;
// view of argsort writes to memory
ctx->fused_ops_write_mask |= 1 << 4;
ctx->fused_topk_moe_mode = TOPK_MOE_SIGMOID_NORM_BIAS;
fusion_string = "TOPK_MOE_SIGMOID_NORM_BIAS";
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) && } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_early_softmax, { i + 3, i + 4 }) &&
ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) && ggml_check_edges(cgraph, i, topk_moe_early_softmax_edges) &&
ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) { ggml_vk_can_fuse_topk_moe(ctx, cgraph, i, TOPK_MOE_EARLY_SOFTMAX)) {
ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1; ctx->num_additional_fused_ops = topk_moe_early_softmax.size() - 1;
// view of argsort writes to memory // view of argsort writes to memory
ctx->fused_ops_write_mask |= 1 << 3; ctx->fused_ops_write_mask |= 1 << 3;
ctx->fused_topk_moe_mode = TOPK_MOE_EARLY_SOFTMAX;
fusion_string = "TOPK_MOE_EARLY_SOFTMAX"; fusion_string = "TOPK_MOE_EARLY_SOFTMAX";
} else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) && } else if (ggml_can_fuse_subgraph(cgraph, i, topk_moe_late_softmax, { i + 1, i + 5 }) &&
ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) && ggml_check_edges(cgraph, i, topk_moe_late_softmax_edges) &&
@ -13692,17 +13422,8 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1; ctx->num_additional_fused_ops = topk_moe_late_softmax.size() - 1;
// view of argsort writes to memory // view of argsort writes to memory
ctx->fused_ops_write_mask |= 1 << 1; ctx->fused_ops_write_mask |= 1 << 1;
ctx->fused_topk_moe_mode = TOPK_MOE_LATE_SOFTMAX;
fusion_string = "TOPK_MOE_LATE_SOFTMAX"; fusion_string = "TOPK_MOE_LATE_SOFTMAX";
} }
if (ctx->fused_topk_moe_mode != TOPK_MOE_COUNT) {
// Look for an additional scale op to fuse - occurs in deepseek2 and nemotron3 nano.
if (ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops - 1, { GGML_OP_DIV, GGML_OP_RESHAPE, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 }) ||
ggml_can_fuse_subgraph(cgraph, i + ctx->num_additional_fused_ops, { GGML_OP_GET_ROWS, GGML_OP_SCALE }, { i + ctx->num_additional_fused_ops + 1 })) {
ctx->fused_topk_moe_scale = true;
ctx->num_additional_fused_ops++;
}
}
} }
ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops; ctx->fused_ops_write_mask |= 1 << ctx->num_additional_fused_ops;
@ -13881,9 +13602,6 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
if (keep_pattern(topk_moe_early_softmax_norm)) { if (keep_pattern(topk_moe_early_softmax_norm)) {
continue; continue;
} }
if (keep_pattern(topk_moe_sigmoid_norm_bias)) {
continue;
}
if (keep_pattern(topk_moe_early_softmax)) { if (keep_pattern(topk_moe_early_softmax)) {
continue; continue;
} }
@ -13910,7 +13628,6 @@ static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph *
} }
// Don't pull forward nodes from fusion patterns // Don't pull forward nodes from fusion patterns
if (match_pattern(topk_moe_early_softmax_norm, j) || if (match_pattern(topk_moe_early_softmax_norm, j) ||
match_pattern(topk_moe_sigmoid_norm_bias, j) ||
match_pattern(topk_moe_early_softmax, j) || match_pattern(topk_moe_early_softmax, j) ||
match_pattern(topk_moe_late_softmax, j)) { match_pattern(topk_moe_late_softmax, j)) {
continue; continue;
@ -14305,19 +14022,6 @@ static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const
} }
static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) {
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
const vk_device& device = ggml_vk_get_device(ctx->device);
// reject any tensors larger than the max buffer size
for (int i = 0; i < GGML_MAX_SRC; i++) {
if (op->src[i] && ggml_nbytes(op->src[i]) > device->max_buffer_size) {
return false;
}
}
if (ggml_nbytes(op) > device->max_buffer_size) {
return false;
}
switch (op->op) { switch (op->op) {
case GGML_OP_UNARY: case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) { switch (ggml_get_unary_op(op)) {
@ -14366,6 +14070,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_MUL_MAT_ID: case GGML_OP_MUL_MAT_ID:
{ {
ggml_type src0_type = op->src[0]->type; ggml_type src0_type = op->src[0]->type;
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
const vk_device& device = ggml_vk_get_device(ctx->device);
if (op->op == GGML_OP_MUL_MAT_ID) { if (op->op == GGML_OP_MUL_MAT_ID) {
if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) { if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) {
// If there's not enough shared memory for row_ids and the result tile, fallback to CPU // If there's not enough shared memory for row_ids and the result tile, fallback to CPU
@ -14426,6 +14132,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
} }
case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_EXT:
{ {
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
auto device = ggml_vk_get_device(ctx->device);
bool coopmat2 = device->coopmat2; bool coopmat2 = device->coopmat2;
uint32_t HSK = op->src[1]->ne[0]; uint32_t HSK = op->src[1]->ne[0];
uint32_t HSV = op->src[2]->ne[0]; uint32_t HSV = op->src[2]->ne[0];
@ -14647,6 +14355,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) { if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
return false; return false;
} }
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
auto device = ggml_vk_get_device(ctx->device);
// pipeline_argsort_large_f32 requires vulkan memory model. // pipeline_argsort_large_f32 requires vulkan memory model.
if (device->vulkan_memory_model) { if (device->vulkan_memory_model) {
return true; return true;
@ -14659,6 +14369,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) { if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
return false; return false;
} }
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
auto device = ggml_vk_get_device(ctx->device);
// We could potentially support larger, using argsort to sort the // We could potentially support larger, using argsort to sort the
// whole thing. Not clear if this is needed. // whole thing. Not clear if this is needed.
uint32_t min_pipeline = (uint32_t)log2f(float(op->ne[0])) + 1; uint32_t min_pipeline = (uint32_t)log2f(float(op->ne[0])) + 1;
@ -14705,6 +14417,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]); return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_CUMSUM: case GGML_OP_CUMSUM:
{ {
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
auto device = ggml_vk_get_device(ctx->device);
if (device->subgroup_arithmetic && device->subgroup_require_full_support) { if (device->subgroup_arithmetic && device->subgroup_require_full_support) {
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]); return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]);
} }
@ -14712,6 +14426,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
} }
case GGML_OP_SOLVE_TRI: case GGML_OP_SOLVE_TRI:
{ {
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
const vk_device& device = ggml_vk_get_device(ctx->device);
if (op->type != GGML_TYPE_F32 || op->src[0]->type != GGML_TYPE_F32) { if (op->type != GGML_TYPE_F32 || op->src[0]->type != GGML_TYPE_F32) {
return false; return false;
} }
@ -14776,6 +14493,9 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return false; return false;
} }
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
const vk_device& device = ggml_vk_get_device(ctx->device);
const uint32_t SPLIT_H = 16; const uint32_t SPLIT_H = 16;
size_t stateC_size = SPLIT_H * d_state * sizeof(float); size_t stateC_size = SPLIT_H * d_state * sizeof(float);
@ -14869,51 +14589,6 @@ static void ggml_backend_vk_device_event_synchronize(ggml_backend_dev_t dev, ggm
VK_CHECK(device->device.waitForFences({ vkev->fence }, true, UINT64_MAX), "event_synchronize"); VK_CHECK(device->device.waitForFences({ vkev->fence }, true, UINT64_MAX), "event_synchronize");
} }
static vk_buffer ggml_vk_buffer_from_host_ptr(vk_device & device, void * ptr, size_t size) {
if (!device->external_memory_host) {
return {};
}
uintptr_t uptr = reinterpret_cast<uintptr_t>(ptr);
if (uptr & (device->min_imported_host_pointer_alignment - 1)) {
return {};
}
if (size & (device->min_imported_host_pointer_alignment - 1)) {
return {};
}
const vk::MemoryPropertyFlags property_flags = vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached;
vk_buffer buf {};
try {
buf = ggml_vk_create_buffer(device, size, { property_flags }, ptr);
} catch (vk::SystemError& e) {
GGML_LOG_WARN("ggml_vulkan: Failed ggml_vk_create_buffer (%s)\n", e.what());
}
return buf;
}
static ggml_backend_buffer_t ggml_backend_vk_device_buffer_from_host_ptr(ggml_backend_dev_t dev, void * ptr, size_t size, size_t max_tensor_size) {
VK_LOG_DEBUG("ggml_backend_vk_device_buffer_from_host_ptr(backend=" << dev << ", ptr=" << ptr << ", size=" << size << ")");
GGML_UNUSED(max_tensor_size);
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
auto device = ggml_vk_get_device(ctx->device);
vk_buffer buf = ggml_vk_buffer_from_host_ptr(device, ptr, size);
if (!buf) {
return {};
}
ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(device, std::move(buf), device->name);
ggml_backend_buffer_t ret = ggml_backend_buffer_init(ggml_backend_vk_device_get_buffer_type(dev), ggml_backend_vk_buffer_interface, bufctx, size);
return ret;
}
static const struct ggml_backend_device_i ggml_backend_vk_device_i = { static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
/* .get_name = */ ggml_backend_vk_device_get_name, /* .get_name = */ ggml_backend_vk_device_get_name,
/* .get_description = */ ggml_backend_vk_device_get_description, /* .get_description = */ ggml_backend_vk_device_get_description,
@ -14923,7 +14598,7 @@ static const struct ggml_backend_device_i ggml_backend_vk_device_i = {
/* .init_backend = */ ggml_backend_vk_device_init, /* .init_backend = */ ggml_backend_vk_device_init,
/* .get_buffer_type = */ ggml_backend_vk_device_get_buffer_type, /* .get_buffer_type = */ ggml_backend_vk_device_get_buffer_type,
/* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type, /* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type,
/* .buffer_from_host_ptr = */ ggml_backend_vk_device_buffer_from_host_ptr, /* .buffer_from_host_ptr = */ NULL,
/* .supports_op = */ ggml_backend_vk_device_supports_op, /* .supports_op = */ ggml_backend_vk_device_supports_op,
/* .supports_buft = */ ggml_backend_vk_device_supports_buft, /* .supports_buft = */ ggml_backend_vk_device_supports_buft,
/* .offload_op = */ ggml_backend_vk_device_offload_op, /* .offload_op = */ ggml_backend_vk_device_offload_op,

View File

@ -14,7 +14,6 @@ layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
layout (constant_id = 0) const uint BLOCK_SIZE = 128; layout (constant_id = 0) const uint BLOCK_SIZE = 128;
layout (constant_id = 1) const uint SUBGROUP_SIZE = 32; layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
layout (constant_id = 2) const uint ELEM_PER_THREAD = 4;
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) #define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
@ -39,45 +38,32 @@ void main() {
last_sum = 0; last_sum = 0;
} }
uint col = tid * ELEM_PER_THREAD; uint col = tid;
uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE * ELEM_PER_THREAD); uint num_iter = CEIL_DIV(p.n_cols, BLOCK_SIZE);
for (int i = 0; i < num_iter; ++i) { for (int i = 0; i < num_iter; ++i) {
FLOAT_TYPE v[ELEM_PER_THREAD]; FLOAT_TYPE v = 0;
FLOAT_TYPE thread_sum = 0; if (col < p.n_cols) {
[[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) { v = FLOAT_TYPE(data_a[src_idx + col]);
if (col + j < p.n_cols) {
thread_sum += FLOAT_TYPE(data_a[src_idx + col + j]);
}
v[j] = thread_sum;
} }
v = subgroupInclusiveAdd(v);
thread_sum = subgroupExclusiveAdd(thread_sum);
[[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
v[j] += thread_sum;
}
// Store the largest partial sum for each subgroup, then add the partials for all // Store the largest partial sum for each subgroup, then add the partials for all
// lower subgroups and the final partial sum from the previous iteration. // lower subgroups and the final partial sum from the previous iteration.
if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) { if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) {
partial[subgroup_id] = v[ELEM_PER_THREAD - 1]; partial[subgroup_id] = v;
} }
barrier(); barrier();
for (int s = 0; s < subgroup_id; ++s) { for (int j = 0; j < subgroup_id; ++j) {
[[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) { v += partial[j];
v[j] += partial[s];
}
}
[[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) {
v[j] += last_sum;
} }
v += last_sum;
barrier(); barrier();
if (tid == BLOCK_SIZE - 1) { if (tid == BLOCK_SIZE - 1) {
last_sum = v[ELEM_PER_THREAD - 1]; last_sum = v;
} }
[[unroll]] for (uint j = 0; j < ELEM_PER_THREAD; ++j) { if (col < p.n_cols) {
if (col + j < p.n_cols) { data_d[dst_idx + col] = D_TYPE(v);
data_d[dst_idx + col + j] = D_TYPE(v[j]);
}
} }
col += BLOCK_SIZE * ELEM_PER_THREAD; col += BLOCK_SIZE;
} }
} }

View File

@ -1,60 +0,0 @@
#version 450
#include "types.glsl"
#include "sum_rows.glsl"
#extension GL_EXT_control_flow_attributes : enable
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_basic : enable
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
layout (binding = 2) writeonly buffer T {D_TYPE data_t[];};
layout (constant_id = 0) const uint BLOCK_SIZE = 128;
layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
shared FLOAT_TYPE partial[BLOCK_SIZE / SUBGROUP_SIZE];
void main() {
const uint row = gl_WorkGroupID.y;
const uint tid = gl_LocalInvocationID.x;
const uint col = gl_GlobalInvocationID.x;
const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
const uint i03_offset = i03 * p.ne01*p.ne02;
const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
const uint i01 = row - i03_offset - i02*p.ne01;
const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
uint subgroup_id = tid / SUBGROUP_SIZE;
FLOAT_TYPE v = 0;
if (col < p.n_cols) {
v = FLOAT_TYPE(data_a[src_idx + col]);
}
v = subgroupInclusiveAdd(v);
// Store the largest partial sum for each subgroup, then add the partials for all
// lower subgroups and the final partial sum from the previous iteration.
if (gl_SubgroupInvocationID == SUBGROUP_SIZE - 1) {
partial[subgroup_id] = v;
}
barrier();
for (int j = 0; j < subgroup_id; ++j) {
v += partial[j];
}
barrier();
if (tid == BLOCK_SIZE - 1) {
data_t[gl_WorkGroupID.x + gl_NumWorkGroups.x * row] = v;
}
if (col < p.n_cols) {
data_d[dst_idx + col] = D_TYPE(v);
}
}

View File

@ -1,66 +0,0 @@
#version 450
#include "types.glsl"
#include "sum_rows.glsl"
#extension GL_EXT_control_flow_attributes : enable
#extension GL_KHR_shader_subgroup_arithmetic : enable
#extension GL_KHR_shader_subgroup_basic : enable
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) buffer D {D_TYPE data_d[];};
layout (binding = 2) readonly buffer T {D_TYPE data_t[];};
layout (constant_id = 0) const uint BLOCK_SIZE = 128;
layout (constant_id = 1) const uint SUBGROUP_SIZE = 32;
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
shared FLOAT_TYPE temp[BLOCK_SIZE / SUBGROUP_SIZE];
void main() {
const uint row = gl_WorkGroupID.y;
const uint tid = gl_LocalInvocationID.x;
const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L);
const uint i03_offset = i03 * p.ne01*p.ne02;
const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L);
const uint i01 = row - i03_offset - i02*p.ne01;
const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03;
const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13;
const uint col = gl_GlobalInvocationID.x;
float v = 0;
// prefetch value we're adding to
if (col < p.n_cols) {
v = data_d[dst_idx + col];
}
// compute the sum of all previous blocks
uint c = tid;
float sum = 0;
while (c < gl_WorkGroupID.x) {
sum += data_t[c + gl_NumWorkGroups.x * row];
c += BLOCK_SIZE;
}
sum = subgroupAdd(sum);
if (gl_SubgroupInvocationID == 0) {
temp[gl_SubgroupID] = sum;
}
barrier();
sum = 0;
[[unroll]] for (uint s = 0; s < BLOCK_SIZE / SUBGROUP_SIZE; ++s) {
sum += temp[s];
}
// Add the sum to what the first pass computed
if (col < p.n_cols) {
data_d[dst_idx + col] = v + sum;
}
}

View File

@ -462,8 +462,7 @@ vec2 get_dm(uint ib, uint a_offset) {
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) #if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
vec2 get_dm(uint ib, uint a_offset) { vec2 get_dm(uint ib, uint a_offset) {
const vec2 dm = vec2(data_a_packed32[a_offset + ib].dm); return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m));
return dm;
} }
#endif #endif

View File

@ -14,8 +14,6 @@ layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
#define K_PER_ITER 8 #define K_PER_ITER 8
#elif defined(DATA_A_QUANT_K) #elif defined(DATA_A_QUANT_K)
#define K_PER_ITER 16 #define K_PER_ITER 16
#elif defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M)
#define K_PER_ITER 32
#else #else
#error unimplemented #error unimplemented
#endif #endif
@ -51,15 +49,6 @@ void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 1]; cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 1];
cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 2]; cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 2];
cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 3]; cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 3];
#elif K_PER_ITER == 32
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 ];
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 1];
cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 2];
cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 3];
cache_b_qs[4] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 4];
cache_b_qs[5] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 5];
cache_b_qs[6] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 6];
cache_b_qs[7] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + 7];
#else #else
#error unimplemented #error unimplemented
#endif #endif

View File

@ -377,118 +377,3 @@ FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
return FLOAT_TYPE(float(cache_b_ds.x) * float(d_scale) * float(q_sum)); return FLOAT_TYPE(float(cache_b_ds.x) * float(d_scale) * float(q_sum));
} }
#endif #endif
#if defined(DATA_A_IQ1_S)
void repack8(uint ib, uint iqs, out i32vec4 out0, out i32vec4 out1) {
const uint ib32 = iqs / 32;
const uint qh = data_a[ib].qh[ib32];
const uint qs16_0 = data_a_packed16[ib].qs[(4 * ib32 + 0) / 2];
const uint qs16_1 = data_a_packed16[ib].qs[(4 * ib32 + 2) / 2];
const uint qs0 = qs16_0 & 0xFF;
const uint qs1 = qs16_0 >> 8;
const uint qs2 = qs16_1 & 0xFF;
const uint qs3 = qs16_1 >> 8;
const uint hi0 = bitfieldExtract(qh, 3 * int(0), 3);
const uint hi1 = bitfieldExtract(qh, 3 * int(1), 3);
const uint hi2 = bitfieldExtract(qh, 3 * int(2), 3);
const uint hi3 = bitfieldExtract(qh, 3 * int(3), 3);
const int32_t grid0 = int32_t(iq1s_grid_gpu[qs0 | (hi0 << 8)]);
const int32_t grid1 = int32_t(iq1s_grid_gpu[qs1 | (hi1 << 8)]);
const int32_t grid2 = int32_t(iq1s_grid_gpu[qs2 | (hi2 << 8)]);
const int32_t grid3 = int32_t(iq1s_grid_gpu[qs3 | (hi3 << 8)]);
out0 = i32vec4((grid0 >> 0) & 0x0F0F0F0F,
(grid0 >> 4) & 0x0F0F0F0F,
(grid1 >> 0) & 0x0F0F0F0F,
(grid1 >> 4) & 0x0F0F0F0F);
out1 = i32vec4((grid2 >> 0) & 0x0F0F0F0F,
(grid2 >> 4) & 0x0F0F0F0F,
(grid3 >> 0) & 0x0F0F0F0F,
(grid3 >> 4) & 0x0F0F0F0F);
}
vec2 get_dm(uint ib, uint iqs) {
const uint ib32 = iqs / 32;
const uint qh = data_a[ib].qh[ib32];
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
const float d = float(data_a[ib].d);
const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
// the -1 cancels out the bias in iq1s_grid_gpu
return FLOAT_TYPE_VEC2(dl, dl * (delta - 1));
}
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
int32_t q_sum = 0;
const uint ib_k = ib_a / 8;
const uint iqs_k = (ib_a % 8) * 32 + iqs * 32;
i32vec4 qs_a0;
i32vec4 qs_a1;
repack8(ib_k, iqs_k, qs_a0, qs_a1);
const vec2 dm = get_dm(ib_k, iqs_k);
q_sum += dotPacked4x8EXT(qs_a0.x, cache_b_qs[0]);
q_sum += dotPacked4x8EXT(qs_a0.y, cache_b_qs[1]);
q_sum += dotPacked4x8EXT(qs_a0.z, cache_b_qs[2]);
q_sum += dotPacked4x8EXT(qs_a0.w, cache_b_qs[3]);
q_sum += dotPacked4x8EXT(qs_a1.x, cache_b_qs[4]);
q_sum += dotPacked4x8EXT(qs_a1.y, cache_b_qs[5]);
q_sum += dotPacked4x8EXT(qs_a1.z, cache_b_qs[6]);
q_sum += dotPacked4x8EXT(qs_a1.w, cache_b_qs[7]);
return FLOAT_TYPE(float(cache_b_ds.x) * float(dm.x) * float(q_sum) + float(dm.y) * float(cache_b_ds.y));
}
#endif
#if defined(DATA_A_IQ1_M)
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
const uint ib_k = ib_a / 8;
const uint iqs_k = (ib_a % 8) * 32 + iqs * 32;
const uint ib32 = iqs_k / 32;
const uint ib64 = ib32 / 2;
const uint16_t[4] scales = data_a[ib_k].scales;
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
const uint qs32 = data_a_packed32[ib_k].qs[ib32];
const uint qh16 = data_a_packed16[ib_k].qh[ib32];
float sum = 0;
const uint sc = data_a[ib_k].scales[ib64];
[[unroll]] for (int l = 0; l < 4; ++l) {
const uint ib16 = 2 * ib32 + l / 2;
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1);
const uint qh = qh16 >> (4 * l);
const uint qs = (qs32 >> (8 * l)) & 0xFF;
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
const int32_t grid = int32_t(iq1s_grid_gpu[qs | ((qh & 7) << 8)]);
int32_t q_sum = 0;
q_sum += dotPacked4x8EXT((grid >> 0) & 0x0F0F0F0F, cache_b_qs[2 * l + 0]);
q_sum += dotPacked4x8EXT((grid >> 4) & 0x0F0F0F0F, cache_b_qs[2 * l + 1]);
int32_t y_sum = 0;
y_sum += dotPacked4x8EXT(int(0x01010101), cache_b_qs[2 * l + 0]);
y_sum += dotPacked4x8EXT(int(0x01010101), cache_b_qs[2 * l + 1]);
// the -1 cancels out the bias in iq1s_grid_gpu
sum += dl * (q_sum + y_sum * (delta - 1));
}
sum *= float(cache_b_ds.x);
return sum;
}
#endif

View File

@ -47,7 +47,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
#endif #endif
#elif defined(DATA_A_Q4_0) #elif defined(DATA_A_Q4_0)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
const uint ib = idx / 4; const uint ib = idx / 4;
const uint iqs = idx & 0x03; const uint iqs = idx & 0x03;
@ -63,15 +63,16 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw); buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw);
#elif defined(DATA_A_Q4_1) #elif defined(DATA_A_Q4_1)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; const uint buf_idx = col * SHMEM_STRIDE + 2 * row;
const uint ib = idx / 4; const uint ib = idx / 4;
const uint iqs = idx & 0x03; const uint iqs = idx & 0x03;
const vec2 dm = vec2(data_a_packed32[ib].dm); const float d = float(data_a_packed16[ib].d);
const uint vui = data_a_packed32[ib].qs[iqs]; const float m = float(data_a_packed16[ib].m);
const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * dm.x + dm.y; const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16);
const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * dm.x + dm.y; const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m;
const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m;
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy); buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy);
buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(v0.zw); buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(v0.zw);
@ -79,7 +80,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.zw); buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.zw);
#elif defined(DATA_A_Q5_0) #elif defined(DATA_A_Q5_0)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; const uint buf_idx = col * SHMEM_STRIDE + row;
const uint ib = idx / 8; const uint ib = idx / 8;
const uint iqs = idx & 0x07; const uint iqs = idx & 0x07;
@ -96,26 +97,22 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw); buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
#elif defined(DATA_A_Q5_1) #elif defined(DATA_A_Q5_1)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; const uint buf_idx = col * SHMEM_STRIDE + row;
const uint ib = idx / 4; const uint ib = idx / 8;
const uint iqs = idx & 0x03; const uint iqs = idx & 0x07;
const vec2 dm = vec2(data_a_packed32[ib].dm); const float d = float(data_a_packed16[ib].d);
const uint uint_qh = data_a_packed32[ib].qh; const float m = float(data_a_packed16[ib].m);
const uvec2 qh0 = uvec2(((uint_qh >> 4*iqs) << 4) & 0x10, (uint_qh >> (4*iqs + 12)) & 0x10); const uint uint_qh = data_a_packed16[ib].qh;
const uvec2 qh1 = uvec2(((uint_qh >> (4*iqs + 1)) << 4) & 0x10, (uint_qh >> (4*iqs + 13)) & 0x10); const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10);
const uvec2 qh2 = uvec2(((uint_qh >> (4*iqs + 2)) << 4) & 0x10, (uint_qh >> (4*iqs + 14)) & 0x10); const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10);
const uvec2 qh3 = uvec2(((uint_qh >> (4*iqs + 3)) << 4) & 0x10, (uint_qh >> (4*iqs + 15)) & 0x10);
const uint vui = data_a_packed32[ib].qs[iqs]; const uint vui = uint(data_a_packed16[ib].qs[iqs]);
const vec4 v0 = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, ((vui >> 12) & 0xF) | qh1.y) * dm.x + dm.y; const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m;
const vec4 v1 = vec4(((vui >> 16) & 0xF) | qh2.x, ((vui >> 20) & 0xF) | qh2.y, ((vui >> 24) & 0xF) | qh3.x, ((vui >> 28) & 0xF) | qh3.y) * dm.x + dm.y;
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xz); buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz);
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v1.xz); buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw);
buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v0.yw);
buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.yw);
#elif defined(DATA_A_Q8_0) #elif defined(DATA_A_Q8_0)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@ -134,21 +131,20 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
const uint ib = idx / 64; // 4 values per idx const uint ib = idx / 128; // 2 values per idx
const uint iqs = (idx % 64) * 2; // 0,2,4..126 const uint iqs = idx % 128; // 0..127
const uint qsi = (iqs / 64) * 16 + (iqs % 16); // 0..15 const uint qsi = (iqs / 64) * 16 + (iqs % 16); // 0..15
const uint scalesi = iqs / 8; // 0..15 const uint scalesi = iqs / 8; // 0..15
const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6
const vec4 qs = vec4(unpack8((data_a_packed32[ib].qs[qsi / 2] >> qsshift) & 0x03030303)); const uvec2 qs = uvec2(unpack8(data_a_packed16[ib].qs[qsi]));
const uint scales = data_a[ib].scales[scalesi]; const uint scales = data_a[ib].scales[scalesi];
const vec2 dm = vec2(data_a[ib].dm); const vec2 dm = vec2(data_a[ib].dm);
const vec4 v = dm.x * float(scales & 0xF) * qs - dm.y * float(scales >> 4); const vec2 v = dm.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - dm.y * float(scales >> 4);
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy);
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw);
#elif defined(DATA_A_Q3_K) #elif defined(DATA_A_Q3_K)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@ -177,8 +173,8 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
const uint ib = idx / 64; // 4 values per idx const uint ib = idx / 128; // 2 values per idx
const uint iqs = (idx % 64) * 2; // 0,2,4..126 const uint iqs = idx % 128; // 0..127
const uint n = iqs / 32; // 0,1,2,3 const uint n = iqs / 32; // 0,1,2,3
const uint b = (iqs % 32) / 16; // 0,1 const uint b = (iqs % 32) / 16; // 0,1
@ -204,16 +200,16 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const float d = loadd.x * sc; const float d = loadd.x * sc;
const float m = -loadd.y * mbyte; const float m = -loadd.y * mbyte;
const vec4 q = vec4(unpack8((data_a_packed32[ib].qs[qsi / 4] >> (b * 4)) & 0x0F0F0F0F)); const vec2 q = vec2(unpack8((uint(data_a_packed16[ib].qs[qsi / 2]) >> (b * 4)) & 0x0F0F).xy);
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m)); buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, q.x, m),
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m)); fma(d, q.y, m));
#elif defined(DATA_A_Q5_K) #elif defined(DATA_A_Q5_K)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
const uint ib = idx / 64; // 4 values per idx const uint ib = idx / 128; // 2 values per idx
const uint iqs = (idx % 64) * 2; // 0,2,4..126 const uint iqs = idx % 128; // 0..127
const uint n = iqs / 32; // 0,1,2,3 const uint n = iqs / 32; // 0,1,2,3
const uint b = (iqs % 32) / 16; // 0,1 const uint b = (iqs % 32) / 16; // 0,1
@ -240,12 +236,12 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
const float d = loadd.x * sc; const float d = loadd.x * sc;
const float m = -loadd.y * mbyte; const float m = -loadd.y * mbyte;
const uint qs = (data_a_packed32[ib].qs[qsi / 4] >> (b * 4)) & 0x0F0F0F0F; const uint qs = (uint(data_a_packed16[ib].qs[qsi / 2]) >> (b * 4)) & 0x0F0F;
const uint qh = ((data_a_packed32[ib].qh[qhi / 4] >> (iqs / 16)) & 0x01010101) << 4; const uint qh = ((uint(data_a_packed16[ib].qh[qhi / 2]) >> (iqs / 16)) & 0x0101) << 4;
const vec4 q = vec4(unpack8(qs | qh)); const vec2 q = vec2(unpack8(qs | qh).xy);
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(fma(d, q.x, m), fma(d, q.y, m)); buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, q.x, m),
buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(fma(d, q.z, m), fma(d, q.w, m)); fma(d, q.y, m));
#elif defined(DATA_A_Q6_K) #elif defined(DATA_A_Q6_K)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2;
@ -459,7 +455,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy);
#elif defined(DATA_A_IQ4_NL) #elif defined(DATA_A_IQ4_NL)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; const uint buf_idx = col * SHMEM_STRIDE + row;
const uint ib = idx / 8; const uint ib = idx / 8;
const uint iqs = idx & 0x07; const uint iqs = idx & 0x07;
@ -473,7 +469,7 @@ void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uin
kvalues_iq4nl[vui >> 12]); kvalues_iq4nl[vui >> 12]);
#elif defined(DATA_A_MXFP4) #elif defined(DATA_A_MXFP4)
const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row;
const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 4; const uint buf_idx = col * SHMEM_STRIDE + row;
const uint ib = idx / 8; const uint ib = idx / 8;
const uint iqs = (idx & 0x07) * 2; const uint iqs = (idx & 0x07) * 2;

View File

@ -15,7 +15,6 @@
layout (push_constant) uniform parameter layout (push_constant) uniform parameter
{ {
uint ne; uint ne;
uint num_blocks;
} p; } p;
#include "types.glsl" #include "types.glsl"
@ -34,7 +33,8 @@ layout (binding = 1) writeonly buffer D {block_q8_1_x4 data_b[];};
shared float shmem[GROUP_SIZE]; shared float shmem[GROUP_SIZE];
#endif #endif
void quantize(const uint wgid) { void quantize() {
const uint wgid = gl_WorkGroupID.x;
const uint tid = INVOCATION_ID; const uint tid = INVOCATION_ID;
// Each thread handles a vec4, so 8 threads handle a block // Each thread handles a vec4, so 8 threads handle a block
@ -45,7 +45,11 @@ void quantize(const uint wgid) {
const uint ib = wgid * blocks_per_group + block_in_wg; const uint ib = wgid * blocks_per_group + block_in_wg;
const uint iqs = tid % 8; const uint iqs = tid % 8;
#ifdef QBLOCK_X4 #ifndef QBLOCK_X4
if (ib >= gl_NumWorkGroups.x * blocks_per_group) {
return;
}
#else
const uint ibx4_outer = ib / 4; const uint ibx4_outer = ib / 4;
const uint ibx4_inner = ib % 4; const uint ibx4_inner = ib % 4;
@ -119,9 +123,5 @@ void quantize(const uint wgid) {
} }
void main() { void main() {
uint wgid = gl_WorkGroupID.x; quantize();
while (wgid < p.num_blocks) {
quantize(wgid);
wgid += gl_NumWorkGroups.x;
}
} }

View File

@ -7,10 +7,6 @@
#include "types.glsl" #include "types.glsl"
#define GATING_FUNC_SOFTMAX 0
#define GATING_FUNC_SIGMOID 1
#define GATING_FUNC_SOFTMAX_WEIGHT 2
layout (push_constant) uniform parameter layout (push_constant) uniform parameter
{ {
uint n_rows; uint n_rows;
@ -18,18 +14,15 @@ layout (push_constant) uniform parameter
uint n_expert_used; uint n_expert_used;
float clamp_min; float clamp_min;
float clamp_max; float clamp_max;
uint gating_func;
uint has_bias;
uint with_norm;
float output_scale;
float output_bias;
}; };
layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in; layout(local_size_x_id = 0, local_size_y = 4, local_size_z = 1) in;
layout(constant_id = 0) const uint WARP_SIZE = 32; layout(constant_id = 0) const uint WARP_SIZE = 32;
layout(constant_id = 1) const uint n_experts_spec = 512; layout(constant_id = 1) const uint n_experts_spec = 512;
layout(constant_id = 2) const bool nexperts_use_push = false; layout(constant_id = 2) const bool with_norm = true;
layout(constant_id = 3) const bool late_softmax = false;
layout(constant_id = 4) const bool nexperts_use_push = false;
uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec; uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec;
@ -38,9 +31,8 @@ uint n_experts = nexperts_use_push ? n_experts_push : n_experts_spec;
const uint experts_per_thread = CEIL_DIV(n_experts_spec, WARP_SIZE); const uint experts_per_thread = CEIL_DIV(n_experts_spec, WARP_SIZE);
layout (binding = 0, std430) readonly buffer Logits {float logits[];}; layout (binding = 0, std430) readonly buffer Logits {float logits[];};
layout (binding = 1, std430) readonly buffer BiasProbs {float bias[];}; layout (binding = 1, std430) writeonly buffer Weights {float weights[];};
layout (binding = 2, std430) writeonly buffer Weights {float weights[];}; layout (binding = 2, std430) writeonly buffer Ids {uint ids[];};
layout (binding = 3, std430) writeonly buffer Ids {uint ids[];};
const float INFINITY = 1.0 / 0.0; const float INFINITY = 1.0 / 0.0;
@ -95,45 +87,20 @@ void main() {
} }
const uint logits_offset = n_experts * row; const uint logits_offset = n_experts * row;
const uint bias_offset = 0; // 1D
const uint weights_offset = n_expert_used * row; const uint weights_offset = n_expert_used * row;
const uint ids_offset = n_experts * row; const uint ids_offset = n_experts * row;
const uint lane = gl_SubgroupInvocationID; const uint lane = gl_SubgroupInvocationID;
float probs[experts_per_thread]; float wt[experts_per_thread];
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
probs[i] = -INFINITY;
}
[[unroll]] [[unroll]]
for (uint i = 0; i < n_experts; i += WARP_SIZE) { for (uint i = 0; i < n_experts; i += WARP_SIZE) {
const uint expert = i + lane; const uint expert = i + lane;
probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY; wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[logits_offset + expert] : -INFINITY;
} }
if (gating_func == GATING_FUNC_SOFTMAX) { if (!late_softmax) {
softmax_warp_inplace(probs, n_experts, lane, nexperts_use_push); softmax_warp_inplace(wt, n_experts, lane, nexperts_use_push);
} else if (gating_func == GATING_FUNC_SIGMOID) {
[[unroll]]
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
const uint expert = i + lane;
probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? 1.f / (1.f + exp(-probs[i / WARP_SIZE])) : -INFINITY;
}
}
float selection_probs[experts_per_thread];
if (has_bias != 0) {
[[unroll]]
for (uint i = 0; i < n_experts; i += WARP_SIZE) {
const uint expert = i + lane;
selection_probs[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? probs[i / WARP_SIZE] + bias[bias_offset + expert] : -INFINITY;
}
} else {
[[unroll]]
for (int i = 0; i < experts_per_thread; i++) {
selection_probs[i] = probs[i];
}
} }
// at this point, each thread holds a portion of softmax, // at this point, each thread holds a portion of softmax,
@ -150,16 +117,14 @@ void main() {
} }
for (int k = 0; k < n_expert_used; k++) { for (int k = 0; k < n_expert_used; k++) {
float max_val = probs[0]; float max_val = wt[0];
float max_val_s = selection_probs[0];
uint max_expert = lane; uint max_expert = lane;
[[unroll]] [[unroll]]
for (uint i = WARP_SIZE; i < n_experts; i += WARP_SIZE) { for (int i = 1; i < experts_per_thread; i++) {
const uint expert = i + lane; const uint expert = lane + i * WARP_SIZE;
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_probs[i / WARP_SIZE] > max_val_s) { if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
max_val = probs[i / WARP_SIZE]; max_val = wt[i];
max_val_s = selection_probs[i / WARP_SIZE];
max_expert = expert; max_expert = expert;
} }
} }
@ -167,11 +132,9 @@ void main() {
[[unroll]] [[unroll]]
for (uint mask = WARP_SIZE / 2; mask > 0; mask /= 2) { for (uint mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
const float val = subgroupShuffleXor(max_val, mask); const float val = subgroupShuffleXor(max_val, mask);
const float val_s = subgroupShuffleXor(max_val_s, mask);
const uint expert = subgroupShuffleXor(max_expert, mask); const uint expert = subgroupShuffleXor(max_expert, mask);
if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) { if (val > max_val || (val == max_val && expert < max_expert)) {
max_val = val; max_val = val;
max_val_s = val_s;
max_expert = expert; max_expert = expert;
} }
} }
@ -181,14 +144,16 @@ void main() {
} }
if ((max_expert & (WARP_SIZE - 1)) == lane) { if ((max_expert & (WARP_SIZE - 1)) == lane) {
selection_probs[max_expert / WARP_SIZE] = -INFINITY; wt[max_expert / WARP_SIZE] = -INFINITY;
ids[ids_offset + k] = max_expert; ids[ids_offset + k] = max_expert;
wt_sum += max_val; if (with_norm) {
wt_sum += max_val;
}
} }
} }
if (with_norm != 0) { if (with_norm) {
wt_sum = subgroupAdd(wt_sum); wt_sum = subgroupAdd(wt_sum);
wt_sum = clamp(wt_sum, clamp_min, clamp_max); wt_sum = clamp(wt_sum, clamp_min, clamp_max);
const float inv_sum = 1.0f / wt_sum; const float inv_sum = 1.0f / wt_sum;
@ -199,7 +164,7 @@ void main() {
} }
} }
if (gating_func == GATING_FUNC_SOFTMAX_WEIGHT) { if (late_softmax) {
softmax_warp_inplace(output_weights, n_expert_used, lane, true); softmax_warp_inplace(output_weights, n_expert_used, lane, true);
} }
@ -207,7 +172,7 @@ void main() {
for (uint i = 0; i < experts_per_thread; ++i) { for (uint i = 0; i < experts_per_thread; ++i) {
uint idx = i * WARP_SIZE + lane; uint idx = i * WARP_SIZE + lane;
if (idx < n_expert_used) { if (idx < n_expert_used) {
weights[weights_offset + idx] = output_scale * output_weights[i] + output_bias; weights[weights_offset + idx] = output_weights[i];
} }
} }
} }

View File

@ -396,12 +396,6 @@ struct block_iq1_s {
uint16_t qh[QUANT_K_IQ1_S/32]; uint16_t qh[QUANT_K_IQ1_S/32];
}; };
struct block_iq1_s_packed16 {
float16_t d;
uint16_t qs[QUANT_K_IQ1_S/8/2];
uint16_t qh[QUANT_K_IQ1_S/32];
};
#define QUANT_K_IQ1_M 256 #define QUANT_K_IQ1_M 256
#define QUANT_R_IQ1_M 1 #define QUANT_R_IQ1_M 1
@ -411,18 +405,6 @@ struct block_iq1_m {
uint16_t scales[QUANT_K_IQ1_M/64]; uint16_t scales[QUANT_K_IQ1_M/64];
}; };
struct block_iq1_m_packed16 {
uint16_t qs[QUANT_K_IQ1_M/8/2];
uint16_t qh[QUANT_K_IQ1_M/16/2];
uint16_t scales[QUANT_K_IQ1_M/64];
};
struct block_iq1_m_packed32 {
uint32_t qs[QUANT_K_IQ1_M/8/4];
uint32_t qh[QUANT_K_IQ1_M/16/4];
uint32_t scales[QUANT_K_IQ1_M/64/2];
};
struct block_iq1_m_packed64 { struct block_iq1_m_packed64 {
uint64_t qs[QUANT_K_IQ1_M/8/8]; uint64_t qs[QUANT_K_IQ1_M/8/8];
uint64_t qh[QUANT_K_IQ1_M/16/8]; uint64_t qh[QUANT_K_IQ1_M/16/8];
@ -433,15 +415,12 @@ struct block_iq1_m_packed64 {
#define QUANT_K QUANT_K_IQ1_S #define QUANT_K QUANT_K_IQ1_S
#define QUANT_R QUANT_R_IQ1_S #define QUANT_R QUANT_R_IQ1_S
#define A_TYPE block_iq1_s #define A_TYPE block_iq1_s
#define A_TYPE_PACKED16 block_iq1_s_packed16
#endif #endif
#if defined(DATA_A_IQ1_M) #if defined(DATA_A_IQ1_M)
#define QUANT_K QUANT_K_IQ1_M #define QUANT_K QUANT_K_IQ1_M
#define QUANT_R QUANT_R_IQ1_M #define QUANT_R QUANT_R_IQ1_M
#define A_TYPE block_iq1_m #define A_TYPE block_iq1_m
#define A_TYPE_PACKED16 block_iq1_m_packed16
#define A_TYPE_PACKED32 block_iq1_m_packed32
#endif #endif
#if defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M) #if defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M)
@ -580,270 +559,7 @@ const uint[1024] iq1s_grid_const = {
0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557 0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557
}; };
// Same content as iq1s_grid_const except each 2-bit value is expanded to 4-bit
// and has 1 added to it (allows packed values to be extracted with & 0x0F0F0F0F
// and 0xF0F0F0F0).
const uint32_t[2048] iq1s_grid_gpu_const = {
0x00000000, 0x00000002, 0x00000101, 0x00000200, 0x00000202, 0x00010001, 0x00010101, 0x00020000,
0x00020002, 0x00020200, 0x00020202, 0x01000101, 0x01010001, 0x01010100, 0x01010102, 0x01020101,
0x02000000, 0x02000002, 0x02000200, 0x02000202, 0x02010101, 0x02020000, 0x02020002, 0x02020200,
0x02020202, 0x00000110, 0x00000111, 0x00010011, 0x00010110, 0x00010112, 0x00010211, 0x00010212,
0x00020111, 0x01000011, 0x01000112, 0x01000211, 0x01010012, 0x01010111, 0x01010212, 0x01020011,
0x01020110, 0x01020112, 0x01020210, 0x02000111, 0x02010011, 0x02010110, 0x02010112, 0x02020111,
0x00000020, 0x00000022, 0x00000220, 0x00000222, 0x00010121, 0x00020020, 0x00020022, 0x00020220,
0x00020222, 0x01000121, 0x01010021, 0x01010221, 0x01020120, 0x01020221, 0x02000020, 0x02000022,
0x02000220, 0x02000222, 0x02010021, 0x02010121, 0x02010221, 0x02020020, 0x02020022, 0x02020220,
0x02020222, 0x00011001, 0x00011100, 0x00011102, 0x00021101, 0x01001001, 0x01001201, 0x01011101,
0x01011202, 0x01021100, 0x01021101, 0x02011001, 0x02011201, 0x02021101, 0x00001011, 0x00001110,
0x00001111, 0x00001112, 0x00011111, 0x00011210, 0x00011212, 0x00021211, 0x01001010, 0x01001111,
0x01001212, 0x01011010, 0x01011011, 0x01011110, 0x01011111, 0x01011112, 0x01011211, 0x01021010,
0x01021012, 0x01021111, 0x01021210, 0x01021212, 0x02001011, 0x02011011, 0x02011111, 0x02011210,
0x02011212, 0x02021011, 0x02021110, 0x02021111, 0x02021112, 0x02021211, 0x00011120, 0x00011221,
0x01001021, 0x01001120, 0x01011020, 0x01011022, 0x01011121, 0x01011220, 0x01021020, 0x01021021,
0x01021122, 0x01021221, 0x02001121, 0x02011021, 0x02011120, 0x02011221, 0x00002000, 0x00002002,
0x00002200, 0x00002202, 0x00012101, 0x00022000, 0x00022002, 0x00022200, 0x00022202, 0x01002101,
0x01012001, 0x01012102, 0x01022101, 0x02002000, 0x02002002, 0x02002200, 0x02002202, 0x02012101,
0x02022000, 0x02022002, 0x02022200, 0x02022202, 0x00002111, 0x00012011, 0x00012110, 0x00012211,
0x00022110, 0x00022111, 0x01002011, 0x01012010, 0x01012011, 0x01012111, 0x01022011, 0x01022110,
0x01022211, 0x02012011, 0x02012110, 0x02012112, 0x02012211, 0x02022111, 0x00002020, 0x00002022,
0x00002220, 0x00002222, 0x00012121, 0x00022020, 0x00022022, 0x00022220, 0x00022222, 0x01002121,
0x01012021, 0x01012221, 0x01022021, 0x01022121, 0x02002020, 0x02002022, 0x02002121, 0x02002220,
0x02002222, 0x02012121, 0x02022020, 0x02022022, 0x02022220, 0x02022222, 0x00110000, 0x00110001,
0x00110100, 0x00110201, 0x00120100, 0x00120101, 0x01100001, 0x01100100, 0x01110000, 0x01110101,
0x01110200, 0x01120001, 0x01120100, 0x01120101, 0x01120201, 0x02110001, 0x02110100, 0x02110102,
0x02120001, 0x02120101, 0x00100011, 0x00100110, 0x00100112, 0x00100211, 0x00110010, 0x00110012,
0x00110111, 0x00110210, 0x00120011, 0x00120110, 0x00120211, 0x01100111, 0x01100212, 0x01110010,
0x01110011, 0x01110012, 0x01110110, 0x01110111, 0x01110112, 0x01110211, 0x01120010, 0x01120111,
0x02100110, 0x02110012, 0x02110111, 0x02120011, 0x02120110, 0x00110021, 0x00110120, 0x00110122,
0x00120121, 0x01100020, 0x01100122, 0x01100221, 0x01110022, 0x01110121, 0x01110220, 0x01110222,
0x01120120, 0x01120122, 0x02100121, 0x02110021, 0x02110120, 0x02110122, 0x02120121, 0x00101001,
0x00101102, 0x00101201, 0x00111100, 0x00111101, 0x00111200, 0x00111201, 0x00121001, 0x00121102,
0x01101001, 0x01101101, 0x01101102, 0x01101200, 0x01101202, 0x01111001, 0x01111100, 0x01111101,
0x01111102, 0x01111201, 0x01121002, 0x01121101, 0x01121200, 0x02101100, 0x02101201, 0x02111000,
0x02111100, 0x02111101, 0x02111200, 0x02111201, 0x02111202, 0x02121001, 0x02121100, 0x02121101,
0x02121201, 0x00101012, 0x00101111, 0x00101212, 0x00111011, 0x00111110, 0x00111111, 0x00111112,
0x00111211, 0x00121010, 0x00121012, 0x00121111, 0x00121210, 0x00121212, 0x01101011, 0x01101110,
0x01101111, 0x01101112, 0x01111011, 0x01111012, 0x01111110, 0x01111111, 0x01111112, 0x01111211,
0x01111212, 0x01121011, 0x01121110, 0x01121111, 0x01121112, 0x01121211, 0x02101010, 0x02101012,
0x02101110, 0x02101111, 0x02101210, 0x02101212, 0x02111010, 0x02111011, 0x02111110, 0x02111111,
0x02111112, 0x02111211, 0x02111212, 0x02121010, 0x02121012, 0x02121111, 0x00101021, 0x00101120,
0x00101121, 0x00101122, 0x00111121, 0x00111122, 0x00111220, 0x00111222, 0x00121021, 0x00121122,
0x01101020, 0x01101022, 0x01101120, 0x01101121, 0x01101220, 0x01101222, 0x01111021, 0x01111121,
0x01111122, 0x01111220, 0x01111221, 0x01121021, 0x01121120, 0x01121121, 0x01121220, 0x01121221,
0x01121222, 0x02101122, 0x02101222, 0x02111022, 0x02111121, 0x02121120, 0x02121221, 0x00112001,
0x00112102, 0x00122101, 0x01102001, 0x01102100, 0x01102102, 0x01102201, 0x01112000, 0x01112101,
0x01112200, 0x01112202, 0x01122000, 0x01122001, 0x01122100, 0x01122102, 0x01122201, 0x02102101,
0x02112001, 0x02112100, 0x02122101, 0x00112010, 0x00112012, 0x00112111, 0x00112212, 0x00122011,
0x00122111, 0x01102012, 0x01102110, 0x01102111, 0x01102210, 0x01112011, 0x01112110, 0x01112111,
0x01112112, 0x01112211, 0x01112212, 0x01122010, 0x01122111, 0x01122212, 0x02102211, 0x02112011,
0x02112012, 0x02112111, 0x02112210, 0x02122011, 0x02122112, 0x02122211, 0x00102221, 0x00112122,
0x00122120, 0x00122122, 0x01102120, 0x01102122, 0x01102221, 0x01112020, 0x01112022, 0x01112121,
0x01112220, 0x01122021, 0x01122122, 0x01122221, 0x02102121, 0x02112021, 0x02112122, 0x02112222,
0x00200000, 0x00200002, 0x00200200, 0x00200202, 0x00210101, 0x00220000, 0x00220002, 0x00220101,
0x00220200, 0x00220202, 0x01200101, 0x01210001, 0x01210201, 0x01220001, 0x01220101, 0x02200000,
0x02200002, 0x02200200, 0x02200202, 0x02210101, 0x02220000, 0x02220002, 0x02220101, 0x02220200,
0x02220202, 0x00200111, 0x00210011, 0x00210110, 0x00210211, 0x00220111, 0x01200012, 0x01200110,
0x01200211, 0x01210111, 0x01210210, 0x01210212, 0x01220011, 0x01220110, 0x01220111, 0x01220112,
0x02200111, 0x02210010, 0x02210112, 0x02210211, 0x02220111, 0x00200021, 0x00200220, 0x00200222,
0x00210021, 0x00210121, 0x00220020, 0x00220022, 0x00220220, 0x00220222, 0x01200121, 0x01210021,
0x01210122, 0x01210221, 0x01220121, 0x02200021, 0x02200220, 0x02200222, 0x02210021, 0x02210121,
0x02220020, 0x02220022, 0x02220220, 0x02220222, 0x00201101, 0x00211100, 0x00211102, 0x00211201,
0x00221101, 0x01201100, 0x01201101, 0x01201102, 0x01201201, 0x01211002, 0x01211101, 0x01211200,
0x01211202, 0x01221102, 0x02201101, 0x02211001, 0x02211100, 0x02211201, 0x02221001, 0x02221101,
0x00201211, 0x00211111, 0x00221011, 0x00221211, 0x01201010, 0x01201111, 0x01201210, 0x01211011,
0x01211110, 0x01211111, 0x01211211, 0x01221012, 0x01221111, 0x01221210, 0x02201211, 0x02211010,
0x02211110, 0x02211111, 0x02211210, 0x02211212, 0x02221011, 0x02221110, 0x02221112, 0x02221211,
0x00201121, 0x00211020, 0x00211022, 0x00211221, 0x00221121, 0x01201021, 0x01201221, 0x01211121,
0x01221020, 0x01221021, 0x01221221, 0x02201120, 0x02201122, 0x02211020, 0x02211222, 0x00202000,
0x00202002, 0x00202200, 0x00202202, 0x00212101, 0x00222000, 0x00222002, 0x00222200, 0x00222202,
0x01202101, 0x01212001, 0x01212100, 0x01222101, 0x02202000, 0x02202002, 0x02202200, 0x02202202,
0x02222000, 0x02222002, 0x02222200, 0x02222202, 0x00202211, 0x00212011, 0x00212110, 0x00212211,
0x00222111, 0x01202112, 0x01202211, 0x01212012, 0x01212111, 0x01222011, 0x01222110, 0x01222112,
0x01222211, 0x02202111, 0x02212010, 0x02212112, 0x02212211, 0x02222110, 0x02222111, 0x00202020,
0x00202022, 0x00202220, 0x00202222, 0x00222020, 0x00222022, 0x00222220, 0x00222222, 0x01202121,
0x01212021, 0x01212122, 0x01212221, 0x01222121, 0x02202020, 0x02202022, 0x02202220, 0x02202222,
0x02212121, 0x02222020, 0x02222022, 0x02222220, 0x02222222, 0x10000101, 0x10010001, 0x10010102,
0x10020101, 0x11000201, 0x11010002, 0x11010101, 0x11010200, 0x11010202, 0x11020001, 0x11020100,
0x11020102, 0x12010100, 0x12010201, 0x12020001, 0x12020102, 0x10000010, 0x10000011, 0x10000110,
0x10000112, 0x10000211, 0x10010012, 0x10010111, 0x10010112, 0x10010210, 0x10010212, 0x10020011,
0x10020112, 0x10020211, 0x11000111, 0x11000210, 0x11000212, 0x11010011, 0x11010110, 0x11010111,
0x11010112, 0x11010211, 0x11010212, 0x11020111, 0x11020210, 0x11020212, 0x12000011, 0x12000110,
0x12000112, 0x12010010, 0x12010012, 0x12010111, 0x12020010, 0x12020011, 0x12020012, 0x10000121,
0x10010021, 0x10010120, 0x10010122, 0x10020121, 0x11000021, 0x11010022, 0x11010121, 0x11010222,
0x11020120, 0x11020221, 0x12000221, 0x12010120, 0x12020121, 0x10001001, 0x10011101, 0x10011201,
0x10021201, 0x11001101, 0x11001200, 0x11001202, 0x11011001, 0x11011100, 0x11011101, 0x11011102,
0x11021001, 0x11021002, 0x11021101, 0x11021200, 0x11021202, 0x12001001, 0x12001102, 0x12001201,
0x12011000, 0x12011002, 0x12011101, 0x12021000, 0x12021001, 0x12021201, 0x10001011, 0x10001012,
0x10001111, 0x10001212, 0x10011011, 0x10011110, 0x10011111, 0x10011112, 0x10011211, 0x10021010,
0x10021111, 0x10021212, 0x11001011, 0x11001110, 0x11001111, 0x11001112, 0x11001211, 0x11011010,
0x11011011, 0x11011110, 0x11011111, 0x11011112, 0x11011210, 0x11011211, 0x11021011, 0x11021110,
0x11021111, 0x11021112, 0x11021211, 0x12001012, 0x12001110, 0x12001111, 0x12001210, 0x12011011,
0x12011110, 0x12011111, 0x12011112, 0x12011211, 0x12011212, 0x12021111, 0x12021210, 0x12021212,
0x10001021, 0x10001121, 0x10001221, 0x10011120, 0x10011121, 0x10011220, 0x10011222, 0x10021021,
0x10021120, 0x10021221, 0x11001020, 0x11001022, 0x11001121, 0x11001220, 0x11011020, 0x11011021,
0x11011022, 0x11011121, 0x11011122, 0x11011221, 0x11021022, 0x11021121, 0x11021220, 0x12001021,
0x12001121, 0x12001222, 0x12011120, 0x12011121, 0x12021021, 0x12021120, 0x12021122, 0x10002101,
0x10012001, 0x10012101, 0x10012202, 0x10022101, 0x11002002, 0x11002201, 0x11012000, 0x11012101,
0x11012200, 0x11022001, 0x11022100, 0x11022102, 0x11022201, 0x12002101, 0x12012001, 0x12012100,
0x12012102, 0x12012201, 0x12022101, 0x10002011, 0x10002111, 0x10002112, 0x10002212, 0x10012010,
0x10012110, 0x10012111, 0x10012210, 0x10022011, 0x10022110, 0x10022112, 0x11002010, 0x11002111,
0x11002212, 0x11012011, 0x11012012, 0x11012110, 0x11012111, 0x11012112, 0x11012211, 0x11022010,
0x11022012, 0x11022111, 0x11022112, 0x11022212, 0x12002112, 0x12002211, 0x12012012, 0x12012111,
0x12012112, 0x12012210, 0x12022011, 0x12022110, 0x12022112, 0x12022211, 0x10012122, 0x11002120,
0x11002122, 0x11002221, 0x11012121, 0x11012220, 0x11012222, 0x11022120, 0x11022221, 0x12012120,
0x12022121, 0x10100001, 0x10100100, 0x10100101, 0x10100102, 0x10100201, 0x10110002, 0x10110101,
0x10110202, 0x10120001, 0x10120100, 0x10120201, 0x11100000, 0x11100101, 0x11100200, 0x11110001,
0x11110100, 0x11110101, 0x11110102, 0x11110201, 0x11120101, 0x11120200, 0x12100102, 0x12100201,
0x12110101, 0x12110200, 0x12120000, 0x12120001, 0x12120102, 0x12120201, 0x10100111, 0x10100210,
0x10100211, 0x10100212, 0x10110011, 0x10110110, 0x10110111, 0x10110112, 0x10110210, 0x10110211,
0x10120010, 0x10120111, 0x10120112, 0x10120210, 0x10120212, 0x11100011, 0x11100110, 0x11100111,
0x11100112, 0x11100211, 0x11110010, 0x11110011, 0x11110012, 0x11110110, 0x11110111, 0x11110112,
0x11110210, 0x11110211, 0x11110212, 0x11120011, 0x11120110, 0x11120111, 0x11120112, 0x11120211,
0x12100012, 0x12100111, 0x12110011, 0x12110110, 0x12110111, 0x12110112, 0x12110211, 0x12120010,
0x12120111, 0x12120212, 0x10100021, 0x10100122, 0x10110022, 0x10110121, 0x10110222, 0x10120021,
0x10120120, 0x11100022, 0x11100121, 0x11100222, 0x11110021, 0x11110120, 0x11110121, 0x11110122,
0x11110221, 0x11120022, 0x11120121, 0x12100121, 0x12110020, 0x12110022, 0x12110121, 0x12110221,
0x12110222, 0x12120120, 0x10101100, 0x10101101, 0x10111001, 0x10111100, 0x10111101, 0x10111102,
0x10111200, 0x10111201, 0x10121001, 0x10121101, 0x10121200, 0x10121202, 0x11101001, 0x11101100,
0x11101101, 0x11101102, 0x11101201, 0x11101202, 0x11111000, 0x11111001, 0x11111100, 0x11111101,
0x11111102, 0x11111200, 0x11111201, 0x11111202, 0x11121001, 0x11121002, 0x11121100, 0x11121101,
0x11121102, 0x11121201, 0x12101000, 0x12101200, 0x12101202, 0x12111001, 0x12111100, 0x12111101,
0x12111102, 0x12111201, 0x12121001, 0x12121100, 0x12121101, 0x12121202, 0x10101011, 0x10101012,
0x10101110, 0x10101111, 0x10101112, 0x10101211, 0x10111010, 0x10111011, 0x10111012, 0x10111110,
0x10111111, 0x10111112, 0x10111211, 0x10111212, 0x10121011, 0x10121110, 0x10121111, 0x10121112,
0x10121211, 0x11101010, 0x11101011, 0x11101012, 0x11101110, 0x11101111, 0x11101112, 0x11101210,
0x11101211, 0x11111010, 0x11111011, 0x11111012, 0x11111110, 0x11111111, 0x11111112, 0x11111210,
0x11111211, 0x11111212, 0x11121010, 0x11121011, 0x11121110, 0x11121111, 0x11121112, 0x11121210,
0x11121211, 0x11121212, 0x12101011, 0x12101110, 0x12101111, 0x12101211, 0x12101212, 0x12111010,
0x12111011, 0x12111110, 0x12111111, 0x12111112, 0x12111210, 0x12111211, 0x12121011, 0x12121110,
0x12121111, 0x12121112, 0x12121211, 0x10101020, 0x10101021, 0x10101022, 0x10101120, 0x10101122,
0x10101220, 0x10101221, 0x10111021, 0x10111120, 0x10111121, 0x10111220, 0x10111221, 0x10121020,
0x10121021, 0x10121022, 0x10121120, 0x10121121, 0x10121122, 0x10121220, 0x10121221, 0x11101021,
0x11101121, 0x11101122, 0x11101220, 0x11101221, 0x11101222, 0x11111020, 0x11111021, 0x11111022,
0x11111120, 0x11111121, 0x11111122, 0x11111220, 0x11111221, 0x11111222, 0x11121021, 0x11121120,
0x11121121, 0x11121221, 0x12101022, 0x12101121, 0x12101122, 0x12101220, 0x12101221, 0x12101222,
0x12111021, 0x12111121, 0x12111222, 0x12121022, 0x12121121, 0x12121122, 0x12121220, 0x12121221,
0x10102100, 0x10102101, 0x10102102, 0x10102201, 0x10112000, 0x10112101, 0x10112200, 0x10122001,
0x10122202, 0x11102101, 0x11102200, 0x11102202, 0x11112001, 0x11112100, 0x11112101, 0x11112102,
0x11112200, 0x11112201, 0x11122000, 0x11122002, 0x11122100, 0x11122101, 0x12102002, 0x12102201,
0x12112000, 0x12112002, 0x12112101, 0x12112200, 0x12122001, 0x12122201, 0x10102011, 0x10102012,
0x10102111, 0x10102212, 0x10112011, 0x10112110, 0x10112111, 0x10112112, 0x10112211, 0x10122111,
0x11102011, 0x11102110, 0x11102111, 0x11102112, 0x11102211, 0x11112010, 0x11112011, 0x11112012,
0x11112110, 0x11112111, 0x11112112, 0x11112210, 0x11112211, 0x11112212, 0x11122011, 0x11122110,
0x11122111, 0x11122112, 0x11122211, 0x12102011, 0x12102111, 0x12102211, 0x12112011, 0x12112110,
0x12112111, 0x12112112, 0x12112210, 0x12112211, 0x12122111, 0x10102120, 0x10102220, 0x10112121,
0x10112222, 0x10122020, 0x10122121, 0x10122122, 0x10122221, 0x11102121, 0x11102220, 0x11102221,
0x11112021, 0x11112121, 0x11112122, 0x11112220, 0x11112221, 0x11122022, 0x11122121, 0x11122220,
0x11122222, 0x12102021, 0x12102222, 0x12112022, 0x12112121, 0x12112122, 0x12112220, 0x12112222,
0x12122021, 0x10200101, 0x10210100, 0x10210102, 0x10210201, 0x10220101, 0x11200100, 0x11210000,
0x11210101, 0x11210102, 0x11210200, 0x11210202, 0x11220001, 0x11220100, 0x11220102, 0x11220201,
0x12200001, 0x12210102, 0x12220101, 0x10200011, 0x10200110, 0x10200112, 0x10200211, 0x10210012,
0x10210111, 0x10220011, 0x10220012, 0x10220112, 0x10220211, 0x11200111, 0x11200211, 0x11210011,
0x11210111, 0x11210112, 0x11210211, 0x11220111, 0x11220112, 0x11220212, 0x12200110, 0x12200212,
0x12210012, 0x12210111, 0x12220011, 0x12220112, 0x12220211, 0x10210021, 0x10210122, 0x10210221,
0x11200020, 0x11200021, 0x11200122, 0x11210121, 0x11210122, 0x11210220, 0x11220020, 0x12200121,
0x12210021, 0x12210122, 0x12220121, 0x10211001, 0x10211002, 0x10211101, 0x10211102, 0x10211202,
0x10221001, 0x10221102, 0x10221201, 0x11201000, 0x11201002, 0x11201101, 0x11201200, 0x11201202,
0x11211001, 0x11211100, 0x11211101, 0x11211102, 0x11211201, 0x11211202, 0x11221000, 0x11221002,
0x11221101, 0x12201100, 0x12201101, 0x12201201, 0x12211000, 0x12211002, 0x12211100, 0x12211101,
0x12211102, 0x12211200, 0x12211202, 0x12221001, 0x12221100, 0x12221201, 0x10201111, 0x10201210,
0x10201212, 0x10211011, 0x10211111, 0x10211112, 0x10211211, 0x11201110, 0x11201111, 0x11201112,
0x11201211, 0x11211010, 0x11211011, 0x11211110, 0x11211111, 0x11211112, 0x11211211, 0x11221011,
0x11221110, 0x11221111, 0x11221112, 0x11221211, 0x12201112, 0x12201211, 0x12201212, 0x12211011,
0x12211111, 0x12211112, 0x12211211, 0x12211212, 0x12221012, 0x12221111, 0x12221112, 0x12221210,
0x10201022, 0x10201221, 0x10211121, 0x10221020, 0x10221122, 0x10221220, 0x10221221, 0x11201020,
0x11201121, 0x11201220, 0x11201222, 0x11211021, 0x11211120, 0x11211121, 0x11211122, 0x11211220,
0x11211222, 0x11221020, 0x11221121, 0x11221220, 0x12201020, 0x12201022, 0x12201121, 0x12201222,
0x12211120, 0x12211122, 0x12211220, 0x12211221, 0x12221020, 0x12221120, 0x12221122, 0x12221222,
0x10212102, 0x10212201, 0x10222101, 0x11202001, 0x11212002, 0x11212101, 0x11212202, 0x11222001,
0x11222201, 0x12202101, 0x12212001, 0x12212200, 0x12222102, 0x10202011, 0x10202110, 0x10212010,
0x10212111, 0x10222011, 0x10222110, 0x10222112, 0x10222211, 0x11202010, 0x11202011, 0x11202111,
0x11202112, 0x11202210, 0x11212011, 0x11212110, 0x11212111, 0x11212112, 0x11212211, 0x11222010,
0x11222111, 0x11222212, 0x12202012, 0x12202110, 0x12202212, 0x12212111, 0x12222011, 0x12222110,
0x12222111, 0x12222211, 0x10212021, 0x10212122, 0x10212220, 0x11202021, 0x11202120, 0x11202221,
0x11212020, 0x11212121, 0x11212220, 0x11212222, 0x11222120, 0x11222121, 0x11222221, 0x12202122,
0x12212120, 0x12212220, 0x12212222, 0x12222122, 0x20000000, 0x20000002, 0x20000200, 0x20000202,
0x20020000, 0x20020002, 0x20020200, 0x20020202, 0x21000101, 0x21010000, 0x21010001, 0x21010100,
0x21010102, 0x21010201, 0x21020101, 0x22000000, 0x22000002, 0x22000200, 0x22000202, 0x22010101,
0x22020000, 0x22020002, 0x22020200, 0x22020202, 0x20000111, 0x20010011, 0x20010110, 0x20010112,
0x20010211, 0x20020111, 0x21000011, 0x21000110, 0x21000211, 0x21010010, 0x21010012, 0x21010111,
0x21010112, 0x21010210, 0x21010211, 0x21020110, 0x21020112, 0x21020211, 0x22000111, 0x22000211,
0x22010110, 0x22010112, 0x22010211, 0x22020111, 0x20000020, 0x20000022, 0x20000220, 0x20000222,
0x20010121, 0x20020020, 0x20020022, 0x20020220, 0x20020222, 0x21010021, 0x21010120, 0x21010221,
0x21020121, 0x22000020, 0x22000022, 0x22000220, 0x22000222, 0x22010121, 0x22020020, 0x22020022,
0x22020220, 0x22020222, 0x20011100, 0x20011201, 0x21001001, 0x21001100, 0x21011001, 0x21011101,
0x21011202, 0x21021001, 0x21021100, 0x21021201, 0x22011100, 0x22011201, 0x20001011, 0x20001211,
0x20011012, 0x20011111, 0x20011212, 0x20021112, 0x20021211, 0x21001010, 0x21001011, 0x21001111,
0x21001210, 0x21011011, 0x21011110, 0x21011111, 0x21011112, 0x21011211, 0x21011212, 0x21021111,
0x21021112, 0x21021210, 0x21021212, 0x22001011, 0x22001110, 0x22001112, 0x22001211, 0x22011010,
0x22011012, 0x22011111, 0x22011210, 0x22021112, 0x20011021, 0x20011122, 0x20011221, 0x20021121,
0x21001021, 0x21001120, 0x21001221, 0x21001222, 0x21011020, 0x21011121, 0x21011221, 0x21011222,
0x21021021, 0x21021122, 0x21021222, 0x22001121, 0x22011021, 0x22011222, 0x22021120, 0x20002000,
0x20002002, 0x20002200, 0x20002202, 0x20012101, 0x20022000, 0x20022002, 0x20022200, 0x20022202,
0x21002001, 0x21002101, 0x21012001, 0x21012100, 0x21012201, 0x21022101, 0x21022201, 0x22002000,
0x22002002, 0x22002200, 0x22002202, 0x22012101, 0x22022000, 0x22022002, 0x22022200, 0x22022202,
0x20002111, 0x20002112, 0x20012011, 0x20012110, 0x20012112, 0x20022111, 0x21002011, 0x21002110,
0x21002112, 0x21002211, 0x21012010, 0x21012012, 0x21012111, 0x21012212, 0x21022011, 0x21022110,
0x22002111, 0x22012112, 0x22012211, 0x22022111, 0x20002020, 0x20002022, 0x20002220, 0x20002222,
0x20012121, 0x20022020, 0x20022022, 0x20022220, 0x20022222, 0x21002121, 0x21012021, 0x21012120,
0x21012122, 0x22002020, 0x22002022, 0x22002220, 0x22002222, 0x22012121, 0x22022020, 0x22022022,
0x22022220, 0x22022222, 0x20100101, 0x20110001, 0x20110102, 0x20110200, 0x20110201, 0x20120101,
0x21100001, 0x21100102, 0x21100201, 0x21110101, 0x21110200, 0x21110202, 0x21120201, 0x21120202,
0x22100101, 0x22110001, 0x22110100, 0x22110102, 0x22110201, 0x22120101, 0x20100011, 0x20100110,
0x20100112, 0x20100211, 0x20110010, 0x20110111, 0x20110210, 0x20110212, 0x20120011, 0x20120110,
0x20120112, 0x20120211, 0x21100010, 0x21100111, 0x21110010, 0x21110011, 0x21110110, 0x21110111,
0x21110112, 0x21110211, 0x21120012, 0x21120111, 0x22100110, 0x22100112, 0x22110012, 0x22110111,
0x22110210, 0x22120011, 0x22120110, 0x22120112, 0x22120211, 0x20100121, 0x20110021, 0x20110120,
0x20110221, 0x20120121, 0x21100120, 0x21100122, 0x21100221, 0x21110020, 0x21110022, 0x21110121,
0x21110220, 0x21120122, 0x21120221, 0x22100121, 0x22110120, 0x22110122, 0x22120221, 0x20101001,
0x20101100, 0x20101102, 0x20111000, 0x20111101, 0x20111200, 0x20121102, 0x21101000, 0x21101202,
0x21111001, 0x21111100, 0x21111101, 0x21111102, 0x21111200, 0x21111201, 0x21121000, 0x21121001,
0x21121002, 0x21121101, 0x22101100, 0x22101102, 0x22111002, 0x22111100, 0x22111101, 0x22111200,
0x22121001, 0x22121201, 0x20101010, 0x20101111, 0x20101210, 0x20101212, 0x20111010, 0x20111011,
0x20111110, 0x20111111, 0x20111112, 0x20111211, 0x20121011, 0x20121111, 0x20121211, 0x20121212,
0x21101011, 0x21101110, 0x21101111, 0x21101112, 0x21101211, 0x21111010, 0x21111011, 0x21111012,
0x21111110, 0x21111111, 0x21111112, 0x21111210, 0x21111211, 0x21111212, 0x21121011, 0x21121110,
0x21121111, 0x21121112, 0x21121211, 0x22101011, 0x22101111, 0x22101210, 0x22111011, 0x22111012,
0x22111110, 0x22111111, 0x22111112, 0x22111211, 0x22111212, 0x22121010, 0x22121012, 0x22121111,
0x22121210, 0x22121212, 0x20101021, 0x20101120, 0x20111020, 0x20111121, 0x20111221, 0x20121020,
0x20121122, 0x20121221, 0x21101121, 0x21101220, 0x21101221, 0x21111021, 0x21111022, 0x21111121,
0x21111122, 0x21111221, 0x21121121, 0x21121220, 0x22101022, 0x22101120, 0x22101221, 0x22101222,
0x22111022, 0x22111120, 0x22111121, 0x22121120, 0x22121122, 0x22121221, 0x20102101, 0x20112102,
0x20112201, 0x20122101, 0x21102001, 0x21102102, 0x21112000, 0x21112002, 0x21112101, 0x21112102,
0x21112202, 0x21122100, 0x21122101, 0x22102101, 0x22112001, 0x22112102, 0x22112201, 0x22122101,
0x20102110, 0x20102112, 0x20102211, 0x20112010, 0x20112012, 0x20112111, 0x20112210, 0x20112212,
0x20122010, 0x20122011, 0x20122110, 0x20122112, 0x21102010, 0x21102012, 0x21102111, 0x21102210,
0x21102212, 0x21112011, 0x21112110, 0x21112111, 0x21112112, 0x21112211, 0x21122012, 0x21122111,
0x21122112, 0x21122212, 0x22102011, 0x22102110, 0x22112010, 0x22112012, 0x22112111, 0x22112212,
0x22122011, 0x22122112, 0x20102121, 0x20112121, 0x20122121, 0x21102120, 0x21102122, 0x21102221,
0x21112020, 0x21112121, 0x21112220, 0x21122021, 0x22102121, 0x22112021, 0x22112120, 0x22112121,
0x22112122, 0x20200000, 0x20200002, 0x20200200, 0x20200202, 0x20210101, 0x20220000, 0x20220002,
0x20220200, 0x20220202, 0x21200101, 0x21210001, 0x21210100, 0x21210102, 0x21210201, 0x22200000,
0x22200002, 0x22200200, 0x22200202, 0x22210101, 0x22220000, 0x22220002, 0x22220200, 0x22220202,
0x20200111, 0x20200211, 0x20210011, 0x20210110, 0x20210112, 0x20210211, 0x20210212, 0x21200112,
0x21200211, 0x21210011, 0x21210111, 0x21210210, 0x21210212, 0x21220011, 0x21220110, 0x22200111,
0x22210010, 0x22210012, 0x22210112, 0x22210211, 0x20200022, 0x20200220, 0x20200222, 0x20210020,
0x20210221, 0x20220022, 0x20220220, 0x20220222, 0x21200121, 0x21210021, 0x21210122, 0x21210221,
0x21220121, 0x22200020, 0x22200022, 0x22200220, 0x22200222, 0x22210121, 0x22220020, 0x22220022,
0x22220220, 0x22220222, 0x20211201, 0x20221101, 0x21201001, 0x21201100, 0x21211000, 0x21211100,
0x21211101, 0x21211200, 0x21211202, 0x21221001, 0x21221101, 0x21221102, 0x21221200, 0x21221201,
0x22201101, 0x20201112, 0x20201211, 0x20211010, 0x20211012, 0x20211111, 0x20211210, 0x20221112,
0x20221211, 0x21201012, 0x21201111, 0x21211011, 0x21211110, 0x21211111, 0x21211112, 0x21211211,
0x21221111, 0x21221212, 0x22201011, 0x22201110, 0x22201111, 0x22201112, 0x22201211, 0x22211012,
0x22211111, 0x22211210, 0x20201121, 0x20211021, 0x20211122, 0x20211222, 0x20221021, 0x20221121,
0x21201120, 0x21201122, 0x21201222, 0x21211022, 0x21211121, 0x21211122, 0x21211220, 0x21221020,
0x21221022, 0x22201122, 0x22211020, 0x22211121, 0x22211122, 0x22211221, 0x22221021, 0x22221120,
0x22221122, 0x20202000, 0x20202002, 0x20202200, 0x20202202, 0x20222000, 0x20222002, 0x20222200,
0x20222202, 0x21212001, 0x21212100, 0x21212102, 0x21212201, 0x22202000, 0x22202002, 0x22202200,
0x22202202, 0x22212101, 0x22222000, 0x22222002, 0x22222200, 0x22222202, 0x20202111, 0x20212110,
0x20212211, 0x20222011, 0x20222111, 0x21202011, 0x21212010, 0x21212111, 0x21212212, 0x21222011,
0x21222112, 0x21222211, 0x22212010, 0x22212112, 0x20202020, 0x20202022, 0x20202220, 0x20202222,
0x20222020, 0x20222022, 0x20222220, 0x20222222, 0x21212021, 0x21212120, 0x21212122, 0x22202020,
0x22202022, 0x22202220, 0x22202222, 0x22212121, 0x22222020, 0x22222022, 0x22222220, 0x22222222,
};
shared uint16_t iq1s_grid[2048]; shared uint16_t iq1s_grid[2048];
shared uint32_t iq1s_grid_gpu[2048];
#define NEEDS_INIT_IQ_SHMEM #define NEEDS_INIT_IQ_SHMEM
void init_iq_shmem(uvec3 wgsize) void init_iq_shmem(uvec3 wgsize)
@ -857,12 +573,6 @@ void init_iq_shmem(uvec3 wgsize)
iq1s_grid[2*idx+1] = g.y; iq1s_grid[2*idx+1] = g.y;
} }
} }
[[unroll]] for (uint i = 0; i < iq1s_grid_gpu_const.length(); i += wgsize.x) {
uint idx = i + gl_LocalInvocationIndex.x;
if (iq1s_grid_gpu_const.length() % wgsize.x == 0 || idx < iq1s_grid_gpu_const.length()) {
iq1s_grid_gpu[idx] = iq1s_grid_gpu_const[idx];
}
}
barrier(); barrier();
} }
#endif #endif

View File

@ -552,9 +552,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
for (const auto& tname : type_names) { for (const auto& tname : type_names) {
std::string load_vec_quant = "2"; std::string load_vec_quant = "2";
if ((tname == "q4_0") || (tname == "q4_1") || (tname == "q5_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s"))
load_vec_quant = "8"; load_vec_quant = "8";
else if ((tname == "q5_0") || (tname == "q8_0") || (tname == "q2_k") || (tname == "q4_k") || (tname == "q5_k") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4")) else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4"))
load_vec_quant = "4"; load_vec_quant = "4";
if (tname == "bf16") { if (tname == "bf16") {
@ -685,7 +685,7 @@ void process_shaders() {
// mul mat vec with integer dot product // mul mat vec with integer dot product
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname) || tname == "iq1_s" || tname == "iq1_m") { if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname)) {
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}})); string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
@ -944,8 +944,6 @@ void process_shaders() {
string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}})); string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}));
string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("cumsum_f32", "cumsum.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("cumsum_multipass1_f32", "cumsum_multipass1.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("cumsum_multipass2_f32", "cumsum_multipass2.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("count_experts", "count_experts.comp", merge_maps(base_dict, {{"A_TYPE", "uint"}, {"D_TYPE", "uint"}})); string_to_spv("count_experts", "count_experts.comp", merge_maps(base_dict, {{"A_TYPE", "uint"}, {"D_TYPE", "uint"}}));
@ -1125,7 +1123,7 @@ void write_output_files() {
for (const std::string& btype : btypes) { for (const std::string& btype : btypes) {
for (const auto& tname : type_names) { for (const auto& tname : type_names) {
if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && !is_k_quant(tname) && tname != "iq1_s" && tname != "iq1_m") { if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && !is_k_quant(tname)) {
continue; continue;
} }
hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n"; hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n";

Some files were not shown because too many files have changed in this diff Show More