Merge pull request #1 from Iemand005/master

Update MSVC fix branch
This commit is contained in:
Lasse Lauwerys 2026-02-04 19:21:18 +01:00 committed by GitHub
commit d69326df7c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
208 changed files with 9293 additions and 4502 deletions

View File

@ -4,7 +4,7 @@
# the module `{ pkgs ... }: { /* config */ }` implicitly uses # the module `{ pkgs ... }: { /* config */ }` implicitly uses
# `_module.args.pkgs` (defined in this case by flake-parts). # `_module.args.pkgs` (defined in this case by flake-parts).
perSystem = perSystem =
{ system, ... }: { lib, system, ... }:
{ {
_module.args = { _module.args = {
# Note: bringing up https://zimbatm.com/notes/1000-instances-of-nixpkgs # Note: bringing up https://zimbatm.com/notes/1000-instances-of-nixpkgs
@ -33,7 +33,7 @@
"CUDA EULA" "CUDA EULA"
"cuDNN EULA" "cuDNN EULA"
] ]
) (p.meta.licenses or [ p.meta.license ]); ) (p.meta.licenses or (lib.toList p.meta.license));
}; };
# Ensure dependencies use ROCm consistently # Ensure dependencies use ROCm consistently
pkgsRocm = import inputs.nixpkgs { pkgsRocm = import inputs.nixpkgs {

View File

@ -3,6 +3,7 @@
llamaVersion, llamaVersion,
numpy, numpy,
tqdm, tqdm,
requests,
sentencepiece, sentencepiece,
pyyaml, pyyaml,
poetry-core, poetry-core,
@ -20,6 +21,7 @@ buildPythonPackage {
tqdm tqdm
sentencepiece sentencepiece
pyyaml pyyaml
requests
]; ];
src = lib.cleanSource ../../gguf-py; src = lib.cleanSource ../../gguf-py;
pythonImportsCheck = [ pythonImportsCheck = [

View File

@ -7,13 +7,6 @@
let let
pythonPackages = python3.pkgs; pythonPackages = python3.pkgs;
buildPythonPackage = pythonPackages.buildPythonPackage;
numpy = pythonPackages.numpy;
tqdm = pythonPackages.tqdm;
sentencepiece = pythonPackages.sentencepiece;
pyyaml = pythonPackages.pyyaml;
poetry-core = pythonPackages.poetry-core;
pytestCheckHook = pythonPackages.pytestCheckHook;
in in
# We're using `makeScope` instead of just writing out an attrset # We're using `makeScope` instead of just writing out an attrset
@ -23,17 +16,18 @@ in
lib.makeScope newScope (self: { lib.makeScope newScope (self: {
inherit llamaVersion; inherit llamaVersion;
gguf-py = self.callPackage ./package-gguf-py.nix { gguf-py = self.callPackage ./package-gguf-py.nix {
inherit inherit (pythonPackages)
buildPythonPackage
numpy numpy
tqdm tqdm
sentencepiece sentencepiece
poetry-core
pyyaml pyyaml
pytestCheckHook pytestCheckHook
requests
buildPythonPackage
poetry-core
; ;
}; };
python-scripts = self.callPackage ./python-scripts.nix { inherit buildPythonPackage poetry-core; }; python-scripts = self.callPackage ./python-scripts.nix { inherit (pythonPackages) buildPythonPackage poetry-core; };
llama-cpp = self.callPackage ./package.nix { }; llama-cpp = self.callPackage ./package.nix { };
docker = self.callPackage ./docker.nix { }; docker = self.callPackage ./docker.nix { };
docker-min = self.callPackage ./docker.nix { interactive = false; }; docker-min = self.callPackage ./docker.nix { interactive = false; };

View File

@ -21,7 +21,8 @@ on:
'**/*.m', '**/*.m',
'**/*.metal', '**/*.metal',
'**/*.comp', '**/*.comp',
'**/*.glsl' '**/*.glsl',
'**/*.wgsl'
] ]
pull_request: pull_request:
@ -42,7 +43,8 @@ on:
'**/*.m', '**/*.m',
'**/*.metal', '**/*.metal',
'**/*.comp', '**/*.comp',
'**/*.glsl' '**/*.glsl',
'**/*.wgsl'
] ]
concurrency: concurrency:
@ -291,6 +293,7 @@ jobs:
cmake -B build \ cmake -B build \
-DLLAMA_FATAL_WARNINGS=ON \ -DLLAMA_FATAL_WARNINGS=ON \
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \ -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
-DGGML_SANITIZE_${{ matrix.sanitizer }}=ON \
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} -DCMAKE_BUILD_TYPE=${{ matrix.build_type }}
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) cmake --build build --config ${{ matrix.build_type }} -j $(nproc)
@ -301,6 +304,7 @@ jobs:
cmake -B build \ cmake -B build \
-DLLAMA_FATAL_WARNINGS=ON \ -DLLAMA_FATAL_WARNINGS=ON \
-DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \ -DLLAMA_SANITIZE_${{ matrix.sanitizer }}=ON \
-DGGML_SANITIZE_${{ matrix.sanitizer }}=ON \
-DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \ -DCMAKE_BUILD_TYPE=${{ matrix.build_type }} \
-DGGML_OPENMP=OFF -DGGML_OPENMP=OFF
cmake --build build --config ${{ matrix.build_type }} -j $(nproc) cmake --build build --config ${{ matrix.build_type }} -j $(nproc)
@ -1371,7 +1375,7 @@ jobs:
id: update_presets id: update_presets
if: ${{ matrix.build == 'arm64-snapdragon' }} if: ${{ matrix.build == 'arm64-snapdragon' }}
run: | run: |
cp docs/backend/hexagon/CMakeUserPresets.json . cp docs/backend/snapdragon/CMakeUserPresets.json .
- name: Build - name: Build
id: ndk_build id: ndk_build
@ -1530,7 +1534,7 @@ jobs:
- name: Test - name: Test
id: ggml-ci id: ggml-ci
run: | run: |
LLAMA_ARG_THREADS=$(nproc) bash ./ci/run.sh ./tmp/results ./tmp/mnt LLAMA_ARG_THREADS=$(nproc) GG_BUILD_HIGH_PERF=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
ggml-ci-arm64-cpu-high-perf: ggml-ci-arm64-cpu-high-perf:
runs-on: ubuntu-22.04-arm runs-on: ubuntu-22.04-arm
@ -1556,7 +1560,7 @@ jobs:
- name: Test - name: Test
id: ggml-ci id: ggml-ci
run: | run: |
LLAMA_ARG_THREADS=$(nproc) GG_BUILD_NO_SVE=1 GG_BUILD_NO_BF16=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt LLAMA_ARG_THREADS=$(nproc) GG_BUILD_HIGH_PERF=1 GG_BUILD_NO_SVE=1 GG_BUILD_NO_BF16=1 GG_BUILD_EXTRA_TESTS_0=1 bash ./ci/run.sh ./tmp/results ./tmp/mnt
ggml-ci-arm64-cpu-high-perf-sve: ggml-ci-arm64-cpu-high-perf-sve:
runs-on: ubuntu-22.04-arm runs-on: ubuntu-22.04-arm

View File

@ -36,7 +36,7 @@ jobs:
strategy: strategy:
matrix: matrix:
sanitizer: [ADDRESS, UNDEFINED] # THREAD is broken sanitizer: [ADDRESS, UNDEFINED] # THREAD is very slow
build_type: [RelWithDebInfo] build_type: [RelWithDebInfo]
include: include:
- build_type: Release - build_type: Release
@ -45,7 +45,7 @@ jobs:
- build_type: Release - build_type: Release
sanitizer: "" sanitizer: ""
extra_args: "LLAMA_ARG_BACKEND_SAMPLING=1" extra_args: "LLAMA_ARG_BACKEND_SAMPLING=1"
fail-fast: false # While -DLLAMA_SANITIZE_THREAD=ON is broken fail-fast: false
steps: steps:
- name: Dependencies - name: Dependencies
@ -72,7 +72,15 @@ jobs:
- name: Build - name: Build
id: cmake_build id: cmake_build
run: | run: |
cmake -B build -DLLAMA_BUILD_BORINGSSL=ON -DGGML_SCHED_NO_REALLOC=ON cmake -B build \
-DLLAMA_BUILD_BORINGSSL=ON \
-DGGML_SCHED_NO_REALLOC=ON \
-DGGML_SANITIZE_ADDRESS=${{ matrix.sanitizer == 'ADDRESS' }} \
-DGGML_SANITIZE_THREAD=${{ matrix.sanitizer == 'THREAD' }} \
-DGGML_SANITIZE_UNDEFINED=${{ matrix.sanitizer == 'UNDEFINED' }} \
-DLLAMA_SANITIZE_ADDRESS=${{ matrix.sanitizer == 'ADDRESS' }} \
-DLLAMA_SANITIZE_THREAD=${{ matrix.sanitizer == 'THREAD' }} \
-DLLAMA_SANITIZE_UNDEFINED=${{ matrix.sanitizer == 'UNDEFINED' }}
cmake --build build --config ${{ matrix.build_type }} -j ${env:NUMBER_OF_PROCESSORS} --target llama-server cmake --build build --config ${{ matrix.build_type }} -j ${env:NUMBER_OF_PROCESSORS} --target llama-server
- name: Python setup - name: Python setup
@ -88,7 +96,7 @@ jobs:
- name: Tests - name: Tests
id: server_integration_tests id: server_integration_tests
if: ${{ (!matrix.disabled_on_pr || !github.event.pull_request) && matrix.build_type == 'Release' }} if: ${{ (!matrix.disabled_on_pr || !github.event.pull_request) }}
run: | run: |
cd tools/server/tests cd tools/server/tests
export ${{ matrix.extra_args }} export ${{ matrix.extra_args }}

1085
AUTHORS

File diff suppressed because it is too large Load Diff

View File

@ -164,29 +164,6 @@ llama_option_depr(WARNING LLAMA_SYCL GGML_SYCL)
llama_option_depr(WARNING LLAMA_SYCL_F16 GGML_SYCL_F16) llama_option_depr(WARNING LLAMA_SYCL_F16 GGML_SYCL_F16)
llama_option_depr(WARNING LLAMA_CANN GGML_CANN) llama_option_depr(WARNING LLAMA_CANN GGML_CANN)
if (NOT MSVC)
if (LLAMA_SANITIZE_THREAD)
message(STATUS "Using -fsanitize=thread")
add_compile_options(-fsanitize=thread)
link_libraries (-fsanitize=thread)
endif()
if (LLAMA_SANITIZE_ADDRESS)
message(STATUS "Using -fsanitize=address")
add_compile_options(-fsanitize=address -fno-omit-frame-pointer)
link_libraries (-fsanitize=address)
endif()
if (LLAMA_SANITIZE_UNDEFINED)
message(STATUS "Using -fsanitize=undefined")
add_compile_options(-fsanitize=undefined)
link_libraries (-fsanitize=undefined)
endif()
endif()
include("cmake/license.cmake") include("cmake/license.cmake")
license_add_file("llama.cpp" "LICENSE") license_add_file("llama.cpp" "LICENSE")

View File

@ -1,6 +1,6 @@
MIT License MIT License
Copyright (c) 2023-2024 The ggml authors Copyright (c) 2023-2026 The ggml authors
Permission is hereby granted, free of charge, to any person obtaining a copy Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal of this software and associated documentation files (the "Software"), to deal

View File

@ -213,6 +213,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- [llama.vim](https://github.com/ggml-org/llama.vim) (MIT) - [llama.vim](https://github.com/ggml-org/llama.vim) (MIT)
- [LARS](https://github.com/abgulati/LARS) (AGPL) - [LARS](https://github.com/abgulati/LARS) (AGPL)
- [Llama Assistant](https://github.com/vietanhdev/llama-assistant) (GPL) - [Llama Assistant](https://github.com/vietanhdev/llama-assistant) (GPL)
- [LlamaLib](https://github.com/undreamai/LlamaLib) (Apache-2.0)
- [LLMFarm](https://github.com/guinmoon/LLMFarm?tab=readme-ov-file) (MIT) - [LLMFarm](https://github.com/guinmoon/LLMFarm?tab=readme-ov-file) (MIT)
- [LLMUnity](https://github.com/undreamai/LLMUnity) (MIT) - [LLMUnity](https://github.com/undreamai/LLMUnity) (MIT)
- [LMStudio](https://lmstudio.ai/) (proprietary) - [LMStudio](https://lmstudio.ai/) (proprietary)

View File

@ -635,6 +635,29 @@ function gg_check_build_requirements {
fi fi
} }
function gg_run_test_backend_ops_cpu {
cd ${SRC}
cd build-ci-release
set -e
(time ./bin/test-backend-ops -b CPU ) 2>&1 | tee -a $OUT/${ci}-test-backend-ops-cpu.log
set +e
}
function gg_sum_test_backend_ops_cpu {
gg_printf '### %s\n\n' "${ci}"
gg_printf 'Runs test-backend-ops for CPU backend\n'
gg_printf '- status: %s\n' "$(cat $OUT/${ci}.exit)"
gg_printf '```\n'
gg_printf '%s\n' "$(cat $OUT/${ci}-test-backend-ops-cpu.log)"
gg_printf '```\n'
gg_printf '\n'
}
## main ## main
export LLAMA_LOG_PREFIX=1 export LLAMA_LOG_PREFIX=1
@ -663,6 +686,10 @@ ret=0
test $ret -eq 0 && gg_run ctest_debug test $ret -eq 0 && gg_run ctest_debug
test $ret -eq 0 && gg_run ctest_release test $ret -eq 0 && gg_run ctest_release
if [ ! -z ${GG_BUILD_HIGH_PERF} ]; then
test $ret -eq 0 && gg_run test_backend_ops_cpu
fi
if [ -z ${GG_BUILD_LOW_PERF} ]; then if [ -z ${GG_BUILD_LOW_PERF} ]; then
test $ret -eq 0 && gg_run embd_bge_small test $ret -eq 0 && gg_run embd_bge_small
test $ret -eq 0 && gg_run rerank_tiny test $ret -eq 0 && gg_run rerank_tiny

View File

@ -32,4 +32,27 @@ function(llama_add_compile_flags)
set(CXX_FLAGS "" PARENT_SCOPE) set(CXX_FLAGS "" PARENT_SCOPE)
endif() endif()
endif() endif()
if (NOT MSVC)
if (LLAMA_SANITIZE_THREAD)
message(STATUS "Using -fsanitize=thread")
add_compile_options(-fsanitize=thread)
link_libraries (-fsanitize=thread)
endif()
if (LLAMA_SANITIZE_ADDRESS)
message(STATUS "Using -fsanitize=address")
add_compile_options(-fsanitize=address -fno-omit-frame-pointer)
link_libraries (-fsanitize=address)
endif()
if (LLAMA_SANITIZE_UNDEFINED)
message(STATUS "Using -fsanitize=undefined")
add_compile_options(-fsanitize=undefined)
link_libraries (-fsanitize=undefined)
endif()
endif()
endfunction() endfunction()

View File

@ -75,6 +75,8 @@ add_library(${TARGET} STATIC
ngram-cache.h ngram-cache.h
ngram-map.cpp ngram-map.cpp
ngram-map.h ngram-map.h
ngram-mod.cpp
ngram-mod.h
peg-parser.cpp peg-parser.cpp
peg-parser.h peg-parser.h
preset.cpp preset.cpp

View File

@ -1301,7 +1301,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
[](common_params & params, bool value) { [](common_params & params, bool value) {
params.kv_unified = value; params.kv_unified = value;
} }
).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED})); ).set_env("LLAMA_ARG_KV_UNIFIED").set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_BATCHED, LLAMA_EXAMPLE_BENCH}));
add_opt(common_arg( add_opt(common_arg(
{"--context-shift"}, {"--context-shift"},
{"--no-context-shift"}, {"--no-context-shift"},
@ -3396,7 +3396,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
} }
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI})); ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_CLI}));
add_opt(common_arg( add_opt(common_arg(
{"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v]", {"--spec-type"}, "[none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]",
string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n", string_format("type of speculative decoding to use when no draft model is provided (default: %s)\n",
common_speculative_type_to_str(params.speculative.type).c_str()), common_speculative_type_to_str(params.speculative.type).c_str()),
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
@ -3410,6 +3410,8 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K; params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K;
} else if (value == "ngram-map-k4v") { } else if (value == "ngram-map-k4v") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V; params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V;
} else if (value == "ngram-mod") {
params.speculative.type = COMMON_SPECULATIVE_TYPE_NGRAM_MOD;
} else { } else {
throw std::invalid_argument("unknown speculative decoding type without draft model"); throw std::invalid_argument("unknown speculative decoding type without draft model");
} }

View File

@ -771,10 +771,12 @@ static std::string apply(
nlohmann::ordered_json inp = nlohmann::ordered_json{ nlohmann::ordered_json inp = nlohmann::ordered_json{
{"messages", messages_override.has_value() ? *messages_override : inputs.messages}, {"messages", messages_override.has_value() ? *messages_override : inputs.messages},
{"tools", tools_override.has_value() ? *tools_override : inputs.tools},
{"bos_token", tmpl.bos_token()}, {"bos_token", tmpl.bos_token()},
{"eos_token", tmpl.eos_token()}, {"eos_token", tmpl.eos_token()},
}; };
if (tools_override.has_value() || !inputs.tools.empty()) {
inp["tools"] = tools_override.has_value() ? *tools_override : inputs.tools;
}
if (inputs.extra_context.is_object()) { if (inputs.extra_context.is_object()) {
// TODO: do we need to merge, or replacing is fine? // TODO: do we need to merge, or replacing is fine?
for (const auto & [k, v] : inputs.extra_context.items()) { for (const auto & [k, v] : inputs.extra_context.items()) {
@ -790,9 +792,6 @@ static std::string apply(
if (inputs.add_generation_prompt) { if (inputs.add_generation_prompt) {
inp["add_generation_prompt"] = true; inp["add_generation_prompt"] = true;
} }
if (inp["tools"].is_null()) {
inp["tools"] = json::array();
}
jinja::global_from_json(ctx, inp, inputs.mark_input); jinja::global_from_json(ctx, inp, inputs.mark_input);
@ -2219,12 +2218,11 @@ static common_chat_params common_chat_params_init_glm_4_5(const common_chat_temp
static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) { static common_chat_params common_chat_params_init_firefunction_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
LOG_DBG("%s\n", __func__); LOG_DBG("%s\n", __func__);
common_chat_params data; common_chat_params data;
const std::optional<json> tools_override = json();
const std::optional<json> additional_context = json { const std::optional<json> additional_context = json {
{"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")}, {"datetime", format_time(inputs.now, "%b %d %Y %H:%M:%S GMT")},
{"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))}, {"functions", json(inputs.tools.empty() ? "" : inputs.tools.dump(2))},
}; };
data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, tools_override, additional_context); data.prompt = apply(tmpl, inputs, /* messages_override =*/ std::nullopt, /* tools_override =*/ std::nullopt, additional_context);
if (inputs.tools.is_array() && !inputs.tools.empty()) { if (inputs.tools.is_array() && !inputs.tools.empty()) {
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED; data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED;
data.grammar = build_grammar([&](const common_grammar_builder & builder) { data.grammar = build_grammar([&](const common_grammar_builder & builder) {
@ -2573,20 +2571,165 @@ static common_chat_params common_chat_params_init_granite(const common_chat_temp
static common_chat_params common_chat_params_init_solar_open(const common_chat_template & tmpl, const struct templates_params & inputs) { static common_chat_params common_chat_params_init_solar_open(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data; common_chat_params data;
// TODO: Reasoning effort // Copy `reasoning_content` to `reasoning`
json additional_context = {}; auto adjusted_messages = json::array();
for (const auto & msg : inputs.messages) {
if (msg.contains("reasoning_content") && msg.at("reasoning_content").is_string()) {
auto adjusted_message = msg;
adjusted_message["reasoning"] = msg.at("reasoning_content");
adjusted_message.erase("reasoning_content");
adjusted_messages.push_back(adjusted_message);
} else {
adjusted_messages.push_back(msg);
}
}
data.prompt = apply(tmpl, inputs, std::nullopt, std::nullopt, additional_context); auto has_tools = inputs.tools.is_array() && !inputs.tools.empty();
data.format = COMMON_CHAT_FORMAT_SOLAR_OPEN; auto include_grammar = true;
auto prompt = apply(tmpl, inputs, /* messages_override= */ adjusted_messages);
// Check if we need to replace the flush token with end token during inference and without generation prompt.
if (inputs.is_inference && !inputs.add_generation_prompt) {
static constexpr std::string_view return_token = "<|flush|>";
static constexpr std::string_view end_token = "<|end|>";
if (size_t pos = prompt.rfind(return_token); pos != std::string::npos) {
prompt.replace(pos, return_token.length(), end_token);
}
}
data.prompt = prompt;
data.format = COMMON_CHAT_FORMAT_PEG_NATIVE;
data.preserved_tokens = { data.preserved_tokens = {
"<|think|>", "<|think|>",
"<|content|>", "<|content|>",
"<|begin|>", "<|begin|>",
"<|end|>", "<|end|>",
"<|tool_calls|>",
"<|tool_call:begin|>",
"<|tool_call:end|>",
"<|tool_call:name|>",
"<|tool_call:args|>",
}; };
// TODO: Tool calling auto parser = build_chat_peg_native_parser([&](common_chat_peg_native_builder & p) {
auto lit_think = p.atomic(p.literal("<|think|>"));
auto lit_assistant_begin = p.atomic(p.literal("<|begin|>assistant"));
auto lit_content = p.atomic(p.literal("<|content|>"));
auto lit_end = p.atomic(p.literal("<|end|>"));
auto parser_until_end = p.until("<|end|>");
// reasoning <- "<|think|>" (!"<|end|>" .)*
auto parser_reasoning = p.rule("reasoning", lit_think + p.reasoning(parser_until_end));
// content <- "<|content|>" (!"<|end|>" .)*
auto parser_content = p.rule("content", lit_content + p.content(parser_until_end));
// wrap_choice(items) <- item-choice wrapped*
// item-choice <- items[0] / ... / items[n]
// wrapped <- "<|end|><|begin|>assistant" item-choice
auto wrap_choice = [&](const std::vector<common_peg_parser> & items) {
auto choice = p.choice(items);
return choice + p.zero_or_more(lit_end + lit_assistant_begin + choice);
};
// wrap_seq(items) <- item[0] "<|end|><|begin|>assistant" item[1] ...
auto wrap_seq = [&](const std::vector<common_peg_parser> & items) {
auto seq = p.sequence();
for (auto i = 0u; i < items.size(); i++) {
if (i == 0) {
seq += items[i];
continue;
}
seq += lit_end + lit_assistant_begin + items[i];
}
return seq;
};
// Response format parser
if (inputs.json_schema.is_object() && !inputs.json_schema.empty()) {
auto parser_response_format = lit_content + p.content(p.schema(p.json(), "response-format", inputs.json_schema));
return p.choice({
wrap_seq({parser_reasoning, parser_response_format}),
wrap_seq({parser_response_format})
});
}
auto lit_tool_call_begin = p.literal("<|tool_call:begin|>");
auto lit_tool_call_name = p.literal("<|tool_call:name|>");
auto lit_tool_call_args = p.literal("<|tool_call:args|>");
auto lit_tool_call_end = p.literal("<|tool_call:end|>");
// Tool call parser
if (has_tools && inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_NONE) {
auto parser_tool_call = p.choice();
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
std::string name = function.at("name");
const auto & schema = function.at("parameters");
// tool(name, schema) <- name "<|tool_call:args|>" schema
parser_tool_call |= p.rule("tool-" + name,
p.atomic(p.tool_name(p.literal(name)) + lit_tool_call_args)
+ p.tool_args(p.schema(p.json(), "tool-" + name + "-schema", schema)));
});
auto min_calls = inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_REQUIRED ? 1 : 0;
auto max_calls = inputs.parallel_tool_calls ? -1 : 1;
// tool-calls <- "<|tool_calls|>" tool-call+
// tool-call <- "<|tool_call:begin|> call-id "<|tool_call:name|>" &([^<]+ "<|tool_call:args|>") tool-choice "<|tool_call:end|>"
// call-id <- [a-zA-Z0-9_-]+
// tool-choice <- tool(t[0].name, t[0].schema) / ... / tool(t[n].name, t[n].schema)
auto parser_tool_calls = p.trigger_rule("tool-calls",
p.atomic(p.literal("<|tool_calls|>"))
+ p.repeat(
p.tool_open(
lit_tool_call_begin
+ p.tool_id(p.chars("[a-zA-Z0-9_-]", 1, -1))
+ lit_tool_call_name
+ p.peek(p.chars("[^<]", 1, -1) + lit_tool_call_args))
+ parser_tool_call
+ p.tool_close(lit_tool_call_end),
/* min = */ 1,
/* max = */ max_calls));
if (min_calls == 1) {
// If required, then try any combination of the reasoning, content, and tool call
return p.choice({
wrap_seq({parser_reasoning, parser_content, parser_tool_calls}),
wrap_seq({parser_reasoning, parser_tool_calls}),
wrap_seq({parser_content, parser_tool_calls}),
wrap_seq({parser_tool_calls})
});
}
return wrap_choice({parser_reasoning, parser_content, parser_tool_calls});
}
// Content only parser
include_grammar = false;
return wrap_choice({parser_reasoning, parser_content});
});
data.parser = parser.save();
if (include_grammar) {
data.grammar_lazy = has_tools && inputs.tool_choice == COMMON_CHAT_TOOL_CHOICE_AUTO;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
auto schema = function.at("parameters");
builder.resolve_refs(schema);
});
parser.build_grammar(builder, data.grammar_lazy);
});
data.grammar_triggers = {
{COMMON_GRAMMAR_TRIGGER_TYPE_WORD, "<|tool_calls|>"}
};
}
return data; return data;
} }
@ -3043,6 +3186,13 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_apriel_1_5(tmpl, params); return common_chat_params_init_apriel_1_5(tmpl, params);
} }
// Solar Open
if (src.find("<|tool_response:begin|>") != std::string::npos &&
src.find("<|tool_response:name|>") != std::string::npos &&
src.find("<|tool_response:result|>") != std::string::npos) {
return common_chat_params_init_solar_open(tmpl, params);
}
// Use generic handler when mixing tools + JSON schema. // Use generic handler when mixing tools + JSON schema.
// TODO: support that mix in handlers below. // TODO: support that mix in handlers below.
if ((params.tools.is_array() && params.json_schema.is_object())) { if ((params.tools.is_array() && params.json_schema.is_object())) {

View File

@ -171,6 +171,7 @@ enum common_speculative_type {
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, // simple self-speculative decoding
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, // self-speculative decoding with n-gram keys only
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, // self-speculative decoding with n-gram keys and 4 m-gram values
COMMON_SPECULATIVE_TYPE_NGRAM_MOD,
COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, // self-speculative decoding with 3-level n-gram cache COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, // self-speculative decoding with 3-level n-gram cache
COMMON_SPECULATIVE_TYPE_COUNT // number of types, unknown type COMMON_SPECULATIVE_TYPE_COUNT // number of types, unknown type
}; };
@ -252,6 +253,8 @@ struct common_params_model {
std::string name = ""; // in format <user>/<model>[:<tag>] (tag is optional) // NOLINT std::string name = ""; // in format <user>/<model>[:<tag>] (tag is optional) // NOLINT
}; };
struct common_ngram_mod;
struct common_params_speculative { struct common_params_speculative {
common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE; // type of speculative decoding common_speculative_type type = COMMON_SPECULATIVE_TYPE_NONE; // type of speculative decoding
@ -269,6 +272,8 @@ struct common_params_speculative {
uint16_t ngram_check_rate = 1; // check rate for ngram lookup uint16_t ngram_check_rate = 1; // check rate for ngram lookup
uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed uint16_t ngram_min_hits = 1; // minimum hits at ngram/mgram lookup for mgram to be proposed
std::shared_ptr<common_ngram_mod> ngram_mod;
std::string lookup_cache_static; // path of static ngram cache file for lookup decoding // NOLINT std::string lookup_cache_static; // path of static ngram cache file for lookup decoding // NOLINT
std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding // NOLINT std::string lookup_cache_dynamic; // path of dynamic ngram cache file for lookup decoding // NOLINT

View File

@ -45,6 +45,8 @@ static float common_ggml_get_float_value(const uint8_t * data,
return v; return v;
} }
#define INDENT " "
template <bool abort> template <bool abort>
void common_debug_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) { void common_debug_print_tensor(uint8_t * data, ggml_type type, const int64_t * ne, const size_t * nb, int64_t n) {
GGML_ASSERT(n > 0); GGML_ASSERT(n > 0);
@ -60,41 +62,41 @@ void common_debug_print_tensor(uint8_t * data, ggml_type type, const int64_t * n
} }
} }
for (int64_t i3 = 0; i3 < ne[3]; i3++) { for (int64_t i3 = 0; i3 < ne[3]; i3++) {
LOG_ERR(" [\n"); LOG(INDENT "[\n");
for (int64_t i2 = 0; i2 < ne[2]; i2++) { for (int64_t i2 = 0; i2 < ne[2]; i2++) {
if (i2 == n && ne[2] > 2 * n) { if (i2 == n && ne[2] > 2 * n) {
LOG_ERR(" ..., \n"); LOG(INDENT INDENT "..., \n");
i2 = ne[2] - n; i2 = ne[2] - n;
} }
LOG_ERR(" [\n"); LOG(INDENT INDENT "[\n");
for (int64_t i1 = 0; i1 < ne[1]; i1++) { for (int64_t i1 = 0; i1 < ne[1]; i1++) {
if (i1 == n && ne[1] > 2 * n) { if (i1 == n && ne[1] > 2 * n) {
LOG_ERR(" ..., \n"); LOG(INDENT INDENT INDENT "..., \n");
i1 = ne[1] - n; i1 = ne[1] - n;
} }
LOG_ERR(" ["); LOG(INDENT INDENT INDENT "[");
for (int64_t i0 = 0; i0 < ne[0]; i0++) { for (int64_t i0 = 0; i0 < ne[0]; i0++) {
if (i0 == n && ne[0] > 2 * n) { if (i0 == n && ne[0] > 2 * n) {
LOG_ERR("..., "); LOG(" ..., ");
i0 = ne[0] - n; i0 = ne[0] - n;
} }
const float v = common_ggml_get_float_value(data, type, nb, i0, i1, i2, i3); const float v = common_ggml_get_float_value(data, type, nb, i0, i1, i2, i3);
LOG_ERR("%12.4f", v); LOG("%12.4f", v);
if (i0 < ne[0] - 1) { if (i0 < ne[0] - 1) {
LOG_ERR(", "); LOG(", ");
} }
} }
LOG_ERR("],\n"); LOG(" ],\n");
} }
LOG_ERR(" ],\n"); LOG(INDENT INDENT "],\n");
} }
LOG_ERR(" ]\n"); LOG(INDENT "]\n");
LOG_ERR(" sum = %f\n", sum); LOG(INDENT "sum = %f\n", sum);
} }
if constexpr (abort) { if constexpr (abort) {
if (std::isnan(sum)) { if (std::isnan(sum)) {
LOG_ERR("encountered NaN - aborting\n"); LOG("encountered NaN - aborting\n");
exit(0); exit(0);
} }
} }
@ -137,9 +139,9 @@ template <bool abort_on_nan> bool common_debug_cb_eval(struct ggml_tensor * t, b
} }
if (matches_filter) { if (matches_filter) {
LOG_ERR("%s: %24s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__, t->name, ggml_type_name(t->type), LOG("%s: %24s = (%s) %10s(%s{%s}, %s}) = {%s}\n", __func__, t->name, ggml_type_name(t->type),
ggml_op_desc(t), src0->name, common_ggml_ne_string(src0).c_str(), src1 ? src1_str : "", ggml_op_desc(t), src0->name, common_ggml_ne_string(src0).c_str(), src1 ? src1_str : "",
common_ggml_ne_string(t).c_str()); common_ggml_ne_string(t).c_str());
} }
const bool is_host = ggml_backend_buffer_is_host(t->buffer); const bool is_host = ggml_backend_buffer_is_host(t->buffer);

View File

@ -144,6 +144,13 @@ value binary_expression::execute_impl(context & ctx) {
return false; return false;
}; };
auto test_is_in = [&]() -> bool {
func_args args(ctx);
args.push_back(left_val);
args.push_back(right_val);
return global_builtins().at("test_is_in")(args)->as_bool();
};
// Handle undefined and null values // Handle undefined and null values
if (is_val<value_undefined>(left_val) || is_val<value_undefined>(right_val)) { if (is_val<value_undefined>(left_val) || is_val<value_undefined>(right_val)) {
if (is_val<value_undefined>(right_val) && (op.value == "in" || op.value == "not in")) { if (is_val<value_undefined>(right_val) && (op.value == "in" || op.value == "not in")) {
@ -223,19 +230,11 @@ value binary_expression::execute_impl(context & ctx) {
return result; return result;
} }
} else if (is_val<value_array>(right_val)) { } else if (is_val<value_array>(right_val)) {
auto & arr = right_val->as_array(); // case: 1 in [0, 1, 2]
bool member = false; bool member = test_is_in();
for (const auto & item : arr) {
if (*left_val == *item) {
member = true;
break;
}
}
if (op.value == "in") { if (op.value == "in") {
JJ_DEBUG("Checking membership: %s in Array is %d", left_val->type().c_str(), member);
return mk_val<value_bool>(member); return mk_val<value_bool>(member);
} else if (op.value == "not in") { } else if (op.value == "not in") {
JJ_DEBUG("Checking non-membership: %s not in Array is %d", left_val->type().c_str(), !member);
return mk_val<value_bool>(!member); return mk_val<value_bool>(!member);
} }
} }
@ -252,22 +251,23 @@ value binary_expression::execute_impl(context & ctx) {
// String membership // String membership
if (is_val<value_string>(left_val) && is_val<value_string>(right_val)) { if (is_val<value_string>(left_val) && is_val<value_string>(right_val)) {
auto left_str = left_val->as_string().str(); // case: "a" in "abc"
auto right_str = right_val->as_string().str(); bool member = test_is_in();
if (op.value == "in") { if (op.value == "in") {
return mk_val<value_bool>(right_str.find(left_str) != std::string::npos); return mk_val<value_bool>(member);
} else if (op.value == "not in") { } else if (op.value == "not in") {
return mk_val<value_bool>(right_str.find(left_str) == std::string::npos); return mk_val<value_bool>(!member);
} }
} }
// Value key in object // Value key in object
if (is_val<value_object>(right_val)) { if (is_val<value_object>(right_val)) {
bool has_key = right_val->has_key(left_val); // case: key in {key: value}
bool member = test_is_in();
if (op.value == "in") { if (op.value == "in") {
return mk_val<value_bool>(has_key); return mk_val<value_bool>(member);
} else if (op.value == "not in") { } else if (op.value == "not in") {
return mk_val<value_bool>(!has_key); return mk_val<value_bool>(!member);
} }
} }

View File

@ -393,6 +393,33 @@ const func_builtins & global_builtins() {
{"test_is_lt", test_compare_fn<value_compare_op::lt>}, {"test_is_lt", test_compare_fn<value_compare_op::lt>},
{"test_is_lessthan", test_compare_fn<value_compare_op::lt>}, {"test_is_lessthan", test_compare_fn<value_compare_op::lt>},
{"test_is_ne", test_compare_fn<value_compare_op::ne>}, {"test_is_ne", test_compare_fn<value_compare_op::ne>},
{"test_is_in", [](const func_args & args) -> value {
args.ensure_count(2);
auto needle = args.get_pos(0);
auto haystack = args.get_pos(1);
if (is_val<value_undefined>(haystack)) {
return mk_val<value_bool>(false);
}
if (is_val<value_array>(haystack)) {
for (const auto & item : haystack->as_array()) {
if (*needle == *item) {
return mk_val<value_bool>(true);
}
}
return mk_val<value_bool>(false);
}
if (is_val<value_string>(haystack)) {
if (!is_val<value_string>(needle)) {
throw raised_exception("'in' test expects args[1] as string when args[0] is string, got args[1] as " + needle->type());
}
return mk_val<value_bool>(
haystack->as_string().str().find(needle->as_string().str()) != std::string::npos);
}
if (is_val<value_object>(haystack)) {
return mk_val<value_bool>(haystack->has_key(needle));
}
throw raised_exception("'in' test expects iterable as first argument, got " + haystack->type());
}},
{"test_is_test", [](const func_args & args) -> value { {"test_is_test", [](const func_args & args) -> value {
args.ensure_vals<value_string>(); args.ensure_vals<value_string>();
auto & builtins = global_builtins(); auto & builtins = global_builtins();
@ -1028,6 +1055,16 @@ const func_builtins & value_none_t::get_builtins() const {
{"safe", [](const func_args &) -> value { {"safe", [](const func_args &) -> value {
return mk_val<value_string>("None"); return mk_val<value_string>("None");
}}, }},
{"strip", [](const func_args &) -> value {
return mk_val<value_string>("None");
}},
{"items", empty_value_fn<value_array>},
{"map", empty_value_fn<value_array>},
{"reject", empty_value_fn<value_array>},
{"rejectattr", empty_value_fn<value_array>},
{"select", empty_value_fn<value_array>},
{"selectattr", empty_value_fn<value_array>},
{"unique", empty_value_fn<value_array>},
}; };
return builtins; return builtins;
} }

View File

@ -12,6 +12,7 @@
#include <set> #include <set>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
namespace jinja { namespace jinja {

View File

@ -7,6 +7,33 @@
#include <cstdio> #include <cstdio>
#include <sstream> #include <sstream>
// prime number used for LCG hash function (32 bit), it is near (sqrt(5) - 1)/2 * 2^32.
#define LCG_FACTOR 2654435761UL
// Compute the LCG hash of a n-gram of size len at offset start.
static uint32_t common_ngram_map_hash(const llama_tokens & tokens, size_t start, size_t len) {
uint32_t hash = 0;
for (size_t i = 0; i < len; ++i) {
hash = hash * LCG_FACTOR + tokens[start + i];
}
return hash;
}
// Print the values of a sublist of `llama_tokens & inp` to a string in the form [v0, v1, v2, ...].
static std::string common_tokens_to_str(const llama_tokens & inp, size_t start, size_t length) {
std::ostringstream oss;
oss << '[';
for (size_t i = 0; i < length; ++i) {
if (i > 0) {
oss << ", ";
}
oss << inp[start + i];
}
oss << ']';
return oss.str();
}
// n-gram simple // n-gram simple
// //
@ -20,21 +47,15 @@
* @return Vector of draft tokens, empty if no matching pattern is found * @return Vector of draft tokens, empty if no matching pattern is found
*/ */
llama_tokens common_ngram_simple_draft( llama_tokens common_ngram_simple_draft(
common_ngram_simple_state & state, const common_ngram_simple_config & config,
const llama_tokens & tokens, llama_token sampled) { const llama_tokens & tokens, llama_token sampled) {
// Simple implementation of self-speculative decoding without a draft model. // Simple implementation of self-speculative decoding without a draft model.
// //
const size_t cur_len = tokens.size(); const size_t cur_len = tokens.size();
// Only check every check_rate tokens to save compute
// i.e., perform check if (cur_len - idx_last_check) >= check_rate
if (state.idx_last_check + state.config.check_rate > cur_len) {
llama_tokens draft_tokens;
return draft_tokens;
}
size_t n_draft_min = state.config.size_ngram; // size of n-gram to lookup in token history const size_t n_draft_min = config.size_ngram; // size of n-gram to lookup in token history
size_t n_draft_max = state.config.size_mgram; // the m-gram following the found n-gram is used for draft const size_t n_draft_max = config.size_mgram; // the m-gram following the found n-gram is used for draft
// vector for tokens we want to verify. // vector for tokens we want to verify.
// return empty vector if there is no match. // return empty vector if there is no match.
@ -53,9 +74,6 @@ llama_tokens common_ngram_simple_draft(
} }
pattern.push_back(sampled); // add the last token to the pattern pattern.push_back(sampled); // add the last token to the pattern
// We do a search in the token history.
state.idx_last_check = cur_len;
size_t match_pos = 0; // we ignore position 0, position 0 == no match size_t match_pos = 0; // we ignore position 0, position 0 == no match
// search backwards, but skip the current match (we are currently there) // search backwards, but skip the current match (we are currently there)
for (size_t j = cur_len - n_draft_min - 1; j > 0; --j) { for (size_t j = cur_len - n_draft_min - 1; j > 0; --j) {
@ -100,7 +118,99 @@ llama_tokens common_ngram_simple_draft(
// maximum number of counted values of a ngram map value. // maximum number of counted values of a ngram map value.
#define COMMON_NGRAM_MAX_VALUE_COUNT 16380 #define COMMON_NGRAM_MAX_VALUE_COUNT 16380
static std::string common_tokens_to_str(const llama_tokens & inp, size_t start, size_t length); void common_ngram_map_begin(
common_ngram_map & map, const llama_tokens & tokens) {
size_t size_begin = tokens.size();
LOG_DBG("%s: begin, idx_last_draft=%zu, new begin=%zu, #keys=%zu\n", __func__,
map.idx_last_check, size_begin, map.keys.size());
size_t count_map_entries_upd = 0;
if (!map.key_map.empty() && size_begin < map.idx_last_check) {
if (map.show_key_map_stats) {
// Print statistics of hash map map_key.
size_t count_nonzero = 0;
uint32_t min_idx = UINT32_MAX;
uint32_t max_idx = 0;
for (size_t i = 0; i < map.key_map.size(); ++i) {
uint32_t key_idx = map.key_map[i];
if (key_idx != 0) {
++count_nonzero;
if (key_idx < min_idx) min_idx = key_idx;
if (key_idx > max_idx) max_idx = key_idx;
}
}
if (count_nonzero == 0) {
min_idx = 0;
}
LOG_INF("%s: key_map stats: entries=%zu, min_idx=%u, max_idx=%u, key_map_last_idx=%u\n",
__func__, count_nonzero, min_idx, max_idx, map.key_map_last_idx);
}
// Update the map from hash to key index (clear outdated entries).
for (size_t i = 0; i < map.key_map.size(); ++i) {
uint32_t key_idx = map.key_map[i];
if (key_idx >= map.size_last_begin) {
map.key_map[i] = 0;
count_map_entries_upd++;
}
}
map.key_map_last_idx = (map.size_last_begin > 0) ? map.size_last_begin - 1 : 0;
}
if (size_begin < map.idx_last_check && !map.keys.empty()) {
// The next token generation will start at index size_begin.
// The tokens between map.size_last_begin and size_begin are no longer valid.
//
// Refresh map: Remove all entries with index >= map.size_last_begin.
size_t count_keys = map.keys.size();
size_t count_keys_del = 0;
size_t count_values_del = 0;
for (int32_t i = map.keys.size() - 1; i >= 0; --i) {
common_ngram_map_key & key = map.keys[i];
if (key.key_idx >= map.size_last_begin) {
// Delete the key.
LOG_DBG("%s: delete key %d at index %zu (>= size_last_begin=%zu)\n", __func__, i, key.key_idx, map.size_last_begin);
map.keys.erase(map.keys.begin() + i);
count_keys_del++;
continue;
}
if (map.key_only) {
continue;
}
// Check the indices of the values.
for (int16_t j = COMMON_NGRAM_MAX_VALUES - 1; j >= 0; --j) {
common_ngram_map_value & value = key.values[j];
if (value.value_idx >= map.size_last_begin) {
// Delete the value.
count_values_del++;
// Move all values after this value to the left.
for (uint16_t k = j; k < COMMON_NGRAM_MAX_VALUES - 1; ++k) {
key.values[k] = key.values[k + 1];
}
// Clear the last value.
key.values[COMMON_NGRAM_MAX_VALUES - 1].value_idx = 0;
key.values[COMMON_NGRAM_MAX_VALUES - 1].value_num = 0;
}
}
if (key.values[0].value_idx == 0) {
// No values left, delete the key.
LOG_DBG("%s: delete key %d at index %zu (no values left)\n", __func__, i, key.key_idx);
map.keys.erase(map.keys.begin() + i);
count_keys_del++;
}
}
LOG_INF("%s: refresh map: idx_last_draft=%zu, new begin=%zu, #keys_checked=%zu, #keys_del=%zu, #values_del=%zu, #hashes_upd=%zu\n", __func__,
map.idx_last_check, size_begin,
count_keys, count_keys_del, count_values_del, count_map_entries_upd);
}
map.idx_last_check = (map.size_last_begin > 0) ? map.size_last_begin - 1 : 0;
map.size_last_begin = size_begin;
}
void common_ngram_map_draft(common_ngram_map & map, void common_ngram_map_draft(common_ngram_map & map,
const llama_tokens & inp, llama_token sampled, const llama_tokens & inp, llama_token sampled,
@ -116,6 +226,10 @@ void common_ngram_map_draft(common_ngram_map & map,
if (cur_len < static_cast<size_t>(2 * n + m)) { if (cur_len < static_cast<size_t>(2 * n + m)) {
return; return;
} }
if (cur_len >= static_cast<size_t>(UINT32_MAX)) {
// key_map uses uint32_t instead of size_t.
GGML_ABORT("%s: cur_len exceeds UINT32_MAX: %zu", __func__, cur_len);
}
// Only check every check_rate tokens to save compute // Only check every check_rate tokens to save compute
// i.e., perform check if (cur_len - idx_last_check) >= check_rate // i.e., perform check if (cur_len - idx_last_check) >= check_rate
@ -134,24 +248,92 @@ void common_ngram_map_draft(common_ngram_map & map,
// search for the key in the map // search for the key in the map
size_t match_pos = 0; size_t match_pos = 0;
for (size_t j = cur_len - n - m - 1; j > 0; --j) { if (map.size_last_begin > cur_len) {
bool match = true; GGML_ABORT("%s: map.size_last_begin > cur_len: %zu > %zu", __func__, map.size_last_begin, cur_len);
for (size_t k = 0; k < n; ++k) { }
if (inp[j + k] != key_tokens[k]) { if (!map.key_map.empty()) {
match = false; // Search for the key in the map key_map from hash of ngrams to index of ngram.
break; uint32_t idx_hash = (common_ngram_map_hash(key_tokens, 0, n) % map.key_map.size());
uint32_t idx_key = map.key_map[idx_hash];
if (idx_key != 0 && idx_key < cur_len - n - m - 1) {
// Check if the key matches the key at idx_key (because of possible collisions).
bool match = true;
for (size_t k = 0; k < n; ++k) {
if (inp[idx_key + k] != key_tokens[k]) {
match = false;
break;
}
}
LOG_DBG("%s: key hash %x -> idx_key %d: match %d\n", __func__, idx_hash, idx_key, match ? 1 : 0);
if (match) {
match_pos = idx_key;
} }
} }
if (match) { }
match_pos = j; if (match_pos == 0 && map.size_last_begin > (size_t) (n + m + 1)) {
break; // Search for the key in [1, map.size_last_begin - n - m -1], descending.
for (size_t j = map.size_last_begin - n - m - 1; j > map.key_map_last_idx; --j) {
// Check if the key matches the key.
bool match = true;
for (size_t k = 0; k < n; ++k) {
if (inp[j + k] != key_tokens[k]) {
match = false;
break;
}
}
if (match) {
match_pos = j;
break;
}
}
}
if (match_pos == 0) {
// In case of a reasoning chat, the part after size_last_begin may be deleted/reordered later.
//
// Search in [size_last_begin, cur_len - n - m - 1], descending.
for (size_t j = cur_len - n - m - 1; j > map.size_last_begin && j > map.key_map_last_idx; --j) {
bool match = true;
for (size_t k = 0; k < n; ++k) {
if (inp[j + k] != key_tokens[k]) {
match = false;
break;
}
}
if (match) {
match_pos = j;
break;
}
} }
} }
if (match_pos > 0) { if (match_pos > 0) {
LOG_INF("%s: cur_len = %zu, n = %d, m = %d, sz_tkns = %zu, sampled = %d, match_pos = %zu\n", __func__, LOG_DBG("%s: cur_len = %zu, n = %d, m = %d, sz_tkns = %zu, sampled = %d, match_pos = %zu\n", __func__,
cur_len, n, m, key_tokens.size(), sampled, match_pos); cur_len, n, m, key_tokens.size(), sampled, match_pos);
} }
if (!map.key_map.empty()) {
// Add hashes of new ngrams in key_map.
//
// Use the same order as above.
if (map.size_last_begin > (size_t) (n + m + 1)) {
for (size_t j = map.size_last_begin - n - m - 1; j > map.key_map_last_idx; --j) {
// compute hash and store index of ngram at idx j in the map.
uint32_t idx_hash = (common_ngram_map_hash(inp, j, n) % map.key_map.size());
if (map.key_map[idx_hash] == 0) {
map.key_map[idx_hash] = j; // collisions may occur
}
}
}
for (size_t j = cur_len - n - m - 1; j > map.size_last_begin && j > map.key_map_last_idx; --j) {
// compute hash and store index of ngram at idx j in the map.
uint32_t idx_hash = (common_ngram_map_hash(inp, j, n) % map.key_map.size());
if (map.key_map[idx_hash] == 0) {
map.key_map[idx_hash] = j;
}
}
map.key_map_last_idx = std::max(static_cast<uint32_t>(cur_len - n - m - 1), map.key_map_last_idx);
}
if (match_pos == 0) { if (match_pos == 0) {
return; return;
} }
@ -202,8 +384,8 @@ void common_ngram_map_draft(common_ngram_map & map,
draft.push_back(inp[match_pos + n + i]); draft.push_back(inp[match_pos + n + i]);
} }
LOG_INF("%s: key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__, LOG_DBG("%s: key_idx = %zu, key_offset = %zu, key_num = %d, draft.size = %zu\n", __func__,
key_offset, curr_key.key_num, draft.size()); curr_key.key_idx, key_offset, curr_key.key_num, draft.size());
map.last_draft_created = false; map.last_draft_created = false;
map.last_draft_key_idx = key_offset; map.last_draft_key_idx = key_offset;
@ -305,7 +487,7 @@ void common_ngram_map_draft(common_ngram_map & map,
} }
} }
if (sum_occur > 0 && max_occur < 3 * sum_occur) { if (sum_occur > 0 && max_occur < 2 * sum_occur) {
// The most frequent value is not much more frequent than the other values. // The most frequent value is not much more frequent than the other values.
// We do not use the draft. // We do not use the draft.
return; return;
@ -347,21 +529,3 @@ void common_ngram_map_accept(common_ngram_map & map, uint16_t n_accepted) {
n_accepted, curr_value.n_accepted); n_accepted, curr_value.n_accepted);
curr_value.n_accepted = n_accepted; curr_value.n_accepted = n_accepted;
} }
// Helper functions.
//
// Print the values of a sublist of `llama_tokens & inp` to a string in the form [v0, v1, v2, ...].
std::string common_tokens_to_str(const llama_tokens & inp, size_t start, size_t length) {
std::ostringstream oss;
oss << '[';
for (size_t i = 0; i < length; ++i) {
if (i > 0) {
oss << ", ";
}
oss << inp[start + i];
}
oss << ']';
return oss.str();
}

View File

@ -9,8 +9,11 @@
// 2. ngram_map: lookup of n-grams followed by m-grams in token history using a map. // 2. ngram_map: lookup of n-grams followed by m-grams in token history using a map.
// The map is a vector of key n-grams, and for each key n-gram there is a list of value m-grams. // The map is a vector of key n-grams, and for each key n-gram there is a list of value m-grams.
// //
// ref: https://github.com/ggml-org/llama.cpp/pull/18471
//
#include "llama.h" #include "llama.h"
#include "common.h"
#include <vector> #include <vector>
@ -24,23 +27,9 @@ struct common_ngram_simple_config {
uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token
}; };
// current state (and config) of n-gram simple.
struct common_ngram_simple_state {
common_ngram_simple_config config;
size_t idx_last_check = 0; // index of last check in context history (mutable)
common_ngram_simple_state(const common_ngram_simple_config & config)
: config(config) {}
};
// Searches for a n-gram in the history and checks whether a draft sequence should be generated. // Searches for a n-gram in the history and checks whether a draft sequence should be generated.
// state: the ngram simple state to search in.
// inp: the tokens generated so far.
// sampled: the token that was just sampled.
// draft: vector to store the draft tokens, initially empty.
llama_tokens common_ngram_simple_draft( llama_tokens common_ngram_simple_draft(
common_ngram_simple_state & state, const common_ngram_simple_config & config,
const llama_tokens & tokens, llama_token sampled); const llama_tokens & tokens, llama_token sampled);
@ -50,10 +39,13 @@ llama_tokens common_ngram_simple_draft(
// maximum number of m-gram values stored for each key n-gram. // maximum number of m-gram values stored for each key n-gram.
#define COMMON_NGRAM_MAX_VALUES 4 #define COMMON_NGRAM_MAX_VALUES 4
// number of entries in the (optional, size 0 to disable) map from ngram-hash to ngram-index.
#define COMMON_NGRAM_HASH_MAP_SIZE 262144
// statistics of a m-gram after a known n-gram // statistics of a m-gram after a known n-gram
struct common_ngram_map_value { struct common_ngram_map_value {
size_t value_idx = 0; // index of value m-gram in token-history (0 if unused) size_t value_idx = 0; // index of value m-gram in token-history (0 if unused)
uint16_t value_num = 0; // number of occurences of this value m-gram after the key n-gram (0 in an unused values-slot) uint16_t value_num = 0; // number of occurences of this value m-gram after the key n-gram (0 in an unused values-slot)
int16_t n_accepted = -1; // number of accepted tokens at last draft (-1 if unused) int16_t n_accepted = -1; // number of accepted tokens at last draft (-1 if unused)
}; };
@ -73,23 +65,43 @@ struct common_ngram_map {
bool key_only; // true if only key n-grams are used, no values. bool key_only; // true if only key n-grams are used, no values.
// first draft: vector only, no map.
std::vector<common_ngram_map_key> keys; // key n-grams which occur several times in token-history std::vector<common_ngram_map_key> keys; // key n-grams which occur several times in token-history
uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token uint16_t check_rate; // check for speculative decoding without draft model for each check_rate token
uint16_t min_hits; // minimum number of key hits to consider a draft uint16_t min_hits; // minimum number of key hits to consider a draft
bool show_key_map_stats = false; // true, if statitics of the key_map should be printed.
common_ngram_map(uint16_t sz_key, uint16_t sz_value, bool only_keys, common_ngram_map(uint16_t sz_key, uint16_t sz_value, bool only_keys,
uint16_t check_rate, uint16_t min_hits) uint16_t check_rate, uint16_t min_hits)
: size_key(sz_key), size_value(sz_value), key_only(only_keys), : size_key(sz_key), size_value(sz_value), key_only(only_keys),
check_rate(check_rate), min_hits(min_hits) {} check_rate(check_rate), min_hits(min_hits) {
key_map.resize(COMMON_NGRAM_HASH_MAP_SIZE); // 2^18 hash entries, 0 entries if key_map shouldn't be used
}
// In reasoning chats the previous reasoning block will be removed from context history.
// A rebuild of the ngram map is needed after that.
size_t size_last_begin = 0; // number of tokens at previous start of generation
bool last_draft_created = false; // true if a draft was created at last call. bool last_draft_created = false; // true if a draft was created at last call.
size_t last_draft_key_idx = 0; // index of last key used for draft generation. size_t last_draft_key_idx = 0; // index of last key used for draft generation (0 = no draft)
uint16_t last_draft_value_idx = 0; // index of last value used for draft generation. uint16_t last_draft_value_idx = 0; // index of last value used for draft generation.
size_t idx_last_check = 0; // index of last check in context history size_t idx_last_check = 0; // index of last check in context history
// optional map "hash to ngram-index" for faster lookup of n-grams. map is empty if unused.
//
// uint32_t instead of size_t (size of current histories is << UINT32_MAX)
std::vector<uint32_t> key_map; // key_map[hash] = index of ngram in context window
uint32_t key_map_last_idx = 0; // index of the last ngram added to key_map
}; };
// Initialize the n-gram map with the given token history.
// map: the ngram map to initialize.
// tokens: the token history to base the map on.
void common_ngram_map_begin(
common_ngram_map & map,
const llama_tokens & tokens);
// Searches for the n-gram in the history and checks whether a draft sequence should be generated. // Searches for the n-gram in the history and checks whether a draft sequence should be generated.
// map: the ngram map to search in. // map: the ngram map to search in.

60
common/ngram-mod.cpp Normal file
View File

@ -0,0 +1,60 @@
#include "ngram-mod.h"
//
// common_ngram_mod
//
common_ngram_mod::common_ngram_mod(uint16_t n, size_t size) : n(n), used(0) {
entries.resize(size);
reset();
}
size_t common_ngram_mod::idx(const entry_t * tokens) const {
size_t res = 0;
for (size_t i = 0; i < n; ++i) {
res = res*6364136223846793005ULL + tokens[i];
}
res = res % entries.size();
return res;
}
void common_ngram_mod::add(const entry_t * tokens) {
const size_t i = idx(tokens);
if (entries[i] == EMPTY) {
used++;
}
entries[i] = tokens[n];
}
common_ngram_mod::entry_t common_ngram_mod::get(const entry_t * tokens) const {
const size_t i = idx(tokens);
return entries[i];
}
void common_ngram_mod::reset() {
std::fill(entries.begin(), entries.end(), EMPTY);
used = 0;
}
size_t common_ngram_mod::get_n() const {
return n;
}
size_t common_ngram_mod::get_used() const {
return used;
}
size_t common_ngram_mod::size() const {
return entries.size();
}
size_t common_ngram_mod::size_bytes() const {
return entries.size() * sizeof(entries[0]);
}

38
common/ngram-mod.h Normal file
View File

@ -0,0 +1,38 @@
#pragma once
#include <cstdint>
#include <vector>
#include <cstddef>
//
// common_ngram_mod
// ref: https://github.com/ggml-org/llama.cpp/pull/19164
//
// basic n-gram hasher
struct common_ngram_mod {
using entry_t = int32_t;
static constexpr entry_t EMPTY = -1;
common_ngram_mod(uint16_t n, size_t size);
size_t idx(const entry_t * tokens) const;
void add(const entry_t * tokens);
entry_t get(const entry_t * tokens) const; // return -1 if not found
void reset();
size_t get_n() const;
size_t get_used() const;
size_t size() const;
size_t size_bytes() const;
private:
size_t n; // ngram size to hash
size_t used;
std::vector<entry_t> entries;
};

View File

@ -6,6 +6,7 @@
#include "log.h" #include "log.h"
#include "ngram-cache.h" #include "ngram-cache.h"
#include "ngram-map.h" #include "ngram-map.h"
#include "ngram-mod.h"
#include "sampling.h" #include "sampling.h"
#include <algorithm> #include <algorithm>
@ -23,6 +24,7 @@ const std::vector<enum common_speculative_type> common_speculative_types = {
COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE, COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE,
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K, COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K,
COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V,
COMMON_SPECULATIVE_TYPE_NGRAM_MOD,
COMMON_SPECULATIVE_TYPE_NGRAM_CACHE COMMON_SPECULATIVE_TYPE_NGRAM_CACHE
}; };
@ -33,6 +35,7 @@ const std::map<std::string, enum common_speculative_type> common_speculative_typ
{"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE}, {"ngram_simple", COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE},
{"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K}, {"ngram_map_k", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K},
{"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V}, {"ngram_map_k4v", COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V},
{"ngram_mod", COMMON_SPECULATIVE_TYPE_NGRAM_MOD},
{"ngram_cache", COMMON_SPECULATIVE_TYPE_NGRAM_CACHE} {"ngram_cache", COMMON_SPECULATIVE_TYPE_NGRAM_CACHE}
}; };
@ -110,6 +113,8 @@ static bool common_speculative_are_compatible(
struct common_speculative_state { struct common_speculative_state {
const enum common_speculative_type type; const enum common_speculative_type type;
// TODO: rename to n_call_draft, n_gen_drafts, n_acc_drafts, n_gen_tokens, n_acc_tokens
// TODO: add n_call_begin, n_call_accept
size_t drafts_call_count = 0; // number of times this implementation was called. size_t drafts_call_count = 0; // number of times this implementation was called.
size_t drafts_generated_count = 0; // number of times a draft or part was generated by this implementation. size_t drafts_generated_count = 0; // number of times a draft or part was generated by this implementation.
size_t drafts_accepted_count = 0; // number of times a draft or part was accepted by the target model. size_t drafts_accepted_count = 0; // number of times a draft or part was accepted by the target model.
@ -119,7 +124,9 @@ struct common_speculative_state {
// TODO: track performance of most recent calls // TODO: track performance of most recent calls
const bool gen_perf = true; // whether to generate performance stats. const bool gen_perf = true; // whether to generate performance stats.
int64_t gen_duration_us = 0; // total time spent in this implementation in microseconds. int64_t t_begin_us = 0; // total time spent in refresh of this implementation in microseconds.
int64_t t_draft_us = 0; // total time spent in generating drafts in this implementation in microseconds.
int64_t t_accept_us = 0; // total time spent in accumulation of this implementation in microseconds.
common_speculative_state(enum common_speculative_type type) : type(type) {} common_speculative_state(enum common_speculative_type type) : type(type) {}
@ -456,12 +463,14 @@ struct common_speculative_state_eagle3 : public common_speculative_state {
// state of self-speculation (simple implementation, not ngram-map) // state of self-speculation (simple implementation, not ngram-map)
struct common_speculative_state_ngram_simple : public common_speculative_state { struct common_speculative_state_ngram_simple : public common_speculative_state {
common_ngram_simple_state state; common_ngram_simple_config config;
uint16_t check_id = 0; // used to control the frequency of generating drafts
common_speculative_state_ngram_simple( common_speculative_state_ngram_simple(
enum common_speculative_type type, enum common_speculative_type type,
common_ngram_simple_state state) common_ngram_simple_config config)
: common_speculative_state(type), state(state) {} : common_speculative_state(type), config(config) {}
void begin(const llama_tokens & prompt) override { void begin(const llama_tokens & prompt) override {
GGML_UNUSED(prompt); GGML_UNUSED(prompt);
@ -472,7 +481,13 @@ struct common_speculative_state_ngram_simple : public common_speculative_state {
const llama_tokens & prompt_tgt, const llama_tokens & prompt_tgt,
llama_token id_last, llama_token id_last,
llama_tokens & result) override { llama_tokens & result) override {
result = common_ngram_simple_draft(state, prompt_tgt, id_last); ++check_id;
if (check_id < config.check_rate) {
return;
}
check_id = 0;
result = common_ngram_simple_draft(config, prompt_tgt, id_last);
GGML_UNUSED(params); GGML_UNUSED(params);
} }
@ -492,7 +507,7 @@ struct common_speculative_state_ngram_map_k : public common_speculative_state {
: common_speculative_state(type), map(std::move(map)) {} : common_speculative_state(type), map(std::move(map)) {}
void begin(const llama_tokens & prompt) override { void begin(const llama_tokens & prompt) override {
GGML_UNUSED(prompt); common_ngram_map_begin(map, prompt);
} }
void draft( void draft(
@ -509,6 +524,132 @@ struct common_speculative_state_ngram_map_k : public common_speculative_state {
} }
}; };
struct common_speculative_state_ngram_mod : public common_speculative_state {
common_ngram_mod & mod;
// the last position in the prompt that was added to the ngram container
size_t i_last = 0;
// length of the last drafted ngram (number of tokens returned by draft)
size_t n_draft_last = 0;
// consecutive accept rounds with low acceptance fraction (< 0.5)
int n_low = 0;
// enable trace logging if LLAMA_TRACE is set
const bool verbose;
common_speculative_state_ngram_mod(enum common_speculative_type type, common_ngram_mod & mod)
: common_speculative_state(type), mod(mod), verbose(std::getenv("LLAMA_TRACE") != nullptr) {
static_assert(sizeof(llama_token) == sizeof(common_ngram_mod::entry_t));
}
void begin(const llama_tokens & prompt) override {
i_last = 0;
n_draft_last = 0;
const size_t n = mod.get_n();
if (prompt.size() < n) {
return;
}
for (size_t i = 0; i < prompt.size() - n; ++i) {
mod.add(prompt.data() + i);
}
i_last = prompt.size() - n;
const double f = (double)mod.get_used() / (double)mod.size();
LOG_INF("%s: ngram_mod occupancy = %zu/%zu (%.2f)\n", __func__, mod.get_used(), mod.size(), f);
constexpr double f_thold = 0.25;
if (f > f_thold) {
LOG_WRN("%s: ngram_mod occupancy %.2f exceeds threshold (%.2f) - resetting\n", __func__, f, f_thold);
mod.reset();
}
}
void draft(
const common_params_speculative & params,
const llama_tokens & prompt_tgt,
llama_token id_last,
llama_tokens & result) override {
GGML_UNUSED(params);
n_draft_last = 0;
const size_t cur_len = prompt_tgt.size();
if (cur_len < mod.get_n()) {
return;
}
const size_t n = mod.get_n();
// add new ngrams in chunks
if (i_last + 32 < cur_len) {
for (size_t i = i_last; i < cur_len - n; ++i) {
mod.add(prompt_tgt.data() + i);
}
i_last = cur_len - n;
}
result.resize(n + params.n_max);
for (size_t i = 0; i < n - 1; ++i) {
result[i] = prompt_tgt[cur_len - n + 1 + i];
}
result[n - 1] = id_last;
for (int i = 0; i < params.n_max; ++i) {
const llama_token token = mod.get(result.data() + i);
if (token == common_ngram_mod::EMPTY) {
if (i < params.n_min) {
result.clear();
return;
}
result.resize(n + i);
break;
}
result[n + i] = token;
}
// only return the m tokens that were drafted
for (size_t i = 0; n + i < result.size(); ++i) {
result[i] = result[n + i];
}
result.resize(result.size() - n);
// store length of drafted ngram for later acceptance analysis
n_draft_last = result.size();
}
void accept(uint16_t n_accepted) override {
if (verbose) {
LOG_INF("%s: accepted %d tokens from %zu drafted tokens\n", __func__, n_accepted, n_draft_last);
}
// compute acceptance fraction if we have a recorded draft length
if (n_draft_last > 0) {
const double f_acc = (double)n_accepted / (double)n_draft_last;
if (f_acc < 0.5) {
n_low++;
if (n_low >= 3) {
LOG_WRN("%s: low acceptance streak (%d) resetting ngram_mod\n", __func__, n_low);
mod.reset();
n_low = 0;
}
} else {
n_low = 0;
}
}
}
};
struct common_speculative_state_ngram_cache : public common_speculative_state { struct common_speculative_state_ngram_cache : public common_speculative_state {
uint16_t n_draft; uint16_t n_draft;
bool save_dynamic; bool save_dynamic;
@ -650,6 +791,7 @@ std::string common_speculative_type_to_str(enum common_speculative_type type) {
case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple"; case COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE: return "ngram_simple";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k"; case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K: return "ngram_map_k";
case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v"; case COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V: return "ngram_map_k4v";
case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: return "ngram_mod";
case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: return "ngram_cache"; case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: return "ngram_cache";
default: return "unknown"; default: return "unknown";
} }
@ -666,8 +808,8 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
// initialization of the speculative decoding system // initialization of the speculative decoding system
// //
common_speculative * common_speculative_init( common_speculative * common_speculative_init(
const common_params_speculative & params, common_params_speculative & params,
llama_context * ctx_tgt) { llama_context * ctx_tgt) {
llama_context * ctx_dft = nullptr; llama_context * ctx_dft = nullptr;
if (params.model_dft) { if (params.model_dft) {
ctx_dft = llama_init_from_model(params.model_dft, params.cparams_dft); ctx_dft = llama_init_from_model(params.model_dft, params.cparams_dft);
@ -687,6 +829,7 @@ common_speculative * common_speculative_init(
bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE); bool has_ngram_simple = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_SIMPLE);
bool has_ngram_map_k = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K); bool has_ngram_map_k = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K);
bool has_ngram_map_k4v = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V); bool has_ngram_map_k4v = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V);
bool has_ngram_mod = (params.type == COMMON_SPECULATIVE_TYPE_NGRAM_MOD);
// In a more complex implementation we could use the same implementation but with different parameters. // In a more complex implementation we could use the same implementation but with different parameters.
// This was initially used in PR-18471 but removed to simplify the code. // This was initially used in PR-18471 but removed to simplify the code.
@ -701,6 +844,22 @@ common_speculative * common_speculative_init(
// This implementation can guess tokens with high acceptance rate but is more expensive. // This implementation can guess tokens with high acceptance rate but is more expensive.
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, params)); configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MAP_K4V, params));
} }
if (has_ngram_mod) {
// shared instance for all speculative decoding contexts
if (!params.ngram_mod) {
params.ngram_mod = std::make_shared<common_ngram_mod>(params.ngram_size_n, 4*1024*1024);
LOG_INF("%s: initialized ngram_mod with n=%d, size=%zu (%.3f MB)\n", __func__,
params.ngram_size_n, params.ngram_mod->size(),
(float)(params.ngram_mod->size_bytes())/1024/1024);
if (params.ngram_size_n < 16) {
LOG_WRN("%s: ngram_mod n=%d is too small - poor quality is possible, see: https://github.com/ggml-org/llama.cpp/pull/19164\n", __func__, params.ngram_size_n);
}
}
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_MOD, params));
}
if (has_ngram_cache) { if (has_ngram_cache) {
configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params)); configs.push_back(common_speculative_config(COMMON_SPECULATIVE_TYPE_NGRAM_CACHE, params));
} }
@ -738,14 +897,14 @@ common_speculative * common_speculative_init(
uint16_t mgram_size_value = ngram_map.size_value; uint16_t mgram_size_value = ngram_map.size_value;
uint16_t check_rate = ngram_map.check_rate; uint16_t check_rate = ngram_map.check_rate;
auto config_simple = common_ngram_simple_config{ auto config_simple = common_ngram_simple_config {
/* .size_ngram = */ ngram_size_key, /* .size_ngram = */ ngram_size_key,
/* .size_mgram = */ mgram_size_value, /* .size_mgram = */ mgram_size_value,
/* .check_rate = */ check_rate /* .check_rate = */ check_rate
}; };
auto state = std::make_unique<common_speculative_state_ngram_simple>( auto state = std::make_unique<common_speculative_state_ngram_simple>(
/* .type = */ config.type, /* .type = */ config.type,
/* .state = */ common_ngram_simple_state(config_simple) /* .state = */ config_simple
); );
impls.push_back(std::move(state)); impls.push_back(std::move(state));
break; break;
@ -758,6 +917,11 @@ common_speculative * common_speculative_init(
)); ));
break; break;
} }
case COMMON_SPECULATIVE_TYPE_NGRAM_MOD: {
GGML_ASSERT(config.params.ngram_mod);
impls.push_back(std::make_unique<common_speculative_state_ngram_mod>(config.type, *config.params.ngram_mod));
break;
}
case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: { case COMMON_SPECULATIVE_TYPE_NGRAM_CACHE: {
auto state = create_state_ngram_cache( auto state = create_state_ngram_cache(
params.lookup_cache_static, params.lookup_cache_dynamic, config); params.lookup_cache_static, params.lookup_cache_dynamic, config);
@ -795,6 +959,7 @@ void common_speculative_begin(common_speculative * spec, const llama_tokens & pr
} }
for (auto & impl : spec->impls) { for (auto & impl : spec->impls) {
common_time_meas tm(impl->t_begin_us, !impl->gen_perf);
impl->begin(prompt); impl->begin(prompt);
} }
} }
@ -810,20 +975,14 @@ llama_tokens common_speculative_draft(
for (auto & impl : spec->impls) { for (auto & impl : spec->impls) {
{ {
const int64_t t_start_us = impl->gen_perf ? ggml_time_us() : 0; common_time_meas tm(impl->t_draft_us, !impl->gen_perf);
impl->draft(params, prompt_tgt, id_last, result); impl->draft(params, prompt_tgt, id_last, result);
const int64_t t_now_us = impl->gen_perf ? ggml_time_us() : 0;
impl->drafts_call_count++; impl->drafts_call_count++;
impl->gen_duration_us += t_now_us - t_start_us; // accumulate duration for this implementation
} }
if (!result.empty()) { if (!result.empty()) {
LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__, LOG_DBG("%s: called impl %s, hist size = %zu, call_count = %zu, gen = %zu\n", __func__,
common_speculative_type_to_str(impl.get()->type).c_str(), common_speculative_type_to_str(impl.get()->type).c_str(), prompt_tgt.size(),
prompt_tgt.size(),
impl.get()->drafts_call_count, result.size()); impl.get()->drafts_call_count, result.size());
spec->curr_impl = impl.get(); // set current implementation for stats spec->curr_impl = impl.get(); // set current implementation for stats
@ -846,12 +1005,15 @@ void common_speculative_accept(common_speculative * spec, uint16_t n_accepted) {
GGML_ASSERT(impl); GGML_ASSERT(impl);
if (n_accepted > 0) { {
impl->drafts_accepted_count++; common_time_meas tm(impl->t_accept_us, !impl->gen_perf);
impl->drafts_accepted_tokens += n_accepted; if (n_accepted > 0) {
} impl->drafts_accepted_count++;
impl->drafts_accepted_tokens += n_accepted;
}
impl->accept(n_accepted); impl->accept(n_accepted);
}
} }
void common_speculative_print_stats(const common_speculative * spec) { void common_speculative_print_stats(const common_speculative * spec) {
@ -863,8 +1025,10 @@ void common_speculative_print_stats(const common_speculative * spec) {
std::string str_perf; std::string str_perf;
if (impl->gen_perf) { if (impl->gen_perf) {
std::ostringstream oss; std::ostringstream oss;
oss << std::fixed << std::setprecision(3) << impl->gen_duration_us / 1000.0; oss << std::fixed << std::setprecision(3) << impl->t_begin_us / 1000.0 << ", ";
str_perf = ", dur = " + oss.str() + " ms"; oss << std::fixed << std::setprecision(3) << impl->t_draft_us / 1000.0 << ", ";
oss << std::fixed << std::setprecision(3) << impl->t_accept_us / 1000.0;
str_perf = ", dur(b,g,a) = " + oss.str() + " ms";
} else { } else {
str_perf = ""; str_perf = "";
} }

View File

@ -15,8 +15,8 @@ enum common_speculative_type common_speculative_type_from_name(const std::string
std::string common_speculative_type_to_str(enum common_speculative_type type); std::string common_speculative_type_to_str(enum common_speculative_type type);
common_speculative * common_speculative_init( common_speculative * common_speculative_init(
const common_params_speculative & params, common_params_speculative & params,
llama_context * ctx_tgt); llama_context * ctx_tgt);
void common_speculative_free(common_speculative * spec); void common_speculative_free(common_speculative * spec);

View File

@ -8806,6 +8806,7 @@ class GraniteMoeModel(GraniteModel):
gate, up = data_torch.split(ffn_dim, dim=-2) gate, up = data_torch.split(ffn_dim, dim=-2)
yield from ModelBase.modify_tensors(self, gate, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), bid) yield from ModelBase.modify_tensors(self, gate, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_GATE_EXP, bid), bid)
yield from ModelBase.modify_tensors(self, up, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), bid) yield from ModelBase.modify_tensors(self, up, self.format_tensor_name(gguf.MODEL_TENSOR.FFN_UP_EXP, bid), bid)
return
has_experts = bool(self.hparams.get('num_local_experts')) has_experts = bool(self.hparams.get('num_local_experts'))

View File

@ -22,12 +22,11 @@
- **DPCPP** *(Data Parallel C++)*: The primary oneAPI SYCL implementation, which includes the icpx/icx Compilers. - **DPCPP** *(Data Parallel C++)*: The primary oneAPI SYCL implementation, which includes the icpx/icx Compilers.
- **oneAPI Libraries**: A set of highly optimized libraries targeting multiple domains *(e.g. Intel oneMKL, oneMath and oneDNN)*. - **oneAPI Libraries**: A set of highly optimized libraries targeting multiple domains *(e.g. Intel oneMKL, oneMath and oneDNN)*.
- **oneAPI LevelZero**: A high performance low level interface for fine-grained control over Intel iGPUs and dGPUs. - **oneAPI LevelZero**: A high performance low level interface for fine-grained control over Intel iGPUs and dGPUs.
- **Nvidia & AMD Plugins**: These are plugins extending oneAPI's DPCPP support to SYCL on Nvidia and AMD GPU targets.
### Llama.cpp + SYCL ### Llama.cpp + SYCL
The llama.cpp SYCL backend is primarily designed for **Intel GPUs**. The llama.cpp SYCL backend is primarily designed for **Intel GPUs**.
SYCL cross-platform capabilities enable support for Nvidia GPUs as well, with limited support for AMD. SYCL cross-platform capabilities enable support for other vendor GPUs as well.
## Recommended Release ## Recommended Release
@ -35,13 +34,16 @@ The following releases are verified and recommended:
|Commit ID|Tag|Release|Verified Platform| Update date| |Commit ID|Tag|Release|Verified Platform| Update date|
|-|-|-|-|-| |-|-|-|-|-|
|24e86cae7219b0f3ede1d5abdf5bf3ad515cccb8|b5377 |[llama-b5377-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b5377/llama-b5377-bin-win-sycl-x64.zip) |ArcB580/Linux/oneAPI 2025.1<br>LNL Arc GPU/Windows 11/oneAPI 2025.1.1|2025-05-15| |24e86cae7219b0f3ede1d5abdf5bf3ad515cccb8|b5377 |[llama-b5377-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b5377/llama-b5377-bin-win-sycl-x64.zip) |Arc B580/Linux/oneAPI 2025.1<br>LNL Arc GPU/Windows 11/oneAPI 2025.1.1|2025-05-15|
|3bcd40b3c593d14261fb2abfabad3c0fb5b9e318|b4040 |[llama-b4040-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b4040/llama-b4040-bin-win-sycl-x64.zip) |Arc770/Linux/oneAPI 2024.1<br>MTL Arc GPU/Windows 11/oneAPI 2024.1| 2024-11-19| |3bcd40b3c593d14261fb2abfabad3c0fb5b9e318|b4040 |[llama-b4040-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b4040/llama-b4040-bin-win-sycl-x64.zip) |Arc A770/Linux/oneAPI 2024.1<br>MTL Arc GPU/Windows 11/oneAPI 2024.1| 2024-11-19|
|fb76ec31a9914b7761c1727303ab30380fd4f05c|b3038 |[llama-b3038-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b3038/llama-b3038-bin-win-sycl-x64.zip) |Arc770/Linux/oneAPI 2024.1<br>MTL Arc GPU/Windows 11/oneAPI 2024.1|| |fb76ec31a9914b7761c1727303ab30380fd4f05c|b3038 |[llama-b3038-bin-win-sycl-x64.zip](https://github.com/ggml-org/llama.cpp/releases/download/b3038/llama-b3038-bin-win-sycl-x64.zip) |Arc A770/Linux/oneAPI 2024.1<br>MTL Arc GPU/Windows 11/oneAPI 2024.1||
## News ## News
- 2026.02
- Remove support for Nvidia & AMD GPU, because the oneAPI plugin for Nvidia & AMD GPU is unavailable: download/installation channels are out of work. User can't build up the software for Nvidia & AMD GPU.
- 2025.11 - 2025.11
- Support malloc memory on device more than 4GB. - Support malloc memory on device more than 4GB.
@ -51,7 +53,7 @@ The following releases are verified and recommended:
|-|-|-|-| |-|-|-|-|
|PVC 1550|39|73|+87%| |PVC 1550|39|73|+87%|
|Flex 170|39|50|+28%| |Flex 170|39|50|+28%|
|Arc770|42|55|+30%| |Arc A770|42|55|+30%|
|MTL|13|16|+23%| |MTL|13|16|+23%|
|ARL-H|14|17|+21%| |ARL-H|14|17|+21%|
@ -62,7 +64,7 @@ The following releases are verified and recommended:
- Use oneDNN as the default GEMM library, improve the compatibility for new Intel GPUs. - Use oneDNN as the default GEMM library, improve the compatibility for new Intel GPUs.
- 2024.5 - 2024.5
- Performance is increased: 34 -> 37 tokens/s of llama-2-7b.Q4_0 on Arc770. - Performance is increased: 34 -> 37 tokens/s of llama-2-7b.Q4_0 on Arc A770.
- Arch Linux is verified successfully. - Arch Linux is verified successfully.
- 2024.4 - 2024.4
@ -111,14 +113,15 @@ On older Intel GPUs, you may try [OpenCL](/docs/backend/OPENCL.md) although the
|-------------------------------|---------|---------------------------------------| |-------------------------------|---------|---------------------------------------|
| Intel Data Center Max Series | Support | Max 1550, 1100 | | Intel Data Center Max Series | Support | Max 1550, 1100 |
| Intel Data Center Flex Series | Support | Flex 170 | | Intel Data Center Flex Series | Support | Flex 170 |
| Intel Arc Series | Support | Arc 770, 730M, Arc A750, B580 | | Intel Arc A-Series | Support | Arc A770, Arc A730M, Arc A750 |
| Intel Arc B-Series | Support | Arc B580 |
| Intel built-in Arc GPU | Support | built-in Arc GPU in Meteor Lake, Arrow Lake, Lunar Lake | | Intel built-in Arc GPU | Support | built-in Arc GPU in Meteor Lake, Arrow Lake, Lunar Lake |
| Intel iGPU | Support | iGPU in 13700k, 13400, i5-1250P, i7-1260P, i7-1165G7 | | Intel iGPU | Support | iGPU in 13700k, 13400, i5-1250P, i7-1260P, i7-1165G7 |
*Notes:* *Notes:*
- **Memory** - **Memory**
- The device memory is a limitation when running a large model. The loaded model size, *`llm_load_tensors: buffer_size`*, is displayed in the log when running `./bin/llama-cli`. - The device memory is a limitation when running a large model. The loaded model size, *`llm_load_tensors: buffer_size`*, is displayed in the log when running `./bin/llama-completion`.
- Please make sure the GPU shared memory from the host is large enough to account for the model's size. For e.g. the *llama-2-7b.Q4_0* requires at least 8.0GB for integrated GPU and 4.0GB for discrete GPU. - Please make sure the GPU shared memory from the host is large enough to account for the model's size. For e.g. the *llama-2-7b.Q4_0* requires at least 8.0GB for integrated GPU and 4.0GB for discrete GPU.
- **Execution Unit (EU)** - **Execution Unit (EU)**
@ -126,20 +129,7 @@ On older Intel GPUs, you may try [OpenCL](/docs/backend/OPENCL.md) although the
### Other Vendor GPU ### Other Vendor GPU
**Verified devices** NA
| Nvidia GPU | Status | Verified Model |
|--------------------------|-----------|----------------|
| Ampere Series | Supported | A100, A4000 |
| Ampere Series *(Mobile)* | Supported | RTX 40 Series |
| AMD GPU | Status | Verified Model |
|--------------------------|--------------|----------------|
| Radeon Pro | Experimental | W6800 |
| Radeon RX | Experimental | 6700 XT |
Note: AMD GPU support is highly experimental and is incompatible with F16.
Additionally, it only supports GPUs with a sub_group_size (warp size) of 32.
## Docker ## Docker
@ -148,11 +138,11 @@ The docker build option is currently limited to *Intel GPU* targets.
### Build image ### Build image
```sh ```sh
# Using FP16
docker build -t llama-cpp-sycl --build-arg="GGML_SYCL_F16=ON" --target light -f .devops/intel.Dockerfile .
# Using FP32 # Using FP32
docker build -t llama-cpp-sycl --build-arg="GGML_SYCL_F16=OFF" --target light -f .devops/intel.Dockerfile . docker build -t llama-cpp-sycl --build-arg="GGML_SYCL_F16=OFF" --target light -f .devops/intel.Dockerfile .
# Using FP16
docker build -t llama-cpp-sycl --build-arg="GGML_SYCL_F16=ON" --target light -f .devops/intel.Dockerfile .
``` ```
*Notes*: *Notes*:
@ -211,14 +201,6 @@ Platform #0: Intel(R) OpenCL HD Graphics
`-- Device #0: Intel(R) Iris(R) Xe Graphics [0x9a49] `-- Device #0: Intel(R) Iris(R) Xe Graphics [0x9a49]
``` ```
- **Nvidia GPU**
In order to target Nvidia GPUs through SYCL, please make sure the CUDA/CUBLAS native requirements *-found [here](README.md#cuda)-* are installed.
- **AMD GPU**
To target AMD GPUs with SYCL, the ROCm stack must be installed first.
2. **Install Intel® oneAPI Base toolkit** 2. **Install Intel® oneAPI Base toolkit**
SYCL backend depends on: SYCL backend depends on:
@ -247,23 +229,6 @@ Upon a successful installation, SYCL is enabled for the available intel devices,
|2025.1| |2025.1|
|2024.1| |2024.1|
- **Adding support to Nvidia GPUs**
**oneAPI Plugin**: In order to enable SYCL support on Nvidia GPUs, please install the [Codeplay oneAPI Plugin for Nvidia GPUs](https://developer.codeplay.com/products/oneapi/nvidia/download). User should also make sure the plugin version matches the installed base toolkit one *(previous step)* for a seamless "oneAPI on Nvidia GPU" setup.
**oneDNN**: The current oneDNN releases *(shipped with the oneAPI base-toolkit)* do not include the NVIDIA backend. Therefore, oneDNN must be compiled from source to enable the NVIDIA target:
```sh
git clone https://github.com/oneapi-src/oneDNN.git
cd oneDNN
cmake -GNinja -Bbuild-nvidia -DDNNL_CPU_RUNTIME=DPCPP -DDNNL_GPU_RUNTIME=DPCPP -DDNNL_GPU_VENDOR=NVIDIA -DONEDNN_BUILD_GRAPH=OFF -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
cmake --build build-nvidia --config Release
```
- **Adding support to AMD GPUs**
**oneAPI Plugin**: In order to enable SYCL support on AMD GPUs, please install the [Codeplay oneAPI Plugin for AMD GPUs](https://developer.codeplay.com/products/oneapi/amd/download). As with Nvidia GPUs, the user should also make sure the plugin version matches the installed base toolkit.
3. **Verify installation and environment** 3. **Verify installation and environment**
In order to check the available SYCL devices on the machine, please use the `sycl-ls` command. In order to check the available SYCL devices on the machine, please use the `sycl-ls` command.
@ -284,25 +249,6 @@ When targeting an intel GPU, the user should expect one or more devices among th
[opencl:gpu][opencl:2] Intel(R) OpenCL Graphics, Intel(R) UHD Graphics 730 OpenCL 3.0 NEO [24.39.31294] [opencl:gpu][opencl:2] Intel(R) OpenCL Graphics, Intel(R) UHD Graphics 730 OpenCL 3.0 NEO [24.39.31294]
``` ```
- **Nvidia GPU**
Similarly, user targeting Nvidia GPUs should expect at least one SYCL-CUDA device [`cuda:gpu`] as below:
```
[opencl:acc][opencl:0] Intel(R) FPGA Emulation Platform for OpenCL(TM), Intel(R) FPGA Emulation Device OpenCL 1.2 [2023.16.12.0.12_195853.xmain-hotfix]
[opencl:cpu][opencl:1] Intel(R) OpenCL, Intel(R) Xeon(R) Gold 6326 CPU @ 2.90GHz OpenCL 3.0 (Build 0) [2023.16.12.0.12_195853.xmain-hotfix]
[cuda:gpu][cuda:0] NVIDIA CUDA BACKEND, NVIDIA A100-PCIE-40GB 8.0 [CUDA 12.5]
```
- **AMD GPU**
For AMD GPUs we should expect at least one SYCL-HIP device [`hip:gpu`]:
```
[opencl:cpu][opencl:0] Intel(R) OpenCL, 12th Gen Intel(R) Core(TM) i9-12900K OpenCL 3.0 (Build 0) [2024.18.6.0.02_160000]
[hip:gpu][hip:0] AMD HIP BACKEND, AMD Radeon PRO W6800 gfx1030 [HIP 60140.9]
```
### II. Build llama.cpp ### II. Build llama.cpp
#### Intel GPU #### Intel GPU
@ -331,47 +277,6 @@ It is possible to come across some precision issues when running tests that stem
instructions, which can be circumvented by setting the environment variable `SYCL_PROGRAM_COMPILE_OPTIONS` instructions, which can be circumvented by setting the environment variable `SYCL_PROGRAM_COMPILE_OPTIONS`
as `-cl-fp32-correctly-rounded-divide-sqrt` as `-cl-fp32-correctly-rounded-divide-sqrt`
#### Nvidia GPU
The SYCL backend depends on [oneMath](https://github.com/uxlfoundation/oneMath) for Nvidia and AMD devices.
By default it is automatically built along with the project. A specific build can be provided by setting the CMake flag `-DoneMath_DIR=/path/to/oneMath/install/lib/cmake/oneMath`.
```sh
# Build LLAMA with Nvidia BLAS acceleration through SYCL
# Setting GGML_SYCL_DEVICE_ARCH is optional but can improve performance
GGML_SYCL_DEVICE_ARCH=sm_80 # Example architecture
# Option 1: Use FP32 (recommended for better performance in most cases)
cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DDNNL_DIR=/path/to/oneDNN/build-nvidia/install/lib/cmake/dnnl
# Option 2: Use FP16
cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=NVIDIA -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx -DGGML_SYCL_F16=ON -DDNNL_DIR=/path/to/oneDNN/build-nvidia/install/lib/cmake/dnnl
# build all binary
cmake --build build --config Release -j -v
```
It is possible to come across some precision issues when running tests that stem from using faster
instructions, which can be circumvented by passing the `-fno-fast-math` flag to the compiler.
#### AMD GPU
The SYCL backend depends on [oneMath](https://github.com/uxlfoundation/oneMath) for Nvidia and AMD devices.
By default it is automatically built along with the project. A specific build can be provided by setting the CMake flag `-DoneMath_DIR=/path/to/oneMath/install/lib/cmake/oneMath`.
```sh
# Build LLAMA with rocBLAS acceleration through SYCL
## AMD
# Use FP32, FP16 is not supported
# Find your GGML_SYCL_DEVICE_ARCH with rocminfo, under the key 'Name:'
GGML_SYCL_DEVICE_ARCH=gfx90a # Example architecture
cmake -B build -DGGML_SYCL=ON -DGGML_SYCL_TARGET=AMD -DGGML_SYCL_DEVICE_ARCH=${GGML_SYCL_DEVICE_ARCH} -DCMAKE_C_COMPILER=icx -DCMAKE_CXX_COMPILER=icpx
# build all binary
cmake --build build --config Release -j -v
```
### III. Run the inference ### III. Run the inference
#### Retrieve and prepare model #### Retrieve and prepare model
@ -422,16 +327,12 @@ Choose one of following methods to run.
- Use device 0: - Use device 0:
```sh ```sh
./examples/sycl/run-llama2.sh 0 ./examples/sycl/test.sh -mg 0
# OR
./examples/sycl/run-llama3.sh 0
``` ```
- Use multiple devices: - Use multiple devices:
```sh ```sh
./examples/sycl/run-llama2.sh ./examples/sycl/test.sh
# OR
./examples/sycl/run-llama3.sh
``` ```
2. Command line 2. Command line
@ -454,13 +355,13 @@ Examples:
- Use device 0: - Use device 0:
```sh ```sh
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -no-cnv -m models/llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 99 -sm none -mg 0 ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -no-cnv -m models/llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 99 -sm none -mg 0 --mmap
``` ```
- Use multiple devices: - Use multiple devices:
```sh ```sh
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -no-cnv -m models/llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 99 -sm layer ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -no-cnv -m models/llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:" -n 400 -e -ngl 99 -sm layer --mmap
``` ```
*Notes:* *Notes:*
@ -576,13 +477,13 @@ Or, use CMake presets to build:
```sh ```sh
cmake --preset x64-windows-sycl-release cmake --preset x64-windows-sycl-release
cmake --build build-x64-windows-sycl-release -j --target llama-cli cmake --build build-x64-windows-sycl-release -j --target llama-completion
cmake -DGGML_SYCL_F16=ON --preset x64-windows-sycl-release cmake -DGGML_SYCL_F16=ON --preset x64-windows-sycl-release
cmake --build build-x64-windows-sycl-release -j --target llama-cli cmake --build build-x64-windows-sycl-release -j --target llama-completion
cmake --preset x64-windows-sycl-debug cmake --preset x64-windows-sycl-debug
cmake --build build-x64-windows-sycl-debug -j --target llama-cli cmake --build build-x64-windows-sycl-debug -j --target llama-completion
``` ```
#### 3. Visual Studio #### 3. Visual Studio
@ -607,7 +508,7 @@ You can use Visual Studio to open the `llama.cpp` folder directly as a CMake pro
- For a minimal experimental setup, you can build only the inference executable using: - For a minimal experimental setup, you can build only the inference executable using:
```Powershell ```Powershell
cmake --build build --config Release -j --target llama-cli cmake --build build --config Release -j --target llama-completion
``` ```
##### - Generating a Visual Studio Solution ##### - Generating a Visual Studio Solution
@ -713,13 +614,7 @@ Choose one of following methods to run.
1. Script 1. Script
``` ```
examples\sycl\win-run-llama-2.bat examples\sycl\win-test.bat
```
or
```
examples\sycl\win-run-llama-3.bat
``` ```
2. Command line 2. Command line
@ -743,13 +638,13 @@ Examples:
- Use device 0: - Use device 0:
``` ```
build\bin\llama-cli.exe -no-cnv -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 99 -sm none -mg 0 build\bin\llama-completion.exe -no-cnv -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 99 -sm none -mg 0 --mmap
``` ```
- Use multiple devices: - Use multiple devices:
``` ```
build\bin\llama-cli.exe -no-cnv -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 99 -sm layer build\bin\llama-completion.exe -no-cnv -m models\llama-2-7b.Q4_0.gguf -p "Building a website can be done in 10 simple steps:\nStep 1:" -n 400 -e -ngl 99 -sm layer --mmap
``` ```
@ -775,15 +670,15 @@ use 1 SYCL GPUs: [0] with Max compute units:512
| Name | Value | Function | | Name | Value | Function |
|--------------------|---------------------------------------|---------------------------------------------| |--------------------|---------------------------------------|---------------------------------------------|
| GGML_SYCL | ON (mandatory) | Enable build with SYCL code path. | | GGML_SYCL | ON (mandatory) | Enable build with SYCL code path. |
| GGML_SYCL_TARGET | INTEL *(default)* \| NVIDIA \| AMD | Set the SYCL target device type. | | GGML_SYCL_TARGET | INTEL *(default)* | Set the SYCL target device type. |
| GGML_SYCL_DEVICE_ARCH | Optional (except for AMD) | Set the SYCL device architecture, optional except for AMD. Setting the device architecture can improve the performance. See the table [--offload-arch](https://github.com/intel/llvm/blob/sycl/sycl/doc/design/OffloadDesign.md#--offload-arch) for a list of valid architectures. | | GGML_SYCL_DEVICE_ARCH | Optional | Set the SYCL device architecture. Setting the device architecture can improve the performance. See the table [--offload-arch](https://github.com/intel/llvm/blob/sycl/sycl/doc/design/OffloadDesign.md#--offload-arch) for a list of valid architectures. |
| GGML_SYCL_F16 | OFF *(default)* \|ON *(optional)* | Enable FP16 build with SYCL code path. (1.) | | GGML_SYCL_F16 | OFF *(default)* \|ON *(optional)* | Enable FP16 build with SYCL code path. (1.) |
| GGML_SYCL_GRAPH | ON *(default)* \|OFF *(Optional)* | Enable build with [SYCL Graph extension](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc). | | GGML_SYCL_GRAPH | OFF *(default)* \|ON *(Optional)* | Enable build with [SYCL Graph extension](https://github.com/intel/llvm/blob/sycl/sycl/doc/extensions/experimental/sycl_ext_oneapi_graph.asciidoc). |
| GGML_SYCL_DNN | ON *(default)* \|OFF *(Optional)* | Enable build with oneDNN. | | GGML_SYCL_DNN | ON *(default)* \|OFF *(Optional)* | Enable build with oneDNN. |
| CMAKE_C_COMPILER | `icx` *(Linux)*, `icx/cl` *(Windows)* | Set `icx` compiler for SYCL code path. | | CMAKE_C_COMPILER | `icx` *(Linux)*, `icx/cl` *(Windows)* | Set `icx` compiler for SYCL code path. |
| CMAKE_CXX_COMPILER | `icpx` *(Linux)*, `icx` *(Windows)* | Set `icpx/icx` compiler for SYCL code path. | | CMAKE_CXX_COMPILER | `icpx` *(Linux)*, `icx` *(Windows)* | Set `icpx/icx` compiler for SYCL code path. |
1. FP16 is recommended for better prompt processing performance on quantized models. Performance is equivalent in text generation but set `GGML_SYCL_F16=OFF` if you are experiencing issues with FP16 builds. 1. FP32 or FP16 have different performance impact to LLM. Recommended to test them for better prompt processing performance on your models. You need to rebuild the code after change `GGML_SYCL_F16=OFF/ON`.
#### Runtime #### Runtime
@ -791,7 +686,7 @@ use 1 SYCL GPUs: [0] with Max compute units:512
|-------------------|------------------|---------------------------------------------------------------------------------------------------------------------------| |-------------------|------------------|---------------------------------------------------------------------------------------------------------------------------|
| GGML_SYCL_DEBUG | 0 (default) or 1 | Enable log function by macro: GGML_SYCL_DEBUG | | GGML_SYCL_DEBUG | 0 (default) or 1 | Enable log function by macro: GGML_SYCL_DEBUG |
| GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features for Intel GPUs. (Recommended to 1 for intel devices older than Gen 10) | | GGML_SYCL_DISABLE_OPT | 0 (default) or 1 | Disable optimize features for Intel GPUs. (Recommended to 1 for intel devices older than Gen 10) |
| GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because graph performance isn't yet better than non-graph performance. | | GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because SYCL Graph is still on development, no better performance. |
| GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. | | GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. |
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer | | ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |
| UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS | 0 (default) or 1 | Support malloc device memory more than 4GB.| | UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS | 0 (default) or 1 | Support malloc device memory more than 4GB.|

View File

@ -1,5 +1,5 @@
{ {
"version": 4, "version": 5,
"configurePresets": [ "configurePresets": [
{ {
"name": "arm64-android-snapdragon", "name": "arm64-android-snapdragon",
@ -16,7 +16,9 @@
"CMAKE_CXX_FLAGS_RELEASE": "-O3 -DNDEBUG", "CMAKE_CXX_FLAGS_RELEASE": "-O3 -DNDEBUG",
"CMAKE_C_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g", "CMAKE_C_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g",
"CMAKE_CXX_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g", "CMAKE_CXX_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g",
"HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}", "CMAKE_PREFIX_PATH": "$env{OPENCL_SDK_ROOT}",
"HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}",
"HEXAGON_TOOLS_ROOT": "$env{HEXAGON_TOOLS_ROOT}",
"PREBUILT_LIB_DIR": "android_aarch64", "PREBUILT_LIB_DIR": "android_aarch64",
"GGML_OPENMP": "OFF", "GGML_OPENMP": "OFF",
"GGML_LLAMAFILE": "OFF", "GGML_LLAMAFILE": "OFF",
@ -31,7 +33,15 @@
"name": "arm64-windows-snapdragon", "name": "arm64-windows-snapdragon",
"inherits": [ "base", "arm64-windows-llvm" ], "inherits": [ "base", "arm64-windows-llvm" ],
"cacheVariables": { "cacheVariables": {
"HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}", "CMAKE_C_FLAGS": "-march=armv8.7a+fp16 -fvectorize -ffp-model=fast -flto -D_GNU_SOURCE",
"CMAKE_CXX_FLAGS": "-march=armv8.7a+fp16 -fvectorize -ffp-model=fast -flto -D_GNU_SOURCE",
"CMAKE_C_FLAGS_RELEASE": "-O3 -DNDEBUG",
"CMAKE_CXX_FLAGS_RELEASE": "-O3 -DNDEBUG",
"CMAKE_C_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g",
"CMAKE_CXX_FLAGS_RELWITHDEBINFO": "-O3 -DNDEBUG -g",
"CMAKE_PREFIX_PATH": "$env{OPENCL_SDK_ROOT}",
"HEXAGON_SDK_ROOT": "$env{HEXAGON_SDK_ROOT}",
"HEXAGON_TOOLS_ROOT": "$env{HEXAGON_TOOLS_ROOT}",
"PREBUILT_LIB_DIR": "windows_aarch64", "PREBUILT_LIB_DIR": "windows_aarch64",
"GGML_OPENMP": "OFF", "GGML_OPENMP": "OFF",
"GGML_LLAMAFILE": "OFF", "GGML_LLAMAFILE": "OFF",

View File

@ -1,6 +1,8 @@
# Snapdragon-based Android devices # Snapdragon-based devices
## How to Build ## Setup
### Android
The easiest way to build llama.cpp for a Snapdragon-based Android device is using the toolchain Docker image (see github.com/snapdragon-toolchain). The easiest way to build llama.cpp for a Snapdragon-based Android device is using the toolchain Docker image (see github.com/snapdragon-toolchain).
This image includes Android NDK, OpenCL SDK, Hexagon SDK, CMake, etc. This image includes Android NDK, OpenCL SDK, Hexagon SDK, CMake, etc.
@ -12,7 +14,24 @@ This method works on Linux, macOS, and Windows. macOS and Windows users should i
[d]/> cd /workspace [d]/> cd /workspace
``` ```
The rest of the Android build process assumes that you're running inside the toolchain container. Note: The rest of the **Android** build process assumes that you're running inside the toolchain container.
### Windows On Snapdragon
Native Windows 11 arm64 builds has the following tools dependencies:
- MS Visual Studio 2026 (Community Edition or Pro)
- MSVC arm64 standard and runtime libraries
- UCRT and Driver Kit
- LLVM core libraries and Clang compiler (winget)
- CMake, Git, Python (winget)
- Hexagon SDK Community Edition 6.4 or later (see windows.md)
- OpenCL SDK 2.3 or later (see windows.md)
Note: The rest of the **Windows** build process assumes that you're running natively in Powershell.
Adapt below build commands accordingly.
## How to Build
Let's build llama.cpp with CPU, OpenCL, and Hexagon backends via CMake presets: Let's build llama.cpp with CPU, OpenCL, and Hexagon backends via CMake presets:
``` ```
@ -49,24 +68,26 @@ Preset CMake variables:
To generate an installable "package" simply use cmake --install: To generate an installable "package" simply use cmake --install:
``` ```
[d]/workspace> cmake --install build-snapdragon --prefix pkg-adb/llama.cpp [d]/workspace> cmake --install build-snapdragon --prefix pkg-snapdragon/llama.cpp
-- Install configuration: "Release" -- Install configuration: "Release"
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-cpu.so -- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-cpu.so
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-opencl.so -- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-opencl.so
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-hexagon.so -- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-hexagon.so
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v73.so -- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-htp-v73.so
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v75.so -- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-htp-v75.so
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v79.so -- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-htp-v79.so
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml-htp-v81.so -- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml-htp-v81.so
-- Installing: /workspace/pkg-adb/llama.cpp/lib/libggml.so -- Installing: /workspace/pkg-snapdragon/llama.cpp/lib/libggml.so
... ...
-- Installing: /workspace/pkg-adb/llama.cpp/bin/llama-bench -- Installing: /workspace/pkg-snapdragon/llama.cpp/bin/llama-bench
-- Installing: /workspace/pkg-adb/llama.cpp/bin/llama-cli -- Installing: /workspace/pkg-snapdragon/llama.cpp/bin/llama-cli
... ...
``` ```
## How to Install ## How to Install
### Android
For this step, your device needs to be configured for on-device development. For this step, your device needs to be configured for on-device development.
Please see https://developer.android.com/studio/debug/dev-options for details. Please see https://developer.android.com/studio/debug/dev-options for details.
@ -74,10 +95,10 @@ Once ADB is enabled, use `adb push` to install `pkg-snapdragon` on the device.
**Note that the toolchain Docker image doesn't have ADB and doesn't set up the ADB bridge. Please use native ADB on the host.** **Note that the toolchain Docker image doesn't have ADB and doesn't set up the ADB bridge. Please use native ADB on the host.**
``` ```
~/src/llama.cpp$ adb push pkg-adb/llama.cpp /data/local/tmp/ ~/src/llama.cpp$ adb push pkg-snapdragon/llama.cpp /data/local/tmp/
pkg-adb/llama.cpp/bin/: 67 files pushed, 0 skipped. 190.2 MB/s (919095042 bytes in 4.607s) pkg-snapdragon/llama.cpp/bin/: 67 files pushed, 0 skipped. 190.2 MB/s (919095042 bytes in 4.607s)
pkg-adb/llama.cpp/include/: 19 files pushed, 0 skipped. 20.5 MB/s (255173 bytes in 0.012s) pkg-snapdragon/llama.cpp/include/: 19 files pushed, 0 skipped. 20.5 MB/s (255173 bytes in 0.012s)
pkg-adb/llama.cpp/lib/: 16 files pushed, 0 skipped. 144.4 MB/s (43801382 bytes in 0.289s) pkg-snapdragon/llama.cpp/lib/: 16 files pushed, 0 skipped. 144.4 MB/s (43801382 bytes in 0.289s)
102 files pushed, 0 skipped. 186.9 MB/s (963151597 bytes in 4.914s) 102 files pushed, 0 skipped. 186.9 MB/s (963151597 bytes in 4.914s)
``` ```
@ -92,6 +113,11 @@ At this point, you should also install some models:
Llama-3.2-1B-Instruct-Q4_0.gguf: 1 file pushed, 0 skipped. 38.3 MB/s (773025920 bytes in 19.250s) Llama-3.2-1B-Instruct-Q4_0.gguf: 1 file pushed, 0 skipped. 38.3 MB/s (773025920 bytes in 19.250s)
``` ```
### Windows
All artifacts are already installed in the `pkg-snapdragon` folder.
To run, adapt below instructions to use Powershell scrits in `scripts/snapdragon/windows`.
## How to Run ## How to Run
The easiest way to run llama.cpp cli tools is using provided wrapper scripts that properly set up all required environment variables. The easiest way to run llama.cpp cli tools is using provided wrapper scripts that properly set up all required environment variables.

View File

@ -0,0 +1,161 @@
## Overview
The document covers procedures for installing the latest GPU and NPU drivers, and OpenCL and Hexagon SDKs.
In order to use Hexagon NPU on Snapdragon Windows devices the underlying HTP Ops libraries (e.g libggml-htp-v73.so)
must be included in the .cat file digitally signed with a trusted certificate.
This document covers details on how to generate personal certificate files (.pfx) and how to configure the system
to allow for test signatures (aka test-signing).
## Install the latest Adreno OpenCL SDK
Either use the trimmed down version (optimized for CI) from
https://github.com/snapdragon-toolchain/opencl-sdk/releases/download/v2.3.2/adreno-opencl-sdk-v2.3.2-arm64-wos.tar.xz
Or download the complete official version from
https://softwarecenter.qualcomm.com/catalog/item/Adreno_OpenCL_SDK?version=2.3.2
Unzip/untar the archive into
```
c:\Qualcomm\OpenCL_SDK\2.3.2
```
## Install the latest Hexagon SDK Community Edition
Either use the trimmed down version (optimized for CI) from
https://github.com/snapdragon-toolchain/hexagon-sdk/releases/download/v6.4.0.2/hexagon-sdk-v6.4.0.2-arm64-wos.tar.xz
Or download the complete official version from
https://softwarecenter.qualcomm.com/catalog/item/Hexagon_SDK?version=6.4.0.2
Unzip/untar the archive into
```
c:\Qualcomm\Hexagon_SDK\6.4.0.2
```
## Install the latest Adreno GPU driver
Download the driver from
https://softwarecenter.qualcomm.com/catalog/item/Windows_Graphics_Driver
After the automated installation and reboot please make sure that the GPU device shows up in the `Device Manager` (under 'Display Adapters`)
## Install the latest Qualcomm NPU driver
Download the driver from
https://softwarecenter.qualcomm.com/catalog/item/Qualcomm_HND
After the automated installation and reboot please make sure that the Hexagon NPU device shows up in the `Device Manager` (under `Neural Processors`).
If the device is not available you can try installing all components (`qcnspmcdm8380`, `qcnspmcdm8380_ext`) manually.
The components are extracted into
```
c:\QCDrivers\qcnspmcdm...
```
## Enable NPU driver test signatures
Please note that the following steps are required only for the Hexagon NPU.
Adreno GPU backend does not require test signatures.
### Enable testsigning
Use `bcdedit` to enable test-signing
```
> bcdedit /set TESTSIGNING ON
```
(Secure Boot may need to be disabled for this to work)
Make sure test-signing is enabled after reboot
```
> bcdedit /enum
...
testsigning Yes
...
```
For additional details see Microsoft guide at
https://learn.microsoft.com/en-us/windows-hardware/drivers/install/the-testsigning-boot-configuration-option
### Create personal certificate
The tools required for this procedure are available as part of Windows SDK and Windows Driver Kit which should be
installed as part of the MS Visual Studio.
They are typically located at
```
c:\Program Files (x86)\Windows Kits\10\bin\10.0.26100.0
```
(replace 10.0.26100.0 with correct version).
To create personal self-signed certificate run the following commands (either from cmd or power-shell):
```
> cd c:\Users\MyUser
> mkdir Certs
> cd Certs
> makecert -r -pe -ss PrivateCertStore -n CN=GGML.HTP.v1 -eku 1.3.6.1.5.5.7.3.3 -sv ggml-htp-v1.pvk ggml-htp-v1.cer
> pvk2pfx.exe -pvk ggml-htp-v1.pvk -spc ggml-htp-v1.cer -pfx ggml-htp-v1.pfx
```
(replace `MyUser` with your username).
Add this certificate to `Trusted Root Certification Authorities` and `Trusted Publishers` stores.
This can be done using `certlm` Certificate Manager tool.
Right click on the certificate store, select `All Tasks -> Import` and follow the prompts to import the certificate from the
PFX file you created above.
For additional details see Microsoft guide at
https://learn.microsoft.com/en-us/windows-hardware/drivers/install/introduction-to-test-signing
Make sure to save the PFX file, you will need it for the build procedures.
Please note that the same certificate can be used for signing any number of builds.
## Build Hexagon backend with signed HTP ops libraries
The overall Hexagon backend build procedure for Windows on Snapdragon is the same as for other platforms.
However, additional settings are required for generating and signing HTP Ops libraries.
```
> $env:OPENCL_SDK_ROOT="C:\Qualcomm\OpenCL_SDK\2.3.2"
> $env:HEXAGON_SDK_ROOT="C:\Qualcomm\Hexagon_SDK\6.4.0.2"
> $env:HEXAGON_TOOLS_ROOT="C:\Qualcomm\Hexagon_SDK\6.4.0.2\tools\HEXAGON_Tools\19.0.04"
> $env:HEXAGON_HTP_CERT="c:\Users\MyUsers\Certs\ggml-htp-v1.pfx"
> $env:WINDOWS_SDK_BIN="C:\Program Files (x86)\Windows Kits\10\bin\10.0.26100.0\arm64"
> cmake --preset arm64-windows-snapdragon-release -B build-wos
...
> cmake --install build-wos --prefix pkg-snapdragon
```
Once the build is complete HTP ops libraries will be installed like this
```
> dir pkg-snapdragon/lib
...
-a---- 1/22/2026 6:01 PM 187656 libggml-htp-v73.so
-a---- 1/22/2026 6:01 PM 191752 libggml-htp-v75.so
-a---- 1/22/2026 6:01 PM 187656 libggml-htp-v79.so
-a---- 1/22/2026 6:01 PM 187656 libggml-htp-v81.so
-a---- 1/22/2026 6:01 PM 4139 libggml-htp.cat
```
The .cat file, the signature and proper certicate installation can be verified with
```
> signtool.exe verify /v /pa .\pkg-snapdragon\lib\libggml-htp.cat
Verifying: .\pkg-snapdragon\lib\libggml-htp.cat
Signature Index: 0 (Primary Signature)
Hash of file (sha256): 9820C664DA59D5EAE31DBB664127FCDAEF59CDC31502496BC567544EC2F401CF
Signing Certificate Chain:
Issued to: GGML.HTP.v1
...
Successfully verified: .\pkg-snapdragon\lib\libggml-htp.cat
...
```

View File

@ -252,9 +252,7 @@ CUDA_VISIBLE_DEVICES="-0" ./build/bin/llama-server --model /srv/models/llama.ggu
The environment variable [`CUDA_SCALE_LAUNCH_QUEUES`](https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/environment-variables.html#cuda-scale-launch-queues) controls the size of CUDA's command buffer, which determines how many GPU operations can be queued before the CPU must wait for the GPU to catch up. A larger buffer reduces CPU-side stalls and allows more work to be queued on a GPU. The environment variable [`CUDA_SCALE_LAUNCH_QUEUES`](https://docs.nvidia.com/cuda/cuda-programming-guide/05-appendices/environment-variables.html#cuda-scale-launch-queues) controls the size of CUDA's command buffer, which determines how many GPU operations can be queued before the CPU must wait for the GPU to catch up. A larger buffer reduces CPU-side stalls and allows more work to be queued on a GPU.
**Default behavior:** llama.cpp automatically sets `CUDA_SCALE_LAUNCH_QUEUES=4x`, which increases the CUDA command buffer to 4 times its default size. This optimization is particularly beneficial for **Multi-GPU setups with pipeline parallelism**, where it significantly improves prompt processing throughput by allowing more operations to be enqueued across GPUs. Consider setting `CUDA_SCALE_LAUNCH_QUEUES=4x`, which increases the CUDA command buffer to 4 times its default size. This optimization is particularly beneficial for **Multi-GPU setups with pipeline parallelism**, where it significantly improves prompt processing throughput by allowing more operations to be enqueued across GPUs.
See PR [#19042](https://github.com/ggml-org/llama.cpp/pull/19042) for performance benchmarks and technical details.
### Unified Memory ### Unified Memory

View File

@ -9,7 +9,7 @@ Download [MiniCPM-o-2_6](https://huggingface.co/openbmb/MiniCPM-o-2_6) PyTorch m
### Build llama.cpp ### Build llama.cpp
Readme modification time: 20250206 Readme modification time: 20250206
If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md) If there are differences in usage, please refer to the official build [documentation](https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md)
Clone llama.cpp: Clone llama.cpp:
```bash ```bash

View File

@ -8,11 +8,11 @@ Download [MiniCPM-o-4](https://huggingface.co/openbmb/MiniCPM-o-4) PyTorch model
### Build llama.cpp ### Build llama.cpp
Readme modification time: 20250206 Readme modification time: 20250206
If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md) If there are differences in usage, please refer to the official build [documentation](https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md)
Clone llama.cpp: Clone llama.cpp:
```bash ```bash
git clone https://github.com/ggerganov/llama.cpp git clone https://github.com/ggml-org/llama.cpp
cd llama.cpp cd llama.cpp
``` ```

View File

@ -8,7 +8,7 @@ Download [MiniCPM-Llama3-V-2_5](https://huggingface.co/openbmb/MiniCPM-Llama3-V-
### Build llama.cpp ### Build llama.cpp
Readme modification time: 20250206 Readme modification time: 20250206
If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md) If there are differences in usage, please refer to the official build [documentation](https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md)
Clone llama.cpp: Clone llama.cpp:
```bash ```bash

View File

@ -8,7 +8,7 @@ Download [MiniCPM-V-2_6](https://huggingface.co/openbmb/MiniCPM-V-2_6) PyTorch m
### Build llama.cpp ### Build llama.cpp
Readme modification time: 20250206 Readme modification time: 20250206
If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md) If there are differences in usage, please refer to the official build [documentation](https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md)
Clone llama.cpp: Clone llama.cpp:
```bash ```bash

View File

@ -8,11 +8,11 @@ Download [MiniCPM-V-4](https://huggingface.co/openbmb/MiniCPM-V-4) PyTorch model
### Build llama.cpp ### Build llama.cpp
Readme modification time: 20250731 Readme modification time: 20250731
If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md) If there are differences in usage, please refer to the official build [documentation](https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md)
Clone llama.cpp: Clone llama.cpp:
```bash ```bash
git clone https://github.com/ggerganov/llama.cpp git clone https://github.com/ggml-org/llama.cpp
cd llama.cpp cd llama.cpp
``` ```

View File

@ -8,11 +8,11 @@ Download [MiniCPM-V-4_5](https://huggingface.co/openbmb/MiniCPM-V-4_5) PyTorch m
### Build llama.cpp ### Build llama.cpp
Readme modification time: 20250826 Readme modification time: 20250826
If there are differences in usage, please refer to the official build [documentation](https://github.com/ggerganov/llama.cpp/blob/master/docs/build.md) If there are differences in usage, please refer to the official build [documentation](https://github.com/ggml-org/llama.cpp/blob/master/docs/build.md)
Clone llama.cpp: Clone llama.cpp:
```bash ```bash
git clone https://github.com/ggerganov/llama.cpp git clone https://github.com/ggml-org/llama.cpp
cd llama.cpp cd llama.cpp
``` ```

View File

@ -97,7 +97,7 @@ Legend:
| SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ | | SILU | ❌ | ✅ | ✅ | 🟡 | 🟡 | 🟡 | ✅ | 🟡 | ✅ | ❌ | ❌ |
| SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | | SILU_BACK | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ |
| SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | | SIN | ❌ | ✅ | ✅ | ✅ | 🟡 | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ |
| SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | | 🟡 | ✅ | ❌ | ❌ | | SOFTPLUS | ❌ | ❌ | ✅ | 🟡 | 🟡 | ❌ | | 🟡 | ✅ | ❌ | ❌ |
| SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | | SOFT_MAX | ❌ | 🟡 | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ |
| SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ | | SOFT_MAX_BACK | ❌ | ❌ | 🟡 | 🟡 | ❌ | ❌ | 🟡 | ✅ | ❌ | ❌ | ❌ |
| SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ | | SOLVE_TRI | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | ❌ | 🟡 | ❌ | ❌ | ❌ |
@ -113,8 +113,8 @@ Legend:
| SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ | | SWIGLU_OAI | ❌ | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ | | TANH | ❌ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | 🟡 | ✅ | ❌ | ❌ |
| TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | | TIMESTEP_EMBEDDING | ❌ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ✅ | ❌ | ❌ | ❌ |
| TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | | 🟡 | ✅ | ❌ | ❌ | | TOP_K | ❌ | ❌ | ✅ | ❌ | ✅ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | | ✅ | ❌ | ❌ | ❌ | | TRI | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | | ✅ | ❌ | ❌ | ❌ |
| TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ | | TRUNC | ❌ | ❌ | ✅ | 🟡 | ❌ | ❌ | 🟡 | 🟡 | ✅ | ❌ | ❌ |
| UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | ❌ | | UPSCALE | ❌ | 🟡 | ✅ | ✅ | 🟡 | 🟡 | 🟡 | 🟡 | ❌ | ❌ | ❌ |
| XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ | | XIELU | ❌ | ❌ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ❌ | ❌ |

File diff suppressed because it is too large Load Diff

View File

@ -6,7 +6,7 @@ llama.cpp supports speculative decoding, a technique that can significantly acce
## Implementations ## Implementations
The `llama-server` application supports several implementations of speculative decoding: The `llama-server` application supports several implementations of speculative decoding. An implementation with draft model can be mixed with an implementation without draft model.
### Draft Model (`draft`) ### Draft Model (`draft`)
@ -32,12 +32,21 @@ An example to use this approach can be the rewriting of source code by a LLM.
This implementation looks for the last n-gram in history that matches the current n-gram and creates a draft using the m tokens following the matched n-gram. It is the simplest self-speculative approach with minimal overhead. This implementation looks for the last n-gram in history that matches the current n-gram and creates a draft using the m tokens following the matched n-gram. It is the simplest self-speculative approach with minimal overhead.
```
llama-server [...] --spec-type ngram-simple --draft-max 64
```
#### n-gram Map Key (`ngram-map-k`) #### n-gram Map Key (`ngram-map-k`)
This implementation looks for the current n-gram of size n (called the _key_) in the token history. If the key n-gram is followed by the same m tokens (called the _mgram_) multiple times, it creates a draft using these m tokens. This approach requires a minimum number of occurrences (argument `--spec-ngram-min-hits`) before generating drafts. This implementation looks for the current n-gram of size n (called the _key_) in the token history. If the key n-gram is followed by the same m tokens (called the _mgram_) multiple times, it creates a draft using these m tokens. This approach requires a minimum number of occurrences (argument `--spec-ngram-min-hits`, default is 1) before generating drafts.
The number of accepted tokens is stored for each used n-gram. The number of accepted tokens is stored for each used n-gram.
**Example:**
```
llama-server [...] --spec-type ngram-map-k --draft-max 64
```
#### n-gram Map Key-4-Values (`ngram-map-k4v`) #### n-gram Map Key-4-Values (`ngram-map-k4v`)
This experimental implementation looks for the current n-gram of size n (called the _key_) in the token history. For each key, up to four _values_ (n-grams of size m, called _mgrams_) are tracked. An internal statistic counts the occurrences of each mgram after the key n-gram. If one mgram is significantly more frequent than the others, it is used as the draft. This experimental implementation looks for the current n-gram of size n (called the _key_) in the token history. For each key, up to four _values_ (n-grams of size m, called _mgrams_) are tracked. An internal statistic counts the occurrences of each mgram after the key n-gram. If one mgram is significantly more frequent than the others, it is used as the draft.
@ -45,17 +54,65 @@ This experimental implementation looks for the current n-gram of size n (called
The number of accepted tokens is stored for each used n-gram. The number of accepted tokens is stored for each used n-gram.
**Example:** Server options to be used if there are a lot of longer repetitions. **Example:** Server options to be used if there are a lot of longer repetitions.
```bash ```
llama-server [...] --spec-type ngram-map-k4v --spec-ngram-size-n 8 --spec-ngram-size-m 8 --spec-ngram-min-hits 2 llama-server [...] --spec-type ngram-map-k4v --spec-ngram-size-n 8 --spec-ngram-size-m 8 --spec-ngram-min-hits 2 --draft-max 64
``` ```
### n-gram Mod (`ngram-mod`)
Add basic ngram hasher for speculative decoding:
- For each ngram, compute a hash using LCG
- For each computed hash, store the next token
- During speculation, iteratively compute the rolling hash of the last n tokens and pick the next token from the storage
Some characteristics:
- Lightweight (~16 MB)
- Constant memory and complexity
- Can generate variable draft lengths (i.e. m is not fixed)
Currently, a single hash pool is shared across all server slots, so different requests can benefit from each other.
**Sample usage:**
```
# notes:
# - small `n` are not recommended
# - MoEs require long drafts
# - dense models: can reduce `--draft-min` and `--draft-max`
llama-server ... --spec-type ngram-mod --spec-ngram-size-n 24 --draft-min 48 --draft-max 64
```
Applications:
- Iterating over a block of text/code (e.g. in llama.vim)
- Reasoning models (when they have to repeat their thinking in the final answer)
- Summarization
Example Video:
- See #19164
### Differences between ngram-simple, ngram-map and ngram-mod
- ngram-simple looks for a previous matching n-gram and inserts the following m-gram.
- ngram-map-k looks for a previous matching n-gram and inserts the following m-gram but uses an internal hash-map of n-grams in the current context window.
- ngram-mod uses a hash pool which is shared across all server slots. The hash pool is a map from n-gram hash to the next token (not the next m-gram as in ngram-map).
## Command-Line Options ## Command-Line Options
If a draft model is combined with a draftless decoding the draftless decoding has higher precedence. If a draft model is combined with a draftless decoding the draftless decoding has higher precedence.
``` ```
--spec-type [none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v] --draft, --draft-n, --draft-max N number of tokens to draft for speculative decoding (default: 16)
(env: LLAMA_ARG_DRAFT_MAX)
--draft-min, --draft-n-min N minimum number of draft tokens to use for speculative decoding
(default: 0)
(env: LLAMA_ARG_DRAFT_MIN)
[...]
--spec-type [none|ngram-cache|ngram-simple|ngram-map-k|ngram-map-k4v|ngram-mod]
type of speculative decoding to use when no draft model is provided type of speculative decoding to use when no draft model is provided
(default: none) (default: none)
--spec-ngram-size-n N ngram size N for ngram-simple/ngram-map speculative decoding, length --spec-ngram-size-n N ngram size N for ngram-simple/ngram-map speculative decoding, length
@ -78,6 +135,7 @@ Specifies a type of speculative decoding without draft model.
| `ngram-simple` | Use simple n-gram pattern matching | | `ngram-simple` | Use simple n-gram pattern matching |
| `ngram-map-k` | Use n-gram pattern matching with n-gram-keys | | `ngram-map-k` | Use n-gram pattern matching with n-gram-keys |
| `ngram-map-k4v` | Use n-gram pattern matching with n-gram-keys and up to four m-gram values (experimental) | | `ngram-map-k4v` | Use n-gram pattern matching with n-gram-keys and up to four m-gram values (experimental) |
| `ngram-mod` | Use basic ngram hasher for speculative decoding with shared pool |
**Example:** Server-instance used to refactor source code. **Example:** Server-instance used to refactor source code.
```bash ```bash
@ -112,9 +170,15 @@ statistics ngram_simple: #calls = 15, #gen drafts = 5, #acc drafts = 5, #gen tok
statistics draft: #calls = 10, #gen drafts = 10, #acc drafts = 10, #gen tokens = 110, #acc tokens = 98 statistics draft: #calls = 10, #gen drafts = 10, #acc drafts = 10, #gen tokens = 110, #acc tokens = 98
``` ```
```
draft acceptance rate = 0.70312 ( 90 accepted / 128 generated)
statistics ngram_mod: #calls = 810, #gen drafts = 15, #acc drafts = 15, #gen tokens = 960, #acc tokens = 730, dur(b,g,a) = 0.149, 0.347, 0.005 ms
```
- `#calls`: number of calls of this implementations - `#calls`: number of calls of this implementations
- `#gen drafts`: number of drafts generated by this implementation - `#gen drafts`: number of drafts generated by this implementation
- `#acc drafts`: number of drafts accepted (partially) by the main model - `#acc drafts`: number of drafts accepted (partially) by the main model
- `#gen tokens`: number of tokens generated by this implementation (including rejected tokens) - `#gen tokens`: number of tokens generated by this implementation (including rejected tokens)
- `#acc tokens`: number of tokens accepted by the main model - `#acc tokens`: number of tokens accepted by the main model
- `dur(b,g,a): durations of begin (new prompt), generation and accumulation (process acceptance).

View File

@ -1,7 +1,7 @@
# Migration notice for binary filenames # Migration notice for binary filenames
> [!IMPORTANT] > [!IMPORTANT]
[2024 Jun 12] Binaries have been renamed w/ a `llama-` prefix. `main` is now `llama-cli`, `server` is `llama-server`, etc (https://github.com/ggerganov/llama.cpp/pull/7809) [2024 Jun 12] Binaries have been renamed w/ a `llama-` prefix. `main` is now `llama-cli`, `server` is `llama-server`, etc (https://github.com/ggml-org/llama.cpp/pull/7809)
This migration was important, but it is a breaking change that may not always be immediately obvious to users. This migration was important, but it is a breaking change that may not always be immediately obvious to users.

View File

@ -28,7 +28,7 @@ int main(int argc, char** argv) {
fprintf(stdout, "\n"); fprintf(stdout, "\n");
fprintf(stdout, "WARNING: The binary '%s' is deprecated.\n", filename.c_str()); fprintf(stdout, "WARNING: The binary '%s' is deprecated.\n", filename.c_str());
fprintf(stdout, " Please use '%s' instead.\n", replacement_filename.c_str()); fprintf(stdout, " Please use '%s' instead.\n", replacement_filename.c_str());
fprintf(stdout, " See https://github.com/ggerganov/llama.cpp/tree/master/examples/deprecation-warning/README.md for more information.\n"); fprintf(stdout, " See https://github.com/ggml-org/llama.cpp/tree/master/examples/deprecation-warning/README.md for more information.\n");
fprintf(stdout, "\n"); fprintf(stdout, "\n");
return EXIT_FAILURE; return EXIT_FAILURE;

View File

@ -402,7 +402,7 @@ class SchemaConverter:
Transforms a regular expression pattern into a GBNF rule. Transforms a regular expression pattern into a GBNF rule.
Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions Input: https://json-schema.org/understanding-json-schema/reference/regular_expressions
Output: https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md Output: https://github.com/ggml-org/llama.cpp/blob/master/grammars/README.md
Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers. Unsupported features: negative/positive lookaheads, greedy/non-greedy modifiers.

View File

@ -50,6 +50,12 @@ int main(int argc, char ** argv) {
const int N = 5; // n-gram size const int N = 5; // n-gram size
const int G = 15; // max verification n-grams const int G = 15; // max verification n-grams
// lookahead requires W + G + 1 sequences for parallel Jacobi decoding
params.n_parallel = W + G + 1;
// unified KV cache is required for coupled sequences in batch splitting
params.kv_unified = true;
// init llama.cpp // init llama.cpp
llama_backend_init(); llama_backend_init();
llama_numa_init(params.numa); llama_numa_init(params.numa);
@ -115,7 +121,7 @@ int main(int argc, char ** argv) {
// seq_id == 0 : the current input token // seq_id == 0 : the current input token
// seq_id [1, W] : tokens from the past N - 1 Jacobi iterations // seq_id [1, W] : tokens from the past N - 1 Jacobi iterations
// seq_id [W + 1, W + G] : verification n-grams // seq_id [W + 1, W + G] : verification n-grams
llama_batch batch = llama_batch_init(params.n_ctx, 0, W + G + 1); llama_batch batch = llama_batch_init(llama_n_ctx(ctx), 0, W + G + 1);
// target model sampling context // target model sampling context
struct common_sampler * smpl = common_sampler_init(model, params.sampling); struct common_sampler * smpl = common_sampler_init(model, params.sampling);

View File

@ -106,7 +106,7 @@ int main(int argc, char ** argv){
std::vector<llama_token> draft; std::vector<llama_token> draft;
llama_batch batch_tgt = llama_batch_init(params.n_ctx, 0, 1); llama_batch batch_tgt = llama_batch_init(llama_n_ctx(ctx), 0, 1);
const auto t_dec_start = ggml_time_us(); const auto t_dec_start = ggml_time_us();

View File

@ -33,11 +33,14 @@ DEVICE ?= auto
causal-convert-model-bf16: OUTTYPE=bf16 causal-convert-model-bf16: OUTTYPE=bf16
causal-convert-model-bf16: causal-convert-model causal-convert-model-bf16: causal-convert-model
causal-convert-model-debug: DEBUG=--debug
causal-convert-model-debug: causal-convert-model
causal-convert-model: causal-convert-model:
$(call validate_model_path,causal-convert-model) $(call validate_model_path,causal-convert-model)
@MODEL_NAME="$(MODEL_NAME)" OUTTYPE="$(OUTTYPE)" MODEL_PATH="$(MODEL_PATH)" \ @MODEL_NAME="$(MODEL_NAME)" OUTTYPE="$(OUTTYPE)" MODEL_PATH="$(MODEL_PATH)" \
METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \ METADATA_OVERRIDE="$(METADATA_OVERRIDE)" \
./scripts/causal/convert-model.sh ./scripts/causal/convert-model.sh $(DEBUG)
causal-convert-mm-model-bf16: OUTTYPE=bf16 causal-convert-mm-model-bf16: OUTTYPE=bf16
causal-convert-mm-model-bf16: MM_OUTTYPE=f16 causal-convert-mm-model-bf16: MM_OUTTYPE=f16

View File

@ -4,12 +4,17 @@ set -e
# Parse command line arguments # Parse command line arguments
MMPROJ="" MMPROJ=""
DEBUG=""
while [[ $# -gt 0 ]]; do while [[ $# -gt 0 ]]; do
case $1 in case $1 in
--mmproj) --mmproj)
MMPROJ="--mmproj" MMPROJ="--mmproj"
shift shift
;; ;;
--debug)
DEBUG="1"
shift
;;
*) *)
shift shift
;; ;;
@ -28,7 +33,12 @@ echo "Data type: ${TYPE}"
echo "Converted model path:: ${CONVERTED_MODEL}" echo "Converted model path:: ${CONVERTED_MODEL}"
echo "Metadata override: ${METADATA_OVERRIDE}" echo "Metadata override: ${METADATA_OVERRIDE}"
CMD_ARGS=("python" "../../convert_hf_to_gguf.py" "--verbose") if [[ -n "$DEBUG" ]]; then
CMD_ARGS=("python" "-m" "pdb")
else
CMD_ARGS=("python")
fi
CMD_ARGS+=("../../convert_hf_to_gguf.py" "--verbose")
CMD_ARGS+=("${MODEL_PATH}") CMD_ARGS+=("${MODEL_PATH}")
CMD_ARGS+=("--outfile" "${CONVERTED_MODEL}") CMD_ARGS+=("--outfile" "${CONVERTED_MODEL}")
CMD_ARGS+=("--outtype" "${TYPE}") CMD_ARGS+=("--outtype" "${TYPE}")

View File

@ -0,0 +1,159 @@
#!/usr/bin/env python3
import argparse
import json
import os
import re
import sys
from pathlib import Path
from typing import Optional
from safetensors import safe_open
MODEL_SAFETENSORS_FILE = "model.safetensors"
MODEL_SAFETENSORS_INDEX = "model.safetensors.index.json"
def get_weight_map(model_path: Path) -> Optional[dict[str, str]]:
index_file = model_path / MODEL_SAFETENSORS_INDEX
if index_file.exists():
with open(index_file, 'r') as f:
index = json.load(f)
return index.get("weight_map", {})
return None
def get_all_tensor_names(model_path: Path) -> list[str]:
weight_map = get_weight_map(model_path)
if weight_map is not None:
return list(weight_map.keys())
single_file = model_path / MODEL_SAFETENSORS_FILE
if single_file.exists():
try:
with safe_open(single_file, framework="pt", device="cpu") as f:
return list(f.keys())
except Exception as e:
print(f"Error reading {single_file}: {e}")
sys.exit(1)
print(f"Error: No safetensors files found in {model_path}")
sys.exit(1)
def find_tensor_file(model_path: Path, tensor_name: str) -> Optional[str]:
weight_map = get_weight_map(model_path)
if weight_map is not None:
return weight_map.get(tensor_name)
single_file = model_path / MODEL_SAFETENSORS_FILE
if single_file.exists():
return single_file.name
return None
def normalize_tensor_name(tensor_name: str) -> str:
normalized = re.sub(r'\.\d+\.', '.#.', tensor_name)
normalized = re.sub(r'\.\d+$', '.#', normalized)
return normalized
def list_all_tensors(model_path: Path, unique: bool = False):
tensor_names = get_all_tensor_names(model_path)
if unique:
seen = set()
for tensor_name in sorted(tensor_names):
normalized = normalize_tensor_name(tensor_name)
if normalized not in seen:
seen.add(normalized)
print(normalized)
else:
for tensor_name in sorted(tensor_names):
print(tensor_name)
def print_tensor_info(model_path: Path, tensor_name: str):
tensor_file = find_tensor_file(model_path, tensor_name)
if tensor_file is None:
print(f"Error: Could not find tensor '{tensor_name}' in model index")
print(f"Model path: {model_path}")
sys.exit(1)
file_path = model_path / tensor_file
try:
with safe_open(file_path, framework="pt", device="cpu") as f:
if tensor_name in f.keys():
tensor_slice = f.get_slice(tensor_name)
shape = tensor_slice.get_shape()
print(f"Tensor: {tensor_name}")
print(f"File: {tensor_file}")
print(f"Shape: {shape}")
else:
print(f"Error: Tensor '{tensor_name}' not found in {tensor_file}")
sys.exit(1)
except FileNotFoundError:
print(f"Error: The file '{file_path}' was not found.")
sys.exit(1)
except Exception as e:
print(f"An error occurred: {e}")
sys.exit(1)
def main():
parser = argparse.ArgumentParser(
description="Print tensor information from a safetensors model"
)
parser.add_argument(
"tensor_name",
nargs="?", # optional (if --list is used for example)
help="Name of the tensor to inspect"
)
parser.add_argument(
"-m", "--model-path",
type=Path,
help="Path to the model directory (default: MODEL_PATH environment variable)"
)
parser.add_argument(
"-l", "--list",
action="store_true",
help="List unique tensor patterns in the model (layer numbers replaced with #)"
)
args = parser.parse_args()
model_path = args.model_path
if model_path is None:
model_path_str = os.environ.get("MODEL_PATH")
if model_path_str is None:
print("Error: --model-path not provided and MODEL_PATH environment variable not set")
sys.exit(1)
model_path = Path(model_path_str)
if not model_path.exists():
print(f"Error: Model path does not exist: {model_path}")
sys.exit(1)
if not model_path.is_dir():
print(f"Error: Model path is not a directory: {model_path}")
sys.exit(1)
if args.list:
list_all_tensors(model_path, unique=True)
else:
if args.tensor_name is None:
print("Error: tensor_name is required when not using --list")
sys.exit(1)
print_tensor_info(model_path, args.tensor_name)
if __name__ == "__main__":
main()

View File

@ -18,13 +18,14 @@ CONTEXT=4096
#support malloc device memory more than 4GB. #support malloc device memory more than 4GB.
export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
LOAD_MODE='--mmap'
if [ $# -gt 0 ]; then if [ $# -gt 0 ]; then
GGML_SYCL_DEVICE=$1 GGML_SYCL_DEVICE=$1
echo "use $GGML_SYCL_DEVICE as main GPU" echo "use $GGML_SYCL_DEVICE as main GPU"
#use signle GPU only #use signle GPU only
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none ${LOAD_MODE}
else else
#use multiple GPUs with same max compute units #use multiple GPUs with same max compute units
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} ${LOAD_MODE}
fi fi

View File

@ -1,31 +0,0 @@
#!/usr/bin/env bash
# MIT license
# Copyright (C) 2025 Intel Corporation
# SPDX-License-Identifier: MIT
# If you want more control, DPC++ Allows selecting a specific device through the
# following environment variable
export ONEAPI_DEVICE_SELECTOR="level_zero:0"
source /opt/intel/oneapi/setvars.sh
#export GGML_SYCL_DEBUG=1
#ZES_ENABLE_SYSMAN=1, Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory. Recommended to use when --split-mode = layer.
INPUT_PROMPT="Building a website can be done in 10 simple steps:\nStep 1:"
MODEL_FILE=models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf
NGL=99 # Layers offloaded to the GPU. If the device runs out of memory, reduce this value according to the model you are using.
CONTEXT=4096
#support malloc device memory more than 4GB.
export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
if [ $# -gt 0 ]; then
GGML_SYCL_DEVICE=$1
echo "Using $GGML_SYCL_DEVICE as the main GPU"
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none
else
#use multiple GPUs with same max compute units
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-completion -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT}
fi

130
examples/sycl/test.sh Executable file
View File

@ -0,0 +1,130 @@
#!/bin/bash
# MIT license
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: MIT
Help() {
cat << EOF
Usage: $(basename "$0") [OPTIONS]
This script processes files with specified options.
Options:
-h, --help Display this help message and exit.
-c, --context <value> Set context length. Bigger need more memory.
-p, --promote <value> Prompt to start generation with.
-m, --model <value> Full model file path.
-mg,--main-gpu <value> Set main GPU ID (0 - n) for single GPU mode.
-sm,--split-mode <value> How to split the model across multiple GPUs, one of:
- none: use one GPU only
- layer (default): split layers and KV across GPUs
- row: split rows across GPUs
-ngl,--n-gpu-layers <value> Max. number of layers to store in VRAM (default: -1)
-lv,--log-verbosity <value> Set the verbosity threshold. Messages with a higher verbosity will be
ignored. Values:
- 0: generic output
- 1: error
- 2: warning
- 3: info
- 4: debug
EOF
}
BIN_FILE=./build/bin/llama-completion
SEED=0
GPUS_SETTING=""
INPUT_PROMPT="Building a website can be done in 10 simple steps:\nStep 1:"
MODEL_FILE=models/llama-2-7b.Q4_0.gguf
NGL=99
CONTEXT=4096
GGML_SYCL_DEVICE=-1
SPLIT_MODE=layer
LOG_VERBOSE=3
while [[ $# -gt 0 ]]; do
case "$1" in
-c|--context)
CONTEXT=$2
# Shift twice to consume both the option flag and its value
shift
shift
;;
-p|--promote)
# Option that is a simple flag (boolean)
INPUT_PROMPT="$2"
# Shift once to consume the option flag
shift
shift
;;
-m|--model)
MODEL_FILE="$2"
# Shift twice to consume both the option flag and its value
shift
shift
;;
-mg|--main-gpu)
GGML_SYCL_DEVICE=$2
SPLIT_MODE=none
# Shift twice to consume both the option flag and its value
shift
shift
;;
-sm|--split-mode)
SPLIT_MODE=$2
# Shift twice to consume both the option flag and its value
shift
shift
;;
-ngl|--n-gpu-layers)
NGL=$2
# Shift twice to consume both the option flag and its value
shift
shift
;;
-lv|--log-verbosity)
LOG_VERBOSE=$2
# Shift twice to consume both the option flag and its value
shift
shift
;;
-h|--help)
Help
exit 0
;;
*)
# Handle unknown options or stop processing options
echo "Invalid option: $1"
# Optional: exit script or shift to treat remaining as positional args
exit 1
;;
esac
done
source /opt/intel/oneapi/setvars.sh
#export GGML_SYCL_DEBUG=1
#ZES_ENABLE_SYSMAN=1, Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory. Recommended to use when --split-mode = layer.
#support malloc device memory more than 4GB.
export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
echo "UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=${UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS}"
if [ $GGML_SYCL_DEVICE -ne -1 ]; then
echo "Use $GGML_SYCL_DEVICE as main GPU"
#use signle GPU only
GPUS_SETTING="-mg $GGML_SYCL_DEVICE -sm ${SPLIT_MODE}"
export ONEAPI_DEVICE_SELECTOR="level_zero:${$GGML_SYCL_DEVICE}"
echo "ONEAPI_DEVICE_SELECTOR=${ONEAPI_DEVICE_SELECTOR}"
else
echo "Use all Intel GPUs, including iGPU & dGPU"
fi
echo "run cmd: ZES_ENABLE_SYSMAN=1 ${BIN_FILE} -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s ${SEED} -c ${CONTEXT} ${GPUS_SETTING} -lv ${LOG_VERBOSE} --mmap "
ZES_ENABLE_SYSMAN=1 ${BIN_FILE} -m ${MODEL_FILE} -no-cnv -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s ${SEED} -c ${CONTEXT} ${GPUS_SETTING} -lv ${LOG_VERBOSE} --mmap

View File

@ -7,5 +7,5 @@ set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:"
:: support malloc device memory more than 4GB. :: support malloc device memory more than 4GB.
set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
set LOAD_MODE="--mmap"
.\build\bin\llama-completion.exe -m models\llama-2-7b.Q4_0.gguf -no-cnv -p %INPUT2% -n 400 -e -ngl 99 -s 0 .\build\bin\llama-completion.exe -m models\llama-2-7b.Q4_0.gguf -no-cnv -p %INPUT2% -n 400 -e -ngl 99 -s 0 %LOAD_MODE%

View File

@ -7,5 +7,5 @@ set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:"
:: support malloc device memory more than 4GB. :: support malloc device memory more than 4GB.
set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1 set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
set LOAD_MODE="--mmap"
.\build\bin\llama-completion.exe -m models\Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -no-cnv -p %INPUT2% -n 400 -s 0 -e -ngl 99 .\build\bin\llama-completion.exe -m models\llama-2-7b.Q4_0.gguf -no-cnv -p %INPUT2% -n 400 -e -ngl 99 -s 0 %LOAD_MODE%

View File

@ -1,4 +1,4 @@
cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories. cmake_minimum_required(VERSION 3.14...3.28) # for add_link_options and implicit target directories.
project("ggml" C CXX ASM) project("ggml" C CXX ASM)
### GGML Version ### GGML Version

View File

@ -1,5 +1,5 @@
/* /*
* Copyright (c) 2023-2024 The ggml authors * Copyright (c) 2023-2026 The ggml authors
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to * of this software and associated documentation files (the "Software"), to

View File

@ -19,6 +19,9 @@ extern "C" {
// abort ggml_graph_compute when true // abort ggml_graph_compute when true
ggml_abort_callback abort_callback; ggml_abort_callback abort_callback;
void * abort_callback_data; void * abort_callback_data;
// use only reference implementations
bool use_ref;
}; };
// numa strategies // numa strategies
@ -132,6 +135,8 @@ extern "C" {
GGML_BACKEND_API void ggml_backend_cpu_set_threadpool (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool); GGML_BACKEND_API void ggml_backend_cpu_set_threadpool (ggml_backend_t backend_cpu, ggml_threadpool_t threadpool);
GGML_BACKEND_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data); GGML_BACKEND_API void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_callback abort_callback, void * abort_callback_data);
GGML_BACKEND_API void ggml_backend_cpu_set_use_ref(ggml_backend_t backend_cpu, bool use_ref);
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void); GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *, float *, int64_t); GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *, float *, int64_t);

View File

@ -7,8 +7,6 @@
extern "C" { extern "C" {
#endif #endif
#define GGML_REMOTING_FRONTEND_NAME "RemotingFrontend"
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_virtgpu_reg(); GGML_BACKEND_API ggml_backend_reg_t ggml_backend_virtgpu_reg();
#ifdef __cplusplus #ifdef __cplusplus

View File

@ -6,7 +6,7 @@
// This documentation is still a work in progress. // This documentation is still a work in progress.
// If you wish some specific topics to be covered, feel free to drop a comment: // If you wish some specific topics to be covered, feel free to drop a comment:
// //
// https://github.com/ggerganov/whisper.cpp/issues/40 // https://github.com/ggml-org/whisper.cpp/issues/40
// //
// ## Overview // ## Overview
// //

View File

@ -222,6 +222,7 @@ if (GGML_SCHED_NO_REALLOC)
endif() endif()
add_library(ggml add_library(ggml
ggml-backend-dl.cpp
ggml-backend-reg.cpp) ggml-backend-reg.cpp)
add_library(ggml::ggml ALIAS ggml) add_library(ggml::ggml ALIAS ggml)

View File

@ -0,0 +1,48 @@
#include "ggml-backend-dl.h"
#ifdef _WIN32
dl_handle * dl_load_library(const fs::path & path) {
// suppress error dialogs for missing DLLs
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
HMODULE handle = LoadLibraryW(path.wstring().c_str());
SetErrorMode(old_mode);
return handle;
}
void * dl_get_sym(dl_handle * handle, const char * name) {
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
void * p = (void *) GetProcAddress(handle, name);
SetErrorMode(old_mode);
return p;
}
const char * dl_error() {
return "";
}
#else
dl_handle * dl_load_library(const fs::path & path) {
dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL);
return handle;
}
void * dl_get_sym(dl_handle * handle, const char * name) {
return dlsym(handle, name);
}
const char * dl_error() {
const char *rslt = dlerror();
return rslt != nullptr ? rslt : "";
}
#endif

View File

@ -0,0 +1,45 @@
#pragma once
#ifdef _WIN32
# define WIN32_LEAN_AND_MEAN
# ifndef NOMINMAX
# define NOMINMAX
# endif
# include <windows.h>
# include <winevt.h>
#else
# include <dlfcn.h>
# include <unistd.h>
#endif
#include <filesystem>
namespace fs = std::filesystem;
#ifdef _WIN32
using dl_handle = std::remove_pointer_t<HMODULE>;
struct dl_handle_deleter {
void operator()(HMODULE handle) {
FreeLibrary(handle);
}
};
#else
using dl_handle = void;
struct dl_handle_deleter {
void operator()(void * handle) {
dlclose(handle);
}
};
#endif
using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;
dl_handle * dl_load_library(const fs::path & path);
void * dl_get_sym(dl_handle * handle, const char * name);
const char * dl_error();

View File

@ -1,5 +1,6 @@
#include "ggml-backend-impl.h" #include "ggml-backend-impl.h"
#include "ggml-backend.h" #include "ggml-backend.h"
#include "ggml-backend-dl.h"
#include "ggml-impl.h" #include "ggml-impl.h"
#include <algorithm> #include <algorithm>
#include <cstring> #include <cstring>
@ -98,72 +99,6 @@ static std::string path_str(const fs::path & path) {
} }
} }
#ifdef _WIN32
using dl_handle = std::remove_pointer_t<HMODULE>;
struct dl_handle_deleter {
void operator()(HMODULE handle) {
FreeLibrary(handle);
}
};
static dl_handle * dl_load_library(const fs::path & path) {
// suppress error dialogs for missing DLLs
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
HMODULE handle = LoadLibraryW(path.wstring().c_str());
SetErrorMode(old_mode);
return handle;
}
static void * dl_get_sym(dl_handle * handle, const char * name) {
DWORD old_mode = SetErrorMode(SEM_FAILCRITICALERRORS);
SetErrorMode(old_mode | SEM_FAILCRITICALERRORS);
void * p = (void *) GetProcAddress(handle, name);
SetErrorMode(old_mode);
return p;
}
static const char * dl_error() {
return "";
}
#else
using dl_handle = void;
struct dl_handle_deleter {
void operator()(void * handle) {
dlclose(handle);
}
};
static void * dl_load_library(const fs::path & path) {
dl_handle * handle = dlopen(path.string().c_str(), RTLD_NOW | RTLD_LOCAL);
return handle;
}
static void * dl_get_sym(dl_handle * handle, const char * name) {
return dlsym(handle, name);
}
static const char * dl_error() {
const char *rslt = dlerror();
return rslt != nullptr ? rslt : "";
}
#endif
using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;
struct ggml_backend_reg_entry { struct ggml_backend_reg_entry {
ggml_backend_reg_t reg; ggml_backend_reg_t reg;
dl_handle_ptr handle; dl_handle_ptr handle;

View File

@ -258,6 +258,7 @@ void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor *
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
if (backend->iface.set_tensor_async == NULL) { if (backend->iface.set_tensor_async == NULL) {
ggml_backend_synchronize(backend);
ggml_backend_tensor_set(tensor, data, offset, size); ggml_backend_tensor_set(tensor, data, offset, size);
} else { } else {
backend->iface.set_tensor_async(backend, tensor, data, offset, size); backend->iface.set_tensor_async(backend, tensor, data, offset, size);
@ -271,6 +272,7 @@ void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_ten
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
if (backend->iface.get_tensor_async == NULL) { if (backend->iface.get_tensor_async == NULL) {
ggml_backend_synchronize(backend);
ggml_backend_tensor_get(tensor, data, offset, size); ggml_backend_tensor_get(tensor, data, offset, size);
} else { } else {
backend->iface.get_tensor_async(backend, tensor, data, offset, size); backend->iface.get_tensor_async(backend, tensor, data, offset, size);

View File

@ -1,5 +1,5 @@
/* /*
* Copyright (c) 2023-2024 The ggml authors * Copyright (c) 2023-2026 The ggml authors
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to * of this software and associated documentation files (the "Software"), to

View File

@ -1,5 +1,5 @@
/* /*
* Copyright (c) 2023-2024 The ggml authors * Copyright (c) 2023-2026 The ggml authors
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to * of this software and associated documentation files (the "Software"), to

View File

@ -1,5 +1,5 @@
/* /*
* Copyright (c) 2023-2024 The ggml authors * Copyright (c) 2023-2026 The ggml authors
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to * of this software and associated documentation files (the "Software"), to

View File

@ -1,5 +1,5 @@
/** /**
* Copyright (c) 2023-2024 The ggml authors * Copyright (c) 2023-2026 The ggml authors
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to * of this software and associated documentation files (the "Software"), to

View File

@ -1,5 +1,5 @@
/* /*
* Copyright (c) 2023-2024 The ggml authors * Copyright (c) 2023-2026 The ggml authors
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to * of this software and associated documentation files (the "Software"), to

View File

@ -1,5 +1,5 @@
/* /*
* Copyright (c) 2023-2024 The ggml authors * Copyright (c) 2023-2026 The ggml authors
* *
* Permission is hereby granted, free of charge, to any person obtaining a copy * Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to * of this software and associated documentation files (the "Software"), to

View File

@ -268,9 +268,9 @@ static inline __m256 quad_fp16_delta_float(const float x0, const float y0, const
_mm_set1_ps(GGML_CPU_FP16_TO_FP32(x0) * GGML_CPU_FP16_TO_FP32(y0))); _mm_set1_ps(GGML_CPU_FP16_TO_FP32(x0) * GGML_CPU_FP16_TO_FP32(y0)));
} }
static inline __m256 quad_mx_delta_float(const int8_t x0, const float y0, const int8_t x1, const float y1) { static inline __m256 quad_mx_delta_float(const uint8_t x0, const float y0, const uint8_t x1, const float y1) {
return _mm256_set_m128(_mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x1) * GGML_CPU_FP16_TO_FP32(y1)), return _mm256_set_m128(_mm_set1_ps(GGML_CPU_E8M0_TO_FP32_HALF(x1) * GGML_CPU_FP16_TO_FP32(y1)),
_mm_set1_ps(GGML_E8M0_TO_FP32_HALF(x0) * GGML_CPU_FP16_TO_FP32(y0))); _mm_set1_ps(GGML_CPU_E8M0_TO_FP32_HALF(x0) * GGML_CPU_FP16_TO_FP32(y0)));
} }
#endif #endif
#elif defined(__SSSE3__) #elif defined(__SSSE3__)
@ -782,6 +782,7 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
__m256 accum1 = _mm256_setzero_ps(); __m256 accum1 = _mm256_setzero_ps();
__m256 accum2 = _mm256_setzero_ps(); __m256 accum2 = _mm256_setzero_ps();
for (; ib + 1 < nb; ib += 2) { for (; ib + 1 < nb; ib += 2) {
const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs); const __m128i q4bits_1 = _mm_loadu_si128((const __m128i*)x[ib + 0].qs);
const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs); const __m128i q4bits_2 = _mm_loadu_si128((const __m128i*)x[ib + 1].qs);
@ -795,10 +796,10 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2); const __m256i p16_2 = mul_add_epi8(q4b_2, q8b_2);
const __m256i p_1 = _mm256_madd_epi16(p16_1, mone); const __m256i p_1 = _mm256_madd_epi16(p16_1, mone);
const __m256i p_2 = _mm256_madd_epi16(p16_2, mone); const __m256i p_2 = _mm256_madd_epi16(p16_2, mone);
accum1 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 0].e)), const __m256 scale0 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 0].d)*GGML_CPU_E8M0_TO_FP32_HALF(x[ib + 0].e));
_mm256_cvtepi32_ps(p_1), accum1); const __m256 scale1 = _mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_CPU_E8M0_TO_FP32_HALF(x[ib + 1].e));
accum2 = _mm256_fmadd_ps(_mm256_set1_ps(GGML_CPU_FP16_TO_FP32(y[ib + 1].d)*GGML_E8M0_TO_FP32_HALF(x[ib + 1].e)), accum1 = _mm256_fmadd_ps(scale0, _mm256_cvtepi32_ps(p_1), accum1);
_mm256_cvtepi32_ps(p_2), accum2); accum2 = _mm256_fmadd_ps(scale1, _mm256_cvtepi32_ps(p_2), accum2);
} }
sumf = hsum_float_8(_mm256_add_ps(accum1, accum2)); sumf = hsum_float_8(_mm256_add_ps(accum1, accum2));
@ -830,7 +831,7 @@ void ggml_vec_dot_mxfp4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const vo
#endif #endif
for (; ib < nb; ++ib) { for (; ib < nb; ++ib) {
const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_E8M0_TO_FP32_HALF(x[ib].e); const float d = GGML_CPU_FP16_TO_FP32(y[ib].d)*GGML_CPU_E8M0_TO_FP32_HALF(x[ib].e);
int sumi1 = 0; int sumi1 = 0;
int sumi2 = 0; int sumi2 = 0;
for (int j = 0; j < QK_MXFP4/2; ++j) { for (int j = 0; j < QK_MXFP4/2; ++j) {
@ -3817,4 +3818,3 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc); ggml_vec_dot_iq4_xs_q8_K_generic(n, s, bs, vx, bx, vy, by, nrc);
#endif #endif
} }

View File

@ -24,6 +24,9 @@ struct ggml_compute_params {
void * wdata; void * wdata;
struct ggml_threadpool * threadpool; struct ggml_threadpool * threadpool;
// use reference implementation
bool use_ref;
}; };

View File

@ -5,7 +5,6 @@
#include "ggml-backend.h" #include "ggml-backend.h"
#include "traits.h" #include "traits.h"
#include "ggml-cpu-impl.h" #include "ggml-cpu-impl.h"
#include "ggml-cpu.h"
#include "ggml-impl.h" #include "ggml-impl.h"
#include "quants.h" #include "quants.h"
#include "ggml-threading.h" #include "ggml-threading.h"
@ -76,6 +75,9 @@
// precomputed f32 table for f16 (256 KB) (simd-mappings.h) // precomputed f32 table for f16 (256 KB) (simd-mappings.h)
float ggml_table_f32_f16[1 << 16]; float ggml_table_f32_f16[1 << 16];
// precomputed f32 table for e8m0 half (1 KB) (simd-mappings.h)
float ggml_table_f32_e8m0_half[1 << 8];
#if defined(__ARM_ARCH) #if defined(__ARM_ARCH)
struct ggml_arm_arch_features_type { struct ggml_arm_arch_features_type {
int sve_cnt; int sve_cnt;
@ -2867,12 +2869,20 @@ struct ggml_cplan ggml_graph_plan(
} break; } break;
case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_EXT:
{ {
const int64_t neq2 = node->src[0]->ne[2]; // number of query heads
const int64_t DK = node->src[1]->ne[0]; const int64_t DK = node->src[1]->ne[0];
const int64_t DV = node->src[2]->ne[0]; const int64_t DV = node->src[2]->ne[0];
// Tiled flash attention scratch (tile sizes defined in common.h) // Tiled flash attention scratch (tile sizes defined in common.h)
// Per-thread: Q_q + KQ + mask + VKQ32 + V32 + padding // Per-thread: Q_q + KQ + mask + VKQ32 + V32 + padding
cur = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks; size_t prefill = sizeof(float)*(GGML_FA_TILE_Q*DK + 2*GGML_FA_TILE_Q*GGML_FA_TILE_KV + GGML_FA_TILE_Q*DV + GGML_FA_TILE_KV*DV)*n_tasks;
// Decode path: n_kv_chunks = n_tasks (one chunk per thread)
// Per-thread: VKQ accmulator (DV), partial M, partial S + intra-thread scratch for V, Q and VKQ
size_t n_chunks = n_tasks;
size_t decode = sizeof(float)*(neq2*n_chunks*(2+DV) + n_tasks*(DK + 2*DV));
cur += MAX(prefill, decode);
} break; } break;
case GGML_OP_FLASH_ATTN_BACK: case GGML_OP_FLASH_ATTN_BACK:
{ {
@ -2929,11 +2939,12 @@ static thread_ret_t ggml_graph_compute_thread(void * data) {
set_numa_thread_affinity(state->ith); set_numa_thread_affinity(state->ith);
struct ggml_compute_params params = { struct ggml_compute_params params = {
/*.ith =*/ state->ith, /*.ith =*/ state->ith,
/*.nth =*/ atomic_load_explicit(&tp->n_graph, memory_order_relaxed) & GGML_THREADPOOL_N_THREADS_MASK, /*.nth =*/ atomic_load_explicit(&tp->n_graph, memory_order_relaxed) & GGML_THREADPOOL_N_THREADS_MASK,
/*.wsize =*/ cplan->work_size, /*.wsize =*/ cplan->work_size,
/*.wdata =*/ cplan->work_data, /*.wdata =*/ cplan->work_data,
/*.threadpool=*/ tp, /*.threadpool =*/ tp,
/*.use_ref =*/ cplan->use_ref,
}; };
GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d \n", state->ith, cplan, state->last_graph); GGML_PRINT_DEBUG("thread #%d compute-start cplan %p last-graph %d \n", state->ith, cplan, state->last_graph);
@ -3673,6 +3684,11 @@ void ggml_cpu_init(void) {
ggml_table_gelu_quick_f16[i] = GGML_CPU_FP32_TO_FP16(ggml_gelu_quick_f32(f)); ggml_table_gelu_quick_f16[i] = GGML_CPU_FP32_TO_FP16(ggml_gelu_quick_f32(f));
} }
// initialize E8M0 half table (256 entries)
for (int i = 0; i < (1 << 8); ++i) {
ggml_table_f32_e8m0_half[i] = GGML_E8M0_TO_FP32_HALF(i);
}
const uint64_t t_end = ggml_time_us(); UNUSED(t_end); const uint64_t t_end = ggml_time_us(); UNUSED(t_end);
GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0); GGML_PRINT_DEBUG("%s: GELU, Quick GELU, SILU and EXP tables initialized in %f ms\n", __func__, (t_end - t_start)/1000.0);

View File

@ -105,6 +105,8 @@ struct ggml_backend_cpu_context {
ggml_abort_callback abort_callback; ggml_abort_callback abort_callback;
void * abort_callback_data; void * abort_callback_data;
bool use_ref; // use reference implementation
}; };
static const char * ggml_backend_cpu_get_name(ggml_backend_t backend) { static const char * ggml_backend_cpu_get_name(ggml_backend_t backend) {
@ -143,6 +145,7 @@ static ggml_backend_graph_plan_t ggml_backend_cpu_graph_plan_create(ggml_backend
cpu_plan->cplan.abort_callback = cpu_ctx->abort_callback; cpu_plan->cplan.abort_callback = cpu_ctx->abort_callback;
cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data; cpu_plan->cplan.abort_callback_data = cpu_ctx->abort_callback_data;
cpu_plan->cplan.use_ref = cpu_ctx->use_ref;
return cpu_plan; return cpu_plan;
} }
@ -182,6 +185,7 @@ static enum ggml_status ggml_backend_cpu_graph_compute(ggml_backend_t backend, s
cplan.abort_callback = cpu_ctx->abort_callback; cplan.abort_callback = cpu_ctx->abort_callback;
cplan.abort_callback_data = cpu_ctx->abort_callback_data; cplan.abort_callback_data = cpu_ctx->abort_callback_data;
cplan.use_ref = cpu_ctx->use_ref;
return ggml_graph_compute(cgraph, &cplan); return ggml_graph_compute(cgraph, &cplan);
} }
@ -223,6 +227,7 @@ ggml_backend_t ggml_backend_cpu_init(void) {
ctx->work_size = 0; ctx->work_size = 0;
ctx->abort_callback = NULL; ctx->abort_callback = NULL;
ctx->abort_callback_data = NULL; ctx->abort_callback_data = NULL;
ctx->use_ref = false;
ggml_backend_t cpu_backend = new ggml_backend { ggml_backend_t cpu_backend = new ggml_backend {
/* .guid = */ ggml_backend_cpu_guid(), /* .guid = */ ggml_backend_cpu_guid(),
@ -270,6 +275,13 @@ void ggml_backend_cpu_set_abort_callback(ggml_backend_t backend_cpu, ggml_abort_
ctx->abort_callback_data = abort_callback_data; ctx->abort_callback_data = abort_callback_data;
} }
void ggml_backend_cpu_set_use_ref(ggml_backend_t backend_cpu, bool use_ref) {
GGML_ASSERT(ggml_backend_is_cpu(backend_cpu));
struct ggml_backend_cpu_context * ctx = (struct ggml_backend_cpu_context *)backend_cpu->context;
ctx->use_ref = use_ref;
}
// CPU backend - device // CPU backend - device
struct ggml_backend_cpu_device_context { struct ggml_backend_cpu_device_context {
@ -646,6 +658,9 @@ static void * ggml_backend_cpu_get_proc_address(ggml_backend_reg_t reg, const ch
if (strcmp(name, "ggml_backend_cpu_is_numa") == 0) { if (strcmp(name, "ggml_backend_cpu_is_numa") == 0) {
return (void *)ggml_is_numa; return (void *)ggml_is_numa;
} }
if (strcmp(name, "ggml_backend_cpu_set_use_ref") == 0) {
return (void *)ggml_backend_cpu_set_use_ref;
}
// threadpool - TODO: move to ggml-base // threadpool - TODO: move to ggml-base
if (strcmp(name, "ggml_threadpool_new") == 0) { if (strcmp(name, "ggml_threadpool_new") == 0) {

View File

@ -8042,12 +8042,14 @@ void ggml_compute_forward_top_k(
} }
} }
// ggml_compute_forward_flash_attn_ext
static void ggml_compute_forward_flash_attn_ext_f16_one_chunk( static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
const ggml_compute_params * params, const ggml_compute_params * params,
ggml_tensor * dst, ggml_tensor * dst,
int ir0, int ir1) { int ir0, int ir1,
int64_t ic_start, int64_t ic_end,
float * partials, int64_t partial_stride) {
const bool write_partials = (partials != nullptr);
const ggml_tensor * q = dst->src[0]; const ggml_tensor * q = dst->src[0];
const ggml_tensor * k = dst->src[1]; const ggml_tensor * k = dst->src[1];
const ggml_tensor * v = dst->src[2]; const ggml_tensor * v = dst->src[2];
@ -8124,7 +8126,6 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
int ith = params->ith; int ith = params->ith;
// loop over n_batch and n_head
for (int ir = ir0; ir < ir1; ++ir) { for (int ir = ir0; ir < ir1; ++ir) {
// q indices // q indices
const int iq3 = ir/(neq2*neq1); const int iq3 = ir/(neq2*neq1);
@ -8165,7 +8166,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
// loop over n_kv and n_head_kv // loop over n_kv and n_head_kv
// ref: https://arxiv.org/pdf/2112.05682.pdf // ref: https://arxiv.org/pdf/2112.05682.pdf
for (int64_t ic = 0; ic < nek1; ++ic) { for (int64_t ic = ic_start; ic < ic_end; ++ic) {
const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f; const float mv = mp ? slope*GGML_CPU_FP16_TO_FP32(mp[ic]) : 0.0f;
if (mv == -INFINITY) { if (mv == -INFINITY) {
continue; continue;
@ -8238,8 +8239,8 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
} }
} }
// sinks // sinks - apply only on the first kv-chunk
if (sinks) { if (sinks && ic_start == 0) {
const float s = ((float *)((char *) sinks->data))[h]; const float s = ((float *)((char *) sinks->data))[h];
float ms = 1.0f; float ms = 1.0f;
@ -8247,6 +8248,7 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
if (s > M) { if (s > M) {
ms = expf(M - s); ms = expf(M - s);
M = s;
ggml_vec_scale_f32(DV, VKQ32, ms); ggml_vec_scale_f32(DV, VKQ32, ms);
} else { } else {
vs = expf(s - M); vs = expf(s - M);
@ -8255,20 +8257,26 @@ static void ggml_compute_forward_flash_attn_ext_f16_one_chunk(
S = S*ms + vs; S = S*ms + vs;
} }
// V /= S if (write_partials) {
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S; // Write M, S, VKQ to partials for later reduction
ggml_vec_scale_f32(DV, VKQ32, S_inv); // partials layout: [M, S, VKQ[DV]] per query head
float * partial = partials + ir * partial_stride;
partial[0] = M;
partial[1] = S;
memcpy(partial + 2, VKQ32, DV * sizeof(float));
} else {
// V /= S
const float S_inv = S == 0.0f ? 0.0f : 1.0f/S;
ggml_vec_scale_f32(DV, VKQ32, S_inv);
// dst indices // dst indices
const int i1 = iq1; const int i1 = iq1;
const int i2 = iq2; const int i2 = iq2;
const int i3 = iq3; const int i3 = iq3;
// original // permute(0, 2, 1, 3)
//memcpy((char *) dst->data + (i1*nb1 + i2*nb2 + i3*nb3), V, nev0*sizeof(float)); memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
}
// permute(0, 2, 1, 3)
memcpy((char *) dst->data + (i3*ne2*ne1 + i2 + i1*ne1)*nb1, VKQ32, nb1);
} }
} }
@ -8546,6 +8554,78 @@ static void ggml_compute_forward_flash_attn_ext_tiled(
} }
} }
// Reduction function: combines partial results across KV chunks
// Partials layout in wdata: [n_q_heads][n_chunks][2 + DV]
static void ggml_flash_attn_ext_reduce_partials(
const ggml_compute_params * params,
ggml_tensor * dst,
const int64_t n_chunks,
const int64_t chunk_size) {
const ggml_tensor * q = dst->src[0];
const ggml_tensor * k = dst->src[1];
const ggml_tensor * v = dst->src[2];
const int64_t DK = k->ne[0];
const int64_t DV = v->ne[0];
const int64_t nek1 = k->ne[1];
const int64_t n_q_heads = q->ne[2];
const int ith = params->ith;
const int nth = params->nth;
const int64_t wdata_per_thread = DK + 2*DV + CACHE_LINE_SIZE_F32;
float * thread_wdata = (float *) params->wdata + ith * wdata_per_thread;
const int64_t partials_offset = nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
const int64_t partial_size = 2 + DV;
const float * partials_base = (const float *) params->wdata + partials_offset;
// Output layout
const int64_t ne1 = dst->ne[1];
const int64_t ne2 = dst->ne[2];
const size_t nb1 = dst->nb[1];
// Each thread reduces a subset of query heads
for (int64_t q_head = ith; q_head < n_q_heads; q_head += nth) {
float M_final = -INFINITY;
float S_final = 0.0f;
float * VKQ_final = thread_wdata;
memset(VKQ_final, 0, DV * sizeof(float));
// Combine partials from all chunks
for (int64_t chunk_idx = 0; chunk_idx < n_chunks; ++chunk_idx) {
const int64_t ic_start = chunk_idx * chunk_size;
if (ic_start >= nek1) continue;
const float * partial = partials_base + (q_head * n_chunks + chunk_idx) * partial_size;
const float M_chunk = partial[0];
const float S_chunk = partial[1];
const float * VKQ_chunk = partial + 2;
if (S_chunk == 0.0f) continue;
const float M_new = fmaxf(M_final, M_chunk);
const float scale_old = expf(M_final - M_new);
const float scale_new = expf(M_chunk - M_new);
for (int64_t d = 0; d < DV; ++d) {
VKQ_final[d] = VKQ_final[d] * scale_old + VKQ_chunk[d] * scale_new;
}
S_final = S_final * scale_old + S_chunk * scale_new;
M_final = M_new;
}
// Normalize and write to output
if (S_final != 0.0f) {
const float S_inv = 1.0f / S_final;
ggml_vec_scale_f32(DV, VKQ_final, S_inv);
}
// iq1=0, iq3=0 for decode
memcpy((char *) dst->data + (0*ne2*ne1 + q_head + 0*ne1)*nb1, VKQ_final, nb1);
}
}
static void ggml_compute_forward_flash_attn_ext_f16( static void ggml_compute_forward_flash_attn_ext_f16(
const ggml_compute_params * params, const ggml_compute_params * params,
ggml_tensor * dst) { ggml_tensor * dst) {
@ -8567,6 +8647,7 @@ static void ggml_compute_forward_flash_attn_ext_f16(
const int64_t DV = nev0; const int64_t DV = nev0;
const int64_t N = neq1; const int64_t N = neq1;
GGML_ASSERT(ne0 == DV); GGML_ASSERT(ne0 == DV);
GGML_ASSERT(ne2 == N); GGML_ASSERT(ne2 == N);
@ -8587,60 +8668,92 @@ static void ggml_compute_forward_flash_attn_ext_f16(
GGML_ASSERT(nb1 <= nb2); GGML_ASSERT(nb1 <= nb2);
GGML_ASSERT(nb2 <= nb3); GGML_ASSERT(nb2 <= nb3);
// parallelize by q rows using ggml_vec_dot_f32
// total rows in q
const int64_t nr = neq1*neq2*neq3;
// rows per thread
const int ith = params->ith; const int ith = params->ith;
const int nth = params->nth; const int nth = params->nth;
// disable for NUMA // When use_ref is set, force the vec-only reference implementation (no tiling, no KV-chunking)
const bool disable_chunking = ggml_is_numa(); const bool use_ref = params->use_ref;
// 4x chunks per thread
int nth_scaled = nth * 4;
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
if (nth == 1 || nchunk < nth || disable_chunking) {
nchunk = nth;
}
if (ith == 0) {
// Every thread starts at ith, so the first unprocessed chunk is nth. This save a bit of coordination right at the start.
ggml_threadpool_chunk_set(params->threadpool, nth);
}
ggml_barrier(params->threadpool);
// The number of elements in each chunk
const int64_t dr = (nr + nchunk - 1) / nchunk;
static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16); const bool kv_is_f32_or_f16 = (k->type == GGML_TYPE_F32 || k->type == GGML_TYPE_F16);
const bool use_tiled = (q->type == GGML_TYPE_F32 && const bool use_split_kv_path = !use_ref && (neq1 == 1 && neq3 == 1) && kv_is_f32_or_f16 && (k->type == v->type) && q->type == GGML_TYPE_F32 && nek1 >= 512;
kv_is_f32_or_f16 &&
k->type == v->type &&
nek1 % KV_TILE_SZ == 0 &&
neq1 >= Q_TILE_SZ); // Only use tiled for batch >= tile size
// The first chunk comes from our thread_id, the rest will get auto-assigned. if (use_split_kv_path) {
int current_chunk = ith; const int64_t chunk_size = (nek1 + nth - 1) / nth;
while (current_chunk < nchunk) { // Partials buffer layout: [q_head][kv_chunk][M, S, VKQ]
const int64_t ir0 = dr * current_chunk; const int64_t partial_size = 2 + DV;
const int64_t ir1 = MIN(ir0 + dr, nr); float * partials_base = (float *) params->wdata + nth * (DK + 2*DV + CACHE_LINE_SIZE_F32);
if (use_tiled) { const int64_t ic_start = ith * chunk_size;
ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1); const int64_t ic_end = std::min(ic_start + chunk_size, nek1);
const int64_t partial_stride = nth * partial_size;
float * chunk_partials = partials_base + ith * partial_size;
if (ic_start < nek1) {
for (int64_t q_head = 0; q_head < neq2; q_head++) {
ggml_compute_forward_flash_attn_ext_f16_one_chunk(
params, dst, q_head, q_head + 1, ic_start, ic_end,
chunk_partials, partial_stride);
}
} else { } else {
ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1); for (int64_t q_head = 0; q_head < neq2; q_head++) {
float * q_partials = chunk_partials + q_head * partial_stride;
q_partials[0] = -INFINITY; // M
q_partials[1] = 0.0f; // S
}
} }
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1); ggml_barrier(params->threadpool);
ggml_flash_attn_ext_reduce_partials(params, dst, nth, chunk_size);
} else {
// total rows in q
const int64_t nr = neq1*neq2*neq3;
// disable for NUMA
const bool disable_chunking = ggml_is_numa();
// 4x chunks per thread
int nth_scaled = nth * 4;
int64_t chunk_size = (nr + nth_scaled - 1) / nth_scaled;
int64_t nchunk = (nr + chunk_size - 1) / chunk_size;
if (nth == 1 || nchunk < nth || disable_chunking) {
nchunk = nth;
}
if (ith == 0) {
ggml_threadpool_chunk_set(params->threadpool, nth);
}
ggml_barrier(params->threadpool);
const int64_t dr = (nr + nchunk - 1) / nchunk;
static constexpr int64_t KV_TILE_SZ = ggml_fa_tile_config::KV;
static constexpr int64_t Q_TILE_SZ = ggml_fa_tile_config::Q;
const bool use_tiled = !use_ref &&
(q->type == GGML_TYPE_F32 &&
kv_is_f32_or_f16 &&
k->type == v->type &&
nek1 % KV_TILE_SZ == 0 &&
neq1 >= Q_TILE_SZ);
int current_chunk = ith;
while (current_chunk < nchunk) {
const int64_t ir0 = dr * current_chunk;
const int64_t ir1 = MIN(ir0 + dr, nr);
if (use_tiled) {
ggml_compute_forward_flash_attn_ext_tiled(params, dst, ir0, ir1);
} else {
ggml_compute_forward_flash_attn_ext_f16_one_chunk(params, dst, ir0, ir1, 0, nek1, nullptr, 0);
}
current_chunk = ggml_threadpool_chunk_add(params->threadpool, 1);
}
} }
} }

View File

@ -116,6 +116,17 @@ extern "C" {
// defined in ggml-cpu.c, initialized in ggml_cpu_init() // defined in ggml-cpu.c, initialized in ggml_cpu_init()
extern float ggml_table_f32_f16[1 << 16]; extern float ggml_table_f32_f16[1 << 16];
// precomputed f32 table for e8m0 half (1 KB)
// defined in ggml-cpu.c, initialized in ggml_cpu_init()
extern float ggml_table_f32_e8m0_half[1 << 8];
// Use lookup table for E8M0 on x86 (faster than bit manipulation)
#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__)
#define GGML_CPU_E8M0_TO_FP32_HALF(x) ggml_table_f32_e8m0_half[(uint8_t)(x)]
#else
#define GGML_CPU_E8M0_TO_FP32_HALF(x) GGML_E8M0_TO_FP32_HALF(x)
#endif
// On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32, // On ARM NEON, it's quicker to directly convert x -> x instead of calling into ggml_lookup_fp16_to_fp32,
// so we define GGML_CPU_FP16_TO_FP32 and GGML_CPU_FP32_TO_FP16 elsewhere for NEON. // so we define GGML_CPU_FP16_TO_FP32 and GGML_CPU_FP32_TO_FP16 elsewhere for NEON.
// This is also true for POWER9. // This is also true for POWER9.

View File

@ -1122,15 +1122,18 @@ struct ggml_tensor_extra_gpu {
#endif #endif
struct ggml_cuda_graph_node_properties { struct ggml_cuda_graph_node_properties {
void * node_address; void * node_data;
ggml_op node_op; ggml_op node_op;
enum ggml_type node_type;
int32_t flags; int32_t flags;
int64_t ne[GGML_MAX_DIMS]; int64_t ne[GGML_MAX_DIMS];
size_t nb[GGML_MAX_DIMS]; size_t nb[GGML_MAX_DIMS];
void * src_address[GGML_MAX_SRC]; void * src_data[GGML_MAX_SRC];
int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)]; int32_t op_params[GGML_MAX_OP_PARAMS / sizeof(int32_t)];
}; };
static_assert(std::is_trivial<ggml_cuda_graph_node_properties>::value, "ggml_cuda_graph_node_properties must be trivial");
struct ggml_cuda_graph { struct ggml_cuda_graph {
#ifdef USE_CUDA_GRAPH #ifdef USE_CUDA_GRAPH
~ggml_cuda_graph() { ~ggml_cuda_graph() {
@ -1150,6 +1153,12 @@ struct ggml_cuda_graph {
int number_consecutive_updates = 0; int number_consecutive_updates = 0;
std::vector<ggml_cuda_graph_node_properties> props; std::vector<ggml_cuda_graph_node_properties> props;
// these are extra tensors (inputs) that participate in the ggml graph but are not nodes
// they properties also have to match in order to be able to safely reuse a CUDA graph
// ref: https://github.com/ggml-org/llama.cpp/pull/18583
// ref: https://github.com/ggml-org/llama.cpp/pull/19165
std::vector<ggml_cuda_graph_node_properties> extra;
void record_update(bool use_graph, bool update_required) { void record_update(bool use_graph, bool update_required) {
if (use_graph && update_required) { if (use_graph && update_required) {
number_consecutive_updates++; number_consecutive_updates++;

View File

@ -310,8 +310,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
} }
} }
const bool V_is_K_view = V->view_src && (V->view_src == K || (V->view_src == K->view_src && V->view_offs == K->view_offs));
const int cc = ggml_cuda_info().devices[device].cc; const int cc = ggml_cuda_info().devices[device].cc;
switch (K->ne[0]) { switch (K->ne[0]) {
@ -334,9 +332,6 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
if (!gqa_opt_applies) { if (!gqa_opt_applies) {
return BEST_FATTN_KERNEL_NONE; return BEST_FATTN_KERNEL_NONE;
} }
if (!V_is_K_view) {
return BEST_FATTN_KERNEL_NONE;
}
break; break;
default: default:
return BEST_FATTN_KERNEL_NONE; return BEST_FATTN_KERNEL_NONE;

View File

@ -70,17 +70,18 @@
#include <condition_variable> #include <condition_variable>
#include <cstddef> #include <cstddef>
#include <cstdint> #include <cstdint>
#include <float.h> #include <cfloat>
#include <initializer_list> #include <initializer_list>
#include <limits> #include <limits>
#include <map> #include <map>
#include <memory> #include <memory>
#include <mutex> #include <mutex>
#include <stdarg.h> #include <cstdarg>
#include <stdio.h> #include <cstdio>
#include <stdlib.h> #include <cstdlib>
#include <string> #include <string>
#include <vector> #include <vector>
#include <unordered_set>
static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size"); static_assert(sizeof(half) == sizeof(ggml_fp16_t), "wrong fp16 size");
@ -2278,13 +2279,19 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc;
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
if (ne2 == 1) { static_assert(MMVQ_MAX_BATCH_SIZE == MMVF_MAX_BATCH_SIZE);
if (ne2 <= MMVQ_MAX_BATCH_SIZE) {
if (ggml_is_quantized(src0->type)) { if (ggml_is_quantized(src0->type)) {
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst); if (ne2 <= 4) {
ggml_cuda_mul_mat_vec_q(ctx, src0, src1, ids, dst);
return;
}
} else { } else {
ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst); if (GGML_CUDA_CC_IS_AMD(cc)) {
ggml_cuda_mul_mat_vec_f(ctx, src0, src1, ids, dst);
return;
}
} }
return;
} }
if (ggml_cuda_should_use_mmq(src0->type, cc, ne12, /*n_experts=*/ne02)) { if (ggml_cuda_should_use_mmq(src0->type, cc, ne12, /*n_experts=*/ne02)) {
@ -2916,22 +2923,27 @@ static bool ggml_cuda_graph_check_compability(ggml_cgraph * cgraph) {
} }
static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) { static void ggml_cuda_graph_node_set_properties(ggml_cuda_graph_node_properties * props, ggml_tensor * node) {
props->node_address = node->data; memset(props, 0, sizeof(ggml_cuda_graph_node_properties));
props->node_data = node->data;
props->node_op = node->op; props->node_op = node->op;
props->node_type = node->type;
props->flags = node->flags; props->flags = node->flags;
for (int i = 0; i < GGML_MAX_DIMS; i++) { for (int i = 0; i < GGML_MAX_DIMS; i++) {
props->ne[i] = node->ne[i]; props->ne[i] = node->ne[i];
props->nb[i] = node->nb[i]; props->nb[i] = node->nb[i];
} }
for (int i = 0; i < GGML_MAX_SRC; i++) { for (int i = 0; i < GGML_MAX_SRC; i++) {
props->src_address[i] = node->src[i] ? node->src[i]->data : nullptr; if (!node->src[i]) {
continue;
}
props->src_data[i] = node->src[i]->data;
} }
memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS); memcpy(props->op_params, node->op_params, GGML_MAX_OP_PARAMS);
} }
static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) { static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_graph_node_properties * props) {
if (node->data != props->node_address && if (node->data != props->node_data && node->op != GGML_OP_VIEW) {
node->op != GGML_OP_VIEW) {
return false; return false;
} }
@ -2939,6 +2951,10 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_
return false; return false;
} }
if (node->type != props->node_type) {
return false;
}
for (int i = 0; i < GGML_MAX_DIMS; i++) { for (int i = 0; i < GGML_MAX_DIMS; i++) {
if (node->ne[i] != props->ne[i]) { if (node->ne[i] != props->ne[i]) {
return false; return false;
@ -2948,12 +2964,18 @@ static bool ggml_cuda_graph_node_properties_match(ggml_tensor * node, ggml_cuda_
} }
} }
for (int i = 0; i < GGML_MAX_SRC; i++) { if (node->op != GGML_OP_VIEW) {
if (node->src[i] && for (int i = 0; i < GGML_MAX_SRC; i++) {
node->src[i]->data != props->src_address[i] && if (!node->src[i]) {
node->op != GGML_OP_VIEW if (props->src_data[i] != nullptr) {
) { return false;
return false; }
continue;
}
if (node->src[i]->data != props->src_data[i]) {
return false;
}
} }
} }
@ -2974,7 +2996,6 @@ static const void * ggml_cuda_graph_get_key(ggml_cgraph * cgraph) {
} }
static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) { static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph) {
bool res = false; bool res = false;
const void * graph_key = ggml_cuda_graph_get_key(cgraph); const void * graph_key = ggml_cuda_graph_get_key(cgraph);
@ -2985,15 +3006,20 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
} }
// Check if the graph size has changed // Check if the graph size has changed
if (graph->props.size() != (size_t)cgraph->n_nodes + cgraph->n_leafs) { if (graph->props.size() != (size_t)cgraph->n_nodes) {
res = true; res = true;
graph->props.resize(cgraph->n_nodes + cgraph->n_leafs); graph->props.resize(cgraph->n_nodes);
} }
// Loop over nodes in GGML graph to determine if CUDA graph update is required // Loop over nodes in GGML graph to determine if CUDA graph update is required
// and store properties to allow this comparison for the next token // and store properties to allow this comparison for the next token
std::unordered_set<ggml_tensor *> seen_node;
std::vector<ggml_tensor *> srcs_extra;
for (int i = 0; i < cgraph->n_nodes; i++) { for (int i = 0; i < cgraph->n_nodes; i++) {
bool props_match = true; bool props_match = true;
seen_node.insert(cgraph->nodes[i]);
if (!res) { if (!res) {
props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]); props_match = ggml_cuda_graph_node_properties_match(cgraph->nodes[i], &graph->props[i]);
} }
@ -3001,17 +3027,31 @@ static bool ggml_cuda_graph_update_required(ggml_backend_cuda_context * cuda_ctx
res = true; res = true;
} }
ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]); ggml_cuda_graph_node_set_properties(&graph->props[i], cgraph->nodes[i]);
for (int src_idx = 0; src_idx < GGML_MAX_SRC; ++src_idx) {
ggml_tensor * src = cgraph->nodes[i]->src[src_idx];
if (src && seen_node.find(src) == seen_node.end()) {
srcs_extra.push_back(src);
}
}
} }
for (int i = 0; i < cgraph->n_leafs; i++) { if (graph->extra.size() != (size_t) srcs_extra.size()) {
res = true;
graph->extra.resize(srcs_extra.size());
}
for (size_t i = 0; i < srcs_extra.size(); ++i) {
bool props_match = true; bool props_match = true;
if (!res) { if (!res) {
props_match = ggml_cuda_graph_node_properties_match(cgraph->leafs[i], &graph->props[cgraph->n_nodes + i]); props_match = ggml_cuda_graph_node_properties_match(srcs_extra[i], &graph->extra[i]);
} }
if (!props_match) { if (!props_match) {
res = true; res = true;
} }
ggml_cuda_graph_node_set_properties(&graph->props[cgraph->n_nodes + i], cgraph->leafs[i]); ggml_cuda_graph_node_set_properties(&graph->extra[i], srcs_extra[i]);
} }
return res; return res;
@ -3080,63 +3120,166 @@ static bool ggml_cuda_should_fuse_rope_set_rows(const ggml_tensor * rope,
return true; return true;
} }
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list<enum ggml_op> ops, std::initializer_list<enum ggml_unary_op> unary_ops) { static bool ggml_cuda_topk_moe_fusion(const struct ggml_cgraph * cgraph, int node_idx, ggml_cuda_topk_moe_args & args) {
args.sigmoid = false;
args.softmax = false;
args.delayed_softmax = false;
args.prob_bias = false;
args.norm = false;
const int n_nodes = cgraph->n_nodes;
ggml_tensor ** nodes = cgraph->nodes;
if (nodes[node_idx]->op == GGML_OP_SOFT_MAX) {
args.softmax = true;
}
if (nodes[node_idx]->op == GGML_OP_UNARY) {
if (ggml_get_unary_op(nodes[node_idx]) != GGML_UNARY_OP_SIGMOID) {
return false;
}
args.sigmoid = true;
}
if (nodes[node_idx]->op == GGML_OP_ARGSORT) {
args.delayed_softmax = true;
}
node_idx++;
if (args.sigmoid || args.softmax) {
// SOFTMAX -> RESHAPE
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_RESHAPE ||
nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
return false;
}
ggml_tensor * probs_reshaped = nodes[node_idx];
node_idx++;
if (node_idx >= n_nodes) {
return false;
}
// src of bias add is the unreshaped probs (-2 instead of -1)
if (nodes[node_idx]->op == GGML_OP_ADD && nodes[node_idx]->src[0] == nodes[node_idx - 2]) {
args.prob_bias = true;
node_idx++;
}
// RESHAPE/ADD -> ARGSORT
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_ARGSORT) {
return false;
}
if (args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
return false;
} else if (!args.prob_bias && nodes[node_idx]->src[0] != nodes[node_idx - 2]) {
return false;
}
node_idx++;
// ARGSORT-> VIEW
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW ||
nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
return false;
}
node_idx++;
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_GET_ROWS) {
return false;
}
// GET_ROWS
if (nodes[node_idx]->src[0] != probs_reshaped || nodes[node_idx]->src[1] != nodes[node_idx - 1]) {
return false;
}
node_idx++;
} else if (args.delayed_softmax) {
if (node_idx - 2 < 0) {
return false;
}
ggml_tensor * probs_reshaped = nodes[node_idx - 2];
// VIEW->ARGSORT
if (node_idx >= n_nodes || nodes[node_idx]->op != GGML_OP_VIEW ||
nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
return false;
}
node_idx++;
// GET_ROWS
if (node_idx >= n_nodes || nodes[node_idx]->src[1] != nodes[node_idx - 1] ||
nodes[node_idx]->src[0] != probs_reshaped) {
return false;
}
node_idx++;
static const std::vector<ggml_op> remaining_ops = { GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
for (const ggml_op op : remaining_ops) {
if (node_idx >= n_nodes || nodes[node_idx]->op != op || nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
return false;
}
node_idx++;
}
}
// At this point we can check for norm + scale. Everything is now at least valid till the norm
if (node_idx >= n_nodes) {
return true;
}
if (nodes[node_idx]->op == GGML_OP_RESHAPE) {
//check RESHAPE->SUM_ROWS->CLAMP->DIV->RESHAPE
static const std::vector<ggml_op> norm_ops = { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP };
args.norm = true;
for (const ggml_op op : norm_ops) {
if (nodes[node_idx]->op == op && nodes[node_idx]->src[0] == nodes[node_idx - 1]) {
node_idx++;
} else {
args.norm = false;
return true;
}
}
// DIV <- CLAMP, RESHAPE
if (nodes[node_idx]->op != GGML_OP_DIV || nodes[node_idx]->src[1] != nodes[node_idx - 1] ||
nodes[node_idx]->src[0] != nodes[node_idx - 3]) {
args.norm = false;
return true;
}
node_idx++;
if (nodes[node_idx]->op != GGML_OP_RESHAPE || nodes[node_idx]->src[0] != nodes[node_idx - 1]) {
args.norm = false;
return true;
}
node_idx++;
}
if (nodes[node_idx]->op == GGML_OP_SCALE && nodes[node_idx]->src[0] == nodes[node_idx - 1]) {
args.scale = true;
}
return true;
}
static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph,
int node_idx,
std::initializer_list<enum ggml_op> ops,
std::initializer_list<enum ggml_unary_op> unary_ops) {
#ifndef NDEBUG #ifndef NDEBUG
const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY); const size_t num_unary = std::count(ops.begin(), ops.end(), GGML_OP_UNARY);
GGML_ASSERT(unary_ops.size() == num_unary); GGML_ASSERT(unary_ops.size() == num_unary);
#endif #endif
//TODO: remove special case once ggml_can_fuse can handle empty nodes
std::initializer_list<enum ggml_op> topk_moe_ops =
ggml_cuda_topk_moe_ops(/*with_norm*/ false, /*delayed_softmax=*/false);
std::initializer_list<enum ggml_op> topk_moe_ops_with_norm =
ggml_cuda_topk_moe_ops(/*with_norm=*/true, /*delayed_softmax=*/false);
std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax =
ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
const auto is_equal = [](const std::initializer_list<enum ggml_op> & list1, const auto is_equal = [](const std::initializer_list<enum ggml_op> & list1,
const std::initializer_list<enum ggml_op> & list2) { const std::initializer_list<enum ggml_op> & list2) {
return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end()); return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end());
}; };
if (is_equal(topk_moe_ops_with_norm, ops) &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx + 9];
ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true;
}
}
if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx + 4];
ggml_tensor * get_rows = cgraph->nodes[node_idx + 4];
ggml_tensor * argsort = cgraph->nodes[node_idx + 2];
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true;
}
}
if (is_equal(topk_moe_ops_delayed_softmax, ops) &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
ggml_tensor * get_rows = cgraph->nodes[node_idx + 2];
ggml_tensor * argsort = cgraph->nodes[node_idx + 0];
int n_expert = cgraph->nodes[node_idx]->src[0]->ne[0];
if (ggml_cuda_should_use_topk_moe(softmax, weights, get_rows, argsort, nullptr, n_expert)) {
return true;
}
}
std::initializer_list<enum ggml_op> mul_mat_bias_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_GLU }; std::initializer_list<enum ggml_op> mul_mat_bias_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_MUL_MAT, GGML_OP_ADD, GGML_OP_GLU };
std::initializer_list<enum ggml_op> mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU }; std::initializer_list<enum ggml_op> mul_mat_id_bias_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_MUL_MAT_ID, GGML_OP_ADD_ID, GGML_OP_GLU };
@ -3398,35 +3541,75 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
// start of fusion operations // start of fusion operations
static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr);
if (!disable_fusion) { if (!disable_fusion) {
ggml_cuda_topk_moe_args args;
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ true), {})) { if (cgraph->nodes[i]->op == GGML_OP_UNARY || cgraph->nodes[i]->op == GGML_OP_SOFT_MAX ||
ggml_tensor * weights = cgraph->nodes[i + 9]; cgraph->nodes[i]->op == GGML_OP_ARGSORT) {
ggml_tensor * selected_experts = cgraph->nodes[i + 3]; const bool can_fuse = ggml_cuda_topk_moe_fusion(cgraph, i, args);
ggml_tensor * clamp = cgraph->nodes[i + 7];
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ true,
/*delayed softmax*/ false, clamp);
i += 9;
continue;
}
if (ggml_cuda_can_fuse(cgraph, i, ggml_cuda_topk_moe_ops(/*with norm*/ false), {})) { std::vector<ggml_op> ops;
ggml_tensor * weights = cgraph->nodes[i + 4];
ggml_tensor * selected_experts = cgraph->nodes[i + 3];
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, selected_experts, /*with norm*/ false,
/*delayed softmax*/ false);
i += 4;
continue;
}
if (ggml_cuda_can_fuse(cgraph, i, if (can_fuse) {
ggml_cuda_topk_moe_ops(/*with norm*/ false, /*delayed softmax*/ true), {})) { const ggml_tensor * logits = node->src[0];
ggml_tensor * weights = cgraph->nodes[i + 5]; ggml_tensor * weights = nullptr;
ggml_tensor * ids = cgraph->nodes[i + 1]; ggml_tensor * ids = nullptr;
const ggml_tensor * bias = nullptr;
const ggml_tensor * clamp = nullptr;
const ggml_tensor * scale = nullptr;
ggml_cuda_op_topk_moe(*cuda_ctx, node->src[0], weights, ids, /*with norm*/ false, if (!args.delayed_softmax) {
/*delayed_softmax*/ true); ggml_op gating_op = args.sigmoid ? GGML_OP_UNARY : GGML_OP_SOFT_MAX;
i += 5; int out_nodes[2]; // nodes which can't be elided
continue;
if (args.prob_bias) {
bias = cgraph->nodes[i + 2]->src[1];
ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ADD, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS });
out_nodes[0] = i + 4;
ids = cgraph->nodes[i + 4];
} else {
ops.insert(ops.end(), { gating_op, GGML_OP_RESHAPE, GGML_OP_ARGSORT, GGML_OP_VIEW,
GGML_OP_GET_ROWS });
out_nodes[0] = i + 3;
ids = cgraph->nodes[i + 3];
}
if (args.norm) {
ops.insert(ops.end(), { GGML_OP_RESHAPE, GGML_OP_SUM_ROWS, GGML_OP_CLAMP,
GGML_OP_DIV, GGML_OP_RESHAPE });
clamp = cgraph->nodes[i + ops.size() - 3];
}
if (args.scale) {
ops.insert(ops.end(), { GGML_OP_SCALE });
scale = cgraph->nodes[i + ops.size() - 1];
}
weights = cgraph->nodes[i + ops.size() - 1];
out_nodes[1] = i + ops.size() - 1;
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
ggml_cuda_should_use_topk_moe(node, logits, weights, ids)) {
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
i += ops.size() - 1;
continue;
}
} else if (!args.norm && !args.prob_bias) {
//special case gpt-oss, no norm, no bias.
ops.insert(ops.end(), { GGML_OP_ARGSORT, GGML_OP_VIEW, GGML_OP_GET_ROWS,
GGML_OP_RESHAPE, GGML_OP_SOFT_MAX, GGML_OP_RESHAPE });
weights = cgraph->nodes[i + 5];
ids = cgraph->nodes[i + 1];
const ggml_tensor * softmax = cgraph->nodes[i + 4];
int out_nodes[2] = { i + 1, i + 5 };
if (ggml_can_fuse_subgraph(cgraph, i, ops.size(), ops.data(), out_nodes, 2) &&
ggml_cuda_should_use_topk_moe(softmax, logits, weights, ids)) {
ggml_cuda_op_topk_moe(*cuda_ctx, logits, weights, ids, clamp, scale, bias, args);
i += ops.size() - 1;
continue;
}
}
}
} }
if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) { if (ggml_cuda_can_fuse(cgraph, i, { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS }, {})) {
@ -3733,14 +3916,14 @@ static void ggml_cuda_graph_evaluate_and_capture(ggml_backend_cuda_context * cud
// Launch graph // Launch graph
CUDA_CHECK(cudaGraphLaunch(graph->instance, cuda_ctx->stream())); CUDA_CHECK(cudaGraphLaunch(graph->instance, cuda_ctx->stream()));
#else #else
GGML_UNUSED(graph_key);
graph_evaluated_or_captured = true; graph_evaluated_or_captured = true;
#endif // USE_CUDA_GRAPH #endif // USE_CUDA_GRAPH
} }
} }
static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) {
#ifdef USE_CUDA_GRAPH #ifdef USE_CUDA_GRAPH
static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx, const void * graph_key) {
ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key); ggml_cuda_graph * graph = cuda_ctx->cuda_graph(graph_key);
if (graph->graph == nullptr) { if (graph->graph == nullptr) {
@ -3753,12 +3936,8 @@ static bool ggml_cuda_graph_set_enabled(ggml_backend_cuda_context * cuda_ctx, co
} }
return graph->is_enabled(); return graph->is_enabled();
#else
GGML_UNUSED(cuda_ctx);
GGML_UNUSED(graph_key);
return false;
#endif // USE_CUDA_GRAPH
} }
#endif // USE_CUDA_GRAPH
static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context; ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *) backend->context;
@ -4876,16 +5055,6 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
static std::mutex mutex; static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex); std::lock_guard<std::mutex> lock(mutex);
if (!initialized) { if (!initialized) {
// Set CUDA_SCALE_LAUNCH_QUEUES before any CUDA API call to improve multi-GPU pipeline parallelism performance
// PR: https://github.com/ggml-org/llama.cpp/pull/19042
if (getenv("CUDA_SCALE_LAUNCH_QUEUES") == nullptr) {
#ifdef _WIN32
_putenv_s("CUDA_SCALE_LAUNCH_QUEUES", "4x");
#else
setenv("CUDA_SCALE_LAUNCH_QUEUES", "4x", 0); // don't overwrite if already set
#endif // _WIN32
}
ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context; ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context;
const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32; const int min_batch_size = getenv("GGML_OP_OFFLOAD_MIN_BATCH") ? atoi(getenv("GGML_OP_OFFLOAD_MIN_BATCH")) : 32;

View File

@ -333,7 +333,33 @@ namespace ggml_cuda_mma {
static __device__ __forceinline__ int get_j(const int l) { static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) { if constexpr (I == 16 && J == 8) {
return 4 * (threadIdx.x / 16) + l; return ne * (threadIdx.x / 16) + l;
} else {
NO_DEVICE_CODE;
return -1;
}
}
#elif defined(AMD_MFMA_AVAILABLE)
static constexpr int ne = I * J / 64;
half2 x[ne] = {{0.0f, 0.0f}};
static constexpr __device__ bool supported() {
if (I == 16 && J == 8) return true;
return false;
}
static __device__ __forceinline__ int get_i(const int l) {
if constexpr (I == 16 && J == 8) {
return threadIdx.x % 16;
} else {
NO_DEVICE_CODE;
return -1;
}
}
static __device__ __forceinline__ int get_j(const int l) {
if constexpr (I == 16 && J == 8) {
return ne * (threadIdx.x / 16) + l;
} else { } else {
NO_DEVICE_CODE; NO_DEVICE_CODE;
return -1; return -1;
@ -391,7 +417,22 @@ namespace ggml_cuda_mma {
static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR; static constexpr data_layout dl = DATA_LAYOUT_I_MAJOR;
#if defined(AMD_WMMA_AVAILABLE) #if defined(AMD_WMMA_AVAILABLE)
static constexpr int ne = I * J / 32; static constexpr int ne = tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::ne;
nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
static constexpr __device__ bool supported() {
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::supported();
}
static __device__ __forceinline__ int get_i(const int l) {
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_i(l);
}
static __device__ __forceinline__ int get_j(const int l) {
return tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::get_j(l);
}
#elif defined(AMD_MFMA_AVAILABLE)
static constexpr int ne = tile<I_, J_, half2, DATA_LAYOUT_I_MAJOR>::ne;
nv_bfloat162 x[ne] = {{0.0f, 0.0f}}; nv_bfloat162 x[ne] = {{0.0f, 0.0f}};
static constexpr __device__ bool supported() { static constexpr __device__ bool supported() {
@ -945,6 +986,32 @@ namespace ggml_cuda_mma {
#endif // AMPERE_MMA_AVAILABLE #endif // AMPERE_MMA_AVAILABLE
} }
template <data_layout dl_ab, data_layout dl_d>
static __device__ __forceinline__ void mma(
tile<16, 16, float, dl_d> & D, const tile<16, 8, float, dl_ab> & A, const tile<16, 8, float, dl_ab> & B) {
#ifdef AMD_MFMA_AVAILABLE
using floatx4_t = __attribute__((ext_vector_type(4))) float;
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
#if defined(CDNA3)
using floatx2_t = __attribute__((ext_vector_type(2))) float;
const floatx2_t& a_frag = reinterpret_cast<const floatx2_t&>(A.x[0]);
const floatx2_t& b_frag = reinterpret_cast<const floatx2_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_mfma_f32_16x16x8_xf32(a_frag, b_frag, acc_frag, 0, 0, 0);
#elif defined(CDNA2) || defined(CDNA1)
#pragma unroll
for (int i = 0; i < 2; ++i) {
acc_frag = __builtin_amdgcn_mfma_f32_16x16x4f32(A.x[i], B.x[i], acc_frag, 0, 0, 0);
}
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // defined(CDNA3)
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // AMD_MFMA_AVAILABLE
}
static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D, static __device__ __forceinline__ void mma_block_scaled(tile<16, 8, float> & D,
const tile<16, 8, int> & A, const tile<16, 8, int> & A,
const tile<8, 8, int> & B, const tile<8, 8, int> & B,
@ -1054,6 +1121,13 @@ namespace ggml_cuda_mma {
GGML_UNUSED_VARS(D, A, B); GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE; NO_DEVICE_CODE;
#endif // RDNA4 #endif // RDNA4
#elif defined(AMD_MFMA_AVAILABLE)
using halfx4_t = __attribute__((ext_vector_type(4))) _Float16;
using floatx4_t = __attribute__((ext_vector_type(4))) float;
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
const halfx4_t& a_frag = reinterpret_cast<const halfx4_t&>(A.x[0]);
const halfx4_t& b_frag = reinterpret_cast<const halfx4_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_mfma_f32_16x16x16f16(a_frag, b_frag, acc_frag, 0, 0, 0);
#else #else
GGML_UNUSED_VARS(D, A, B); GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE; NO_DEVICE_CODE;
@ -1081,11 +1155,31 @@ namespace ggml_cuda_mma {
#else #else
GGML_UNUSED_VARS(D, A, B); GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE; NO_DEVICE_CODE;
#endif // RDNA4 #endif // defined(RDNA4)
#elif defined(AMD_MFMA_AVAILABLE)
using floatx4_t = __attribute__((ext_vector_type(4))) float;
floatx4_t& acc_frag = reinterpret_cast<floatx4_t&>(D.x[0]);
#if defined(CDNA3) || defined(CDNA2)
using bf16x4_t = __attribute__((ext_vector_type(4))) __bf16;
const bf16x4_t& a_frag = reinterpret_cast<const bf16x4_t&>(A.x[0]);
const bf16x4_t& b_frag = reinterpret_cast<const bf16x4_t&>(B.x[0]);
acc_frag = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_frag, b_frag, acc_frag, 0, 0, 0);
#elif defined(CDNA1)
#pragma unroll
for (int i = 0; i < 2; ++i) {
using bf16x2_t = __attribute__((ext_vector_type(2))) __bf16;
const bf16x2_t& a_frag = reinterpret_cast<const bf16x2_t&>(A.x[i]);
const bf16x2_t& b_frag = reinterpret_cast<const bf16x2_t&>(B.x[i]);
acc_frag = __builtin_amdgcn_mfma_f32_16x16x8bf16(a_frag, b_frag, acc_frag, 0, 0, 0);
}
#else #else
GGML_UNUSED_VARS(D, A, B); GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE; NO_DEVICE_CODE;
#endif // AMPERE_MMA_AVAILABLE #endif // defined(CDNA3) || defined(CDNA2)
#else
GGML_UNUSED_VARS(D, A, B);
NO_DEVICE_CODE;
#endif // defined(AMD_WMMA_AVAILABLE)
} }
template <data_layout dl_d, data_layout dl_ab> template <data_layout dl_d, data_layout dl_ab>

View File

@ -2,6 +2,13 @@
#include "mmf.cuh" #include "mmf.cuh"
#include "mmid.cuh" #include "mmid.cuh"
static __forceinline__ int mmf_get_rows_per_block(const int cc) {
if (GGML_CUDA_CC_IS_CDNA(cc)) {
return MMF_ROWS_PER_BLOCK_CDNA;
} else {
return MMF_ROWS_PER_BLOCK;
}
}
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
GGML_ASSERT( src1->type == GGML_TYPE_F32); GGML_ASSERT( src1->type == GGML_TYPE_F32);
@ -89,28 +96,32 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
ids_info_ptr = &ids_info; ids_info_ptr = &ids_info;
} }
const int device = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[device].cc;
const int rows_per_block = mmf_get_rows_per_block(cc);
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F32: { case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0->data; const float * src0_d = (const float *) src0->data;
constexpr int vals_per_T = 1; constexpr int vals_per_T = 1;
mul_mat_f_switch_cols_per_block( mul_mat_f_switch_rows_per_block<float>(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
} break; } break;
case GGML_TYPE_F16: { case GGML_TYPE_F16: {
const half2 * src0_d = (const half2 *) src0->data; const half2 * src0_d = (const half2 *) src0->data;
constexpr int vals_per_T = 2; constexpr int vals_per_T = 2;
mul_mat_f_switch_cols_per_block( mul_mat_f_switch_rows_per_block<half2>(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
} break; } break;
case GGML_TYPE_BF16: { case GGML_TYPE_BF16: {
const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data; const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
constexpr int vals_per_T = 2; constexpr int vals_per_T = 2;
mul_mat_f_switch_cols_per_block( mul_mat_f_switch_rows_per_block<nv_bfloat162>(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst, rows_per_block, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr); ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream(), ids_info_ptr);
} break; } break;
@ -140,7 +151,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
return false; return false;
} }
} }
if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) { if (src0_ne[1] % mmf_get_rows_per_block(cc) != 0) {
return false;
}
if (GGML_CUDA_CC_IS_CDNA3(cc) && type == GGML_TYPE_BF16) {
return false; return false;
} }
@ -153,6 +168,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
} else { } else {
if (GGML_CUDA_CC_IS_RDNA3_0(cc) && src1_ncols > 8) { if (GGML_CUDA_CC_IS_RDNA3_0(cc) && src1_ncols > 8) {
return false; return false;
} else if (GGML_CUDA_CC_IS_CDNA2(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) {
//TODO: truse CDNA2 as CDNA1, tune the perf when CDNA2 is available.
return false;
} else if (GGML_CUDA_CC_IS_CDNA1(cc) && (type == GGML_TYPE_F16 || type == GGML_TYPE_BF16)) {
return false;
} else if (src1_ncols > 16) { } else if (src1_ncols > 16) {
return false; return false;
} }
@ -160,11 +180,11 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
switch (type) { switch (type) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
return ampere_mma_available(cc); return ampere_mma_available(cc) || amd_mfma_available(cc);
case GGML_TYPE_F16: case GGML_TYPE_F16:
return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc); return volta_mma_available(cc) || turing_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc);
case GGML_TYPE_BF16: case GGML_TYPE_BF16:
return ampere_mma_available(cc) || amd_wmma_available(cc); return ampere_mma_available(cc) || amd_wmma_available(cc) || amd_mfma_available(cc);
default: default:
return false; return false;
} }

View File

@ -7,6 +7,31 @@
using namespace ggml_cuda_mma; using namespace ggml_cuda_mma;
#define MMF_ROWS_PER_BLOCK 32 #define MMF_ROWS_PER_BLOCK 32
#define MMF_ROWS_PER_BLOCK_CDNA 64
static __forceinline__ int64_t mmf_get_max_block_size(int cc) {
if (GGML_CUDA_CC_IS_CDNA(cc)) {
return 512;
} else {
return 256;
}
}
static __forceinline__ int mmf_get_padding(int cc) {
if (GGML_CUDA_CC_IS_CDNA(cc)) {
return 2;
} else {
return 4;
}
}
static constexpr __device__ int mmf_get_padding() {
#if defined(AMD_MFMA_AVAILABLE)
return 2;
#else
return 4;
#endif // defined(AMD_MFMA_AVAILABLE)
}
struct mmf_ids_data { struct mmf_ids_data {
const int32_t * ids_src_compact = nullptr; const int32_t * ids_src_compact = nullptr;
@ -29,23 +54,25 @@ static __global__ void mul_mat_f(
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added // TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) #if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
#if defined(AMD_WMMA_AVAILABLE) #if defined(AMD_WMMA_AVAILABLE)
// Special case for tf32, just dummy mma layout as wmma doesn't support it. if constexpr (!(std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
constexpr bool is_tf32 = std::is_same_v<T, float>; typedef tile<16, 8, T, get_input_data_layout()> tile_A;
constexpr int tile_B_I = is_tf32 ? 8 : 16; typedef tile<16, 8, T, get_input_data_layout()> tile_B;
constexpr int tile_C_J = is_tf32 ? 8 : 16; typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout(); #elif defined(AMD_MFMA_AVAILABLE)
typedef tile<16, 8, T, ab_layout> tile_A; if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {
typedef tile<tile_B_I, 8, T, ab_layout> tile_B; typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A;
typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C; typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B;
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
#else #else
#ifdef VOLTA_MMA_AVAILABLE #ifdef VOLTA_MMA_AVAILABLE
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else { if constexpr (!std::is_same_v<T, half2> || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A; typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A;
typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B; typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C; typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C;
#else #else
if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
typedef tile<16, 8, T> tile_A; typedef tile<16, 8, T> tile_A;
typedef tile<8, 8, T> tile_B; typedef tile<8, 8, T> tile_B;
typedef tile<16, 8, float> tile_C; typedef tile<16, 8, float> tile_C;
@ -57,7 +84,7 @@ static __global__ void mul_mat_f(
} }
constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int tile_k_padded = warp_size + 4; constexpr int tile_k_padded = warp_size + mmf_get_padding();
constexpr int ntA = rows_per_block / tile_A::I; constexpr int ntA = rows_per_block / tile_A::I;
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I; constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
@ -198,7 +225,7 @@ static __global__ void mul_mat_f(
} }
float * buf_iw = (float *) compute_base; float * buf_iw = (float *) compute_base;
constexpr int kiw = nwarps*rows_per_block + 4; constexpr int kiw = nwarps*rows_per_block + mmf_get_padding();
if (nwarps > 1) { if (nwarps > 1) {
__syncthreads(); __syncthreads();
@ -228,27 +255,34 @@ static __global__ void mul_mat_f(
return; return;
} }
float sum = 0.0f; float sum[rows_per_block/warp_size] = {0.0f};
static_assert(rows_per_block == warp_size, "need loop/check"); static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size.");
#pragma unroll #pragma unroll
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) { for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
const int i = i0 + threadIdx.x; #pragma unroll
for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) {
const int i = i0 + i1*warp_size + threadIdx.x;
sum += buf_iw[j*kiw + i]; sum[i1] += buf_iw[j*kiw + i];
}
} }
if constexpr (!has_ids) { if constexpr (!has_ids) {
dst[j*stride_col_dst + row0 + threadIdx.x] = sum; #pragma unroll
for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
dst[j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
}
} else { } else {
const int slot = (j < cols_per_block) ? slot_map[j] : -1; const int slot = (j < cols_per_block) ? slot_map[j] : -1;
if (slot >= 0 && (col_base + j) < ncols_dst_total) { if (slot >= 0 && (col_base + j) < ncols_dst_total) {
dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum; #pragma unroll
for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
dst[slot*stride_channel_dst + j*stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
}
} }
} }
} }
#ifdef VOLTA_MMA_AVAILABLE
} }
#endif //VOLTA_MMA_AVAILABLE
#else #else
GGML_UNUSED_VARS(x, y, ids, dst, GGML_UNUSED_VARS(x, y, ids, dst,
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
@ -256,7 +290,7 @@ static __global__ void mul_mat_f(
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
NO_DEVICE_CODE; NO_DEVICE_CODE;
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) #endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
} }
//This kernel is for larger batch sizes of mul_mat_id //This kernel is for larger batch sizes of mul_mat_id
@ -271,23 +305,25 @@ static __global__ void mul_mat_f_ids(
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
const uint3 sis1_fd, const uint3 nch_fd) { const uint3 sis1_fd, const uint3 nch_fd) {
// TODO: handle this in a consistent and simpler way after AMD MFMA support has been added // TODO: handle this in a consistent and simpler way after AMD MFMA support has been added
#if (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) #if defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
#if defined(AMD_WMMA_AVAILABLE) #if defined(AMD_WMMA_AVAILABLE)
// Special case for tf32, just dummy mma layout as wmma doesn't support it. if constexpr (!(std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
constexpr bool is_tf32 = std::is_same_v<T, float>; typedef tile<16, 8, T, get_input_data_layout()> tile_A;
constexpr int tile_B_I = is_tf32 ? 8 : 16; typedef tile<16, 8, T, get_input_data_layout()> tile_B;
constexpr int tile_C_J = is_tf32 ? 8 : 16; typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
constexpr data_layout ab_layout = is_tf32 ? DATA_LAYOUT_I_MAJOR : get_input_data_layout(); #elif defined(AMD_MFMA_AVAILABLE)
typedef tile<16, 8, T, ab_layout> tile_A; if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK_CDNA) {NO_DEVICE_CODE;} else {
typedef tile<tile_B_I, 8, T, ab_layout> tile_B; typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_A;
typedef tile<16, tile_C_J, float, DATA_LAYOUT_J_MAJOR> tile_C; typedef tile<16, 8, T, DATA_LAYOUT_I_MAJOR> tile_B;
typedef tile<16, 16, float, DATA_LAYOUT_J_MAJOR> tile_C;
#else #else
#ifdef VOLTA_MMA_AVAILABLE #ifdef VOLTA_MMA_AVAILABLE
if constexpr (!std::is_same_v<T, half2>) {NO_DEVICE_CODE;} else { if constexpr (!std::is_same_v<T, half2> || rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A; typedef tile<32, 4, T, DATA_LAYOUT_I_MAJOR> tile_A;
typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B; typedef tile< 8, 4, T, DATA_LAYOUT_I_MAJOR_MIRRORED> tile_B;
typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C; typedef tile<32, 8, float, DATA_LAYOUT_I_MAJOR> tile_C;
#else #else
if constexpr (rows_per_block != MMF_ROWS_PER_BLOCK) {NO_DEVICE_CODE;} else {
typedef tile<16, 8, T> tile_A; typedef tile<16, 8, T> tile_A;
typedef tile<8, 8, T> tile_B; typedef tile<8, 8, T> tile_B;
typedef tile<16, 8, float> tile_C; typedef tile<16, 8, float> tile_C;
@ -300,7 +336,7 @@ static __global__ void mul_mat_f_ids(
constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int tile_k_padded = warp_size + 4; constexpr int tile_k_padded = warp_size + mmf_get_padding();
constexpr int ntA = rows_per_block / tile_A::I; constexpr int ntA = rows_per_block / tile_A::I;
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I; constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
@ -467,7 +503,7 @@ static __global__ void mul_mat_f_ids(
} }
float * buf_iw = (float *) compute_base; float * buf_iw = (float *) compute_base;
constexpr int kiw = nwarps*rows_per_block + 4; constexpr int kiw = nwarps*rows_per_block + mmf_get_padding();
if (nwarps > 1) { if (nwarps > 1) {
__syncthreads(); __syncthreads();
@ -497,13 +533,16 @@ static __global__ void mul_mat_f_ids(
return; return;
} }
float sum = 0.0f; float sum[rows_per_block/warp_size] = {0.0f};
static_assert(rows_per_block == warp_size, "need loop/check"); static_assert((rows_per_block % warp_size) == 0, "rows_per_block must be a multiple of warp_size.");
#pragma unroll #pragma unroll
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) { for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
const int i = i0 + threadIdx.x; #pragma unroll
for (int i1 = 0; i1 < sizeof(sum)/sizeof(sum[0]); ++i1) {
const int i = i0 + i1*warp_size + threadIdx.x;
sum += buf_iw[j*kiw + i]; sum[i1] += buf_iw[j * kiw + i];
}
} }
const int global_j = col_base + j; const int global_j = col_base + j;
@ -513,23 +552,24 @@ static __global__ void mul_mat_f_ids(
const int token = (int) qrm.x; const int token = (int) qrm.x;
if (token < ncols_dst_total) { if (token < ncols_dst_total) {
const int slot = (int) qrm.y; const int slot = (int) qrm.y;
dst[slot*stride_channel_dst + token*stride_col_dst + row0 + threadIdx.x] = sum; #pragma unroll
for (int i0 = 0; i0 < sizeof(sum)/sizeof(sum[0]); ++i0) {
dst[slot * stride_channel_dst + token * stride_col_dst + row0 + i0*warp_size + threadIdx.x] = sum[i0];
}
} }
} }
} }
#ifdef VOLTA_MMA_AVAILABLE
} }
#endif // VOLTA_MMA_AVAILABLE
#else #else
GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst, GGML_UNUSED_VARS(x, y, ids_src_compact, ids_dst_compact, expert_bounds, dst,
ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst, ncols, ncols_dst_total, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd); sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, sis1_fd, nch_fd);
NO_DEVICE_CODE; NO_DEVICE_CODE;
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)) || defined(AMD_WMMA_AVAILABLE) #endif // defined(VOLTA_MMA_AVAILABLE) || defined(TURING_MMA_AVAILABLE) || defined(AMD_WMMA_AVAILABLE) || defined(AMD_MFMA_AVAILABLE)
} }
template<typename T, int cols_per_block, int nwarps> template<typename T, int rows_per_block, int cols_per_block, int nwarps>
static inline void mul_mat_f_switch_ids( static inline void mul_mat_f_switch_ids(
const T * x, const float * y, const int32_t * ids, float * dst, const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t ncols_dst, const int64_t nchannels_dst, const int64_t ncols_x, const int64_t ncols_dst, const int64_t nchannels_dst,
@ -553,7 +593,7 @@ static inline void mul_mat_f_switch_ids(
const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1); const uint3 sis1_fd = ids_data->sis1 > 0 ? init_fastdiv_values((uint32_t) ids_data->sis1) : make_uint3(0, 0, 1);
const uint3 nch_fd = init_fastdiv_values((uint32_t) nchannels_dst); const uint3 nch_fd = init_fastdiv_values((uint32_t) nchannels_dst);
mul_mat_f_ids<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>> mul_mat_f_ids<T, rows_per_block, cols_per_block, nwarps><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
(x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst, (x, y, ids_data->ids_src_compact, ids_data->ids_dst_compact, ids_data->expert_bounds_dev, dst,
ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
@ -564,19 +604,19 @@ static inline void mul_mat_f_switch_ids(
dim3 block_nums_ids = block_nums; dim3 block_nums_ids = block_nums;
block_nums_ids.y *= col_tiles; block_nums_ids.y *= col_tiles;
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>> mul_mat_f<T, rows_per_block, cols_per_block, nwarps, true><<<block_nums_ids, block_dims, nbytes_shared_total, stream>>>
(x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, (x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} else { } else {
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>> mul_mat_f<T, rows_per_block, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
(x, y, ids, dst, ncols_x, cols_per_block, nchannels_dst, stride_row, stride_col_y, stride_col_dst, (x, y, ids, dst, ncols_x, cols_per_block, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} }
} }
template <typename T, int cols_per_block> template <typename T, int rows_per_block, int cols_per_block>
void mul_mat_f_cuda( void mul_mat_f_cuda(
const T * x, const float * y, const int32_t * ids, float * dst, const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst, const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
@ -605,7 +645,7 @@ void mul_mat_f_cuda(
int64_t nwarps_best = 1; int64_t nwarps_best = 1;
int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2); int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2);
int64_t max_block_size = 256; int64_t max_block_size = mmf_get_max_block_size(cc);
for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) { for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2); const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
if (niter < niter_best) { if (niter < niter_best) {
@ -614,10 +654,9 @@ void mul_mat_f_cuda(
} }
} }
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK; const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + mmf_get_padding(cc)) * 4;
const int nbytes_shared_iter = nwarps_best * (volta_mma_available(cc) ? tile_A_32::I : tile_A_16::I) * (warp_size + 4) * 4; const int nbytes_cols_per_block_pad = (amd_wmma_available(cc) || amd_mfma_available(cc)) ? tile_B_16::I : tile_B_8::I;
const int nbytes_cols_per_block_pad = amd_wmma_available(cc) ? tile_B_16::I : tile_B_8::I; const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + mmf_get_padding(cc)) * 4;
const int nbytes_shared_combine = GGML_PAD(cols_per_block, nbytes_cols_per_block_pad) * (nwarps_best*rows_per_block + 4) * 4;
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine); const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0; const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap; const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
@ -628,56 +667,56 @@ void mul_mat_f_cuda(
switch (nwarps_best) { switch (nwarps_best) {
case 1: { case 1: {
mul_mat_f_switch_ids<T, cols_per_block, 1>( mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 1>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data); ids_data);
} break; } break;
case 2: { case 2: {
mul_mat_f_switch_ids<T, cols_per_block, 2>( mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 2>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data); ids_data);
} break; } break;
case 3: { case 3: {
mul_mat_f_switch_ids<T, cols_per_block, 3>( mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 3>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data); ids_data);
} break; } break;
case 4: { case 4: {
mul_mat_f_switch_ids<T, cols_per_block, 4>( mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 4>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data); ids_data);
} break; } break;
case 5: { case 5: {
mul_mat_f_switch_ids<T, cols_per_block, 5>( mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 5>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data); ids_data);
} break; } break;
case 6: { case 6: {
mul_mat_f_switch_ids<T, cols_per_block, 6>( mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 6>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data); ids_data);
} break; } break;
case 7: { case 7: {
mul_mat_f_switch_ids<T, cols_per_block, 7>( mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 7>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
ids_data); ids_data);
} break; } break;
case 8: { case 8: {
mul_mat_f_switch_ids<T, cols_per_block, 8>( mul_mat_f_switch_ids<T, rows_per_block, cols_per_block, 8>(
x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst, x, y, ids, dst, ncols_x, ncols_dst, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream, sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream,
@ -691,7 +730,7 @@ void mul_mat_f_cuda(
GGML_UNUSED_VARS(nchannels_y); GGML_UNUSED_VARS(nchannels_y);
} }
template <typename T> template <typename T, int rows_per_block>
static void mul_mat_f_switch_cols_per_block( static void mul_mat_f_switch_cols_per_block(
const T * x, const float * y, const int32_t * ids, float * dst, const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst, const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
@ -708,82 +747,82 @@ static void mul_mat_f_switch_cols_per_block(
switch (ncols_case) { switch (ncols_case) {
case 1: { case 1: {
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, rows_per_block, 1>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 2: { case 2: {
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, rows_per_block, 2>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 3: { case 3: {
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, rows_per_block, 3>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 4: { case 4: {
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, rows_per_block, 4>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 5: { case 5: {
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, rows_per_block, 5>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 6: { case 6: {
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, rows_per_block, 6>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 7: { case 7: {
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, rows_per_block, 7>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 8: { case 8: {
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, rows_per_block, 8>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 9: { case 9: {
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, rows_per_block, 9>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 10: { case 10: {
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, rows_per_block, 10>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 11: { case 11: {
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, rows_per_block, 11>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 12: { case 12: {
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, rows_per_block, 12>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 13: { case 13: {
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, rows_per_block, 13>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 14: { case 14: {
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, rows_per_block, 14>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 15: { case 15: {
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, rows_per_block, 15>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
case 16: { case 16: {
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst, mul_mat_f_cuda<T, rows_per_block, 16>(x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break; } break;
@ -793,8 +832,36 @@ static void mul_mat_f_switch_cols_per_block(
} }
} }
#define DECL_MMF_CASE_HELPER(T, ncols_dst) \ template <typename T>
template void mul_mat_f_cuda<T, ncols_dst>( \ static void mul_mat_f_switch_rows_per_block(
const int rows_per_block, const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t stride_col_id, const int stride_row_id,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream, const mmf_ids_data * ids_data) {
switch (rows_per_block) {
case MMF_ROWS_PER_BLOCK: {
mul_mat_f_switch_cols_per_block<T, MMF_ROWS_PER_BLOCK>(
x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
case MMF_ROWS_PER_BLOCK_CDNA: {
mul_mat_f_switch_cols_per_block<T, MMF_ROWS_PER_BLOCK_CDNA>(
x, y, ids, dst, ncols_x, nrows_x, ncols_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream, ids_data);
} break;
default:
GGML_ABORT("unsupported rows_per_block: %i", rows_per_block);
}
}
#define DECL_MMF_CASE_HELPER(T, nrows_dst, ncols_dst) \
template void mul_mat_f_cuda<T, nrows_dst, ncols_dst>( \
const T * x, const float * y, const int32_t * ids, float * dst, \ const T * x, const float * y, const int32_t * ids, float * dst, \
const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \ const int64_t ncols_x, const int64_t nrows_x, int64_t ncols_dst_total, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
const int64_t stride_col_id, const int64_t stride_row_id, \ const int64_t stride_col_id, const int64_t stride_row_id, \
@ -803,16 +870,22 @@ static void mul_mat_f_switch_cols_per_block(
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \ const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \
cudaStream_t stream, const mmf_ids_data * ids_data); cudaStream_t stream, const mmf_ids_data * ids_data);
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) #if !defined(GGML_USE_MUSA)
#define DECL_MMF_CASE_EXTERN(ncols_dst) \ #define DECL_MMF_CASE_EXTERN(ncols_dst) \
extern DECL_MMF_CASE_HELPER(float, ncols_dst) \ extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \
extern DECL_MMF_CASE_HELPER(half2, ncols_dst) \ extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \
extern DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst) extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \
extern DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
extern DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
extern DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst)
#define DECL_MMF_CASE(ncols_dst) \ #define DECL_MMF_CASE(ncols_dst) \
DECL_MMF_CASE_HELPER(float, ncols_dst) \ DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK, ncols_dst) \
DECL_MMF_CASE_HELPER(half2, ncols_dst) \ DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK, ncols_dst) \
DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst) DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK, ncols_dst) \
DECL_MMF_CASE_HELPER(float, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
DECL_MMF_CASE_HELPER(half2, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst) \
DECL_MMF_CASE_HELPER(nv_bfloat162, MMF_ROWS_PER_BLOCK_CDNA, ncols_dst)
DECL_MMF_CASE_EXTERN(1); DECL_MMF_CASE_EXTERN(1);
DECL_MMF_CASE_EXTERN(2); DECL_MMF_CASE_EXTERN(2);

View File

@ -3697,13 +3697,20 @@ static __global__ void mul_mat_q(
tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop); tile_x_max_i, tile_y_max_j, kb0_start, kb0_stop);
} }
template <ggml_type type, int mmq_x, bool need_check> template <ggml_type type, int mmq_x, bool need_check>
static __global__ void mul_mat_q_stream_k_fixup( static __global__ void mul_mat_q_stream_k_fixup(const int32_t * ids_dst,
const int32_t * ids_dst, const int32_t * expert_bounds, float * __restrict__ dst, const float * __restrict__ tmp_last_tile, const int32_t * expert_bounds,
const int ncols_x, const int nrows_x, const int ncols_dst, const int stride_col_dst, float * __restrict__ dst,
const int nchannels_y, const int stride_channel_dst, const int nsamples_y, const int stride_sample_dst, const float * __restrict__ tmp_last_tile,
const int ncols_max) { const int ncols_x,
const int nrows_x,
const int ncols_dst,
const size_t stride_col_dst,
const int nchannels_y,
const size_t stride_channel_dst,
const int nsamples_y,
const size_t stride_sample_dst,
const int ncols_max) {
constexpr int mmq_y = get_mmq_y_device(); constexpr int mmq_y = get_mmq_y_device();
constexpr int qk = ggml_cuda_type_traits<type>::qk; constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int ITER_K = get_iter_k(type); constexpr int ITER_K = get_iter_k(type);

View File

@ -4,26 +4,48 @@
#include "mmvf.cuh" #include "mmvf.cuh"
#include "convert.cuh" #include "convert.cuh"
template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false> template <typename T, typename type_acc, int ncols_dst, int block_size, bool has_fusion = false, bool is_multi_token_id = false>
static __global__ void mul_mat_vec_f( static __global__ void mul_mat_vec_f(
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
const int ncols2, const int nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst, const int ncols2, const uint3 nchannels_y, const int stride_row, const int stride_col_y2, const int stride_col_dst,
const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
const int ids_stride) {
const int row = blockIdx.x; const int row = blockIdx.x;
// for MUL_MAT_ID - blockIdx.y = n_expert_used, blockIdx.z = ncols_dst (tokens)
const int channel_dst = blockIdx.y; const int channel_dst = blockIdx.y;
const int channel_x = ids ? ids[channel_dst] : fastdiv((uint32_t) channel_dst, channel_ratio); const int tid = threadIdx.x;
const int channel_y = ids ? channel_dst % nchannels_y : channel_dst;
const int sample_dst = blockIdx.z; int token_idx;
int channel_x;
int channel_y;
int sample_dst;
if constexpr (is_multi_token_id) {
// Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case
token_idx = blockIdx.z;
channel_x = ids[channel_dst + token_idx * ids_stride];
channel_y = fastmodulo(channel_dst, nchannels_y);
sample_dst = 0;
} else {
token_idx = ids ? blockIdx.z : 0;
channel_x = ids ? ids[blockIdx.y + token_idx * ids_stride] : fastdiv((uint32_t) channel_dst, channel_ratio);
channel_y = ids ? fastmodulo(blockIdx.y, nchannels_y) : channel_dst;
sample_dst = ids ? 0 : blockIdx.z;
}
const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio); const int sample_x = fastdiv((uint32_t) sample_dst, sample_ratio);
const int sample_y = sample_dst; const int sample_y = sample_dst;
const int tid = threadIdx.x;
constexpr int warp_size = ggml_cuda_get_physical_warp_size(); constexpr int warp_size = ggml_cuda_get_physical_warp_size();
x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row; x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y; y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst; dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
if constexpr (is_multi_token_id) {
y += token_idx*stride_col_y2*2;
dst += token_idx*stride_col_dst;
}
bool use_gate = false; bool use_gate = false;
bool use_bias = false; bool use_bias = false;
@ -56,8 +78,10 @@ static __global__ void mul_mat_vec_f(
if (use_gate) { if (use_gate) {
gate_x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row; gate_x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row*stride_row;
} }
const int channel_bias = ids ? channel_x : channel_dst;
if constexpr (has_fusion) { if constexpr (has_fusion) {
const int channel_bias = ids ? channel_x : channel_dst;
if (use_bias) { if (use_bias) {
x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst; x_bias += int64_t(sample_dst)*stride_sample_dst + channel_bias*stride_channel_dst;
} }
@ -349,36 +373,36 @@ static __global__ void mul_mat_vec_f(
} }
} }
template<typename T, typename type_acc, int ncols_dst, int block_size> template<typename T, typename type_acc, int ncols_dst, int block_size, bool is_multi_token_id = false>
static void mul_mat_vec_f_switch_fusion( static void mul_mat_vec_f_switch_fusion(
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t ncols, const uint3 nchannels_y,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const uint3 channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, const uint3 sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const cudaStream_t stream) { const dim3 & block_dims, const dim3 & block_nums, const int nbytes_shared, const int ids_stride, const cudaStream_t stream) {
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
if constexpr (ncols_dst == 1) { if constexpr (ncols_dst == 1) {
if (has_fusion) { if (has_fusion) {
mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true><<<block_nums, block_dims, nbytes_shared, stream>>> mul_mat_vec_f<T, type_acc, ncols_dst, block_size, true, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
return; return;
} }
} }
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
mul_mat_vec_f<T, type_acc, ncols_dst, block_size><<<block_nums, block_dims, nbytes_shared, stream>>> mul_mat_vec_f<T, type_acc, ncols_dst, block_size, false, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, (x, y, ids, fusion, dst, ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
} }
template <typename T, typename type_acc, int ncols_dst> template <typename T, typename type_acc, int ncols_dst, bool is_multi_token_id = false>
void launch_mul_mat_vec_f_cuda( void launch_mul_mat_vec_f_cuda(
const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const T * x, const float * y, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const int64_t ncols, const int64_t nrows, const int64_t ncols, const int64_t nrows,
@ -386,12 +410,13 @@ void launch_mul_mat_vec_f_cuda(
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) { const int64_t nsamples_or_ntokens, const int64_t ids_stride, cudaStream_t stream) {
GGML_ASSERT(ncols % 2 == 0); GGML_ASSERT(ncols % 2 == 0);
GGML_ASSERT(stride_row % 2 == 0); GGML_ASSERT(stride_row % 2 == 0);
GGML_ASSERT(stride_col_y % 2 == 0); GGML_ASSERT(stride_col_y % 2 == 0);
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0); GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
GGML_ASSERT( nsamples_dst % nsamples_x == 0); GGML_ASSERT( nsamples_dst % nsamples_x == 0);
const uint3 nchannels_y_fd = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0);
const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x); const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x); const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
@ -415,56 +440,56 @@ void launch_mul_mat_vec_f_cuda(
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
const int nbytes_shared = warp_size*sizeof(float) + (has_fusion ? warp_size*sizeof(float) : 0); const int nbytes_shared = warp_size*sizeof(float) + (has_fusion ? warp_size*sizeof(float) : 0);
const dim3 block_nums(nrows, nchannels_dst, nsamples_dst); const dim3 block_nums(nrows, nchannels_dst, nsamples_or_ntokens);
const dim3 block_dims(block_size_best, 1, 1); const dim3 block_dims(block_size_best, 1, 1);
switch (block_size_best) { switch (block_size_best) {
case 32: { case 32: {
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 32> mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 32, is_multi_token_id>
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
} break; } break;
case 64: { case 64: {
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 64> mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 64, is_multi_token_id>
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
} break; } break;
case 96: { case 96: {
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 96> mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 96, is_multi_token_id>
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
} break; } break;
case 128: { case 128: {
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 128> mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 128, is_multi_token_id>
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
} break; } break;
case 160: { case 160: {
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 160> mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 160, is_multi_token_id>
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
} break; } break;
case 192: { case 192: {
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 192> mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 192, is_multi_token_id>
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
} break; } break;
case 224: { case 224: {
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 224> mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 224, is_multi_token_id>
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
} break; } break;
case 256: { case 256: {
mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 256> mul_mat_vec_f_switch_fusion<T, type_acc, ncols_dst, 256, is_multi_token_id>
(x, y, ids, fusion, dst, ncols/2, nchannels_y, stride_row, stride_col_y/2, stride_col_dst, (x, y, ids, fusion, dst, ncols/2, nchannels_y_fd, stride_row, stride_col_y/2, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, stream); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, block_dims, block_nums, nbytes_shared, ids_stride, stream);
} break; } break;
default: { default: {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
@ -480,55 +505,88 @@ static void mul_mat_vec_f_cuda_switch_ncols_dst(
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) { const int64_t ids_stride, cudaStream_t stream) {
const bool has_ids = ids != nullptr;
if (has_ids && ncols_dst > 1) {
// Multi-token MUL_MAT_ID path only - single-token goes through regular path below
constexpr int c_ncols_dst = 1;
launch_mul_mat_vec_f_cuda<T, type_acc, c_ncols_dst, true>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
ncols_dst, ids_stride, stream);
return;
}
if (has_ids) {
// Single-token MUL_MAT_ID path
constexpr int c_ncols_dst = 1;
launch_mul_mat_vec_f_cuda<T, type_acc, c_ncols_dst>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
ncols_dst, ids_stride, stream);
return;
}
switch (ncols_dst) { switch (ncols_dst) {
case 1: case 1:
launch_mul_mat_vec_f_cuda<T, type_acc, 1> launch_mul_mat_vec_f_cuda<T, type_acc, 1>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
nsamples_dst, ids_stride, stream);
break; break;
case 2: case 2:
launch_mul_mat_vec_f_cuda<T, type_acc, 2> launch_mul_mat_vec_f_cuda<T, type_acc, 2>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
nsamples_dst, ids_stride, stream);
break; break;
case 3: case 3:
launch_mul_mat_vec_f_cuda<T, type_acc, 3> launch_mul_mat_vec_f_cuda<T, type_acc, 3>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
nsamples_dst, ids_stride, stream);
break; break;
case 4: case 4:
launch_mul_mat_vec_f_cuda<T, type_acc, 4> launch_mul_mat_vec_f_cuda<T, type_acc, 4>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
nsamples_dst, ids_stride, stream);
break; break;
case 5: case 5:
launch_mul_mat_vec_f_cuda<T, type_acc, 5> launch_mul_mat_vec_f_cuda<T, type_acc, 5>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
nsamples_dst, ids_stride, stream);
break; break;
case 6: case 6:
launch_mul_mat_vec_f_cuda<T, type_acc, 6> launch_mul_mat_vec_f_cuda<T, type_acc, 6>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
nsamples_dst, ids_stride, stream);
break; break;
case 7: case 7:
launch_mul_mat_vec_f_cuda<T, type_acc, 7> launch_mul_mat_vec_f_cuda<T, type_acc, 7>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
nsamples_dst, ids_stride, stream);
break; break;
case 8: case 8:
launch_mul_mat_vec_f_cuda<T, type_acc, 8> launch_mul_mat_vec_f_cuda<T, type_acc, 8>
(x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst, (x, y, ids, fusion, dst, ncols, nrows, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst,
nsamples_dst, ids_stride, stream);
break; break;
default: default:
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
@ -544,21 +602,21 @@ static void mul_mat_vec_f_cuda(
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
enum ggml_prec prec, cudaStream_t stream) { const int64_t ids_stride, enum ggml_prec prec, cudaStream_t stream) {
if constexpr(std::is_same_v<T, half>) { if constexpr(std::is_same_v<T, half>) {
if (prec == GGML_PREC_DEFAULT) { if (prec == GGML_PREC_DEFAULT) {
mul_mat_vec_f_cuda_switch_ncols_dst<T, half> mul_mat_vec_f_cuda_switch_ncols_dst<T, half>
(x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
return; return;
} }
} }
mul_mat_vec_f_cuda_switch_ncols_dst<T, float> mul_mat_vec_f_cuda_switch_ncols_dst<T, float>
(x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst, (x, y, ids, fusion, dst, ncols, nrows, ncols_dst, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y,
stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); stride_channel_dst, nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
} }
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
@ -573,7 +631,7 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor
const size_t ts_src1 = ggml_type_size(src1->type); const size_t ts_src1 = ggml_type_size(src1->type);
const size_t ts_dst = ggml_type_size(dst->type); const size_t ts_dst = ggml_type_size(dst->type);
GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1. GGML_ASSERT(!ids || ne12 <= MMVF_MAX_BATCH_SIZE);
GGML_ASSERT(ne13 == ne3); GGML_ASSERT(ne13 == ne3);
GGML_ASSERT( nb00 == ts_src0); GGML_ASSERT( nb00 == ts_src0);
@ -626,29 +684,31 @@ void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor
const int64_t ncols_dst = ids ? ne2 : ne1; const int64_t ncols_dst = ids ? ne2 : ne1;
const int64_t nchannels_y = ids ? ne11 : ne12; const int64_t nchannels_y = ids ? ne11 : ne12;
const int64_t nchannels_dst = ids ? ne1 : ne2; const int64_t nchannels_dst = ids ? ne1 : ne2;
const int64_t stride_col_dst = ids ? s2 : s1;
const int64_t stride_col_y = ids ? s12 : s11;
const int64_t stride_channel_dst = ids ? s1 : s2; const int64_t stride_channel_dst = ids ? s1 : s2;
const int64_t stride_channel_y = ids ? s11 : s12; const int64_t stride_channel_y = ids ? s11 : s12;
GGML_ASSERT(!ids || ncols_dst == 1); const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F32: { case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0->data; const float * src0_d = (const float *) src0->data;
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream()); ne03, ne3, s03, s13, s3, ids_stride, prec, ctx.stream());
} break; } break;
case GGML_TYPE_F16: { case GGML_TYPE_F16: {
const half * src0_d = (const half *) src0->data; const half * src0_d = (const half *) src0->data;
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream()); ne03, ne3, s03, s13, s3, ids_stride, prec, ctx.stream());
} break; } break;
case GGML_TYPE_BF16: { case GGML_TYPE_BF16: {
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data; const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0->data;
mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, s11, s1, mul_mat_vec_f_cuda(src0_d, src1_d, ids_d, fusion_local, dst_d, ne00, ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, prec, ctx.stream()); ne03, ne3, s03, s13, s3, ids_stride, prec, ctx.stream());
} break; } break;
default: default:
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
@ -695,19 +755,19 @@ void ggml_cuda_op_mul_mat_vec_f(
const float * src0_d = (const float *) src0_dd_i; const float * src0_d = (const float *) src0_dd_i;
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream);
} break; } break;
case GGML_TYPE_F16: { case GGML_TYPE_F16: {
const half * src0_d = (const half *) src0_dd_i; const half * src0_d = (const half *) src0_dd_i;
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream);
} break; } break;
case GGML_TYPE_BF16: { case GGML_TYPE_BF16: {
const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i; const nv_bfloat16 * src0_d = (const nv_bfloat16 *) src0_dd_i;
mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst, mul_mat_vec_f_cuda(src0_d, src1_ddf_i, nullptr, empty, dst_dd_i, ne00, row_diff, src1_ncols, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, prec, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, 0, prec, stream);
} break; } break;
default: default:
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));

View File

@ -1,5 +1,7 @@
#include "common.cuh" #include "common.cuh"
#define MMVF_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVF kernels.
void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, void ggml_cuda_mul_mat_vec_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst,
const ggml_cuda_mm_fusion_args_host * fusion = nullptr); const ggml_cuda_mm_fusion_args_host * fusion = nullptr);

View File

@ -137,15 +137,15 @@ static constexpr __host__ __device__ int calc_rows_per_block(int ncols_dst, int
return 1; return 1;
} }
// tell the compiler to use as many registers as it wants, see nwarps definition below template <ggml_type type, int ncols_dst, bool has_fusion, bool is_multi_token_id = false>
template <ggml_type type, int ncols_dst, bool has_fusion>
__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) __launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mul_mat_vec_q( static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst, const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, const ggml_cuda_mm_fusion_args_device fusion, float * __restrict__ dst,
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) { const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst,
const uint32_t ids_stride) {
constexpr int qk = ggml_cuda_type_traits<type>::qk; constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi; constexpr int qi = ggml_cuda_type_traits<type>::qi;
@ -162,11 +162,25 @@ static __global__ void mul_mat_vec_q(
const int blocks_per_row_x = ncols_x / qk; const int blocks_per_row_x = ncols_x / qk;
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi; constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
// The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1.
const uint32_t channel_dst = blockIdx.y; const uint32_t channel_dst = blockIdx.y;
const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst; uint32_t token_idx = 0;
const uint32_t sample_dst = blockIdx.z; uint32_t channel_x;
uint32_t channel_y;
uint32_t sample_dst;
if constexpr (is_multi_token_id) {
// Multi-token MUL_MAT_ID path, adding these in the normal path causes a perf regression for n_tokens=1 case
token_idx = blockIdx.z;
channel_x = ids[channel_dst + token_idx * ids_stride];
channel_y = fastmodulo(channel_dst, nchannels_y);
sample_dst = 0;
} else {
channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
sample_dst = blockIdx.z;
}
const uint32_t sample_x = fastdiv(sample_dst, sample_ratio); const uint32_t sample_x = fastdiv(sample_dst, sample_ratio);
const uint32_t sample_y = sample_dst; const uint32_t sample_y = sample_dst;
@ -188,11 +202,11 @@ static __global__ void mul_mat_vec_q(
active_glu = fusion.glu_op; active_glu = fusion.glu_op;
} }
const uint32_t channel_bias = ids ? channel_x : channel_dst;
float x_biases[ncols_dst] = { 0.0f }; float x_biases[ncols_dst] = { 0.0f };
float gate_biases[ncols_dst] = { 0.0f }; float gate_biases[ncols_dst] = { 0.0f };
if constexpr (has_fusion) { if constexpr (has_fusion) {
const uint32_t channel_bias = ids ? channel_x : channel_dst;
if (use_bias) { if (use_bias) {
x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0; x_bias = x_bias + sample_dst*stride_sample_dst + channel_bias*stride_channel_dst + row0;
// 1. Hide latency by prefetching bias and gate here // 1. Hide latency by prefetching bias and gate here
@ -222,6 +236,9 @@ static __global__ void mul_mat_vec_q(
float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}}; float tmp_gate[ncols_dst][rows_per_cuda_block] = {{0.0f}};
const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y; const block_q8_1 * y = ((const block_q8_1 *) vy) + sample_y*stride_sample_y + channel_y*stride_channel_y;
if constexpr (is_multi_token_id) {
y += token_idx*stride_col_y;
}
const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x; const int kbx_offset = sample_x*stride_sample_x + channel_x*stride_channel_x + row0*stride_row_x;
for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) { for (int kbx = tid / (qi/vdr); kbx < blocks_per_row_x; kbx += blocks_per_iter) {
@ -275,6 +292,10 @@ static __global__ void mul_mat_vec_q(
dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0; dst += sample_dst*stride_sample_dst + channel_dst*stride_channel_dst + row0;
if constexpr (is_multi_token_id) {
dst += token_idx*stride_col_dst;
}
// sum up partial sums and write back result // sum up partial sums and write back result
#pragma unroll #pragma unroll
for (int j = 0; j < ncols_dst; ++j) { for (int j = 0; j < ncols_dst; ++j) {
@ -335,40 +356,41 @@ static __global__ void mul_mat_vec_q(
} }
static std::pair<dim3, dim3> calc_launch_params( static std::pair<dim3, dim3> calc_launch_params(
const int ncols_dst, const int nrows_x, const int nchannels_y, const int nsamples_y, const int ncols_dst, const int nrows_x, const int nchannels_dst, const int nsamples_or_ntokens,
const int warp_size, const mmvq_parameter_table_id table_id) { const int warp_size, const mmvq_parameter_table_id table_id) {
const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id); const int64_t nblocks = (nrows_x + calc_rows_per_block(ncols_dst, table_id) - 1) / calc_rows_per_block(ncols_dst, table_id);
const dim3 block_nums(nblocks, nchannels_y, nsamples_y); const dim3 block_nums(nblocks, nchannels_dst, nsamples_or_ntokens);
const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1); const dim3 block_dims(warp_size, calc_nwarps(ncols_dst, table_id), 1);
return {block_nums, block_dims}; return {block_nums, block_dims};
} }
template<ggml_type type, int c_ncols_dst> template<ggml_type type, int c_ncols_dst, bool is_multi_token_id = false>
static void mul_mat_vec_q_switch_fusion( static void mul_mat_vec_q_switch_fusion(
const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst, const void * vx, const void * vy, const int32_t * ids, const ggml_cuda_mm_fusion_args_device fusion, float * dst,
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x, const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio, const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst, const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst,
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared, cudaStream_t stream) { const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared,
const uint32_t ids_stride, cudaStream_t stream) {
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
if constexpr (c_ncols_dst == 1) { if constexpr (c_ncols_dst == 1) {
if (has_fusion) { if (has_fusion) {
mul_mat_vec_q<type, c_ncols_dst, true><<<block_nums, block_dims, nbytes_shared, stream>>> mul_mat_vec_q<type, c_ncols_dst, true, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
return; return;
} }
} }
GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1"); GGML_ASSERT(!has_fusion && "fusion only supported for ncols_dst=1");
mul_mat_vec_q<type, c_ncols_dst, false><<<block_nums, block_dims, nbytes_shared, stream>>> mul_mat_vec_q<type, c_ncols_dst, false, is_multi_token_id><<<block_nums, block_dims, nbytes_shared, stream>>>
(vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride);
} }
template <ggml_type type> template <ggml_type type>
@ -379,7 +401,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
const int nchannels_x, const int nchannels_y, const int nchannels_dst, const int nchannels_x, const int nchannels_y, const int nchannels_dst,
const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
cudaStream_t stream) { const int ids_stride, cudaStream_t stream) {
GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE); GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE);
@ -393,8 +415,19 @@ static void mul_mat_vec_q_switch_ncols_dst(
const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc); const mmvq_parameter_table_id table_id = get_device_table_id(ggml_cuda_info().devices[device].cc);
const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr; const bool has_fusion = fusion.gate != nullptr || fusion.x_bias != nullptr || fusion.gate_bias != nullptr;
const bool has_ids = ids != nullptr;
if (has_ids && ncols_dst > 1) {
// Multi-token MUL_MAT_ID path only - single-token goes through regular path below
constexpr int c_ncols_dst = 1;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, ncols_dst, warp_size, table_id);
mul_mat_vec_q_switch_fusion<type, c_ncols_dst, true>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, ids_stride, stream);
return;
}
GGML_ASSERT(!ids || ncols_dst == 1);
switch (ncols_dst) { switch (ncols_dst) {
case 1: { case 1: {
constexpr int c_ncols_dst = 1; constexpr int c_ncols_dst = 1;
@ -402,7 +435,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream); dims.first, dims.second, 0, ids_stride, stream);
} break; } break;
case 2: { case 2: {
constexpr int c_ncols_dst = 2; constexpr int c_ncols_dst = 2;
@ -410,7 +443,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream); dims.first, dims.second, 0, ids_stride, stream);
} break; } break;
case 3: { case 3: {
constexpr int c_ncols_dst = 3; constexpr int c_ncols_dst = 3;
@ -418,7 +451,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream); dims.first, dims.second, 0, ids_stride, stream);
} break; } break;
case 4: { case 4: {
constexpr int c_ncols_dst = 4; constexpr int c_ncols_dst = 4;
@ -426,7 +459,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream); dims.first, dims.second, 0, ids_stride, stream);
} break; } break;
case 5: { case 5: {
constexpr int c_ncols_dst = 5; constexpr int c_ncols_dst = 5;
@ -434,7 +467,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream); dims.first, dims.second, 0, ids_stride, stream);
} break; } break;
case 6: { case 6: {
constexpr int c_ncols_dst = 6; constexpr int c_ncols_dst = 6;
@ -442,7 +475,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream); dims.first, dims.second, 0, ids_stride, stream);
} break; } break;
case 7: { case 7: {
constexpr int c_ncols_dst = 7; constexpr int c_ncols_dst = 7;
@ -450,7 +483,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream); dims.first, dims.second, 0, ids_stride, stream);
} break; } break;
case 8: { case 8: {
constexpr int c_ncols_dst = 8; constexpr int c_ncols_dst = 8;
@ -458,7 +491,7 @@ static void mul_mat_vec_q_switch_ncols_dst(
mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst, mul_mat_vec_q_switch_fusion<type, c_ncols_dst>(vx, vy, ids, fusion, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst, sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst,
dims.first, dims.second, 0, stream); dims.first, dims.second, 0, ids_stride, stream);
} break; } break;
default: default:
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
@ -474,127 +507,127 @@ static void mul_mat_vec_q_switch_type(
const int nchannels_x, const int nchannels_y, const int nchannels_dst, const int nchannels_x, const int nchannels_y, const int nchannels_dst,
const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst, const int nsamples_x, const int nsamples_dst, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst,
cudaStream_t stream) { const int ids_stride, cudaStream_t stream) {
switch (type_x) { switch (type_x) {
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_0> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_0>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break; break;
case GGML_TYPE_Q4_1: case GGML_TYPE_Q4_1:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_1> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_1>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break; break;
case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_0:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_0> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_0>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break; break;
case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_1:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_1> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_1>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break; break;
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q8_0> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q8_0>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break; break;
case GGML_TYPE_MXFP4: case GGML_TYPE_MXFP4:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_MXFP4>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break; break;
case GGML_TYPE_Q2_K: case GGML_TYPE_Q2_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q2_K>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break; break;
case GGML_TYPE_Q3_K: case GGML_TYPE_Q3_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q3_K> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q3_K>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break; break;
case GGML_TYPE_Q4_K: case GGML_TYPE_Q4_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_K> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q4_K>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break; break;
case GGML_TYPE_Q5_K: case GGML_TYPE_Q5_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_K> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q5_K>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break; break;
case GGML_TYPE_Q6_K: case GGML_TYPE_Q6_K:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q6_K> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_Q6_K>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break; break;
case GGML_TYPE_IQ2_XXS: case GGML_TYPE_IQ2_XXS:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XXS> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XXS>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break; break;
case GGML_TYPE_IQ2_XS: case GGML_TYPE_IQ2_XS:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XS> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_XS>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break; break;
case GGML_TYPE_IQ2_S: case GGML_TYPE_IQ2_S:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_S> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ2_S>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break; break;
case GGML_TYPE_IQ3_XXS: case GGML_TYPE_IQ3_XXS:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_XXS> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_XXS>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break; break;
case GGML_TYPE_IQ1_S: case GGML_TYPE_IQ1_S:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_S> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_S>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break; break;
case GGML_TYPE_IQ1_M: case GGML_TYPE_IQ1_M:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_M> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ1_M>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break; break;
case GGML_TYPE_IQ4_NL: case GGML_TYPE_IQ4_NL:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_NL> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_NL>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break; break;
case GGML_TYPE_IQ4_XS: case GGML_TYPE_IQ4_XS:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_XS> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ4_XS>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break; break;
case GGML_TYPE_IQ3_S: case GGML_TYPE_IQ3_S:
mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_S> mul_mat_vec_q_switch_ncols_dst<GGML_TYPE_IQ3_S>
(vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, fusion, dst, ncols_x, nrows_x, ncols_dst, stride_row_x, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream); nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, ids_stride, stream);
break; break;
default: default:
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
@ -622,7 +655,7 @@ void ggml_cuda_mul_mat_vec_q(
GGML_ASSERT( nb0 == ts_dst); GGML_ASSERT( nb0 == ts_dst);
GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type)); GGML_ASSERT(!ids || ids->nb[0] == ggml_type_size(ids->type));
GGML_ASSERT(!ids || ne12 == 1); // Implementation is only correct for batch size 1. GGML_ASSERT(!ids || ne12 <= MMVQ_MAX_BATCH_SIZE);
const float * src1_d = (const float *) src1->data; const float * src1_d = (const float *) src1->data;
const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr; const int32_t * ids_d = ids ? (const int32_t *) ids->data : nullptr;
@ -693,11 +726,13 @@ void ggml_cuda_mul_mat_vec_q(
const int64_t stride_channel_dst = ids ? s1 : s2; const int64_t stride_channel_dst = ids ? s1 : s2;
const int64_t stride_channel_y = ids ? s11 : s12; const int64_t stride_channel_y = ids ? s11 : s12;
const int64_t ids_stride = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
mul_mat_vec_q_switch_type( mul_mat_vec_q_switch_type(
src0->data, src0->type, src1_q8_1.get(), ids_d, fusion_local, dst_d, ne00, src0->data, src0->type, src1_q8_1.get(), ids_d, fusion_local, dst_d, ne00,
ne01, ncols_dst, s01, stride_col_y, stride_col_dst, ne01, ncols_dst, s01, stride_col_y, stride_col_dst,
ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst, ne02, nchannels_y, nchannels_dst, s02, stride_channel_y, stride_channel_dst,
ne03, ne3, s03, s13, s3, stream); ne03, ne3, s03, s13, s3, ids_stride, stream);
} }
void ggml_cuda_op_mul_mat_vec_q( void ggml_cuda_op_mul_mat_vec_q(
@ -726,7 +761,7 @@ void ggml_cuda_op_mul_mat_vec_q(
ggml_cuda_mm_fusion_args_device fusion_local{}; ggml_cuda_mm_fusion_args_device fusion_local{};
mul_mat_vec_q_switch_type( mul_mat_vec_q_switch_type(
src0_dd_i, src0->type, src1_ddq_i, nullptr, fusion_local, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst, src0_dd_i, src0->type, src1_ddq_i, nullptr, fusion_local, dst_dd_i, ne00, row_diff, src1_ncols, stride_row_x, stride_col_y, nrows_dst,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, stream); 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, stream);
GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size); GGML_UNUSED_VARS(src1, dst, src1_ddf_i, src1_ncols, src1_padded_row_size);
} }

View File

@ -5,6 +5,13 @@
#include <cmath> #include <cmath>
#include <initializer_list> #include <initializer_list>
// Kernel config struct - passed by value to CUDA kernel
struct topk_moe_config {
bool use_sigmoid;
bool with_norm;
bool delayed_softmax;
};
// Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path. // Warp-local softmax used for both the pre-top-k logits and the post-top-k delayed path.
template <int experts_per_thread, bool use_limit> template <int experts_per_thread, bool use_limit>
__device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) { __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
@ -50,6 +57,16 @@ __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const in
} }
} }
template <int experts_per_thread, bool use_limit>
__device__ void sigmoid_warp_inplace(float (&vals)[experts_per_thread], const int limit, const int lane) {
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
const int idx = lane + i * WARP_SIZE;
const bool active = !use_limit || (idx < limit);
vals[i] = active ? 1.f / (1.f + expf(-vals[i])) : -INFINITY;
}
}
/* /*
This kernel does the following: This kernel does the following:
1. optionally softmax over the logits per token [n_experts, n_tokens] 1. optionally softmax over the logits per token [n_experts, n_tokens]
@ -59,13 +76,16 @@ __device__ void softmax_warp_inplace(float (&vals)[experts_per_thread], const in
It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models
*/ */
template <int n_experts, bool with_norm, bool delayed_softmax = false> template <int n_experts, bool has_bias>
__launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits, __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits,
float * weights, float * weights,
int32_t * ids, int32_t * ids,
const int n_rows, float * bias,
const int n_expert_used, const int n_rows,
const float clamp_val) { const int n_expert_used,
const float clamp_val,
const float scale_val,
const topk_moe_config config) {
const int row = blockIdx.x * blockDim.y + threadIdx.y; const int row = blockIdx.x * blockDim.y + threadIdx.y;
if (row >= n_rows) { if (row >= n_rows) {
return; return;
@ -79,14 +99,41 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
float wt[experts_per_thread]; float wt[experts_per_thread];
// Initialize all slots to -INFINITY
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
wt[i] = -INFINITY;
}
#pragma unroll #pragma unroll
for (int i = 0; i < n_experts; i += WARP_SIZE) { for (int i = 0; i < n_experts; i += WARP_SIZE) {
const int expert = i + threadIdx.x; const int expert = i + threadIdx.x;
wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY; wt[i / WARP_SIZE] = (n_experts % WARP_SIZE == 0 || expert < n_experts) ? logits[expert] : -INFINITY;
} }
if constexpr (!delayed_softmax) { if (!config.delayed_softmax) {
softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x); if (config.use_sigmoid) {
sigmoid_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
} else {
softmax_warp_inplace<experts_per_thread, false>(wt, n_experts, threadIdx.x);
}
}
// selection_wt is only needed when bias is present (selection uses wt + bias)
// when no bias, we use wt directly for both selection and weight values
float selection_wt[has_bias ? experts_per_thread : 1];
if constexpr (has_bias) {
#pragma unroll
for (int i = 0; i < experts_per_thread; i++) {
selection_wt[i] = -INFINITY;
}
#pragma unroll
for (int i = 0; i < n_experts; i += WARP_SIZE) {
const int expert = i + threadIdx.x;
selection_wt[i / WARP_SIZE] =
(n_experts % WARP_SIZE == 0 || expert < n_experts) ? wt[i / WARP_SIZE] + bias[expert] : -INFINITY;
}
} }
//at this point, each thread holds either a portion of the softmax distribution //at this point, each thread holds either a portion of the softmax distribution
@ -106,22 +153,56 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
float max_val = wt[0]; float max_val = wt[0];
int max_expert = threadIdx.x; int max_expert = threadIdx.x;
#pragma unroll if constexpr (has_bias) {
for (int i = 1; i < experts_per_thread; i++) { float max_val_s = selection_wt[0];
const int expert = threadIdx.x + i * WARP_SIZE;
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
max_val = wt[i];
max_expert = expert;
}
}
#pragma unroll #pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) { for (int i = 1; i < experts_per_thread; i++) {
const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE); const int expert = threadIdx.x + i * WARP_SIZE;
const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE); if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && selection_wt[i] > max_val_s) {
if (val > max_val || (val == max_val && expert < max_expert)) { max_val = wt[i];
max_val = val; max_val_s = selection_wt[i];
max_expert = expert; max_expert = expert;
}
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
const float val_s = __shfl_xor_sync(0xFFFFFFFF, max_val_s, mask, WARP_SIZE);
const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
if (val_s > max_val_s || (val_s == max_val_s && expert < max_expert)) {
max_val = val;
max_val_s = val_s;
max_expert = expert;
}
}
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
selection_wt[max_expert / WARP_SIZE] = -INFINITY;
}
} else {
#pragma unroll
for (int i = 1; i < experts_per_thread; i++) {
const int expert = threadIdx.x + i * WARP_SIZE;
if ((n_experts % WARP_SIZE == 0 || expert < n_experts) && wt[i] > max_val) {
max_val = wt[i];
max_expert = expert;
}
}
#pragma unroll
for (int mask = WARP_SIZE / 2; mask > 0; mask /= 2) {
const float val = __shfl_xor_sync(0xFFFFFFFF, max_val, mask, WARP_SIZE);
const int expert = __shfl_xor_sync(0xFFFFFFFF, max_expert, mask, WARP_SIZE);
if (val > max_val || (val == max_val && expert < max_expert)) {
max_val = val;
max_expert = expert;
}
}
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
wt[max_expert / WARP_SIZE] = -INFINITY;
} }
} }
@ -130,16 +211,14 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
} }
if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) { if ((max_expert & (WARP_SIZE - 1)) == threadIdx.x) {
wt[max_expert / WARP_SIZE] = -INFINITY;
ids[k] = max_expert; ids[k] = max_expert;
if constexpr (with_norm) { if (config.with_norm) {
wt_sum += max_val; wt_sum += max_val;
} }
} }
} }
if constexpr (with_norm) { if (config.with_norm) {
wt_sum = warp_reduce_sum(wt_sum); wt_sum = warp_reduce_sum(wt_sum);
wt_sum = max(wt_sum, clamp_val); wt_sum = max(wt_sum, clamp_val);
const float inv_sum = 1.0f / wt_sum; const float inv_sum = 1.0f / wt_sum;
@ -149,7 +228,7 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
} }
} }
if constexpr (delayed_softmax) { if (config.delayed_softmax) {
softmax_warp_inplace<experts_per_thread, true>(output_weights, n_expert_used, threadIdx.x); softmax_warp_inplace<experts_per_thread, true>(output_weights, n_expert_used, threadIdx.x);
} }
@ -157,25 +236,25 @@ __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float *
for (int i = 0; i < experts_per_thread; i++) { for (int i = 0; i < experts_per_thread; i++) {
const int idx = i * WARP_SIZE + threadIdx.x; const int idx = i * WARP_SIZE + threadIdx.x;
if (idx < n_expert_used) { if (idx < n_expert_used) {
weights[idx] = output_weights[i]; weights[idx] = output_weights[i] * scale_val;
} }
} }
if (!with_norm) {
GGML_UNUSED(clamp_val);
}
} }
template <bool with_norm, bool delayed_softmax = false> template<bool has_bias>
static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx, static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
const float * logits, const float * logits,
float * weights, float * weights,
int32_t * ids, int32_t * ids,
float * bias,
const int n_rows, const int n_rows,
const int n_expert, const int n_expert,
const int n_expert_used, const int n_expert_used,
const float clamp_val) { const float clamp_val,
static_assert(!(with_norm && delayed_softmax), "delayed softmax is not supported with weight normalization"); const float scale_val,
const topk_moe_config config) {
GGML_ASSERT(!(config.with_norm && config.delayed_softmax) &&
"delayed softmax is not supported with weight normalization");
const int rows_per_block = 4; const int rows_per_block = 4;
dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1); dim3 grid_dims((n_rows + rows_per_block - 1) / rows_per_block, 1, 1);
dim3 block_dims(WARP_SIZE, rows_per_block, 1); dim3 block_dims(WARP_SIZE, rows_per_block, 1);
@ -183,44 +262,48 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
switch (n_expert) { switch (n_expert) {
case 1: case 1:
topk_moe_cuda<1, with_norm, delayed_softmax> topk_moe_cuda<1, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); clamp_val, scale_val, config);
break; break;
case 2: case 2:
topk_moe_cuda<2, with_norm, delayed_softmax> topk_moe_cuda<2, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); clamp_val, scale_val, config);
break; break;
case 4: case 4:
topk_moe_cuda<4, with_norm, delayed_softmax> topk_moe_cuda<4, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); clamp_val, scale_val, config);
break; break;
case 8: case 8:
topk_moe_cuda<8, with_norm, delayed_softmax> topk_moe_cuda<8, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); clamp_val, scale_val, config);
break; break;
case 16: case 16:
topk_moe_cuda<16, with_norm, delayed_softmax> topk_moe_cuda<16, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); clamp_val, scale_val, config);
break; break;
case 32: case 32:
topk_moe_cuda<32, with_norm, delayed_softmax> topk_moe_cuda<32, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); clamp_val, scale_val, config);
break; break;
case 64: case 64:
topk_moe_cuda<64, with_norm, delayed_softmax> topk_moe_cuda<64, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); clamp_val, scale_val, config);
break; break;
case 128: case 128:
topk_moe_cuda<128, with_norm, delayed_softmax> topk_moe_cuda<128, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); clamp_val, scale_val, config);
break; break;
case 256: case 256:
topk_moe_cuda<256, with_norm, delayed_softmax> topk_moe_cuda<256, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); clamp_val, scale_val, config);
break; break;
case 512: case 512:
topk_moe_cuda<512, with_norm, delayed_softmax> topk_moe_cuda<512, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
<<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, n_rows, n_expert_used, clamp_val); clamp_val, scale_val, config);
break;
case 576:
topk_moe_cuda<576, has_bias><<<grid_dims, block_dims, 0, stream>>>(logits, weights, ids, bias, n_rows, n_expert_used,
clamp_val, scale_val, config);
break; break;
default: default:
GGML_ASSERT(false && "fatal error"); GGML_ASSERT(false && "fatal error");
@ -228,13 +311,14 @@ static void launch_topk_moe_cuda(ggml_backend_cuda_context & ctx,
} }
} }
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const ggml_tensor * logits, const ggml_tensor * logits,
ggml_tensor * weights, ggml_tensor * weights,
ggml_tensor * ids, ggml_tensor * ids,
const bool with_norm, const ggml_tensor * clamp,
const bool delayed_softmax, const ggml_tensor * scale,
ggml_tensor * clamp) { const ggml_tensor * bias,
const ggml_cuda_topk_moe_args & args) {
GGML_ASSERT(logits->type == GGML_TYPE_F32); GGML_ASSERT(logits->type == GGML_TYPE_F32);
GGML_ASSERT(weights->type == GGML_TYPE_F32); GGML_ASSERT(weights->type == GGML_TYPE_F32);
GGML_ASSERT(ids->type == GGML_TYPE_I32); GGML_ASSERT(ids->type == GGML_TYPE_I32);
@ -245,107 +329,75 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const float * logits_d = (const float *) logits->data; const float * logits_d = (const float *) logits->data;
float * weights_d = (float *) weights->data; float * weights_d = (float *) weights->data;
int32_t * ids_d = (int32_t *) ids->data; int32_t * ids_d = (int32_t *) ids->data;
float * bias_d = bias ? (float *) bias->data : nullptr;
float scale_val = scale ? ggml_get_op_params_f32(scale, 0) : 1.0f;
GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts); GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts);
const int n_expert_used = weights->ne[1]; const int n_expert_used = weights->ne[1];
const bool with_norm = clamp != nullptr;
float clamp_val = -INFINITY; float clamp_val = -INFINITY;
if (with_norm) { if (clamp) {
if (clamp) { clamp_val = ggml_get_op_params_f32(clamp, 0);
clamp_val = ggml_get_op_params_f32(clamp, 0); }
}
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used, clamp_val); topk_moe_config config;
config.use_sigmoid = args.sigmoid;
config.with_norm = with_norm;
config.delayed_softmax = args.delayed_softmax;
if (bias) {
launch_topk_moe_cuda<true>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val,
scale_val, config);
} else { } else {
GGML_ASSERT(clamp == nullptr); launch_topk_moe_cuda<false>(ctx, logits_d, weights_d, ids_d, bias_d, n_rows, n_experts, n_expert_used, clamp_val,
if (delayed_softmax) { scale_val, config);
launch_topk_moe_cuda<false, true>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
clamp_val);
} else {
launch_topk_moe_cuda<false, false>(ctx, logits_d, weights_d, ids_d, n_rows, n_experts, n_expert_used,
clamp_val);
}
} }
} }
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op,
const ggml_tensor * weights, const ggml_tensor * weights,
const ggml_tensor * get_rows, const ggml_tensor * logits,
const ggml_tensor * argsort, const ggml_tensor * ids) {
const ggml_tensor * clamp, const int n_expert = ids->nb[1] / ids->nb[0];
int n_expert) { if (((n_expert & (n_expert - 1)) != 0 || n_expert > 512) && n_expert != 576) {
ggml_tensor * probs = get_rows->src[0];
if (probs->op != GGML_OP_RESHAPE) {
return false;
}
probs = probs->src[0];
ggml_tensor * selection_probs = argsort->src[0];
if (probs != selection_probs) {
return false; return false;
} }
float scale = 1.0f; if (!ggml_is_contiguous(weights) || !ggml_is_contiguous(logits)) {
float max_bias = 0.0f;
memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));
if (!ggml_is_contiguous(softmax->src[0]) || !ggml_is_contiguous(weights)) {
return false; return false;
} }
if (scale != 1.0f || max_bias != 0.0f) { if (gating_op->op == GGML_OP_SOFT_MAX) {
return false; const ggml_tensor * softmax = gating_op;
} float scale = 1.0f;
float max_bias = 0.0f;
// don't fuse when masks or sinks are present memcpy(&scale, (const float *) softmax->op_params + 0, sizeof(float));
if (softmax->src[1] || softmax->src[2]) { memcpy(&max_bias, (const float *) softmax->op_params + 1, sizeof(float));
return false;
}
// n_expert must be a power of 2 if (!ggml_is_contiguous(softmax->src[0])) {
if ((n_expert & (n_expert - 1)) != 0 || n_expert > 512) {
return false;
}
if (clamp) {
if (clamp->op != GGML_OP_CLAMP) {
return false; return false;
} }
float max_val = ggml_get_op_params_f32(clamp, 1);
if (max_val != INFINITY) { if (scale != 1.0f || max_bias != 0.0f) {
return false;
}
// don't fuse when masks or sinks are present
if (softmax->src[1] || softmax->src[2]) {
return false;
}
} else if (gating_op->op == GGML_OP_UNARY) {
ggml_unary_op op = ggml_get_unary_op(gating_op);
if (op != GGML_UNARY_OP_SIGMOID) {
return false; return false;
} }
} }
return true; return true;
} }
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool norm, bool delayed_softmax) {
static std::initializer_list<enum ggml_op> norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SUM_ROWS, GGML_OP_CLAMP, GGML_OP_DIV,
GGML_OP_RESHAPE };
static std::initializer_list<enum ggml_op> no_norm_ops = { GGML_OP_SOFT_MAX, GGML_OP_RESHAPE, GGML_OP_ARGSORT,
GGML_OP_VIEW, GGML_OP_GET_ROWS };
static std::initializer_list<enum ggml_op> delayed_softmax_ops = { GGML_OP_ARGSORT, GGML_OP_VIEW,
GGML_OP_GET_ROWS, GGML_OP_RESHAPE,
GGML_OP_SOFT_MAX, GGML_OP_RESHAPE };
GGML_ASSERT(!norm || !delayed_softmax);
if (delayed_softmax) {
return delayed_softmax_ops;
}
if (norm) {
return norm_ops;
}
return no_norm_ops;
}

View File

@ -3,19 +3,25 @@
#include <initializer_list> #include <initializer_list>
void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, struct ggml_cuda_topk_moe_args {
const ggml_tensor * logits, bool sigmoid{};
ggml_tensor * weights, bool softmax{};
ggml_tensor * ids, bool delayed_softmax{};
const bool with_norm, bool prob_bias{};
const bool delayed_softmax = false, bool norm{};
ggml_tensor * weight_clamp = nullptr); bool scale{};
};
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * softmax, void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx,
const ggml_tensor * logits,
ggml_tensor * weights,
ggml_tensor * ids,
const ggml_tensor * clamp,
const ggml_tensor * scale,
const ggml_tensor * bias,
const ggml_cuda_topk_moe_args & args);
bool ggml_cuda_should_use_topk_moe(const ggml_tensor * gating_op,
const ggml_tensor * weights, const ggml_tensor * weights,
const ggml_tensor * get_rows, const ggml_tensor * logits,
const ggml_tensor * argsort, const ggml_tensor * ids);
const ggml_tensor * clamp,
int n_expert);
std::initializer_list<enum ggml_op> ggml_cuda_topk_moe_ops(bool with_norm, bool delayed_softmax = false);

View File

@ -1,7 +1,29 @@
file(TO_CMAKE_PATH "${HEXAGON_SDK_ROOT}" HEXAGON_SDK_ROOT)
file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT)
if (NOT IS_DIRECTORY "${HEXAGON_SDK_ROOT}")
message(FATAL_ERROR "Make sure HEXAGON_SDK_ROOT point to the correct Hexagon SDK installation.")
endif()
if (NOT IS_DIRECTORY "${HEXAGON_TOOLS_ROOT}")
message("Try to read HEXAGON_TOOLS_ROOT from hexagon_sdk.json")
file(READ "${HEXAGON_SDK_ROOT}/hexagon_sdk.json" HEXAGON_SDK_CONFIG_PATH)
string(JSON HEXAGON_TOOLS_PATH GET ${HEXAGON_SDK_CONFIG_PATH} "root" "tools" "info" 0 "path")
message("Found HEXAGON_TOOLS_PATH: ${HEXAGON_TOOLS_PATH}")
set(HEXAGON_TOOLS_ROOT "${HEXAGON_SDK_ROOT}/${HEXAGON_TOOLS_PATH}")
file(TO_CMAKE_PATH "${HEXAGON_TOOLS_ROOT}" HEXAGON_TOOLS_ROOT)
if (NOT IS_DIRECTORY "${HEXAGON_TOOLS_ROOT}")
message(FATAL_ERROR "Make sure HEXAGON_TOOLS_ROOT point to the correct Hexagon SDK installation.")
endif()
endif()
message(STATUS "hexagon: using ${HEXAGON_SDK_ROOT} and ${HEXAGON_TOOLS_ROOT} for building libggml-htp skels")
include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake) include(${HEXAGON_SDK_ROOT}/build/cmake/hexagon_fun.cmake)
include(ExternalProject) include(ExternalProject)
option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF) option(GGML_HEXAGON_HTP_DEBUG "ggml-hexagon: enable HTP debug output" OFF)
set(GGML_HEXAGON_HTP_CERT "$ENV{HEXAGON_HTP_CERT}" CACHE PATH "ggml-hexagon: enable HTP library signing using certificate")
set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml-hexagon: quantize group size (32, 64, or 128)") set(GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE 128 CACHE STRING "ggml-hexagon: quantize group size (32, 64, or 128)")
add_library(htp_iface OBJECT add_library(htp_iface OBJECT
@ -25,56 +47,71 @@ else()
target_link_options(htp_iface PUBLIC -ldl) target_link_options(htp_iface PUBLIC -ldl)
endif() endif()
link_custom_library(htp_iface cdsprpc)
link_custom_library(htp_iface rpcmem)
set(TARGET_NAME ggml-hexagon) set(TARGET_NAME ggml-hexagon)
ggml_add_backend_library(${TARGET_NAME} ggml_add_backend_library(${TARGET_NAME}
ggml-hexagon.cpp htp-utils.c htp-utils.h ../../include/ggml-hexagon.h) ggml-hexagon.cpp
htp-drv.cpp
htp-drv.h
libdl.h
../../include/ggml-hexagon.h)
target_link_libraries(${TARGET_NAME} PRIVATE htp_iface) target_link_libraries(${TARGET_NAME} PRIVATE htp_iface)
target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/htp ${CMAKE_CURRENT_BINARY_DIR}) target_include_directories(${TARGET_NAME} PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/htp ${CMAKE_CURRENT_BINARY_DIR})
# Build HTP bits # Build HTP skels
set(HTP_CMAKE_ARGS set(HTP_SKELS)
-DCMAKE_TOOLCHAIN_FILE=${CMAKE_CURRENT_SOURCE_DIR}/htp/cmake-toolchain.cmake function(build_htp_skel V)
-DCMAKE_BUILD_TYPE=Release ExternalProject_Add(htp-${V}
-DCMAKE_INSTALL_LIBDIR=${CMAKE_CURRENT_BINARY_DIR} SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
-DHEXAGON_SDK_ROOT=$ENV{HEXAGON_SDK_ROOT} BUILD_BYPRODUCTS ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-${V}.so
-DHEXAGON_TOOLS_ROOT=$ENV{HEXAGON_TOOLS_ROOT} CMAKE_ARGS
-DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG} -DCMAKE_BUILD_TYPE=Release
-DGGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}) -DCMAKE_TOOLCHAIN_FILE=${CMAKE_CURRENT_SOURCE_DIR}/htp/cmake-toolchain.cmake
-DCMAKE_INSTALL_LIBDIR=${CMAKE_CURRENT_BINARY_DIR}
-DHEXAGON_SDK_ROOT=${HEXAGON_SDK_ROOT}
-DHEXAGON_TOOLS_ROOT=${HEXAGON_TOOLS_ROOT}
-DHEXAGON_HTP_DEBUG=${GGML_HEXAGON_HTP_DEBUG}
-DGGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE=${GGML_HEXAGON_FP32_QUANTIZE_GROUP_SIZE}
-DDSP_VERSION=${V}
-DPREBUILT_LIB_DIR="toolv19_${V}")
list(APPEND HTP_SKELS ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-${V}.so)
set(HTP_SKELS ${HTP_SKELS} PARENT_SCOPE)
endfunction()
ExternalProject_Add(htp-v68 build_htp_skel(v68)
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON build_htp_skel(v69)
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v68 -DPREBUILT_LIB_DIR="toolv19_v68") build_htp_skel(v73)
build_htp_skel(v75)
ExternalProject_Add(htp-v69 build_htp_skel(v79)
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON build_htp_skel(v81)
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v69 -DPREBUILT_LIB_DIR="toolv19_v69")
ExternalProject_Add(htp-v73
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v73 -DPREBUILT_LIB_DIR="toolv19_v73")
ExternalProject_Add(htp-v75
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v75 -DPREBUILT_LIB_DIR="toolv19_v75")
ExternalProject_Add(htp-v79
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v79 -DPREBUILT_LIB_DIR="toolv19_v79")
ExternalProject_Add(htp-v81
SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/htp BUILD_ALWAYS ON
CMAKE_ARGS ${HTP_CMAKE_ARGS} -DDSP_VERSION=v81 -DPREBUILT_LIB_DIR="toolv19_v81")
# Install Hexagon skels required at runtime # Install Hexagon skels required at runtime
install(FILES install(FILES ${HTP_SKELS} TYPE LIB)
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v68.so
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v69.so if (CMAKE_SYSTEM_NAME MATCHES Windows AND GGML_HEXAGON_HTP_CERT)
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v73.so file(TO_CMAKE_PATH "$ENV{WINDOWS_SDK_BIN}/arm64" WINSDK_BIN0_ARM64)
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v75.so file(TO_CMAKE_PATH "$ENV{WINDOWS_SDK_BIN}/x86" WINSDK_BIN0_X86)
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v79.so file(TO_CMAKE_PATH "$ENV{WindowsSdkVerBinPath}/arm64" WINSDK_BIN1_ARM64)
${CMAKE_CURRENT_BINARY_DIR}/libggml-htp-v81.so file(TO_CMAKE_PATH "$ENV{WindowsSdkVerBinPath}/x86" WINSDK_BIN1_X86)
TYPE LIB)
set(WINSDK_PATHS ${WINSDK_BIN0_ARM64} ${WINSDK_BIN0_X86} ${WINSDK_BIN1_ARM64} ${WINSDK_BIN1_X86})
find_program(INF2CAT NAMES inf2cat.exe PATHS ${WINSDK_PATHS} REQUIRED)
find_program(SIGNTOOL NAMES signtool.exe PATHS ${WINSDK_PATHS} REQUIRED)
message(STATUS "hexagon: using ${GGML_HEXAGON_HTP_CERT} to sign libggml-htp skels")
set(LIBGGML_HTP_CAT ${CMAKE_CURRENT_BINARY_DIR}/libggml-htp.cat)
add_custom_target(libggml-htp-cat
BYPRODUCTS ${LIBGGML_HTP_CAT}
DEPENDS libggml-htp.inf ${HTP_SKELS}
COMMAND ${CMAKE_COMMAND} -E copy ${CMAKE_CURRENT_SOURCE_DIR}/libggml-htp.inf ${CMAKE_CURRENT_BINARY_DIR}
COMMAND ${INF2CAT} /driver:${CMAKE_CURRENT_BINARY_DIR} /os:10_25H2_ARM64
COMMAND ${SIGNTOOL} sign /fd sha256 /f ${GGML_HEXAGON_HTP_CERT} ${LIBGGML_HTP_CAT}
COMMENT "generating and signing libggml-htp.cat file"
VERBATIM
)
add_dependencies(${TARGET_NAME} libggml-htp-cat)
install(FILES ${LIBGGML_HTP_CAT} TYPE LIB)
endif()

View File

@ -14,9 +14,6 @@
#ifdef _WIN32 #ifdef _WIN32
# include <sal.h> # include <sal.h>
# ifndef _WINDOWS
# define _WINDOWS
# endif
#else #else
# include <semaphore.h> # include <semaphore.h>
# include <unistd.h> # include <unistd.h>
@ -25,8 +22,6 @@
#pragma clang diagnostic ignored "-Wnested-anon-types" #pragma clang diagnostic ignored "-Wnested-anon-types"
#pragma clang diagnostic ignored "-Wgnu-anonymous-struct" #pragma clang diagnostic ignored "-Wgnu-anonymous-struct"
#include "htp-utils.h"
#include <AEEStdErr.h> #include <AEEStdErr.h>
#include <dspqueue.h> #include <dspqueue.h>
#include <rpcmem.h> #include <rpcmem.h>
@ -40,6 +35,7 @@
#include "op-desc.h" #include "op-desc.h"
#include "htp-msg.h" #include "htp-msg.h"
#include "htp_iface.h" #include "htp_iface.h"
#include "htp-drv.h"
static size_t opt_ndev = 1; static size_t opt_ndev = 1;
static size_t opt_nhvx = 0; // use all static size_t opt_nhvx = 0; // use all
@ -150,9 +146,9 @@ void ggml_hexagon_session::enqueue(struct htp_general_req &req, struct dspqueue_
0, // flags - the framework will autoset this 0, // flags - the framework will autoset this
n_bufs, // number of buffers n_bufs, // number of buffers
bufs, // buffer references bufs, // buffer references
sizeof(req), sizeof(req), // Message length
(const uint8_t *) &req, // Message (const uint8_t *) &req, // Message
1000000 // Timeout DSPQUEUE_TIMEOUT // Timeout
); );
if (err != 0) { if (err != 0) {
@ -182,13 +178,13 @@ void ggml_hexagon_session::flush() {
// Read response packet from queue // Read response packet from queue
int err = dspqueue_read(q, &flags, int err = dspqueue_read(q, &flags,
HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references HTP_MAX_PACKET_BUFFERS, // Maximum number of buffer references
&n_bufs, // Number of buffer references &n_bufs, // Number of buffer references
bufs, // Buffer references bufs, // Buffer references
sizeof(rsp), // Max message length sizeof(rsp), // Max message length
&rsp_size, // Message length &rsp_size, // Message length
(uint8_t *) &rsp, (uint8_t *) &rsp, // Message
1000000); // Timeout DSPQUEUE_TIMEOUT); // Timeout
if (err == AEE_EEXPIRED) { if (err == AEE_EEXPIRED) {
// TODO: might need to bail out if the HTP is stuck on something // TODO: might need to bail out if the HTP is stuck on something
@ -269,13 +265,7 @@ struct ggml_backend_hexagon_buffer_context {
ggml_backend_hexagon_buffer_context(ggml_hexagon_session * sess, size_t size, bool repack) { ggml_backend_hexagon_buffer_context(ggml_hexagon_session * sess, size_t size, bool repack) {
size += 4 * 1024; // extra page for padding size += 4 * 1024; // extra page for padding
if (rpcmem_alloc2) { this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size);
this->base = (uint8_t *) rpcmem_alloc2(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size);
} else {
GGML_LOG_INFO("ggml-hex: %s rpcmem_alloc2 not found, falling back to rpcmem_alloc\n", sess->name.c_str());
this->base = (uint8_t *) rpcmem_alloc(RPCMEM_HEAP_ID_SYSTEM, RPCMEM_DEFAULT_FLAGS | RPCMEM_HEAP_NOREG, size);
}
if (!this->base) { if (!this->base) {
GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer : size %zu\n", sess->name.c_str(), size); GGML_LOG_ERROR("ggml-hex: %s failed to allocate buffer : size %zu\n", sess->name.c_str(), size);
throw std::runtime_error("ggml-hex: rpcmem_alloc failed (see log for details)"); throw std::runtime_error("ggml-hex: rpcmem_alloc failed (see log for details)");
@ -2461,12 +2451,12 @@ static void ggml_backend_hexagon_free(ggml_backend_t backend) {
} }
static inline bool op_reuse_src1(const ggml_tensor * op1, const ggml_tensor * op0) { static inline bool op_reuse_src1(const ggml_tensor * op1, const ggml_tensor * op0) {
return (op0 && op0->src[1] == op1->src[1] && ggml_is_quantized(op0->src[0]->type) && ggml_is_quantized(op1->src[1]->type)); return (op0 && op0->src[1] == op1->src[1] && ggml_is_quantized(op0->src[0]->type));
} }
static inline bool is_compute_op(ggml_tensor *node) static inline bool is_compute_op(ggml_tensor *node)
{ {
return !(ggml_op_is_empty(node->op) || ggml_is_empty(node)); return !ggml_op_is_empty(node->op) && !ggml_is_empty(node) && (node->flags & GGML_TENSOR_FLAG_COMPUTE);
} }
// scan the graph and figure out last compute op index // scan the graph and figure out last compute op index
@ -2488,7 +2478,7 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
const int last = last_compute_op(graph); const int last = last_compute_op(graph);
const struct ggml_tensor * prev_quant_op = nullptr; // prev executed op with quantizer const struct ggml_tensor * prev_op = nullptr; // prev executed op
for (int i = 0; i < graph->n_nodes; ++i) { for (int i = 0; i < graph->n_nodes; ++i) {
ggml_tensor * node = graph->nodes[i]; ggml_tensor * node = graph->nodes[i];
@ -2497,17 +2487,15 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
continue; continue;
} }
if ((node->flags & GGML_TENSOR_FLAG_COMPUTE) == 0) {
continue;
}
uint32_t flags = 0; uint32_t flags = 0;
// skip quantizer if src1 is reused // skip quantizer if src1 is reused
if (op_reuse_src1(node, prev_quant_op)) { if (op_reuse_src1(node, prev_op)) {
flags |= HTP_OPFLAGS_SKIP_QUANTIZE; flags |= HTP_OPFLAGS_SKIP_QUANTIZE;
} }
prev_op = node;
// ask for early notification for the last Op // ask for early notification for the last Op
if (i == last) { if (i == last) {
flags |= HTP_OPFLAGS_EARLY_WAKEUP; flags |= HTP_OPFLAGS_EARLY_WAKEUP;
@ -2520,7 +2508,6 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
} else { } else {
ggml_hexagon_dispatch_op<init_binary_req<false>>(sess, node, flags); ggml_hexagon_dispatch_op<init_binary_req<false>>(sess, node, flags);
} }
prev_quant_op = node;
break; break;
case GGML_OP_MUL_MAT_ID: case GGML_OP_MUL_MAT_ID:
if (ggml_is_quantized(node->src[0]->type)) { if (ggml_is_quantized(node->src[0]->type)) {
@ -2528,7 +2515,6 @@ static ggml_status ggml_backend_hexagon_graph_compute(ggml_backend_t backend, gg
} else { } else {
ggml_hexagon_dispatch_op<init_binary_id_req<false>>(sess, node, flags); ggml_hexagon_dispatch_op<init_binary_id_req<false>>(sess, node, flags);
} }
prev_quant_op = node;
break; break;
case GGML_OP_MUL: case GGML_OP_MUL:
case GGML_OP_ADD: case GGML_OP_ADD:
@ -2670,7 +2656,7 @@ static std::vector<int> ggml_hexagon_graph_optimize_reorder(const std::vector<no
} }
// that many nodes forward to search for stackable nodes that can reuse VTCM // that many nodes forward to search for stackable nodes that can reuse VTCM
constexpr int N_FORWARD = 8; constexpr int N_FORWARD = 16;
for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) { for (int i1 = i0 + 1; i1 < i0 + N_FORWARD && i1 < n; i1++) {
if (used[i1]) { if (used[i1]) {
@ -3056,10 +3042,12 @@ ggml_hexagon_registry::ggml_hexagon_registry(ggml_backend_reg_t reg) {
} }
} }
#if defined(__ANDROID__)
if (opt_arch < 75) { if (opt_arch < 75) {
opt_ndev = 1; opt_ndev = 1;
GGML_LOG_WARN("ggml-hex: forcing ndev to 1 for SoCs archs lower than v75.\n"); GGML_LOG_WARN("ggml-hex: forcing ndev to 1 for SoCs archs lower than v75.\n");
} }
#endif
GGML_LOG_INFO("ggml-hex: Hexagon Arch version v%d\n", opt_arch); GGML_LOG_INFO("ggml-hex: Hexagon Arch version v%d\n", opt_arch);
@ -3156,6 +3144,8 @@ static void ggml_hexagon_init(ggml_backend_reg * reg) {
opt_arch = strtoul(str_arch, NULL, 0); opt_arch = strtoul(str_arch, NULL, 0);
} }
opt_hostbuf = str_hostbuf ? atoi(str_hostbuf) : 1;
reg->context = new ggml_hexagon_registry(reg); reg->context = new ggml_hexagon_registry(reg);
HEX_VERBOSE("ggml-hex: size-of-general-req %zu size-of-general-rsp %zu\n", sizeof(struct htp_general_req), HEX_VERBOSE("ggml-hex: size-of-general-req %zu size-of-general-rsp %zu\n", sizeof(struct htp_general_req),
@ -3180,6 +3170,11 @@ ggml_backend_reg_t ggml_backend_hexagon_reg(void) {
static std::mutex mutex; static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex); std::lock_guard<std::mutex> lock(mutex);
if (!initialized) { if (!initialized) {
auto nErr = htpdrv_init();
if (nErr != AEE_SUCCESS) {
return NULL;
}
ggml_hexagon_init(&reg); ggml_hexagon_init(&reg);
} }

View File

@ -0,0 +1,418 @@
// sample drv interface
#pragma clang diagnostic ignored "-Wgnu-anonymous-struct"
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#pragma clang diagnostic ignored "-Wsign-compare"
#include <filesystem>
#include <set>
#include <sstream>
#include <string>
#ifdef _WIN32
# define WIN32_LEAN_AND_MEAN
# ifndef NOMINMAX
# define NOMINMAX
# endif
# include <windows.h>
# include <winevt.h>
#else
# include <dlfcn.h>
# include <unistd.h>
#endif
#include "ggml-impl.h"
#include "htp-drv.h"
#include "libdl.h"
#include <domain.h>
//
// Driver API types
//
typedef void * (*rpcmem_alloc_pfn_t)(int heapid, uint32_t flags, int size);
typedef void * (*rpcmem_alloc2_pfn_t)(int heapid, uint32_t flags, size_t size);
typedef void (*rpcmem_free_pfn_t)(void * po);
typedef int (*rpcmem_to_fd_pfn_t)(void * po);
typedef AEEResult (*dspqueue_create_pfn_t)(int domain,
uint32_t flags,
uint32_t req_queue_size,
uint32_t resp_queue_size,
dspqueue_callback_t packet_callback,
dspqueue_callback_t error_callback,
void * callback_context,
dspqueue_t * queue);
typedef AEEResult (*dspqueue_close_pfn_t)(dspqueue_t queue);
typedef AEEResult (*dspqueue_export_pfn_t)(dspqueue_t queue, uint64_t *queue_id);
typedef AEEResult (*dspqueue_write_pfn_t)(dspqueue_t queue, uint32_t flags,
uint32_t num_buffers,
struct dspqueue_buffer *buffers,
uint32_t message_length,
const uint8_t *message,
uint32_t timeout_us);
typedef AEEResult (*dspqueue_read_pfn_t)(dspqueue_t queue, uint32_t *flags,
uint32_t max_buffers, uint32_t *num_buffers,
struct dspqueue_buffer *buffers,
uint32_t max_message_length,
uint32_t *message_length, uint8_t *message,
uint32_t timeout_us);
typedef int (*fastrpc_mmap_pfn_t)(int domain, int fd, void *addr, int offset, size_t length, enum fastrpc_map_flags flags);
typedef int (*fastrpc_munmap_pfn_t)(int domain, int fd, void *addr, size_t length);
typedef int (*remote_handle64_open_pfn_t)(const char* name, remote_handle64 *ph);
typedef int (*remote_handle64_invoke_pfn_t)(remote_handle64 h, uint32_t dwScalars, remote_arg *pra);
typedef int (*remote_handle64_close_pfn_t)(remote_handle h);
typedef int (*remote_handle_control_pfn_t)(uint32_t req, void* data, uint32_t datalen);
typedef int (*remote_handle64_control_pfn_t)(remote_handle64 h, uint32_t req, void* data, uint32_t datalen);
typedef int (*remote_session_control_pfn_t)(uint32_t req, void *data, uint32_t datalen);
//
// Driver API pfns
//
rpcmem_alloc_pfn_t rpcmem_alloc_pfn = nullptr;
rpcmem_alloc2_pfn_t rpcmem_alloc2_pfn = nullptr;
rpcmem_free_pfn_t rpcmem_free_pfn = nullptr;
rpcmem_to_fd_pfn_t rpcmem_to_fd_pfn = nullptr;
fastrpc_mmap_pfn_t fastrpc_mmap_pfn = nullptr;
fastrpc_munmap_pfn_t fastrpc_munmap_pfn = nullptr;
dspqueue_create_pfn_t dspqueue_create_pfn = nullptr;
dspqueue_close_pfn_t dspqueue_close_pfn = nullptr;
dspqueue_export_pfn_t dspqueue_export_pfn = nullptr;
dspqueue_write_pfn_t dspqueue_write_pfn = nullptr;
dspqueue_read_pfn_t dspqueue_read_pfn = nullptr;
remote_handle64_open_pfn_t remote_handle64_open_pfn = nullptr;
remote_handle64_invoke_pfn_t remote_handle64_invoke_pfn = nullptr;
remote_handle64_close_pfn_t remote_handle64_close_pfn = nullptr;
remote_handle_control_pfn_t remote_handle_control_pfn = nullptr;
remote_handle64_control_pfn_t remote_handle64_control_pfn = nullptr;
remote_session_control_pfn_t remote_session_control_pfn = nullptr;
//
// Driver API
//
void * rpcmem_alloc(int heapid, uint32_t flags, int size) {
return rpcmem_alloc_pfn(heapid, flags, size);
}
void * rpcmem_alloc2(int heapid, uint32_t flags, size_t size) {
if (rpcmem_alloc2_pfn) {
return rpcmem_alloc2_pfn(heapid, flags, size);
} else {
GGML_LOG_INFO("ggml-hex: rpcmem_alloc2 not found, falling back to rpcmem_alloc\n");
return rpcmem_alloc_pfn(heapid, flags, size);
}
}
void rpcmem_free(void * po) {
return rpcmem_free_pfn(po);
}
int rpcmem_to_fd(void * po) {
return rpcmem_to_fd_pfn(po);
}
HTPDRV_API int fastrpc_mmap(int domain, int fd, void * addr, int offset, size_t length, enum fastrpc_map_flags flags) {
return fastrpc_mmap_pfn(domain, fd, addr, offset, length, flags);
}
HTPDRV_API int fastrpc_munmap(int domain, int fd, void * addr, size_t length) {
return fastrpc_munmap_pfn(domain, fd, addr, length);
}
AEEResult dspqueue_create(int domain,
uint32_t flags,
uint32_t req_queue_size,
uint32_t resp_queue_size,
dspqueue_callback_t packet_callback,
dspqueue_callback_t error_callback,
void * callback_context,
dspqueue_t * queue) {
return dspqueue_create_pfn(domain, flags, req_queue_size, resp_queue_size, packet_callback, error_callback,
callback_context, queue);
}
AEEResult dspqueue_close(dspqueue_t queue) {
return dspqueue_close_pfn(queue);
}
AEEResult dspqueue_export(dspqueue_t queue, uint64_t * queue_id) {
return dspqueue_export_pfn(queue, queue_id);
}
AEEResult dspqueue_write(dspqueue_t queue,
uint32_t flags,
uint32_t num_buffers,
struct dspqueue_buffer * buffers,
uint32_t message_length,
const uint8_t * message,
uint32_t timeout_us) {
return dspqueue_write_pfn(queue, flags, num_buffers, buffers, message_length, message, timeout_us);
}
AEEResult dspqueue_read(dspqueue_t queue,
uint32_t * flags,
uint32_t max_buffers,
uint32_t * num_buffers,
struct dspqueue_buffer * buffers,
uint32_t max_message_length,
uint32_t * message_length,
uint8_t * message,
uint32_t timeout_us) {
return dspqueue_read_pfn(queue, flags, max_buffers, num_buffers, buffers, max_message_length, message_length,
message, timeout_us);
}
HTPDRV_API int remote_handle64_open(const char * name, remote_handle64 * ph) {
return remote_handle64_open_pfn(name, ph);
}
HTPDRV_API int remote_handle64_invoke(remote_handle64 h, uint32_t dwScalars, remote_arg * pra) {
return remote_handle64_invoke_pfn(h, dwScalars, pra);
}
HTPDRV_API int remote_handle64_close(remote_handle64 h) {
return remote_handle64_close_pfn(h);
}
HTPDRV_API int remote_handle_control(uint32_t req, void * data, uint32_t datalen) {
return remote_handle_control_pfn(req, data, datalen);
}
HTPDRV_API int remote_handle64_control(remote_handle64 h, uint32_t req, void * data, uint32_t datalen) {
return remote_handle64_control_pfn(h, req, data, datalen);
}
HTPDRV_API int remote_session_control(uint32_t req, void * data, uint32_t datalen) {
return remote_session_control_pfn(req, data, datalen);
}
#ifdef _WIN32
static std::string wstr_to_str(std::wstring_view wstr) {
std::string result;
if (wstr.empty()) {
return result;
}
auto bytes_needed = WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS,
wstr.data(), (int) wstr.size(),
nullptr, 0, nullptr, nullptr);
if (bytes_needed == 0) {
GGML_LOG_ERROR("ggml-hex: WideCharToMultiByte failed. Error %lu\n", GetLastError());
throw std::runtime_error("Invalid wstring input");
}
result.resize(bytes_needed, '\0');
int bytes_written = WideCharToMultiByte(CP_UTF8, WC_ERR_INVALID_CHARS,
wstr.data(), (int) wstr.size(),
result.data(), bytes_needed,
nullptr, nullptr);
if (bytes_written == 0) {
GGML_LOG_ERROR("ggml-hex: WideCharToMultiByte failed. Error %lu\n", GetLastError());
throw std::runtime_error("Wstring conversion failed");
}
return result;
}
static std::string get_driver_path() {
std::wstring serviceName = L"qcnspmcdm";
std::string result;
// Get a handle to the SCM database.
SC_HANDLE schSCManager = OpenSCManagerW(NULL, NULL, STANDARD_RIGHTS_READ);
if (nullptr == schSCManager) {
GGML_LOG_ERROR("ggml-hex: Failed to open SCManager. Error: %lu\n", GetLastError());
return result;
}
// Get a handle to the service.
SC_HANDLE schService = OpenServiceW(schSCManager, // SCM database
serviceName.c_str(), // name of service
SERVICE_QUERY_CONFIG); // need query config access
if (nullptr == schService) {
GGML_LOG_ERROR("ggml-hex: Failed to open qcnspmcdm service. Error: %lu\n", GetLastError());
CloseServiceHandle(schSCManager);
return result;
}
// Store the size of buffer used as an output.
DWORD bufferSize;
if (!QueryServiceConfigW(schService, NULL, 0, &bufferSize) &&
(GetLastError() != ERROR_INSUFFICIENT_BUFFER)) {
GGML_LOG_ERROR("ggml-hex: Failed to query service config. Error: %lu\n", GetLastError());
CloseServiceHandle(schService);
CloseServiceHandle(schSCManager);
return result;
}
// Get the configuration of the service.
LPQUERY_SERVICE_CONFIGW serviceConfig =
static_cast<LPQUERY_SERVICE_CONFIGW>(LocalAlloc(LMEM_FIXED, bufferSize));
if (!QueryServiceConfigW(schService, serviceConfig, bufferSize, &bufferSize)) {
fprintf(stderr, "ggml-hex: Failed to query service config. Error: %lu\n", GetLastError());
LocalFree(serviceConfig);
CloseServiceHandle(schService);
CloseServiceHandle(schSCManager);
return result;
}
// Read the driver file path get its parent directory
std::wstring driverPath = std::wstring(serviceConfig->lpBinaryPathName);
driverPath = driverPath.substr(0, driverPath.find_last_of(L"\\"));
// Clean up resources
LocalFree(serviceConfig);
CloseServiceHandle(schService);
CloseServiceHandle(schSCManager);
// Driver path would contain invalid path string, like:
// \SystemRoot\System32\DriverStore\FileRepository\qcadsprpc8280.inf_arm64_c2b9460c9a072f37
// "\SystemRoot" should be replace with a correct one (e.g. C:\Windows)
const std::wstring systemRootPlaceholder = L"\\SystemRoot";
if (0 != driverPath.compare(0, systemRootPlaceholder.length(), systemRootPlaceholder)) {
GGML_LOG_ERROR("ggml-hex: String pattern not found in driver path.\n");
return result;
}
// Replace \SystemRoot with an absolute path from system ENV windir
const std::wstring systemRootEnv = L"windir";
// Query the number of wide charactors this variable requires
DWORD numWords = GetEnvironmentVariableW(systemRootEnv.c_str(), NULL, 0);
if (numWords == 0) {
GGML_LOG_ERROR("ggml-hex: Failed get systemRoot environment variable\n");
return result;
}
// Query the actual system root name from environment variable
std::vector<wchar_t> systemRoot(numWords + 1);
numWords = GetEnvironmentVariableW(systemRootEnv.c_str(), systemRoot.data(), numWords + 1);
if (numWords == 0) {
GGML_LOG_ERROR("ggml-hex: Failed to read windir environment variable\n");
return result;
}
driverPath.replace(0, systemRootPlaceholder.length(), std::wstring(systemRoot.data()));
return wstr_to_str(driverPath);
}
#endif
using dl_handle_ptr = std::unique_ptr<dl_handle, dl_handle_deleter>;
int htpdrv_init() {
static dl_handle_ptr lib_cdsp_rpc_handle = nullptr;
static bool initialized = false;
#ifdef _WIN32
std::string drv_path = get_driver_path() + "\\" + "libcdsprpc.dll";
#else
std::string drv_path = "libcdsprpc.so";
#endif
if (initialized) {
GGML_LOG_INFO("ggml-hex: Driver already loaded\n");
return AEE_SUCCESS;
}
GGML_LOG_INFO("ggml-hex: Loading driver %s\n", drv_path.c_str());
fs::path path{ drv_path.c_str() };
dl_handle_ptr handle { dl_load_library(path) };
if (!handle) {
GGML_LOG_ERROR("ggml-hex: failed to load %s: %s\n", path.u8string().c_str(), dl_error());
return AEE_EUNABLETOLOAD;
}
#define dlsym(drv, type, pfn, symbol, ignore) \
do { \
pfn = (type) dl_get_sym(drv, #symbol); \
if (!ignore && nullptr == pfn) { \
GGML_LOG_ERROR("ggml-hex: failed to dlsym %s\n", #symbol); \
return AEE_EUNABLETOLOAD; \
} \
} while (0)
dlsym(handle.get(), rpcmem_alloc_pfn_t, rpcmem_alloc_pfn, rpcmem_alloc, false);
dlsym(handle.get(), rpcmem_alloc2_pfn_t, rpcmem_alloc2_pfn, rpcmem_alloc2, true);
dlsym(handle.get(), rpcmem_free_pfn_t, rpcmem_free_pfn, rpcmem_free, false);
dlsym(handle.get(), rpcmem_to_fd_pfn_t, rpcmem_to_fd_pfn, rpcmem_to_fd, false);
dlsym(handle.get(), fastrpc_mmap_pfn_t, fastrpc_mmap_pfn, fastrpc_mmap, false);
dlsym(handle.get(), fastrpc_munmap_pfn_t, fastrpc_munmap_pfn, fastrpc_munmap, false);
dlsym(handle.get(), dspqueue_create_pfn_t, dspqueue_create_pfn, dspqueue_create, false);
dlsym(handle.get(), dspqueue_close_pfn_t, dspqueue_close_pfn, dspqueue_close, false);
dlsym(handle.get(), dspqueue_export_pfn_t, dspqueue_export_pfn, dspqueue_export, false);
dlsym(handle.get(), dspqueue_write_pfn_t, dspqueue_write_pfn, dspqueue_write, false);
dlsym(handle.get(), dspqueue_read_pfn_t, dspqueue_read_pfn, dspqueue_read, false);
dlsym(handle.get(), remote_handle64_open_pfn_t, remote_handle64_open_pfn, remote_handle64_open, false);
dlsym(handle.get(), remote_handle64_invoke_pfn_t, remote_handle64_invoke_pfn, remote_handle64_invoke, false);
dlsym(handle.get(), remote_handle_control_pfn_t, remote_handle_control_pfn, remote_handle_control, false);
dlsym(handle.get(), remote_handle64_control_pfn_t, remote_handle64_control_pfn, remote_handle64_control, false);
dlsym(handle.get(), remote_session_control_pfn_t, remote_session_control_pfn, remote_session_control, false);
dlsym(handle.get(), remote_handle64_close_pfn_t, remote_handle64_close_pfn, remote_handle64_close, false);
lib_cdsp_rpc_handle = std::move(handle);
initialized = true;
return AEE_SUCCESS;
}
domain * get_domain(int domain_id) {
int i = 0;
int size = sizeof(supported_domains) / sizeof(domain);
for (i = 0; i < size; i++) {
if (supported_domains[i].id == domain_id) {
return &supported_domains[i];
}
}
return NULL;
}
int get_hex_arch_ver(int domain, int * arch) {
if (!remote_handle_control_pfn) {
GGML_LOG_ERROR("ggml-hex: remote_handle_control is not supported on this device\n");
return AEE_EUNSUPPORTEDAPI;
}
struct remote_dsp_capability arch_ver;
arch_ver.domain = (uint32_t) domain;
arch_ver.attribute_ID = ARCH_VER;
arch_ver.capability = (uint32_t) 0;
int err = remote_handle_control(DSPRPC_GET_DSP_INFO, &arch_ver, sizeof(arch_ver));
if ((err & 0xff) == (AEE_EUNSUPPORTEDAPI & 0xff)) {
GGML_LOG_ERROR("ggml-hex: FastRPC capability API is not supported on this device\n");
return AEE_EUNSUPPORTEDAPI;
}
if (err != AEE_SUCCESS) {
GGML_LOG_ERROR("ggml-hex: FastRPC capability query failed (err %d)\n", err);
return err;
}
switch (arch_ver.capability & 0xff) {
case 0x68:
*arch = 68;
return 0;
case 0x69:
*arch = 69;
return 0;
case 0x73:
*arch = 73;
return 0;
case 0x75:
*arch = 75;
return 0;
case 0x79:
*arch = 79;
return 0;
case 0x81:
*arch = 81;
return 0;
}
return -1;
}

View File

@ -0,0 +1,121 @@
#pragma once
#ifdef __cplusplus
extern "C" {
#endif
#ifdef _WIN32
# pragma clang diagnostic ignored "-Wignored-attributes"
#endif
#include <AEEStdErr.h>
#include <rpcmem.h>
#include <remote.h>
#include <dspqueue.h>
#if defined(_WIN32) && !defined(__MINGW32__)
# ifdef GGML_BACKEND_BUILD
# define HTPDRV_API __declspec(dllexport) extern
# else
# define HTPDRV_API __declspec(dllimport) extern
# endif
#else
# define HTPDRV_API __attribute__ ((visibility ("default"))) extern
#endif
/* Offset to differentiate HLOS and Hexagon error codes.
Stores the value of AEE_EOFFSET for Hexagon. */
#ifndef DSP_OFFSET
# define DSP_OFFSET 0x80000400
#endif
/* Errno for connection reset by peer. */
#ifndef ECONNRESET
# ifdef __hexagon__
# define ECONNRESET 104
# endif
#endif
/* Abstraction of different OS specific sleep APIs.
SLEEP accepts input in seconds. */
#ifndef SLEEP
# ifdef __hexagon__
# define SLEEP(x) \
{ /* Do nothing for simulator. */ \
}
# else
# ifdef _WIN32
# define SLEEP(x) Sleep(1000 * x) /* Sleep accepts input in milliseconds. */
# else
# define SLEEP(x) sleep(x) /* sleep accepts input in seconds. */
# endif
# endif
#endif
/* Include windows specific header files. */
#ifdef _WIN32
# include <windows.h>
# include <sysinfoapi.h>
# define _CRT_SECURE_NO_WARNINGS 1
# define _WINSOCK_DEPRECATED_NO_WARNINGS 1
#endif
/* Includes and defines for all HLOS except windows */
#if !defined(__hexagon__) && !defined(_WIN32)
# include "unistd.h"
# include <sys/time.h>
#endif
/* Includes and defines for Hexagon and all HLOS except Windows. */
#if !defined(_WIN32)
/* Weak reference to remote symbol for compilation. */
# pragma weak remote_session_control
# pragma weak remote_handle_control
# pragma weak remote_handle64_control
# pragma weak fastrpc_mmap
# pragma weak fastrpc_munmap
# pragma weak rpcmem_alloc2
#endif
#if !defined(_WIN32)
# pragma weak remote_system_request
#endif
#ifdef _WIN32
# define DSPQUEUE_TIMEOUT DSPQUEUE_TIMEOUT_NONE
#else
# define DSPQUEUE_TIMEOUT 1000000
#endif
/**
* htpdrv_init API: driver interface entry point
*
* @return Return AEE error codes as defined in Hexagon SDK.
*/
HTPDRV_API int htpdrv_init(void);
/**
* get_domain API: get domain struct from domain value.
*
* @param[in] domain value of a domain
* @return Returns domain struct of the domain if it is supported or else
* returns NULL.
*
*/
HTPDRV_API domain * get_domain(int domain_id);
/**
* get_hex_arch_ver API: query the Hexagon processor architecture version information
*
* @param[in] domain_id value of a domain
* @param[out] Arch version (73, 75, ...)
* @return 0 if query is successful.
* non-zero if error, return value points to the error.
*
*/
HTPDRV_API int get_hex_arch_ver(int domain, int * arch);
#ifdef __cplusplus
}
#endif

View File

@ -1,454 +0,0 @@
#pragma clang diagnostic ignored "-Wgnu-anonymous-struct"
#pragma clang diagnostic ignored "-Wmissing-prototypes"
#pragma clang diagnostic ignored "-Wsign-compare"
#define GGML_COMMON_IMPL_C
#include "ggml-backend-impl.h"
#include "ggml-common.h"
#include "ggml-hexagon.h"
#include "ggml-impl.h"
#include "htp-utils.h"
#include <domain.h>
#include <remote.h>
#include <stdbool.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
domain * get_domain(int domain_id) {
int i = 0;
int size = sizeof(supported_domains) / sizeof(domain);
for (i = 0; i < size; i++) {
if (supported_domains[i].id == domain_id) {
return &supported_domains[i];
}
}
return NULL;
}
bool is_valid_domain_id(int domain_id, int compute_only) {
int i = 0;
int size = sizeof(supported_domains) / sizeof(domain);
if (compute_only) {
return is_CDSP(domain_id);
}
for (i = 0; i < size; i++) {
if (supported_domains[i].id == domain_id) {
return true;
}
}
return false;
}
int get_domains_info(char * domain_type, int * num_domains, fastrpc_domain ** domains_info) {
int nErr = AEE_SUCCESS;
int ss_info = 0;
if (domain_type != NULL) {
if (strcmp(domain_type, "LPASS") == 0) {
ss_info = FASTRPC_LPASS;
} else if (strcmp(domain_type, "HPASS") == 0) {
ss_info = FASTRPC_HPASS;
} else {
ss_info = FASTRPC_NSP;
}
}
system_req_payload req = { 0 };
req.id = FASTRPC_GET_DOMAINS;
req.sys.domains = NULL;
fastrpc_domain * domain = NULL;
if (ss_info != 0) {
req.sys.flags = DOMAINS_LIST_FLAGS_SET_TYPE(req.sys.flags, ss_info);
} else {
req.sys.flags = 0;
}
#ifdef _WIN32
nErr = AEE_EUNSUPPORTED;
goto bail;
#endif
if (remote_system_request) {
nErr = remote_system_request(&req);
if (nErr != AEE_SUCCESS) {
GGML_LOG_ERROR("Failure in remote_system_request call: %d.\n", nErr);
goto bail;
}
// Allocate memory for domain-info array
req.sys.max_domains = req.sys.num_domains;
if ((req.sys.domains = calloc(req.sys.num_domains, sizeof(fastrpc_domain))) == NULL) {
nErr = AEE_ENOMEMORY;
GGML_LOG_ERROR("Unable to allocate memory for req.sys.domains");
goto bail;
}
nErr = remote_system_request(&req);
if (nErr != AEE_SUCCESS) {
GGML_LOG_ERROR("Failure in remote_system_request call: %d.\n", nErr);
goto bail;
}
for (int i = 0; i < req.sys.num_domains; i++) {
// Verify that only requested type domains were returned
domain = &req.sys.domains[i];
if (domain->type != ss_info && domain_type != NULL) {
nErr = -1;
GGML_LOG_ERROR("Incorrect data received from remote_system_request.\n");
goto bail;
}
}
*domains_info = req.sys.domains;
*num_domains = req.sys.num_domains;
} else {
nErr = AEE_EUNSUPPORTED;
goto bail;
}
bail:
if (nErr && !req.sys.domains) {
free(req.sys.domains);
}
return nErr;
}
int get_effective_domain_id(char * domain_name, int session_id, int * effec_domain_id) {
int err = 0;
remote_rpc_effective_domain_id_t sess = { 0 };
sess.domain_name = domain_name;
sess.domain_name_len = strlen(domain_name);
sess.session_id = session_id;
err = remote_session_control(FASTRPC_GET_EFFECTIVE_DOMAIN_ID, &sess, sizeof(sess));
if (err) {
GGML_LOG_ERROR("Error 0x%x: failed to get effective domain id for %s, session id %d\n", err, sess.domain_name,
session_id);
return err;
}
*effec_domain_id = sess.effective_domain_id;
return err;
}
int get_dsp_support(int * domain) {
int nErr = AEE_SUCCESS;
*domain = CDSP_DOMAIN_ID; // DSP domain default value is CDSP_DOMAIN_ID
if (remote_handle_control) {
struct remote_dsp_capability dsp_capability_domain = { CDSP_DOMAIN_ID, DOMAIN_SUPPORT, 0 };
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain, sizeof(struct remote_dsp_capability));
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
goto bail;
}
if (dsp_capability_domain.capability == 0) {
dsp_capability_domain.domain = ADSP_DOMAIN_ID; // Check for ADSP support.
dsp_capability_domain.attribute_ID = DOMAIN_SUPPORT;
dsp_capability_domain.capability = 0;
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain,
sizeof(struct remote_dsp_capability));
if (dsp_capability_domain.capability) {
*domain = ADSP_DOMAIN_ID; // For targets like Agatti (not having cDSP), domain is ADSP_DOMAIN_ID
}
}
if (nErr != AEE_SUCCESS) {
GGML_LOG_ERROR("\nget_dsp_support failed with Error 0x%x\n", nErr);
goto bail;
}
} else {
nErr = AEE_EUNSUPPORTEDAPI;
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
}
bail:
return nErr;
}
int get_vtcm_info(int domain, uint32_t * capability, uint32_t attr) {
int nErr = AEE_SUCCESS;
*capability = 0;
if (attr == VTCM_PAGE || attr == VTCM_COUNT) {
} else {
nErr = AEE_EBADPARM;
GGML_LOG_ERROR("Unsupported attr. Only VTCM_PAGE and VTCM_COUNT supported\n");
goto bail;
}
if (remote_handle_control) {
if (domain == ADSP_DOMAIN_ID || domain == CDSP_DOMAIN_ID) {
/*
* Query the DSP for VTCM information
* Since the ADSP does not have a dedicated VTCM, we expect the output to be 0
*/
struct remote_dsp_capability dsp_capability_vtcm_dsp;
dsp_capability_vtcm_dsp.domain = (uint32_t) domain;
dsp_capability_vtcm_dsp.attribute_ID = attr;
dsp_capability_vtcm_dsp.capability = (uint32_t) 0;
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_vtcm_dsp,
sizeof(struct remote_dsp_capability));
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
nErr = AEE_SUCCESS;
goto bail;
} else if (nErr == AEE_SUCCESS) {
*capability = dsp_capability_vtcm_dsp.capability;
} else {
GGML_LOG_ERROR("\nget_vtcm_info failed with Error 0x%x\n", nErr);
goto bail;
}
} else {
nErr = AEE_EUNSUPPORTED;
GGML_LOG_ERROR("Unsupported domain %d\n", domain);
goto bail;
}
} else {
nErr = AEE_EUNSUPPORTEDAPI;
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
}
bail:
return nErr;
}
bool is_unsignedpd_supported(int domain_id) {
int nErr = AEE_SUCCESS;
if (remote_handle_control) {
struct remote_dsp_capability dsp_capability_domain = { domain_id, UNSIGNED_PD_SUPPORT, 0 };
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_domain, sizeof(struct remote_dsp_capability));
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device. Falling back to signed pd.\n");
return false;
}
if (nErr) {
GGML_LOG_ERROR("\nERROR 0x%x: FastRPC Capability API failed. Falling back to signed pd.", nErr);
return false;
}
if (dsp_capability_domain.capability == 1) {
return true;
}
} else {
nErr = AEE_EUNSUPPORTEDAPI;
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device. Falling back to signed pd.\n");
return false;
}
return false;
}
bool get_unsignedpd_support(void) {
return is_unsignedpd_supported(CDSP_DOMAIN_ID);
}
bool is_async_fastrpc_supported(int domain) {
int nErr = AEE_SUCCESS;
if (remote_handle_control) {
if (domain == CDSP_DOMAIN_ID) {
/*
* Query the DSP for ASYNC_FASTRPC_SUPPORT information
* Async fastrpc is supported only on CDSP
*/
struct remote_dsp_capability dsp_capability_async_support;
dsp_capability_async_support.domain = (uint32_t) domain;
dsp_capability_async_support.attribute_ID = ASYNC_FASTRPC_SUPPORT;
dsp_capability_async_support.capability = (uint32_t) 0;
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_async_support,
sizeof(struct remote_dsp_capability));
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
nErr = AEE_SUCCESS;
goto bail;
} else if (dsp_capability_async_support.capability == 1) {
return true;
}
if (nErr != AEE_SUCCESS) {
GGML_LOG_ERROR("\nis_async_fastrpc_supported failed with Error 0x%x\n", nErr);
goto bail;
}
} else {
nErr = AEE_EUNSUPPORTED;
GGML_LOG_ERROR("Async fastrpc is not supported on domain %d\n", domain);
goto bail;
}
} else {
nErr = AEE_EUNSUPPORTEDAPI;
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
}
bail:
return false;
}
bool is_status_notification_supported(int domain) {
int nErr = AEE_SUCCESS;
if (remote_handle_control) {
/*
* Query the DSP for STATUS_NOTIFICATION_SUPPORT information
* DSP User PD status notification Support
*/
struct remote_dsp_capability dsp_capability_status_notification_support;
dsp_capability_status_notification_support.domain = (uint32_t) domain;
dsp_capability_status_notification_support.attribute_ID = STATUS_NOTIFICATION_SUPPORT;
dsp_capability_status_notification_support.capability = (uint32_t) 0;
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_status_notification_support,
sizeof(struct remote_dsp_capability));
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
nErr = AEE_SUCCESS;
goto bail;
} else if (dsp_capability_status_notification_support.capability == 1) {
return true;
}
if (nErr != AEE_SUCCESS) {
GGML_LOG_ERROR("\nis_status_notification_supported failed with Error 0x%x\n", nErr);
goto bail;
}
} else {
nErr = AEE_EUNSUPPORTEDAPI;
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
}
bail:
return false;
}
int get_hmx_support_info(int domain, uint32_t * capability, uint32_t attr) {
int nErr = AEE_SUCCESS;
*capability = 0;
if (attr != HMX_SUPPORT_SPATIAL && attr != HMX_SUPPORT_DEPTH) {
nErr = AEE_EBADPARM;
GGML_LOG_ERROR("Unsupported attr. Only HMX_SUPPORT_SPATIAL and HMX_SUPPORT_DEPTH supported\n");
goto bail;
}
if (remote_handle_control) {
if (domain == CDSP_DOMAIN_ID) {
/*
* Query the DSP for HMX SUPPORT information
* HMX is supported on CDSP only
*/
struct remote_dsp_capability dsp_capability_hmx_dsp;
dsp_capability_hmx_dsp.domain = (uint32_t) domain;
dsp_capability_hmx_dsp.attribute_ID = attr;
dsp_capability_hmx_dsp.capability = (uint32_t) 0;
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_hmx_dsp,
sizeof(struct remote_dsp_capability));
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
nErr = AEE_SUCCESS;
goto bail;
} else if (nErr == AEE_SUCCESS) {
*capability = dsp_capability_hmx_dsp.capability;
} else {
GGML_LOG_ERROR("\nget_hmx_support_info failed with Error 0x%x\n", nErr);
goto bail;
}
} else {
nErr = AEE_EUNSUPPORTED;
GGML_LOG_ERROR("HMX support is not there for domain %d\n", domain);
goto bail;
}
} else {
nErr = AEE_EUNSUPPORTEDAPI;
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
}
bail:
return nErr;
}
int get_hex_arch_ver(int domain, int * arch) {
if (!remote_handle_control) {
GGML_LOG_ERROR("ggml-hex: remote_handle_control is not supported on this device\n");
return AEE_EUNSUPPORTEDAPI;
}
struct remote_dsp_capability arch_ver;
arch_ver.domain = (uint32_t) domain;
arch_ver.attribute_ID = ARCH_VER;
arch_ver.capability = (uint32_t) 0;
int err = remote_handle_control(DSPRPC_GET_DSP_INFO, &arch_ver, sizeof(arch_ver));
if ((err & 0xff) == (AEE_EUNSUPPORTEDAPI & 0xff)) {
GGML_LOG_ERROR("ggml-hex: FastRPC capability API is not supported on this device\n");
return AEE_EUNSUPPORTEDAPI;
}
if (err != AEE_SUCCESS) {
GGML_LOG_ERROR("ggml-hex: FastRPC capability query failed (err %d)\n", err);
return err;
}
switch (arch_ver.capability & 0xff) {
case 0x68:
*arch = 68;
return 0;
case 0x69:
*arch = 69;
return 0;
case 0x73:
*arch = 73;
return 0;
case 0x75:
*arch = 75;
return 0;
case 0x79:
*arch = 79;
return 0;
case 0x81:
*arch = 81;
return 0;
}
return -1;
}
int get_hvx_support_info(int domain, uint32_t * capability, uint32_t attr) {
int nErr = AEE_SUCCESS;
*capability = 0;
if (remote_handle_control) {
if (domain == CDSP_DOMAIN_ID) {
/*
* Query the DSP for HVX SUPPORT information
* HVX is supported on CDSP only
*/
struct remote_dsp_capability dsp_capability_hvx_dsp;
dsp_capability_hvx_dsp.domain = (uint32_t) domain;
dsp_capability_hvx_dsp.attribute_ID = attr;
dsp_capability_hvx_dsp.capability = (uint32_t) 0;
nErr = remote_handle_control(DSPRPC_GET_DSP_INFO, &dsp_capability_hvx_dsp,
sizeof(struct remote_dsp_capability));
if ((nErr & 0xFF) == (AEE_EUNSUPPORTEDAPI & 0xFF)) {
GGML_LOG_ERROR("\nFastRPC Capability API is not supported on this device\n");
GGML_LOG_ERROR("Running the usecase without checking the capability\n");
nErr = AEE_SUCCESS;
goto bail;
} else if (nErr == AEE_SUCCESS) {
*capability = dsp_capability_hvx_dsp.capability;
} else {
GGML_LOG_ERROR("\nget_hvx_support_info failed with Error 0x%x\n", nErr);
goto bail;
}
} else {
nErr = AEE_EUNSUPPORTED;
GGML_LOG_ERROR("HVX support is not available on domain %d\n", domain);
goto bail;
}
} else {
nErr = AEE_EUNSUPPORTEDAPI;
GGML_LOG_ERROR("remote_dsp_capability interface is not supported on this device\n");
}
bail:
return nErr;
}

View File

@ -1,221 +0,0 @@
#ifndef HTP_UTILS_H
#define HTP_UTILS_H
#ifdef __cplusplus
extern "C" {
#endif
#include <AEEStdErr.h>
#include <inttypes.h>
#include <remote.h>
#include <rpcmem.h>
#include <stdbool.h>
/* Offset to differentiate HLOS and Hexagon error codes.
Stores the value of AEE_EOFFSET for Hexagon. */
#ifndef DSP_OFFSET
# define DSP_OFFSET 0x80000400
#endif
/* Errno for connection reset by peer. */
#ifndef ECONNRESET
# ifdef __hexagon__
# define ECONNRESET 104
# endif
#endif
/* Abstraction of different OS specific sleep APIs.
SLEEP accepts input in seconds. */
#ifndef SLEEP
# ifdef __hexagon__
# define SLEEP(x) \
{ /* Do nothing for simulator. */ \
}
# else
# ifdef _WINDOWS
# define SLEEP(x) Sleep(1000 * x) /* Sleep accepts input in milliseconds. */
# else
# define SLEEP(x) sleep(x) /* sleep accepts input in seconds. */
# endif
# endif
#endif
/* Include windows specific header files. */
#ifdef _WINDOWS
# include <sysinfoapi.h>
# include <windows.h>
# define _CRT_SECURE_NO_WARNINGS 1
# define _WINSOCK_DEPRECATED_NO_WARNINGS 1
/* Including this file for custom implementation of getopt function. */
# include "getopt_custom.h"
#endif
/* Includes and defines for all HLOS except windows */
#if !defined(__hexagon__) && !defined(_WINDOWS)
# include "unistd.h"
# include <sys/time.h>
#endif
/* Includes and defines for Hexagon and all HLOS except Windows. */
#if !defined(_WINDOWS)
/* Weak reference to remote symbol for compilation. */
# pragma weak remote_session_control
# pragma weak remote_handle_control
# pragma weak remote_handle64_control
# pragma weak fastrpc_mmap
# pragma weak fastrpc_munmap
# pragma weak rpcmem_alloc2
#endif
#if !defined(_WINDOWS)
# pragma weak remote_system_request
#endif
/**
* Wrapper for FastRPC Capability API: query DSP support.
*
* @param[out] domain pointer to supported domain.
* @return 0 if query is successful.
* non-zero if error, return value points to the error.
*/
int get_dsp_support(int * domain);
/**
* Wrapper for FastRPC Capability API: query VTCM information.
*
* @param[in] domain value of domain in the queried.
* @param[out] capability capability value of the attribute queried.
* @param[in] attr value of the attribute to the queried.
* @return 0 if query is successful.
* non-zero if error, return value points to the error.
*/
int get_vtcm_info(int domain, uint32_t * capability, uint32_t attr);
/**
* Wrapper for FastRPC Capability API: query unsigned pd support on CDSP domain.
*
* @return true if unsigned pd is supported.
* false if unsigned pd is not supported, capability query failed.
*/
bool get_unsignedpd_support(void);
/**
* Wrapper for FastRPC Capability API: query unsigned pd support.
*
* @param[in] domain value of domain in the queried.
* @return true if unsigned pd is supported.
* false if unsigned pd is not supported, capability query failed.
*/
bool is_unsignedpd_supported(int domain_id);
/**
* is_valid_domain_id API: query a domain id is valid.
*
* @param[in] domain value of domain in the queried.
* @param[in] compute_only value of domain is only compared with CDSP domains supported by the target when enabled.
* @return true if value of domain is valid.
* false if value of domain is not valid.
*/
bool is_valid_domain_id(int domain_id, int compute_only);
/**
* get_domain API: get domain struct from domain value.
*
* @param[in] domain value of a domain
* @return Returns domain struct of the domain if it is supported or else
* returns NULL.
*
*/
domain * get_domain(int domain_id);
/**
* get_domains_info API: get information for all the domains available on the device
*
* @param[in] domain_type pointer to domain type
* @param[in] num_domains pointer to number of domains
* @param[in] domains_info pointer to save discovered domains information.
* @return 0 if query is successful.
* non-zero if error, return value points to the error.
*
* It is user's responsibility to free the memory used to store the domains info whose address is present in domains_info before closing the application.
*
*/
int get_domains_info(char * domain_type, int * num_domains, fastrpc_domain ** domains_info);
/**
* get_effective_domain_id API: get effective domain id for given session id
*
* @param[in] domain_name pointer to domain name
* @param[in] session_id
* @param[in] effec_domain_id pointer to save obtained effective domain id.
* @return 0 if query is successful.
* non-zero if error, return value points to the error.
*
*/
int get_effective_domain_id(char * domain_name, int session_id, int * effec_domain_id);
/**
* is_async_fastrpc_supported API: query a domain id has async fastrpc supported or not
*
* @param[in] domain_id value of a domain
* @return Returns true or false stating support of Async FastRPC
*
*/
bool is_async_fastrpc_supported(int domain_id);
/**
* is_status_notification_supported API: query the DSP for STATUS_NOTIFICATION_SUPPORT information
*
* @param[in] domain_id value of a domain
* @return Returns true or false stating status notification support information
*
*/
bool is_status_notification_supported(int domain_id);
/**
* get_hmx_support_info API: query the DSP for HMX SUPPORT information
*
* @param[in] domain_id value of a domain
* @param[out] capability capability value of the attribute queried.
* @param[in] attr value of the attribute to the queried.
* @return 0 if query is successful.
* non-zero if error, return value points to the error.
*
*/
int get_hmx_support_info(int domain, uint32_t * capability, uint32_t attr);
/**
* get_hex_arch_ver API: query the Hexagon processor architecture version information
*
* @param[in] domain_id value of a domain
* @param[out] Arch version (73, 75, ...)
* @return 0 if query is successful.
* non-zero if error, return value points to the error.
*
*/
int get_hex_arch_ver(int domain, int * arch);
/**
* get_hvx_support_info API: query the DSP for HVX SUPPORT information
*
* @param[in] domain_id value of a domain
* @param[out] capability capability value of the attribute queried.
* @param[in] attr value of the attribute to the queried.
* @return 0 if query is successful.
* non-zero if error, return value points to the error.
*
*/
int get_hvx_support_info(int domain, uint32_t * capability, uint32_t attr);
#ifdef __cplusplus
}
#endif
#endif //DSP_CAPABILITIES_UTILS_H

View File

@ -17,6 +17,12 @@
#include "htp-msg.h" #include "htp-msg.h"
#include "htp-ops.h" #include "htp-ops.h"
static inline HVX_Vector hvx_load_f32_to_f16(const HVX_Vector * restrict src, const HVX_Vector zero) {
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(src[0], zero); // 32 elements
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(src[1], zero); // 32 elements
return Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
}
// Dot product of FP32 and FP16 vectors, accumulating to float // Dot product of FP32 and FP16 vectors, accumulating to float
static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) { static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict y, const void * restrict x, unsigned int n, float s) {
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32 const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
@ -33,23 +39,19 @@ static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict
#pragma unroll(4) #pragma unroll(4)
for (i = 0; i < nvec; i++) { for (i = 0; i < nvec; i++) {
// Load y (fp32) and convert into fp16 // Load y (fp32) and convert into fp16
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
// Load x (fp16) // Load x (fp16)
HVX_Vector x_hf = vx[i]; HVX_Vector x_hf = vx[i];
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
} }
if (nloe) { if (nloe) {
// Load y (fp32) and convert into fp16 // Load y (fp32) and convert into fp16
HVX_Vector y0_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+0], zero); // 32 elements HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
HVX_Vector y1_qf = Q6_Vqf32_vsub_VsfVsf(vy[i*2+1], zero); // 32 elements
HVX_Vector y_hf = Q6_Vh_vdeal_Vh(Q6_Vhf_equals_Wqf32(Q6_W_vcombine_VV(y1_qf, y0_qf)));
// Load x (fp16) // Load x (fp16)
HVX_Vector x_hf = vx[i]; HVX_Vector x_hf = vx[i];
@ -62,13 +64,72 @@ static inline void hvx_dot_f32_f16_aa(float * restrict r, const void * restrict
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
} }
rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_f32(s)); rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum)); hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum));
}
hvx_vec_store_u(r, 4, rsum); // Dot product of FP32 and FP16 vectors, accumulating to float
static inline void hvx_dot_f32_f16_aa_rx2(float * restrict r,
const void * restrict y,
const void * restrict x0,
const void * restrict x1,
unsigned int n,
float s) {
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp32
const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16
const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
const HVX_Vector zero = Q6_V_vsplat_R(0);
HVX_Vector rsum0 = Q6_V_vsplat_R(0);
HVX_Vector rsum1 = Q6_V_vsplat_R(0);
uint32_t i = 0;
#pragma unroll(2)
for (i = 0; i < nvec; i++) {
// Load y (fp32) and convert into fp16
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
// Load x (fp16)
HVX_Vector x0_hf = vx0[i];
HVX_Vector x1_hf = vx1[i];
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
}
if (nloe) {
// Load y (fp32) and convert into fp16
HVX_Vector y_hf = hvx_load_f32_to_f16(&vy[i*2], zero);
// Load x (fp16)
HVX_Vector x0_hf = vx0[i];
HVX_Vector x1_hf = vx1[i];
// Zero-out unused elements
// Note that we need to clear both x and y because they may contain NANs
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
x0_hf = Q6_V_vand_QV(bmask, x0_hf);
x1_hf = Q6_V_vand_QV(bmask, x1_hf);
y_hf = Q6_V_vand_QV(bmask, y_hf);
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
}
HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1));
hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum));
} }
// Dot product of two F16 vectors, accumulating to float // Dot product of two F16 vectors, accumulating to float
@ -91,7 +152,7 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
} }
if (nloe) { if (nloe) {
@ -103,12 +164,62 @@ static inline void hvx_dot_f16_f16_aa(float * restrict r, const void * restrict
HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf); HVX_VectorPair xy_qf = Q6_Wqf32_vmpy_VhfVhf(x_hf, y_hf);
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); rsum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)), rsum));
} }
rsum = Q6_Vqf32_vmpy_VsfVsf(Q6_Vsf_equals_Vqf32(rsum), hvx_vec_splat_f32(s)); rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32(rsum));
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum)); hvx_vec_store_u(r, 4, Q6_Vsf_equals_Vqf32(rsum));
hvx_vec_store_u(r, 4, rsum); }
static inline void hvx_dot_f16_f16_aa_rx2(float * restrict r,
const void * restrict y,
const void * restrict x0,
const void * restrict x1,
unsigned int n,
float s) {
const HVX_Vector * restrict vx0 = (const HVX_Vector * restrict) x0; // fp16
const HVX_Vector * restrict vx1 = (const HVX_Vector * restrict) x1; // fp16
const HVX_Vector * restrict vy = (const HVX_Vector * restrict) y; // fp16
uint32_t nvec = n / VLEN_FP16; // num full fp16 hvx vectors
uint32_t nloe = n % VLEN_FP16; // leftover elements
const HVX_Vector zero = Q6_V_vsplat_R(0);
HVX_Vector rsum0 = Q6_V_vsplat_R(0);
HVX_Vector rsum1 = Q6_V_vsplat_R(0);
uint32_t i = 0;
#pragma unroll(4)
for (i = 0; i < nvec; i++) {
HVX_Vector y_hf = vy[i];
HVX_Vector x0_hf = vx0[i];
HVX_Vector x1_hf = vx1[i];
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
}
if (nloe) {
HVX_Vector y_hf = vy[i];
// Load x (fp16) and zero-out unused elements
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe * 2);
HVX_Vector x0_hf = Q6_V_vand_QV(bmask, vx0[i]);
HVX_Vector x1_hf = Q6_V_vand_QV(bmask, vx1[i]);
HVX_VectorPair xy0_qf = Q6_Wqf32_vmpy_VhfVhf(x0_hf, y_hf);
HVX_VectorPair xy1_qf = Q6_Wqf32_vmpy_VhfVhf(x1_hf, y_hf);
rsum0 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy0_qf), Q6_V_hi_W(xy0_qf)), rsum0));
rsum1 = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)), rsum1));
}
HVX_Vector rsum = Q6_Vqf32_vmpy_VsfVsf(hvx_vec_splat_f32(s), hvx_vec_reduce_sum_f32x2(rsum0, rsum1));
hvx_vec_store_u(r, 8, Q6_Vsf_equals_Vqf32(rsum));
} }
// MAD: y (F32) += x (F16) * s (float) // MAD: y (F32) += x (F16) * s (float)
@ -317,20 +428,22 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
// Inner loop processing the block from VTCM // Inner loop processing the block from VTCM
uint32_t ic = 0; uint32_t ic = 0;
const bool is_q_fp32 = (q->type == HTP_TYPE_F32);
// Process in blocks of 32 (VLEN_FP32) // Process in blocks of 32 (VLEN_FP32)
static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 == 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage"); static_assert(FLASH_ATTN_BLOCK_SIZE / VLEN_FP32 <= 4, "FLASH_ATTN_BLOCK_SIZE changed, fix HVX_Vector_x4 usage");
HVX_Vector_x4 scores_x4; HVX_Vector_x4 scores_x4;
HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY); HVX_Vector v_max = hvx_vec_splat_f32(-INFINITY);
for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) { for (uint32_t iv = 0; ic + VLEN_FP32 <= current_block_size; ic += VLEN_FP32, ++iv) {
// 1. Compute scores // 1. Compute scores
float __attribute__((aligned(VLEN))) scores_arr[FLASH_ATTN_BLOCK_SIZE]; float __attribute__((aligned(VLEN))) scores_arr[VLEN_FP32];
for (int j = 0; j < VLEN_FP32; ++j) { for (int j = 0; j < VLEN_FP32; j += 2) {
const uint32_t cur_ic = ic + j; const uint32_t cur_ic = ic + j;
const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded; const uint8_t * k_ptr = k_base + cur_ic * size_k_row_padded;
if (q->type == HTP_TYPE_F32) { if (is_q_fp32) {
hvx_dot_f32_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale); hvx_dot_f32_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
} else { } else {
hvx_dot_f16_f16_aa(&scores_arr[j], q_ptr_vtcm, k_ptr, DK, scale); hvx_dot_f16_f16_aa_rx2(&scores_arr[j], q_ptr_vtcm, k_ptr, k_ptr + size_k_row_padded, DK, scale);
} }
} }
@ -403,7 +516,7 @@ static void flash_attn_ext_f16_thread(struct htp_ops_context * octx, int ith, in
float s_val; float s_val;
const uint8_t * k_ptr = k_base + ic * size_k_row_padded; const uint8_t * k_ptr = k_base + ic * size_k_row_padded;
if (q->type == HTP_TYPE_F32) { if (is_q_fp32) {
hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale); hvx_dot_f32_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);
} else { } else {
hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale); hvx_dot_f16_f16_aa(&s_val, q_ptr_vtcm, k_ptr, DK, scale);

View File

@ -28,19 +28,16 @@ static void hvx_vec_dump_f16(char * pref, HVX_Vector v) {
} }
static void hvx_vec_dump_f32_n(char * pref, HVX_Vector v, uint32_t n) { static void hvx_vec_dump_f32_n(char * pref, HVX_Vector v, uint32_t n) {
union { HVX_VectorAlias u = { .v = v };
HVX_Vector v;
float d[32];
} u = { .v = v };
const uint32_t n0 = n / 16; const uint32_t n0 = n / 16;
const uint32_t n1 = n % 16; const uint32_t n1 = n % 16;
int i = 0; int i = 0;
for (; i < n0; i++) { for (; i < n0; i++) {
hex_dump_f32_line(pref, u.d + (16 * i), 16); hex_dump_f32_line(pref, u.fp32 + (16 * i), 16);
} }
if (n1) { if (n1) {
hex_dump_f32_line(pref, u.d + (16 * i), n1); hex_dump_f32_line(pref, u.fp32 + (16 * i), n1);
} }
} }

View File

@ -44,6 +44,45 @@ static inline HVX_Vector hvx_vec_reduce_sum_qf32(HVX_Vector in) {
return hvx_vec_reduce_sum_n_qf32(in, 32); return hvx_vec_reduce_sum_n_qf32(in, 32);
} }
#if __HVX_ARCH__ > 75
static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) {
HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4);
HVX_Vector sum_sf = Q6_Vsf_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump));
sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 2));
sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 4));
sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 8));
sum_sf = Q6_Vsf_vadd_VsfVsf(sum_sf, Q6_V_vror_VR(sum_sf, VLEN / 16));
return sum_sf;
}
static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) {
unsigned int total = n * 4; // total vec nbytes
unsigned int width = 4; // fp32 nbytes
HVX_Vector sum = in, sum_t;
while (width < total) {
sum_t = Q6_V_vror_VR(sum, width); // rotate right
sum = Q6_Vsf_vadd_VsfVsf(sum, sum_t); // elementwise sum
width = width << 1;
}
return sum;
}
#else
static inline HVX_Vector hvx_vec_reduce_sum_f32x2(HVX_Vector in0, HVX_Vector in1) {
HVX_VectorPair sump = Q6_W_vshuff_VVR(in1, in0, 4);
HVX_Vector sum_qf = Q6_Vqf32_vadd_VsfVsf(Q6_V_lo_W(sump), Q6_V_hi_W(sump));
sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 2));
sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 4));
sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 8));
sum_qf = Q6_Vqf32_vadd_Vqf32Vsf(sum_qf, Q6_V_vror_VR(Q6_Vsf_equals_Vqf32(sum_qf), VLEN / 16));
return Q6_Vsf_equals_Vqf32(sum_qf);
}
static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) { static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n) {
unsigned int total = n * 4; // total vec nbytes unsigned int total = n * 4; // total vec nbytes
unsigned int width = 4; // fp32 nbytes unsigned int width = 4; // fp32 nbytes
@ -57,6 +96,8 @@ static inline HVX_Vector hvx_vec_reduce_sum_n_f32(HVX_Vector in, unsigned int n)
return sum; return sum;
} }
#endif
static inline HVX_Vector hvx_vec_reduce_sum_f32(HVX_Vector in) { static inline HVX_Vector hvx_vec_reduce_sum_f32(HVX_Vector in) {
return hvx_vec_reduce_sum_n_f32(in, 32); return hvx_vec_reduce_sum_n_f32(in, 32);
} }

View File

@ -11,6 +11,7 @@
#include "hex-dma.h" #include "hex-dma.h"
#include "hvx-utils.h" #include "hvx-utils.h"
#include "hvx-dump.h"
#define GGML_COMMON_DECL_C #define GGML_COMMON_DECL_C
#include "ggml-common.h" #include "ggml-common.h"
@ -320,7 +321,7 @@ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void *
const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
// Row sum (qf32) // Row sum (sf)
HVX_Vector r0_sum = Q6_V_vsplat_R(0); HVX_Vector r0_sum = Q6_V_vsplat_R(0);
// Multiply and accumulate into int32. // Multiply and accumulate into int32.
@ -344,7 +345,7 @@ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void *
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
} }
// Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
@ -362,14 +363,14 @@ static void vec_dot_q4x4x2_q8x4x2(const int n, float * restrict s, const void *
// Zero out unused scales // Zero out unused scales
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
r0_dd = Q6_V_vand_QV(bmask, r0_dd); r0_dd = Q6_V_vand_QV(bmask, r0_dd);
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
} }
// Reduce and convert into fp32 r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum));
hvx_vec_store_u(&s[0], 4, r0_sum); hvx_vec_store_u(&s[0], 4, r0_sum);
} }
@ -402,7 +403,7 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
// Row sum (qf32) // Row sum (sf)
HVX_Vector r0_sum = Q6_V_vsplat_R(0); HVX_Vector r0_sum = Q6_V_vsplat_R(0);
HVX_Vector r1_sum = Q6_V_vsplat_R(0); HVX_Vector r1_sum = Q6_V_vsplat_R(0);
@ -432,8 +433,8 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
} }
// Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
@ -456,20 +457,18 @@ static void vec_dot_q4x4x2_q8x4x2_rx2(const int n,
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
r0_dd = Q6_V_vand_QV(bmask, r0_dd); r0_dd = Q6_V_vand_QV(bmask, r0_dd);
r1_dd = Q6_V_vand_QV(bmask, r1_dd); r1_dd = Q6_V_vand_QV(bmask, r1_dd);
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
r1_ia = Q6_V_vand_QV(bmask, r1_ia);
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
} }
// Convert into fp32 and reduce HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum)); hvx_vec_store_u(&s[0], 8, rsum);
r1_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r1_sum));
HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
} }
static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
@ -493,7 +492,7 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void *
const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
// Row sum (qf32) // Row sum (sf)
HVX_Vector r0_sum = Q6_V_vsplat_R(0); HVX_Vector r0_sum = Q6_V_vsplat_R(0);
// Multiply and accumulate into int32. // Multiply and accumulate into int32.
@ -517,7 +516,7 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void *
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
} }
// Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
@ -535,14 +534,14 @@ static void vec_dot_q8x4x2_q8x4x2(const int n, float * restrict s, const void *
// Zero out unused scales // Zero out unused scales
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
r0_dd = Q6_V_vand_QV(bmask, r0_dd); r0_dd = Q6_V_vand_QV(bmask, r0_dd);
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
} }
// Reduce and convert into fp32 r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum));
hvx_vec_store_u(&s[0], 4, r0_sum); hvx_vec_store_u(&s[0], 4, r0_sum);
} }
@ -605,8 +604,8 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
} }
// Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks // Process leftovers, we still load full 4x4x2 block but zero out unused scales/blocks
@ -629,20 +628,18 @@ static void vec_dot_q8x4x2_q8x4x2_rx2(const int n,
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
r0_dd = Q6_V_vand_QV(bmask, r0_dd); r0_dd = Q6_V_vand_QV(bmask, r0_dd);
r1_dd = Q6_V_vand_QV(bmask, r1_dd); r1_dd = Q6_V_vand_QV(bmask, r1_dd);
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
r1_ia = Q6_V_vand_QV(bmask, r1_ia);
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
} }
// Convert into fp32 and reduce HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum)); hvx_vec_store_u(&s[0], 8, rsum);
r1_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r1_sum));
HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
} }
static void vec_dot_mxfp4x4x2_q8x4x2(const int n, static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
@ -669,7 +666,7 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
// Row sum (qf32) // Row sum (sf)
HVX_Vector r0_sum = Q6_V_vsplat_R(0); HVX_Vector r0_sum = Q6_V_vsplat_R(0);
// Multiply and accumulate into int32. // Multiply and accumulate into int32.
@ -708,7 +705,7 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
} }
// Process leftovers // Process leftovers
@ -741,14 +738,14 @@ static void vec_dot_mxfp4x4x2_q8x4x2(const int n,
// Zero-out unused scales // Zero-out unused scales
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
r0_dd = Q6_V_vand_QV(bmask, r0_dd); r0_dd = Q6_V_vand_QV(bmask, r0_dd);
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
} }
// Reduce and convert into fp32 r0_sum = hvx_vec_reduce_sum_f32(r0_sum);
r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum));
hvx_vec_store_u(&s[0], 4, r0_sum); hvx_vec_store_u(&s[0], 4, r0_sum);
} }
@ -781,13 +778,13 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first const uint8_t * restrict y_q = ((const uint8_t *) vy + 0); // quants first
const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales const uint8_t * restrict y_d = ((const uint8_t *) vy + y_qrow_size); // then scales
// Row sum (qf32) // Row sum (sf)
HVX_Vector r0_sum = Q6_V_vsplat_R(0); HVX_Vector r0_sum = Q6_V_vsplat_R(0);
HVX_Vector r1_sum = Q6_V_vsplat_R(0); HVX_Vector r1_sum = Q6_V_vsplat_R(0);
// Multiply and accumulate into int32. // Multiply and accumulate into int32.
// Compute combined scale (fp32). // Compute combined scale (fp32).
// Apply scale to acc and accumulate into the row sum (qf32). // Apply scale to acc and accumulate into the row sum (f32).
const uint32_t nb = n / qk; // num full blocks const uint32_t nb = n / qk; // num full blocks
int32_t nloe = n % qk; // num leftover elemements (must be signed) int32_t nloe = n % qk; // num leftover elemements (must be signed)
@ -829,8 +826,8 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
} }
// Process leftovers // Process leftovers
@ -867,24 +864,22 @@ static void vec_dot_mxfp4x4x2_q8x4x2_rx2(const int n,
HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d)); HVX_Vector r0_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r0_d, vy_d));
HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d)); HVX_Vector r1_dd = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vmpy_VsfVsf(r1_d, vy_d));
// Zero-out unused scales // Zero-out unused values
HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8); HVX_VectorPred bmask = Q6_Q_vsetq_R(nloe / 8);
r0_dd = Q6_V_vand_QV(bmask, r0_dd); r0_dd = Q6_V_vand_QV(bmask, r0_dd);
r1_dd = Q6_V_vand_QV(bmask, r1_dd); r1_dd = Q6_V_vand_QV(bmask, r1_dd);
r0_ia = Q6_V_vand_QV(bmask, r0_ia);
r1_ia = Q6_V_vand_QV(bmask, r1_ia);
HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd); HVX_Vector r0_fa = Q6_Vqf32_vmpy_VsfVsf(r0_ia, r0_dd);
HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd); HVX_Vector r1_fa = Q6_Vqf32_vmpy_VsfVsf(r1_ia, r1_dd);
r0_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r0_sum, r0_fa); r0_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r0_fa, r0_sum));
r1_sum = Q6_Vqf32_vadd_Vqf32Vqf32(r1_sum, r1_fa); r1_sum = Q6_Vsf_equals_Vqf32(Q6_Vqf32_vadd_Vqf32Vsf(r1_fa, r1_sum));
} }
// Convert into fp32 and reduce HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(r0_sum, r1_sum);
r0_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r0_sum)); hvx_vec_store_u(&s[0], 8, rsum);
r1_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(r1_sum));
HVX_VectorPair p0 = Q6_W_vshuff_VVR(r1_sum, r0_sum, 4);
hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
} }
static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
@ -913,7 +908,7 @@ static void vec_dot_f16_f16_aa(const int n, float * restrict s, const void * res
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
} }
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum)); rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
hvx_vec_store_u(&s[0], 4, rsum); hvx_vec_store_u(&s[0], 4, rsum);
} }
@ -957,11 +952,8 @@ static void vec_dot_f16_f16_aa_rx2(const int n,
rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf))); rsum1 = Q6_Vqf32_vadd_Vqf32Vqf32(rsum1, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy1_qf), Q6_V_hi_W(xy1_qf)));
} }
rsum0 = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum0)); HVX_Vector rsum = hvx_vec_reduce_sum_f32x2(Q6_Vsf_equals_Vqf32(rsum0), Q6_Vsf_equals_Vqf32(rsum1));
rsum1 = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum1)); hvx_vec_store_u(&s[0], 8, rsum);
HVX_VectorPair p0 = Q6_W_vshuff_VVR(rsum1, rsum0, 4);
hvx_vec_store_u(&s[0], 8, Q6_V_lo_W(p0));
} }
static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) { static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * restrict vx, const void * restrict vy) {
@ -990,7 +982,7 @@ static void vec_dot_f16_f16_uu(const int n, float * restrict s, const void * res
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
} }
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum)); rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
hvx_vec_store_u(&s[0], 4, rsum); hvx_vec_store_u(&s[0], 4, rsum);
} }
@ -1042,7 +1034,8 @@ static void vec_dot_f16_f32_uu(const int n, float * restrict s, const void * res
rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf))); rsum = Q6_Vqf32_vadd_Vqf32Vqf32(rsum, Q6_Vqf32_vadd_Vqf32Vqf32(Q6_V_lo_W(xy_qf), Q6_V_hi_W(xy_qf)));
} }
rsum = Q6_Vsf_equals_Vqf32(hvx_vec_reduce_sum_qf32(rsum)); // Convert into fp32 and reduce
rsum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(rsum));
hvx_vec_store_u(&s[0], 4, rsum); hvx_vec_store_u(&s[0], 4, rsum);
} }

View File

@ -154,8 +154,8 @@ static void hvx_fast_softmax_f32(const uint8_t * restrict src,
v_pad[i] = v3; v_pad[i] = v3;
} }
v = hvx_vec_reduce_sum_qf32(sum_vec); v = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_vec));
sum_vec = hvx_vec_repl4(Q6_Vsf_equals_Vqf32(v)); sum_vec = hvx_vec_repl4(v);
HVX_VectorPred pos_sum = Q6_Q_vcmp_gt_VwVw(sum_vec, zero_v); HVX_VectorPred pos_sum = Q6_Q_vcmp_gt_VwVw(sum_vec, zero_v);
HVX_Vector v4 = hvx_vec_inverse_f32(sum_vec); HVX_Vector v4 = hvx_vec_inverse_f32(sum_vec);

View File

@ -57,8 +57,8 @@ static void hvx_fast_rms_norm_f32(const uint8_t * restrict src,
sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2); sum_v = Q6_Vqf32_vadd_Vqf32Vqf32(sum_v, v2);
} }
HVX_Vector reduced_sum = hvx_vec_reduce_sum_qf32(sum_v); HVX_Vector reduced_sum = hvx_vec_reduce_sum_f32(Q6_Vsf_equals_Vqf32(sum_v));
sum_v = hvx_vec_repl4(Q6_Vsf_equals_Vqf32(reduced_sum)); sum_v = hvx_vec_repl4(reduced_sum);
HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems); HVX_Vector t_v = hvx_vec_splat_f32((float) num_elems);
HVX_Vector denom_v = hvx_vec_inverse_f32(t_v); HVX_Vector denom_v = hvx_vec_inverse_f32(t_v);

Some files were not shown because too many files have changed in this diff Show More