Merge remote-tracking branch 'upstream/master'

This commit is contained in:
Reese Levine 2025-09-11 17:13:02 -07:00
commit a5da437098
147 changed files with 7804 additions and 3694 deletions

View File

@ -1063,7 +1063,17 @@ jobs:
run: | run: |
git clone https://github.com/rocm/rocwmma --branch rocm-6.2.4 --depth 1 git clone https://github.com/rocm/rocwmma --branch rocm-6.2.4 --depth 1
- name: Install - name: Cache ROCm Installation
id: cache-rocm
uses: actions/cache@v4
with:
path: C:\Program Files\AMD\ROCm
key: rocm-6.1-${{ runner.os }}-v1
restore-keys: |
rocm-6.1-${{ runner.os }}-
- name: Install ROCm
if: steps.cache-rocm.outputs.cache-hit != 'true'
id: depends id: depends
run: | run: |
$ErrorActionPreference = "Stop" $ErrorActionPreference = "Stop"
@ -1071,13 +1081,28 @@ jobs:
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe" Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
write-host "Installing AMD HIP SDK" write-host "Installing AMD HIP SDK"
$proc = Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -PassThru $proc = Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -PassThru
$proc.WaitForExit(600000) $completed = $proc.WaitForExit(600000)
if (-not $completed) {
Write-Error "ROCm installation timed out after 10 minutes. Killing the process"
$proc.Kill()
exit 1
}
if ($proc.ExitCode -ne 0) {
Write-Error "ROCm installation failed with exit code $($proc.ExitCode)"
exit 1
}
write-host "Completed AMD HIP SDK installation" write-host "Completed AMD HIP SDK installation"
- name: Verify ROCm - name: Verify ROCm
id: verify id: verify
run: | run: |
& 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version # Find and test ROCm installation
$clangPath = Get-ChildItem 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | Select-Object -First 1
if (-not $clangPath) {
Write-Error "ROCm installation not found"
exit 1
}
& $clangPath.FullName --version
- name: Install ccache - name: Install ccache
uses: ggml-org/ccache-action@v1.2.16 uses: ggml-org/ccache-action@v1.2.16

View File

@ -17,7 +17,7 @@ jobs:
steps: steps:
- uses: actions/stale@v5 - uses: actions/stale@v5
with: with:
exempt-issue-labels: "refactoring,help wanted,good first issue,research,bug,roadmap" exempt-issue-labels: "refactoring,help wanted,good first issue,research 🔬,bug,roadmap"
days-before-issue-stale: 30 days-before-issue-stale: 30
days-before-issue-close: 14 days-before-issue-close: 14
stale-issue-label: "stale" stale-issue-label: "stale"

View File

@ -544,13 +544,23 @@ jobs:
run: | run: |
git clone https://github.com/rocm/rocwmma --branch rocm-6.2.4 --depth 1 git clone https://github.com/rocm/rocwmma --branch rocm-6.2.4 --depth 1
- name: Cache ROCm Installation
id: cache-rocm
uses: actions/cache@v4
with:
path: C:\Program Files\AMD\ROCm
key: rocm-6.1-${{ runner.os }}-v1
restore-keys: |
rocm-6.1-${{ runner.os }}-
- name: ccache - name: ccache
uses: ggml-org/ccache-action@v1.2.16 uses: ggml-org/ccache-action@v1.2.16
with: with:
key: windows-latest-cmake-hip-${{ matrix.name }}-x64 key: windows-latest-cmake-hip-${{ matrix.name }}-x64
evict-old-files: 1d evict-old-files: 1d
- name: Install - name: Install ROCm
if: steps.cache-rocm.outputs.cache-hit != 'true'
id: depends id: depends
run: | run: |
$ErrorActionPreference = "Stop" $ErrorActionPreference = "Stop"
@ -558,13 +568,28 @@ jobs:
Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe" Invoke-WebRequest -Uri "https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q3-WinSvr2022-For-HIP.exe" -OutFile "${env:RUNNER_TEMP}\rocm-install.exe"
write-host "Installing AMD HIP SDK" write-host "Installing AMD HIP SDK"
$proc = Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -PassThru $proc = Start-Process "${env:RUNNER_TEMP}\rocm-install.exe" -ArgumentList '-install' -NoNewWindow -PassThru
$proc.WaitForExit(600000) $completed = $proc.WaitForExit(600000)
if (-not $completed) {
Write-Error "ROCm installation timed out after 10 minutes. Killing the process"
$proc.Kill()
exit 1
}
if ($proc.ExitCode -ne 0) {
Write-Error "ROCm installation failed with exit code $($proc.ExitCode)"
exit 1
}
write-host "Completed AMD HIP SDK installation" write-host "Completed AMD HIP SDK installation"
- name: Verify ROCm - name: Verify ROCm
id: verify id: verify
run: | run: |
& 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' --version # Find and test ROCm installation
$clangPath = Get-ChildItem 'C:\Program Files\AMD\ROCm\*\bin\clang.exe' | Select-Object -First 1
if (-not $clangPath) {
Write-Error "ROCm installation not found"
exit 1
}
& $clangPath.FullName --version
- name: Build - name: Build
id: cmake_build id: cmake_build

View File

@ -16,6 +16,9 @@
- Use the following format for the squashed commit title: `<module> : <commit title> (#<issue_number>)`. For example: `utils : fix typo in utils.py (#1234)` - Use the following format for the squashed commit title: `<module> : <commit title> (#<issue_number>)`. For example: `utils : fix typo in utils.py (#1234)`
- Optionally pick a `<module>` from here: https://github.com/ggml-org/llama.cpp/wiki/Modules - Optionally pick a `<module>` from here: https://github.com/ggml-org/llama.cpp/wiki/Modules
- Consider adding yourself to [CODEOWNERS](CODEOWNERS) - Consider adding yourself to [CODEOWNERS](CODEOWNERS)
- Let authors, who are also collaborators, merge their own PRs
- When merging a PR by a contributor, make sure you have a good understanding of the changes
- Be mindful of maintenance: most of the work going into a feature happens after the PR is merged. If the PR author is not committed to contribute long-term, someone else needs to take responsibility (you)
# Coding guidelines # Coding guidelines

View File

@ -1263,6 +1263,18 @@ static std::string list_builtin_chat_templates() {
return msg.str(); return msg.str();
} }
static bool is_truthy(const std::string & value) {
return value == "on" || value == "enabled" || value == "1";
}
static bool is_falsey(const std::string & value) {
return value == "off" || value == "disabled" || value == "0";
}
static bool is_autoy(const std::string & value) {
return value == "auto" || value == "-1";
}
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) { common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
// load dynamic backends // load dynamic backends
ggml_backend_load_all(); ggml_backend_load_all();
@ -1544,21 +1556,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.n_chunks = value; params.n_chunks = value;
} }
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_RETRIEVAL})); ).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_RETRIEVAL}));
add_opt(common_arg( add_opt(common_arg({ "-fa", "--flash-attn" }, "[on|off|auto]",
{"-fa", "--flash-attn"}, "FA", string_format("set Flash Attention use ('on', 'off', or 'auto', default: '%s')",
string_format("set Flash Attention use ('on', 'off', or 'auto', default: '%s')", llama_flash_attn_type_name(params.flash_attn_type)), llama_flash_attn_type_name(params.flash_attn_type)),
[](common_params & params, const std::string & value) { [](common_params & params, const std::string & value) {
if (value == "on" || value == "enabled" || value == "1") { if (is_truthy(value)) {
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED; params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED;
} else if (value == "off" || value == "disabled" || value == "0") { } else if (is_falsey(value)) {
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED; params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
} else if (value == "auto" || value == "-1") { } else if (is_autoy(value)) {
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
} else { } else {
throw std::runtime_error(string_format("error: unkown value for --flash-attn: '%s'\n", value.c_str())); throw std::runtime_error(
} string_format("error: unkown value for --flash-attn: '%s'\n", value.c_str()));
} }
).set_env("LLAMA_ARG_FLASH_ATTN")); }).set_env("LLAMA_ARG_FLASH_ATTN"));
add_opt(common_arg( add_opt(common_arg(
{"-p", "--prompt"}, "PROMPT", {"-p", "--prompt"}, "PROMPT",
"prompt to start generation with; for system message, use -sys", "prompt to start generation with; for system message, use -sys",
@ -3134,13 +3146,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
common_log_set_file(common_log_main(), value.c_str()); common_log_set_file(common_log_main(), value.c_str());
} }
)); ));
add_opt(common_arg( add_opt(common_arg({ "--log-colors" }, "[on|off|auto]",
{"--log-colors"}, "Set colored logging ('on', 'off', or 'auto', default: 'auto')\n"
"Enable colored logging", "'auto' enables colors when output is to a terminal",
[](common_params &) { [](common_params &, const std::string & value) {
common_log_set_colors(common_log_main(), true); if (is_truthy(value)) {
} common_log_set_colors(common_log_main(), LOG_COLORS_ENABLED);
).set_env("LLAMA_LOG_COLORS")); } else if (is_falsey(value)) {
common_log_set_colors(common_log_main(), LOG_COLORS_DISABLED);
} else if (is_autoy(value)) {
common_log_set_colors(common_log_main(), LOG_COLORS_AUTO);
} else {
throw std::invalid_argument(
string_format("error: unkown value for --log-colors: '%s'\n", value.c_str()));
}
}).set_env("LLAMA_LOG_COLORS"));
add_opt(common_arg( add_opt(common_arg(
{"-v", "--verbose", "--log-verbose"}, {"-v", "--verbose", "--log-verbose"},
"Set verbosity level to infinity (i.e. log all messages, useful for debugging)", "Set verbosity level to infinity (i.e. log all messages, useful for debugging)",

View File

@ -163,6 +163,19 @@ common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::strin
throw std::runtime_error("Invalid tool_choice: " + tool_choice); throw std::runtime_error("Invalid tool_choice: " + tool_choice);
} }
bool common_chat_templates_support_enable_thinking(const common_chat_templates * chat_templates) {
common_chat_templates_inputs dummy_inputs;
common_chat_msg msg;
msg.role = "user";
msg.content = "test";
dummy_inputs.messages = {msg};
dummy_inputs.enable_thinking = false;
const auto rendered_no_thinking = common_chat_templates_apply(chat_templates, dummy_inputs);
dummy_inputs.enable_thinking = true;
const auto rendered_with_thinking = common_chat_templates_apply(chat_templates, dummy_inputs);
return rendered_no_thinking.prompt != rendered_with_thinking.prompt;
}
template <> template <>
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messages) { std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messages) {
std::vector<common_chat_msg> msgs; std::vector<common_chat_msg> msgs;
@ -618,11 +631,13 @@ const char * common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2"; case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2";
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2"; case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2";
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1"; case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1: return "DeepSeek V3.1";
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro"; case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B"; case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
case COMMON_CHAT_FORMAT_GRANITE: return "Granite"; case COMMON_CHAT_FORMAT_GRANITE: return "Granite";
case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS"; case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS"; case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS";
case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2";
default: default:
throw std::runtime_error("Unknown chat format"); throw std::runtime_error("Unknown chat format");
} }
@ -684,11 +699,13 @@ static void parse_json_tool_calls(
size_t from = std::string::npos; size_t from = std::string::npos;
auto first = true; auto first = true;
while (true) { while (true) {
auto start_pos = builder.pos();
auto res = function_regex_start_only && first auto res = function_regex_start_only && first
? builder.try_consume_regex(*function_regex_start_only) ? builder.try_consume_regex(*function_regex_start_only)
: function_regex : function_regex
? builder.try_find_regex(*function_regex, from) ? builder.try_find_regex(*function_regex, from)
: std::nullopt; : std::nullopt;
if (res) { if (res) {
std::string name; std::string name;
if (get_function_name) { if (get_function_name) {
@ -723,6 +740,8 @@ static void parse_json_tool_calls(
return; return;
} }
throw common_chat_msg_partial_exception("incomplete tool call"); throw common_chat_msg_partial_exception("incomplete tool call");
} else {
builder.move_to(start_pos);
} }
break; break;
} }
@ -1184,6 +1203,67 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te
}); });
return data; return data;
} }
static common_chat_params common_chat_params_init_nemotron_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
// Generate the prompt using the apply() function with the template
data.prompt = apply(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_NEMOTRON_V2;
// Handle thinking tags appropriately based on inputs.enable_thinking
if (string_ends_with(data.prompt, "<think>\n")) {
if (!inputs.enable_thinking) {
data.prompt += "</think>";
} else {
data.thinking_forced_open = true;
}
}
// When tools are present, build grammar for the <TOOLCALL> format, similar to CommandR, but without tool call ID
if (!inputs.tools.is_null() && inputs.tools.is_array() && !inputs.tools.empty()) {
data.grammar_lazy = true;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
schemas.push_back({
{ "type", "object" },
{ "properties",
{
{ "name",
{
{ "type", "string" },
{ "const", function.at("name") },
} },
{ "arguments", function.at("parameters") },
} },
{ "required", json::array({ "name", "arguments" }) },
});
});
auto schema = json{
{ "type", "array" },
{ "items", schemas.size() == 1 ? schemas[0] : json{ { "anyOf", schemas } } },
{ "minItems", 1 },
};
if (!inputs.parallel_tool_calls) {
schema["maxItems"] = 1;
}
builder.add_rule("root",
std::string(data.thinking_forced_open ? "( \"</think>\" space )? " : "") +
"\"<TOOLCALL>\" " + builder.add_schema("tool_calls", schema) +
" \"</TOOLCALL>\"");
});
data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
// If thinking_forced_open, then we capture the </think> tag in the grammar,
// (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
std::string(data.thinking_forced_open ?
"[\\s\\S]*?(</think>\\s*)" :
"(?:<think>[\\s\\S]*?</think>\\s*)?") +
"(<TOOLCALL>)[\\s\\S]*" });
}
return data;
}
static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) { static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) {
if (!builder.syntax().parse_tool_calls) { if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest()); builder.add_content(builder.consume_rest());
@ -1313,6 +1393,71 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
} }
return data; return data;
} }
static common_chat_params common_chat_params_init_deepseek_v3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
// Pass thinking context for DeepSeek V3.1 template
json additional_context = {
{"thinking", inputs.enable_thinking},
};
auto prompt = apply(tmpl, inputs,
/* messages_override= */ inputs.messages,
/* tools_override= */ std::nullopt,
additional_context);
data.prompt = prompt;
data.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
if (string_ends_with(data.prompt, "<think>")) {
if (!inputs.enable_thinking) {
data.prompt += "</think>";
} else {
data.thinking_forced_open = true;
}
}
if (inputs.tools.is_array() && !inputs.tools.empty()) {
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null();
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
std::string name = function.at("name");
auto parameters = function.at("parameters");
builder.resolve_refs(parameters);
tool_rules.push_back(builder.add_rule(name + "-call",
"( \"<tool▁call▁begin>\" )? \"" + name + "<tool▁sep>"
"\" " + builder.add_schema(name + "-args", parameters) + " "
"\"<tool▁call▁end>\""));
});
// Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag,
// so we accept common variants (then it's all constrained)
builder.add_rule("root",
std::string(data.thinking_forced_open ? "( \"</think>\" space )? " : "") +
"( \"<tool▁calls▁begin>\" | \"<tool_calls_begin>\" | \"<tool calls begin>\" | \"<tool\\\\_calls\\\\_begin>\" | \"<tool▁calls>\" ) "
"(" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " "
"\"<tool▁calls▁end>\""
" space");
data.grammar_triggers.push_back({
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
// If thinking_forced_open, then we capture the </think> tag in the grammar,
// (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
std::string(data.thinking_forced_open ? "[\\s\\S]*?(</think>\\s*)" : "(?:<think>[\\s\\S]*?</think>\\s*)?") +
"(<tool▁calls▁begin>|<tool_calls_begin>|<tool calls begin>|<tool\\\\_calls\\\\_begin>|<tool▁calls>)[\\s\\S]*"
});
data.preserved_tokens = {
"<think>",
"</think>",
"<tool▁calls▁begin>",
"<tool▁call▁begin>",
"<tool▁sep>",
"<tool▁call▁end>",
"<tool▁calls▁end>",
};
});
}
return data;
}
static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) { static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<think>", "</think>"); builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) { if (!builder.syntax().parse_tool_calls) {
@ -1334,6 +1479,66 @@ static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
tool_calls_end); tool_calls_end);
} }
static void common_chat_parse_deepseek_v3_1_content(common_chat_msg_parser & builder) {
static const common_regex function_regex("(?:<tool▁call▁begin>)?([^\\n<]+)(?:<toolsep>)");
static const common_regex close_regex("(?:[\\s]*)?<toolcallend>");
static const common_regex tool_calls_begin("(?:<tool▁calls▁begin>|<tool_calls_begin>|<tool calls begin>|<tool\\\\_calls\\\\_begin>|<tool▁calls>)");
static const common_regex tool_calls_end("<tool▁calls▁end>");
if (!builder.syntax().parse_tool_calls) {
LOG_DBG("%s: not parse_tool_calls\n", __func__);
builder.add_content(builder.consume_rest());
return;
}
LOG_DBG("%s: parse_tool_calls\n", __func__);
parse_json_tool_calls(
builder,
/* block_open= */ tool_calls_begin,
/* function_regex_start_only= */ std::nullopt,
function_regex,
close_regex,
tool_calls_end);
}
static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) {
// DeepSeek V3.1 outputs reasoning content between "<think>" and "</think>" tags, followed by regular content
// First try to parse using the standard reasoning parsing method
LOG_DBG("%s: thinking_forced_open: %s\n", __func__, std::to_string(builder.syntax().thinking_forced_open).c_str());
auto start_pos = builder.pos();
auto found_end_think = builder.try_find_literal("</think>");
builder.move_to(start_pos);
if (builder.syntax().thinking_forced_open && !builder.is_partial() && !found_end_think) {
LOG_DBG("%s: no end_think, not partial, adding content\n", __func__);
common_chat_parse_deepseek_v3_1_content(builder);
} else if (builder.try_parse_reasoning("<think>", "</think>")) {
// If reasoning was parsed successfully, the remaining content is regular content
LOG_DBG("%s: parsed reasoning, adding content\n", __func__);
// </think><tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>NAME\n```json\nJSON\n```<tool▁call▁end><tool▁calls▁end>
common_chat_parse_deepseek_v3_1_content(builder);
} else {
if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE) {
LOG_DBG("%s: reasoning_format none, adding content\n", __func__);
common_chat_parse_deepseek_v3_1_content(builder);
return;
}
// If no reasoning tags found, check if we should treat everything as reasoning
if (builder.syntax().thinking_forced_open) {
// If thinking is forced open but no tags found, treat everything as reasoning
LOG_DBG("%s: thinking_forced_open, adding reasoning content\n", __func__);
builder.add_reasoning_content(builder.consume_rest());
} else {
LOG_DBG("%s: no thinking_forced_open, adding content\n", __func__);
// <tool▁call▁begin>NAME<tool▁sep>JSON<tool▁call▁end>
common_chat_parse_deepseek_v3_1_content(builder);
}
}
}
static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) { static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data; common_chat_params data;
auto prompt = apply(tmpl, inputs); auto prompt = apply(tmpl, inputs);
@ -1830,7 +2035,7 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
// If thinking_forced_open, then we capture the </think> tag in the grammar, // If thinking_forced_open, then we capture the </think> tag in the grammar,
// (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
std::string(data.thinking_forced_open ? "[\\s\\S]*?(</think>\\s*)" : "(?:<think>[\\s\\S]*?</think>\\s*)?") + ( std::string(data.thinking_forced_open ? "[\\s\\S]*?(</think>\\s*)" : "(?:<think>[\\s\\S]*?</think>\\s*)?") + (
"(\\s*" "\\s*("
"(?:<tool_call>" "(?:<tool_call>"
"|<function" "|<function"
"|(?:```(?:json|xml)?\n\\s*)?(?:<function_call>|<tools>|<xml><json>|<response>)?" "|(?:```(?:json|xml)?\n\\s*)?(?:<function_call>|<tools>|<xml><json>|<response>)?"
@ -2060,6 +2265,33 @@ static void common_chat_parse_granite(common_chat_msg_parser & builder) {
} }
} }
static void common_chat_parse_nemotron_v2(common_chat_msg_parser & builder) {
// Parse thinking tags
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// Look for tool calls
static const common_regex tool_call_regex(regex_escape("<TOOLCALL>"));
if (auto res = builder.try_find_regex(tool_call_regex)) {
builder.move_to(res->groups[0].end);
// Expect JSON array of tool calls
auto tool_calls_data = builder.consume_json();
if (tool_calls_data.json.is_array()) {
if (!builder.try_consume_literal("</TOOLCALL>")) {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
builder.add_tool_calls(tool_calls_data.json);
} else {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
}
builder.add_content(builder.consume_rest());
}
static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) { static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
// Parse thinking tags first - this handles the main reasoning content // Parse thinking tags first - this handles the main reasoning content
builder.try_parse_reasoning("<seed:think>", "</seed:think>"); builder.try_parse_reasoning("<seed:think>", "</seed:think>");
@ -2263,6 +2495,12 @@ static common_chat_params common_chat_templates_apply_jinja(
} }
} }
// DeepSeek V3.1: detect based on specific patterns in the template
if (src.find("message['prefix'] is defined and message['prefix'] and thinking") != std::string::npos &&
params.json_schema.is_null()) {
return common_chat_params_init_deepseek_v3_1(tmpl, params);
}
// DeepSeek R1: use handler in all cases except json schema (thinking / tools). // DeepSeek R1: use handler in all cases except json schema (thinking / tools).
if (src.find("<tool▁calls▁begin>") != std::string::npos && params.json_schema.is_null()) { if (src.find("<tool▁calls▁begin>") != std::string::npos && params.json_schema.is_null()) {
return common_chat_params_init_deepseek_r1(tmpl, params); return common_chat_params_init_deepseek_r1(tmpl, params);
@ -2293,6 +2531,11 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_seed_oss(tmpl, params, inputs); return common_chat_params_init_seed_oss(tmpl, params, inputs);
} }
// Nemotron v2
if (src.find("<SPECIAL_10>") != std::string::npos) {
return common_chat_params_init_nemotron_v2(tmpl, params);
}
// Use generic handler when mixing tools + JSON schema. // Use generic handler when mixing tools + JSON schema.
// TODO: support that mix in handlers below. // TODO: support that mix in handlers below.
if ((params.tools.is_array() && params.json_schema.is_object())) { if ((params.tools.is_array() && params.json_schema.is_object())) {
@ -2430,6 +2673,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
case COMMON_CHAT_FORMAT_DEEPSEEK_R1: case COMMON_CHAT_FORMAT_DEEPSEEK_R1:
common_chat_parse_deepseek_r1(builder); common_chat_parse_deepseek_r1(builder);
break; break;
case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1:
common_chat_parse_deepseek_v3_1(builder);
break;
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2:
common_chat_parse_functionary_v3_2(builder); common_chat_parse_functionary_v3_2(builder);
break; break;
@ -2454,6 +2700,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
case COMMON_CHAT_FORMAT_SEED_OSS: case COMMON_CHAT_FORMAT_SEED_OSS:
common_chat_parse_seed_oss(builder); common_chat_parse_seed_oss(builder);
break; break;
case COMMON_CHAT_FORMAT_NEMOTRON_V2:
common_chat_parse_nemotron_v2(builder);
break;
default: default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
} }

View File

@ -107,11 +107,13 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_FIREFUNCTION_V2, COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
COMMON_CHAT_FORMAT_DEEPSEEK_V3_1,
COMMON_CHAT_FORMAT_HERMES_2_PRO, COMMON_CHAT_FORMAT_HERMES_2_PRO,
COMMON_CHAT_FORMAT_COMMAND_R7B, COMMON_CHAT_FORMAT_COMMAND_R7B,
COMMON_CHAT_FORMAT_GRANITE, COMMON_CHAT_FORMAT_GRANITE,
COMMON_CHAT_FORMAT_GPT_OSS, COMMON_CHAT_FORMAT_GPT_OSS,
COMMON_CHAT_FORMAT_SEED_OSS, COMMON_CHAT_FORMAT_SEED_OSS,
COMMON_CHAT_FORMAT_NEMOTRON_V2,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
}; };
@ -198,6 +200,8 @@ common_chat_msg common_chat_parse(const std::string & input, bool is_p
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice); common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
bool common_chat_templates_support_enable_thinking(const common_chat_templates * chat_templates);
// Parses a JSON array of messages in OpenAI's chat completion API format. // Parses a JSON array of messages in OpenAI's chat completion API format.
// T can be std::string containing JSON or nlohmann::ordered_json // T can be std::string containing JSON or nlohmann::ordered_json
template <class T> std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const T & messages); template <class T> std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const T & messages);

View File

@ -843,9 +843,10 @@ public:
_build_object_rule( _build_object_rule(
properties, required, name, properties, required, name,
schema.contains("additionalProperties") ? schema["additionalProperties"] : json())); schema.contains("additionalProperties") ? schema["additionalProperties"] : json()));
} else if ((schema_type.is_null() || schema_type == "object") && schema.contains("allOf")) { } else if ((schema_type.is_null() || schema_type == "object" || schema_type == "string") && schema.contains("allOf")) {
std::unordered_set<std::string> required; std::unordered_set<std::string> required;
std::vector<std::pair<std::string, json>> properties; std::vector<std::pair<std::string, json>> properties;
std::map<std::string, size_t> enum_values;
std::string hybrid_name = name; std::string hybrid_name = name;
std::function<void(const json &, bool)> add_component = [&](const json & comp_schema, bool is_required) { std::function<void(const json &, bool)> add_component = [&](const json & comp_schema, bool is_required) {
if (comp_schema.contains("$ref")) { if (comp_schema.contains("$ref")) {
@ -857,6 +858,14 @@ public:
required.insert(prop.key()); required.insert(prop.key());
} }
} }
} else if (comp_schema.contains("enum")) {
for (const auto & v : comp_schema["enum"]) {
const auto rule = _generate_constant_rule(v);
if (enum_values.find(rule) == enum_values.end()) {
enum_values[rule] = 0;
}
enum_values[rule] += 1;
}
} else { } else {
// todo warning // todo warning
} }
@ -870,6 +879,17 @@ public:
add_component(t, true); add_component(t, true);
} }
} }
if (!enum_values.empty()) {
std::vector<std::string> enum_intersection;
for (const auto & p : enum_values) {
if (p.second == schema["allOf"].size()) {
enum_intersection.push_back(p.first);
}
}
if (!enum_intersection.empty()) {
return _add_rule(rule_name, "(" + string_join(enum_intersection, " | ") + ") space");
}
}
return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json())); return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json()));
} else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) { } else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) {
json items = schema.contains("items") ? schema["items"] : schema["prefixItems"]; json items = schema.contains("items") ? schema["items"] : schema["prefixItems"];

View File

@ -4,17 +4,52 @@
#include <condition_variable> #include <condition_variable>
#include <cstdarg> #include <cstdarg>
#include <cstdio> #include <cstdio>
#include <cstdlib>
#include <cstring>
#include <mutex> #include <mutex>
#include <sstream> #include <sstream>
#include <thread> #include <thread>
#include <vector> #include <vector>
#if defined(_WIN32)
# include <io.h>
# include <windows.h>
# define isatty _isatty
# define fileno _fileno
#else
# include <unistd.h>
#endif // defined(_WIN32)
int common_log_verbosity_thold = LOG_DEFAULT_LLAMA; int common_log_verbosity_thold = LOG_DEFAULT_LLAMA;
void common_log_set_verbosity_thold(int verbosity) { void common_log_set_verbosity_thold(int verbosity) {
common_log_verbosity_thold = verbosity; common_log_verbosity_thold = verbosity;
} }
// Auto-detect if colors should be enabled based on terminal and environment
static bool common_log_should_use_colors_auto() {
// Check NO_COLOR environment variable (https://no-color.org/)
if (const char * no_color = std::getenv("NO_COLOR")) {
if (no_color[0] != '\0') {
return false;
}
}
// Check TERM environment variable
if (const char * term = std::getenv("TERM")) {
if (std::strcmp(term, "dumb") == 0) {
return false;
}
}
// Check if stdout and stderr are connected to a terminal
// We check both because log messages can go to either
bool stdout_is_tty = isatty(fileno(stdout));
bool stderr_is_tty = isatty(fileno(stderr));
return stdout_is_tty || stderr_is_tty;
}
static int64_t t_us() { static int64_t t_us() {
return std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now().time_since_epoch()).count(); return std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
} }
@ -353,6 +388,11 @@ struct common_log * common_log_init() {
struct common_log * common_log_main() { struct common_log * common_log_main() {
static struct common_log log; static struct common_log log;
static std::once_flag init_flag;
std::call_once(init_flag, [&]() {
// Set default to auto-detect colors
log.set_colors(common_log_should_use_colors_auto());
});
return &log; return &log;
} }
@ -380,8 +420,19 @@ void common_log_set_file(struct common_log * log, const char * file) {
log->set_file(file); log->set_file(file);
} }
void common_log_set_colors(struct common_log * log, bool colors) { void common_log_set_colors(struct common_log * log, log_colors colors) {
log->set_colors(colors); if (colors == LOG_COLORS_AUTO) {
log->set_colors(common_log_should_use_colors_auto());
return;
}
if (colors == LOG_COLORS_DISABLED) {
log->set_colors(false);
return;
}
GGML_ASSERT(colors == LOG_COLORS_ENABLED);
log->set_colors(true);
} }
void common_log_set_prefix(struct common_log * log, bool prefix) { void common_log_set_prefix(struct common_log * log, bool prefix) {

View File

@ -24,6 +24,12 @@
#define LOG_DEFAULT_DEBUG 1 #define LOG_DEFAULT_DEBUG 1
#define LOG_DEFAULT_LLAMA 0 #define LOG_DEFAULT_LLAMA 0
enum log_colors {
LOG_COLORS_AUTO = -1,
LOG_COLORS_DISABLED = 0,
LOG_COLORS_ENABLED = 1,
};
// needed by the LOG_TMPL macro to avoid computing log arguments if the verbosity lower // needed by the LOG_TMPL macro to avoid computing log arguments if the verbosity lower
// set via common_log_set_verbosity() // set via common_log_set_verbosity()
extern int common_log_verbosity_thold; extern int common_log_verbosity_thold;
@ -65,10 +71,10 @@ void common_log_add(struct common_log * log, enum ggml_log_level level, const ch
// D - debug (stderr, V = LOG_DEFAULT_DEBUG) // D - debug (stderr, V = LOG_DEFAULT_DEBUG)
// //
void common_log_set_file (struct common_log * log, const char * file); // not thread-safe void common_log_set_file (struct common_log * log, const char * file); // not thread-safe
void common_log_set_colors (struct common_log * log, bool colors); // not thread-safe void common_log_set_colors (struct common_log * log, log_colors colors); // not thread-safe
void common_log_set_prefix (struct common_log * log, bool prefix); // whether to output prefix to each log void common_log_set_prefix (struct common_log * log, bool prefix); // whether to output prefix to each log
void common_log_set_timestamps(struct common_log * log, bool timestamps); // whether to output timestamps in the prefix void common_log_set_timestamps(struct common_log * log, bool timestamps); // whether to output timestamps in the prefix
// helper macros for logging // helper macros for logging
// use these to avoid computing log arguments if the verbosity of the log is higher than the threshold // use these to avoid computing log arguments if the verbosity of the log is higher than the threshold

View File

@ -5128,6 +5128,20 @@ class EmbeddingGemma(Gemma3Model):
def set_gguf_parameters(self): def set_gguf_parameters(self):
super().set_gguf_parameters() super().set_gguf_parameters()
# Override the sliding window size as it gets adjusted by the Gemma3TextConfig
# constructor. We want to use the value from the original model's config.json.
# ref: https://github.com/huggingface/transformers/pull/40700
with open(self.dir_model / "config.json", "r", encoding="utf-8") as f:
config = json.load(f)
orig_sliding_window = config.get("sliding_window")
if orig_sliding_window is None:
raise ValueError("sliding_window not found in model config - this is required for the model")
logger.info(f"Using original sliding_window from config: {orig_sliding_window} "
f"instead of {self.hparams['sliding_window']}")
self.gguf_writer.add_sliding_window(orig_sliding_window)
self._try_set_pooling_type() self._try_set_pooling_type()
@ -6687,6 +6701,8 @@ class T5Model(TextModel):
self.gguf_writer.add_embedding_length(self.hparams["d_model"]) self.gguf_writer.add_embedding_length(self.hparams["d_model"])
self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"]) self.gguf_writer.add_feed_forward_length(self.hparams["d_ff"])
self.gguf_writer.add_block_count(self.hparams["num_layers"]) self.gguf_writer.add_block_count(self.hparams["num_layers"])
if (dec_n_layer := self.hparams.get("num_decoder_layers")) is not None:
self.gguf_writer.add_decoder_block_count(dec_n_layer)
self.gguf_writer.add_head_count(self.hparams["num_heads"]) self.gguf_writer.add_head_count(self.hparams["num_heads"])
self.gguf_writer.add_key_length(self.hparams["d_kv"]) self.gguf_writer.add_key_length(self.hparams["d_kv"])
self.gguf_writer.add_value_length(self.hparams["d_kv"]) self.gguf_writer.add_value_length(self.hparams["d_kv"])

View File

@ -12,7 +12,7 @@ import json
from math import prod from math import prod
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast
from transformers import AutoConfig from transformers import AutoConfig, AutoTokenizer
import torch import torch
@ -26,6 +26,8 @@ import gguf
# reuse model definitions from convert_hf_to_gguf.py # reuse model definitions from convert_hf_to_gguf.py
from convert_hf_to_gguf import LazyTorchTensor, ModelBase from convert_hf_to_gguf import LazyTorchTensor, ModelBase
from gguf.constants import GGUFValueType
logger = logging.getLogger("lora-to-gguf") logger = logging.getLogger("lora-to-gguf")
@ -369,7 +371,31 @@ if __name__ == '__main__':
self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora") self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
def set_gguf_parameters(self): def set_gguf_parameters(self):
logger.debug("GGUF KV: %s = %d", gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha)
self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha) self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha)
alora_invocation_tokens = lparams.get("alora_invocation_tokens")
invocation_string = lparams.get("invocation_string")
if invocation_string and not alora_invocation_tokens:
logger.debug("Tokenizing invocation_string -> alora_invocation_tokens")
base_model_path_or_id = hparams.get("_name_or_path")
try:
tokenizer = AutoTokenizer.from_pretrained(base_model_path_or_id)
except ValueError:
logger.error("Unable to load tokenizer from %s", base_model_path_or_id)
raise
# NOTE: There's an off-by-one with the older aLoRAs where
# the invocation string includes the "<|start_of_turn|>"
# token, but the adapters themselves were trained to
# activate _after_ that first token, so we drop it here.
alora_invocation_tokens = tokenizer(invocation_string)["input_ids"][1:]
if alora_invocation_tokens:
logger.debug("GGUF KV: %s = %s", gguf.Keys.Adapter.ALORA_INVOCATION_TOKENS, alora_invocation_tokens)
self.gguf_writer.add_key_value(
gguf.Keys.Adapter.ALORA_INVOCATION_TOKENS,
alora_invocation_tokens,
GGUFValueType.ARRAY,
GGUFValueType.UINT32,
)
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
# Never add extra tensors (e.g. rope_freqs) for LoRA adapters # Never add extra tensors (e.g. rope_freqs) for LoRA adapters

View File

@ -314,3 +314,11 @@ Converting the matmul weight format from ND to NZ to improve performance. Enable
### GGML_CANN_ACL_GRAPH ### GGML_CANN_ACL_GRAPH
Operators are executed using ACL graph execution, rather than in op-by-op (eager) mode. Enabled by default. Operators are executed using ACL graph execution, rather than in op-by-op (eager) mode. Enabled by default.
### GGML_CANN_GRAPH_CACHE_CAPACITY
Maximum number of compiled CANN graphs kept in the LRU cache, default is 12. When the number of cached graphs exceeds this capacity, the least recently used graph will be evicted.
### GGML_CANN_PREFILL_USE_GRAPH
Enable ACL graph execution during the prefill stage, default is false. This option is only effective when FA is enabled.

View File

@ -42,18 +42,6 @@ cmake --build build --config Release -j $(nproc)
cmake --build build --config Release -j $(nproc) cmake --build build --config Release -j $(nproc)
``` ```
- By default, NNPA is disabled by default. To enable it:
```bash
cmake -S . -B build \
-DCMAKE_BUILD_TYPE=Release \
-DGGML_BLAS=ON \
-DGGML_BLAS_VENDOR=OpenBLAS \
-DGGML_NNPA=ON
cmake --build build --config Release -j $(nproc)
```
- For debug builds: - For debug builds:
```bash ```bash
@ -164,15 +152,11 @@ All models need to be converted to Big-Endian. You can achieve this in three cas
Only available in IBM z15/LinuxONE 3 or later system with the `-DGGML_VXE=ON` (turned on by default) compile flag. No hardware acceleration is possible with llama.cpp with older systems, such as IBM z14/arch12. In such systems, the APIs can still run but will use a scalar implementation. Only available in IBM z15/LinuxONE 3 or later system with the `-DGGML_VXE=ON` (turned on by default) compile flag. No hardware acceleration is possible with llama.cpp with older systems, such as IBM z14/arch12. In such systems, the APIs can still run but will use a scalar implementation.
### 2. NNPA Vector Intrinsics Acceleration ### 2. zDNN Accelerator (WIP)
Only available in IBM z16/LinuxONE 4 or later system with the `-DGGML_NNPA=ON` (turned off by default) compile flag. No hardware acceleration is possible with llama.cpp with older systems, such as IBM z15/arch13. In such systems, the APIs can still run but will use a scalar implementation.
### 3. zDNN Accelerator (WIP)
Only available in IBM z17/LinuxONE 5 or later system with the `-DGGML_ZDNN=ON` compile flag. No hardware acceleration is possible with llama.cpp with older systems, such as IBM z15/arch13. In such systems, the APIs will default back to CPU routines. Only available in IBM z17/LinuxONE 5 or later system with the `-DGGML_ZDNN=ON` compile flag. No hardware acceleration is possible with llama.cpp with older systems, such as IBM z15/arch13. In such systems, the APIs will default back to CPU routines.
### 4. Spyre Accelerator ### 3. Spyre Accelerator
_Only available with IBM z17 / LinuxONE 5 or later system. No support currently available._ _Only available with IBM z17 / LinuxONE 5 or later system. No support currently available._
@ -230,10 +214,6 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl
CXXFLAGS="-include cstdint" pip3 install -r requirements.txt CXXFLAGS="-include cstdint" pip3 install -r requirements.txt
``` ```
5. `-DGGML_NNPA=ON` generates gibberish output
Answer: We are aware of this as detailed in [this issue](https://github.com/ggml-org/llama.cpp/issues/14877). Please either try reducing the number of threads, or disable the compile option using `-DGGML_NNPA=OFF`.
## Getting Help on IBM Z & LinuxONE ## Getting Help on IBM Z & LinuxONE
1. **Bugs, Feature Requests** 1. **Bugs, Feature Requests**
@ -258,38 +238,38 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl
## Appendix B: SIMD Support Matrix ## Appendix B: SIMD Support Matrix
| | VX/VXE/VXE2 | NNPA | zDNN | Spyre | | | VX/VXE/VXE2 | zDNN | Spyre |
| ---------- | ----------- | ---- | ---- | ----- | |------------|-------------|------|-------|
| FP32 | ✅ | ✅ | ✅ | ❓ | | FP32 | ✅ | ✅ | ❓ |
| FP16 | ✅ | ✅ | ❓ | ❓ | | FP16 | ✅ | ❓ | ❓ |
| BF16 | 🚫 | 🚫 | ❓ | ❓ | | BF16 | 🚫 | ❓ | ❓ |
| Q4_0 | ✅ | ✅ | ❓ | ❓ | | Q4_0 | ✅ | ❓ | ❓ |
| Q4_1 | ✅ | ✅ | ❓ | ❓ | | Q4_1 | ✅ | ❓ | ❓ |
| MXFP4 | 🚫 | 🚫 | ❓ | ❓ | | MXFP4 | 🚫 | ❓ | ❓ |
| Q5_0 | ✅ | ✅ | ❓ | ❓ | | Q5_0 | ✅ | ❓ | ❓ |
| Q5_1 | ✅ | ✅ | ❓ | ❓ | | Q5_1 | ✅ | ❓ | ❓ |
| Q8_0 | ✅ | ✅ | ❓ | ❓ | | Q8_0 | ✅ | ❓ | ❓ |
| Q2_K | 🚫 | 🚫 | ❓ | ❓ | | Q2_K | 🚫 | ❓ | ❓ |
| Q3_K | ✅ | ✅ | ❓ | ❓ | | Q3_K | ✅ | ❓ | ❓ |
| Q4_K | ✅ | ✅ | ❓ | ❓ | | Q4_K | ✅ | ❓ | ❓ |
| Q5_K | ✅ | ✅ | ❓ | ❓ | | Q5_K | ✅ | ❓ | ❓ |
| Q6_K | ✅ | ✅ | ❓ | ❓ | | Q6_K | ✅ | ❓ | ❓ |
| TQ1_0 | 🚫 | 🚫 | ❓ | ❓ | | TQ1_0 | 🚫 | ❓ | ❓ |
| TQ2_0 | 🚫 | 🚫 | ❓ | ❓ | | TQ2_0 | 🚫 | ❓ | ❓ |
| IQ2_XXS | 🚫 | 🚫 | ❓ | ❓ | | IQ2_XXS | 🚫 | ❓ | ❓ |
| IQ2_XS | 🚫 | 🚫 | ❓ | ❓ | | IQ2_XS | 🚫 | ❓ | ❓ |
| IQ2_S | 🚫 | 🚫 | ❓ | ❓ | | IQ2_S | 🚫 | ❓ | ❓ |
| IQ3_XXS | 🚫 | 🚫 | ❓ | ❓ | | IQ3_XXS | 🚫 | ❓ | ❓ |
| IQ3_S | 🚫 | 🚫 | ❓ | ❓ | | IQ3_S | 🚫 | ❓ | ❓ |
| IQ1_S | 🚫 | 🚫 | ❓ | ❓ | | IQ1_S | 🚫 | ❓ | ❓ |
| IQ1_M | 🚫 | 🚫 | ❓ | ❓ | | IQ1_M | 🚫 | ❓ | ❓ |
| IQ4_NL | ✅ | ✅ | ❓ | ❓ | | IQ4_NL | ✅ | ❓ | ❓ |
| IQ4_XS | ✅ | ✅ | ❓ | ❓ | | IQ4_XS | ✅ | ❓ | ❓ |
| FP32->FP16 | 🚫 | ✅ | ❓ | ❓ | | FP32->FP16 | 🚫 | ❓ | ❓ |
| FP16->FP32 | 🚫 | ✅ | ❓ | ❓ | | FP16->FP32 | 🚫 | ❓ | ❓ |
- ✅ - acceleration available - ✅ - acceleration available
- 🚫 - acceleration unavailable, will still run using scalar implementation - 🚫 - acceleration unavailable, will still run using scalar implementation
- ❓ - acceleration unknown, please contribute if you can test it yourself - ❓ - acceleration unknown, please contribute if you can test it yourself
Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on Aug 22, 2025. Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on Sep 6, 2025.

View File

@ -333,17 +333,17 @@ static void print_params(struct my_llama_hparams * params) {
} }
static void print_tensor_info(const struct ggml_context * ctx) { static void print_tensor_info(const struct ggml_context * ctx) {
for (auto t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { for (auto * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
LOG_INF("%s: Allocating ", __func__); LOG_INF("%s: Allocating ", __func__);
int64_t total = 1; int64_t total = 1;
int i = 0; int i = 0;
for (; i < ggml_n_dims(t); ++i) { for (; i < ggml_n_dims(t); ++i) {
if (i > 0) LOG("x "); if (i > 0) { LOG_INF("x "); }
LOG("[%" PRId64 "] ", t->ne[i]); LOG_INF("[%" PRId64 "] ", t->ne[i]);
total *= t->ne[i]; total *= t->ne[i];
} }
if (i > 1) LOG("= [%" PRId64 "] ", total); if (i > 1) { LOG_INF("= [%" PRId64 "] ", total); }
LOG("float space for %s\n", ggml_get_name(t)); LOG_INF("float space for %s\n", ggml_get_name(t));
} }
} }

View File

@ -28,6 +28,15 @@ static std::string ggml_ne_string(const ggml_tensor * t) {
return str; return str;
} }
static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
union {
float f;
uint32_t i;
} u;
u.i = (uint32_t)h.bits << 16;
return u.f;
}
static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) { static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) {
size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0]; size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
float v; float v;
@ -43,6 +52,8 @@ static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t *
v = (float) *(int16_t *) &data[i]; v = (float) *(int16_t *) &data[i];
} else if (type == GGML_TYPE_I8) { } else if (type == GGML_TYPE_I8) {
v = (float) *(int8_t *) &data[i]; v = (float) *(int8_t *) &data[i];
} else if (type == GGML_TYPE_BF16) {
v = ggml_compute_bf16_to_fp32(*(ggml_bf16_t *) &data[i]);
} else { } else {
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }

View File

@ -586,9 +586,10 @@ class SchemaConverter:
properties = list(schema.get('properties', {}).items()) properties = list(schema.get('properties', {}).items())
return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties'))) return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties')))
elif schema_type in (None, 'object') and 'allOf' in schema: elif schema_type in (None, 'object', 'string') and 'allOf' in schema:
required = set() required = set()
properties = [] properties = []
enum_sets = []
hybrid_name = name hybrid_name = name
def add_component(comp_schema, is_required): def add_component(comp_schema, is_required):
if (ref := comp_schema.get('$ref')) is not None: if (ref := comp_schema.get('$ref')) is not None:
@ -600,6 +601,9 @@ class SchemaConverter:
if is_required: if is_required:
required.add(prop_name) required.add(prop_name)
if 'enum' in comp_schema:
enum_sets.append(set(comp_schema['enum']))
for t in schema['allOf']: for t in schema['allOf']:
if 'anyOf' in t: if 'anyOf' in t:
for tt in t['anyOf']: for tt in t['anyOf']:
@ -607,6 +611,15 @@ class SchemaConverter:
else: else:
add_component(t, is_required=True) add_component(t, is_required=True)
if enum_sets:
enum_intersection = enum_sets[0]
for s in enum_sets[1:]:
enum_intersection &= s
if enum_intersection:
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ') space'
return self._add_rule(rule_name, rule)
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None)) return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None))
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema): elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):

View File

@ -1,5 +1,6 @@
--extra-index-url https://download.pytorch.org/whl/cpu --extra-index-url https://download.pytorch.org/whl/cpu
torch~=2.6.0 torch
torchvision~=0.21.0 torchvision
transformers~=4.55.0 transformers
huggingface-hub~=0.34.0 huggingface-hub
accelerate

View File

@ -9,15 +9,134 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import torch import torch
import numpy as np import numpy as np
unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME') ### If you want to dump RoPE activations, apply this monkey patch to the model
### class from Transformers that you are running (replace apertus.modeling_apertus
### with the proper package and class for your model
### === START ROPE DEBUG ===
# from transformers.models.apertus.modeling_apertus import apply_rotary_pos_emb
parser = argparse.ArgumentParser(description='Process model with specified path') # orig_rope = apply_rotary_pos_emb
parser.add_argument('--model-path', '-m', help='Path to the model') # torch.set_printoptions(threshold=float('inf'))
# torch.set_printoptions(precision=6, sci_mode=False)
# def debug_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
# # log inputs
# summarize(q, "RoPE.q_in")
# summarize(k, "RoPE.k_in")
# # call original
# q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
# # log outputs
# summarize(q_out, "RoPE.q_out")
# summarize(k_out, "RoPE.k_out")
# return q_out, k_out
# # Patch it
# import transformers.models.apertus.modeling_apertus as apertus_mod # noqa: E402
# apertus_mod.apply_rotary_pos_emb = debug_rope
### == END ROPE DEBUG ===
def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3):
"""
Print a tensor in llama.cpp debug style.
Supports:
- 2D tensors (seq, hidden)
- 3D tensors (batch, seq, hidden)
- 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head
Shows first and last max_vals of each vector per sequence position.
"""
t = tensor.detach().to(torch.float32).cpu()
# Determine dimensions
if t.ndim == 3:
_, s, _ = t.shape
elif t.ndim == 2:
_, s = 1, t.shape[0]
t = t.unsqueeze(0)
elif t.ndim == 4:
_, s, _, _ = t.shape
else:
print(f"Skipping tensor due to unsupported dimensions: {t.ndim}")
return
ten_shape = t.shape
print(f"ggml_debug: {name} = (f32) ... = {{{ten_shape}}}")
print(" [")
print(" [")
# Determine indices for first and last sequences
first_indices = list(range(min(s, max_seq)))
last_indices = list(range(max(0, s - max_seq), s))
# Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq
has_overlap = bool(set(first_indices) & set(last_indices)) or (max_seq * 2 == s)
# Combine indices
if has_overlap:
# If there's overlap, just use the combined unique indices
indices = sorted(list(set(first_indices + last_indices)))
separator_index = None
else:
# If no overlap, we'll add a separator between first and last sequences
indices = first_indices + last_indices
separator_index = len(first_indices)
for i, si in enumerate(indices):
# Add separator if needed
if separator_index is not None and i == separator_index:
print(" ...")
# Extract appropriate slice
vec = t[0, si]
if vec.ndim == 2: # 4D case: flatten heads × dim_per_head
flat = vec.flatten().tolist()
else: # 2D or 3D case
flat = vec.tolist()
# First and last slices
first = flat[:max_vals]
last = flat[-max_vals:] if len(flat) >= max_vals else flat
first_str = ", ".join(f"{v:12.4f}" for v in first)
last_str = ", ".join(f"{v:12.4f}" for v in last)
print(f" [{first_str}, ..., {last_str}]")
print(" ],")
print(" ]")
print(f" sum = {t.sum().item():.6f}\n")
def debug_hook(name):
def fn(_m, input, output):
if isinstance(input, torch.Tensor):
summarize(input, name + "_in")
elif isinstance(input, (tuple, list)) and isinstance(input[0], torch.Tensor):
summarize(input[0], name + "_in")
if isinstance(output, torch.Tensor):
summarize(output, name + "_out")
elif isinstance(output, (tuple, list)) and isinstance(output[0], torch.Tensor):
summarize(output[0], name + "_out")
return fn
unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
parser = argparse.ArgumentParser(description="Process model with specified path")
parser.add_argument("--model-path", "-m", help="Path to the model")
args = parser.parse_args() args = parser.parse_args()
model_path = os.environ.get('MODEL_PATH', args.model_path) model_path = os.environ.get("MODEL_PATH", args.model_path)
if model_path is None: if model_path is None:
parser.error("Model path must be specified either via --model-path argument or MODEL_PATH environment variable") parser.error(
"Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
)
config = AutoConfig.from_pretrained(model_path) config = AutoConfig.from_pretrained(model_path)
@ -34,18 +153,30 @@ config = AutoConfig.from_pretrained(model_path)
if unreleased_model_name: if unreleased_model_name:
model_name_lower = unreleased_model_name.lower() model_name_lower = unreleased_model_name.lower()
unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}" unreleased_module_path = (
f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
)
class_name = f"{unreleased_model_name}ForCausalLM" class_name = f"{unreleased_model_name}ForCausalLM"
print(f"Importing unreleased model module: {unreleased_module_path}") print(f"Importing unreleased model module: {unreleased_module_path}")
try: try:
model_class = getattr(importlib.import_module(unreleased_module_path), class_name) model_class = getattr(
model = model_class.from_pretrained(model_path) # Note: from_pretrained, not fromPretrained importlib.import_module(unreleased_module_path), class_name
)
model = model_class.from_pretrained(
model_path
) # Note: from_pretrained, not fromPretrained
except (ImportError, AttributeError) as e: except (ImportError, AttributeError) as e:
print(f"Failed to import or load model: {e}") print(f"Failed to import or load model: {e}")
exit(1) exit(1)
else: else:
model = AutoModelForCausalLM.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained(
model_path, device_map="auto", offload_folder="offload"
)
for name, module in model.named_modules():
if len(list(module.children())) == 0: # only leaf modules
module.register_forward_hook(debug_hook(name))
model_name = os.path.basename(model_path) model_name = os.path.basename(model_path)
# Printing the Model class to allow for easier debugging. This can be useful # Printing the Model class to allow for easier debugging. This can be useful

View File

@ -7,7 +7,7 @@ base_model:
Recommended way to run this model: Recommended way to run this model:
```sh ```sh
llama-server -hf {namespace}/{model_name}-GGUF llama-server -hf {namespace}/{model_name}-GGUF --embeddings
``` ```
Then the endpoint can be accessed at http://localhost:8080/embedding, for Then the endpoint can be accessed at http://localhost:8080/embedding, for

View File

@ -134,7 +134,6 @@ option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON)
option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON) option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON)
option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF) option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF)
option(GGML_VXE "ggml: enable vxe" ON) option(GGML_VXE "ggml: enable vxe" ON)
option(GGML_NNPA "ggml: enable nnpa" OFF) # temp disabled by default, see: https://github.com/ggml-org/llama.cpp/issues/14877
option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF) option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF)
set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM") set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM")

View File

@ -132,6 +132,8 @@ extern "C" {
GGML_BACKEND_DEVICE_TYPE_CPU, GGML_BACKEND_DEVICE_TYPE_CPU,
// GPU device using dedicated memory // GPU device using dedicated memory
GGML_BACKEND_DEVICE_TYPE_GPU, GGML_BACKEND_DEVICE_TYPE_GPU,
// integrated GPU device using host memory
GGML_BACKEND_DEVICE_TYPE_IGPU,
// accelerator devices intended to be used together with the CPU backend (e.g. BLAS or AMX) // accelerator devices intended to be used together with the CPU backend (e.g. BLAS or AMX)
GGML_BACKEND_DEVICE_TYPE_ACCEL GGML_BACKEND_DEVICE_TYPE_ACCEL
}; };
@ -150,11 +152,21 @@ extern "C" {
// all the device properties // all the device properties
struct ggml_backend_dev_props { struct ggml_backend_dev_props {
// device name
const char * name; const char * name;
// device description
const char * description; const char * description;
// device free memory in bytes
size_t memory_free; size_t memory_free;
// device total memory in bytes
size_t memory_total; size_t memory_total;
// device type
enum ggml_backend_dev_type type; enum ggml_backend_dev_type type;
// device id
// for PCI devices, this should be the PCI bus id formatted as "domain:bus:device.function" (e.g. "0000:01:00.0")
// if the id is unknown, this should be NULL
const char * device_id;
// device capabilities
struct ggml_backend_dev_caps caps; struct ggml_backend_dev_caps caps;
}; };

View File

@ -101,7 +101,6 @@ extern "C" {
GGML_BACKEND_API int ggml_cpu_has_riscv_v (void); GGML_BACKEND_API int ggml_cpu_has_riscv_v (void);
GGML_BACKEND_API int ggml_cpu_has_vsx (void); GGML_BACKEND_API int ggml_cpu_has_vsx (void);
GGML_BACKEND_API int ggml_cpu_has_vxe (void); GGML_BACKEND_API int ggml_cpu_has_vxe (void);
GGML_BACKEND_API int ggml_cpu_has_nnpa (void);
GGML_BACKEND_API int ggml_cpu_has_wasm_simd (void); GGML_BACKEND_API int ggml_cpu_has_wasm_simd (void);
GGML_BACKEND_API int ggml_cpu_has_llamafile (void); GGML_BACKEND_API int ggml_cpu_has_llamafile (void);
@ -135,6 +134,7 @@ extern "C" {
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void); GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *, float *, int64_t); GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *, float *, int64_t);
GGML_BACKEND_API void ggml_cpu_fp32_to_i32 (const float *, int32_t *, int64_t);
GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t); GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t);
GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t); GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t);
GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t); GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t);

View File

@ -43,14 +43,8 @@ GGML_BACKEND_API ggml_backend_t ggml_backend_metal_init(void);
GGML_BACKEND_API bool ggml_backend_is_metal(ggml_backend_t backend); GGML_BACKEND_API bool ggml_backend_is_metal(ggml_backend_t backend);
GGML_DEPRECATED(
GGML_BACKEND_API ggml_backend_buffer_t ggml_backend_metal_buffer_from_ptr(void * data, size_t size, size_t max_size),
"obsoleted by the new device interface - https://github.com/ggml-org/llama.cpp/pull/9713");
GGML_BACKEND_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data); GGML_BACKEND_API void ggml_backend_metal_set_abort_callback(ggml_backend_t backend, ggml_abort_callback abort_callback, void * user_data);
GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_metal_buffer_type(void);
// helper to check if the device supports a specific family // helper to check if the device supports a specific family
// ideally, the user code should be doing these checks // ideally, the user code should be doing these checks
// ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf // ref: https://developer.apple.com/metal/Metal-Feature-Set-Tables.pdf

View File

@ -1404,6 +1404,7 @@ extern "C" {
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b); struct ggml_tensor * b);
// note: casting from f32 to i32 will discard the fractional part
GGML_API struct ggml_tensor * ggml_cast( GGML_API struct ggml_tensor * ggml_cast(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
@ -1528,7 +1529,11 @@ extern "C" {
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a); struct ggml_tensor * a);
// supports 3D: a->ne[2] == b->ne[1] // supports 4D a:
// a [n_embd, ne1, ne2, ne3]
// b I32 [n_rows, ne2, ne3, 1]
//
// return [n_embd, n_rows, ne2, ne3]
GGML_API struct ggml_tensor * ggml_get_rows( GGML_API struct ggml_tensor * ggml_get_rows(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, // data struct ggml_tensor * a, // data

View File

@ -8,7 +8,7 @@
extern "C" { extern "C" {
#endif #endif
#define GGML_BACKEND_API_VERSION 1 #define GGML_BACKEND_API_VERSION 2
// //
// Backend buffer type // Backend buffer type
@ -114,6 +114,9 @@ extern "C" {
void (*event_record)(ggml_backend_t backend, ggml_backend_event_t event); void (*event_record)(ggml_backend_t backend, ggml_backend_event_t event);
// wait for an event on on a different stream // wait for an event on on a different stream
void (*event_wait) (ggml_backend_t backend, ggml_backend_event_t event); void (*event_wait) (ggml_backend_t backend, ggml_backend_event_t event);
// (optional) sort/optimize the nodes in the graph
void (*optimize_graph) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
}; };
struct ggml_backend { struct ggml_backend {

View File

@ -400,9 +400,8 @@ ggml_backend_t ggml_backend_init_by_type(enum ggml_backend_dev_type type, const
ggml_backend_t ggml_backend_init_best(void) { ggml_backend_t ggml_backend_init_best(void) {
ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU); ggml_backend_dev_t dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_GPU);
if (!dev) { dev = dev ? dev : ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_IGPU);
dev = ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU); dev = dev ? dev : ggml_backend_dev_by_type(GGML_BACKEND_DEVICE_TYPE_CPU);
}
if (!dev) { if (!dev) {
return nullptr; return nullptr;
} }

View File

@ -463,6 +463,13 @@ void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event)
backend->iface.event_wait(backend, event); backend->iface.event_wait(backend, event);
} }
static void ggml_backend_optimize_graph(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
GGML_ASSERT(backend);
if (backend->iface.optimize_graph != NULL) {
backend->iface.optimize_graph(backend, cgraph);
}
}
// Backend device // Backend device
const char * ggml_backend_dev_name(ggml_backend_dev_t device) { const char * ggml_backend_dev_name(ggml_backend_dev_t device) {
@ -1298,6 +1305,10 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra
struct ggml_backend_sched_split * split = &sched->splits[i]; struct ggml_backend_sched_split * split = &sched->splits[i];
split->graph = ggml_graph_view(graph, split->i_start, split->i_end); split->graph = ggml_graph_view(graph, split->i_start, split->i_end);
// Optimize this split of the graph. This needs to happen before we make graph_copy,
// so they are in sync.
ggml_backend_optimize_graph(sched->backends[split->backend_id], &split->graph);
// add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split // add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split
for (int j = 0; j < split->n_inputs; j++) { for (int j = 0; j < split->n_inputs; j++) {
assert(graph_copy->size > (graph_copy->n_nodes + 1)); assert(graph_copy->size > (graph_copy->n_nodes + 1));

View File

@ -270,6 +270,7 @@ static struct ggml_backend_i blas_backend_i = {
/* .graph_compute = */ ggml_backend_blas_graph_compute, /* .graph_compute = */ ggml_backend_blas_graph_compute,
/* .event_record = */ NULL, /* .event_record = */ NULL,
/* .event_wait = */ NULL, /* .event_wait = */ NULL,
/* .optimize_graph = */ NULL,
}; };
static ggml_guid_t ggml_backend_blas_guid(void) { static ggml_guid_t ggml_backend_blas_guid(void) {

View File

@ -2268,8 +2268,6 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
* stream, and persistent buffers for rope init/cache. * stream, and persistent buffers for rope init/cache.
* @param dst The destination ggml_tensor whose computation * @param dst The destination ggml_tensor whose computation
* depends on the RoPE values (usually Qcur/Kcur). * depends on the RoPE values (usually Qcur/Kcur).
* @param sin_tensor_buffer Pre-allocated buffer for storing repeated sin values.
* @param cos_tensor_buffer Pre-allocated buffer for storing repeated cos values.
* @param theta_scale Scalar exponent base for computing theta scale values. * @param theta_scale Scalar exponent base for computing theta scale values.
* @param freq_scale Frequency scaling factor, applied to theta scale. * @param freq_scale Frequency scaling factor, applied to theta scale.
* @param attn_factor Attention scaling factor, applied to sin/cos. * @param attn_factor Attention scaling factor, applied to sin/cos.
@ -2277,17 +2275,23 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
* (dim expansion vs repeat_interleave). * (dim expansion vs repeat_interleave).
*/ */
static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
void* sin_tensor_buffer, void* cos_tensor_buffer,
float* corr_dims, float ext_factor, float* corr_dims, float ext_factor,
float theta_scale, float freq_scale, float theta_scale, float freq_scale,
float attn_factor, bool is_neox) { float attn_factor, bool is_neox) {
// int sin/cos cache, cache has different repeat method depond on
// @param.is_neox
ggml_tensor* src0 = dst->src[0]; // input ggml_tensor* src0 = dst->src[0]; // input
ggml_tensor* src1 = dst->src[1]; // position ggml_tensor* src1 = dst->src[1]; // position
ggml_tensor* src2 = dst->src[2]; // freq_factors ggml_tensor* src2 = dst->src[2]; // freq_factors
if(src2 == nullptr && ctx.rope_cache.cached
&& ctx.rope_cache.ext_factor == ext_factor
&& ctx.rope_cache.theta_scale == theta_scale
&& ctx.rope_cache.freq_scale == freq_scale
&& ctx.rope_cache.attn_factor == attn_factor
&& ctx.rope_cache.is_neox == is_neox) {
// use cache.
return;
}
int64_t theta_scale_length = src0->ne[0] / 2; int64_t theta_scale_length = src0->ne[0] / 2;
int64_t theta_scale_ne[] = {theta_scale_length, 1, 1, 1}; int64_t theta_scale_ne[] = {theta_scale_length, 1, 1, 1};
size_t theta_scale_nb[] = {sizeof(float), sizeof(float), sizeof(float), size_t theta_scale_nb[] = {sizeof(float), sizeof(float), sizeof(float),
@ -2316,8 +2320,6 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
ctx.rope_cache.freq_scale != freq_scale) { ctx.rope_cache.freq_scale != freq_scale) {
ctx.rope_cache.theta_scale_length = theta_scale_length; ctx.rope_cache.theta_scale_length = theta_scale_length;
ctx.rope_cache.theta_scale = theta_scale;
ctx.rope_cache.freq_scale = freq_scale;
if (ctx.rope_cache.theta_scale_cache != nullptr) { if (ctx.rope_cache.theta_scale_cache != nullptr) {
ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache)); ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache));
@ -2342,7 +2344,7 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
// return MIN(1, MAX(0, y)) - 1; // return MIN(1, MAX(0, y)) - 1;
yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float)); yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
void* yarn_ramp_buffer = yarn_ramp_allocator.get(); void* yarn_ramp_buffer = yarn_ramp_allocator.get();
acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float_t), acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
float zero_value = 0, one_value = 1; float zero_value = 0, one_value = 1;
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]); float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
@ -2411,6 +2413,20 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
ggml_cann_release_resources(ctx, acl_freq_factors_tensor, acl_freq_fac_res_tensor); ggml_cann_release_resources(ctx, acl_freq_factors_tensor, acl_freq_fac_res_tensor);
} }
// init sin_repeat && cos_repeat, only to accelerate first layer on each device
if (position_length > ctx.rope_cache.position_length) {
ctx.rope_cache.position_length = position_length;
if (ctx.rope_cache.sin_cache != nullptr) {
ACL_CHECK(aclrtFree(ctx.rope_cache.sin_cache));
}
if (ctx.rope_cache.cos_cache != nullptr) {
ACL_CHECK(aclrtFree(ctx.rope_cache.cos_cache));
}
int64_t repeat_theta_length = theta_scale_length * position_length * 2;
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.sin_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.cos_cache, repeat_theta_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
}
// position // position
aclTensor* acl_position_tensor = ggml_cann_create_tensor( aclTensor* acl_position_tensor = ggml_cann_create_tensor(
src1->data, ggml_cann_type_mapping(src1->type), src1->data, ggml_cann_type_mapping(src1->type),
@ -2462,10 +2478,10 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1]; sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
} }
aclTensor* acl_sin_repeat_tensor = aclTensor* acl_sin_repeat_tensor =
ggml_cann_create_tensor(sin_tensor_buffer, ACL_FLOAT, sizeof(float), ggml_cann_create_tensor(ctx.rope_cache.sin_cache, ACL_FLOAT, sizeof(float),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
aclTensor* acl_cos_repeat_tensor = aclTensor* acl_cos_repeat_tensor =
ggml_cann_create_tensor(cos_tensor_buffer, ACL_FLOAT, sizeof(float), ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
// repeat // repeat
@ -2483,6 +2499,14 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
num_repeats, output_size); num_repeats, output_size);
} }
// Other layers use cache except first layer.
ctx.rope_cache.cached = true;
ctx.rope_cache.ext_factor = ext_factor;
ctx.rope_cache.theta_scale = theta_scale;
ctx.rope_cache.freq_scale = freq_scale;
ctx.rope_cache.attn_factor = attn_factor;
ctx.rope_cache.is_neox = is_neox;
ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor, ggml_cann_release_resources(ctx, acl_theta_scale_tensor, acl_position_tensor,
acl_theta_tensor, acl_sin_tensor, acl_sin_repeat_tensor, acl_cos_tensor, acl_theta_tensor, acl_sin_tensor, acl_sin_repeat_tensor, acl_cos_tensor,
acl_cos_repeat_tensor); acl_cos_repeat_tensor);
@ -2504,10 +2528,7 @@ aclnnStatus aclnnRotaryPositionEmbedding(void* workspace,
#endif #endif
void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) { void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
// TODO: use ascendc
// Only test with LLAMA model.
ggml_tensor* src0 = dst->src[0]; // input ggml_tensor* src0 = dst->src[0]; // input
ggml_tensor* src1 = dst->src[1];
// param // param
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
@ -2538,15 +2559,8 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
// sin/cos tensor length.
int64_t repeat_theta_length = src0->ne[0] * src1->ne[0];
ggml_cann_pool_alloc sin_tensor_allocator(ctx.pool(), repeat_theta_length * sizeof(float));
ggml_cann_pool_alloc cos_tensor_allocator(ctx.pool(), repeat_theta_length * sizeof(float));
void *sin_tensor_buffer = sin_tensor_allocator.get();
void *cos_tensor_buffer = cos_tensor_allocator.get();
// init ctx.rope_cos/rope_sin cache // init ctx.rope_cos/rope_sin cache
aclnn_cache_init(ctx, dst, sin_tensor_buffer, cos_tensor_buffer, corr_dims, ext_factor, aclnn_cache_init(ctx, dst, corr_dims, ext_factor,
theta_scale, freq_scale, attn_factor, is_neox); theta_scale, freq_scale, attn_factor, is_neox);
int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1}; int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
@ -2556,10 +2570,10 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1]; sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
} }
aclTensor* acl_sin_reshape_tensor = aclTensor* acl_sin_reshape_tensor =
ggml_cann_create_tensor(sin_tensor_buffer, ACL_FLOAT, sizeof(float), ggml_cann_create_tensor(ctx.rope_cache.sin_cache, ACL_FLOAT, sizeof(float),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
aclTensor* acl_cos_reshape_tensor = aclTensor* acl_cos_reshape_tensor =
ggml_cann_create_tensor(cos_tensor_buffer, ACL_FLOAT, sizeof(float), ggml_cann_create_tensor(ctx.rope_cache.cos_cache, ACL_FLOAT, sizeof(float),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
aclTensor* acl_src = ggml_cann_create_tensor(src0); aclTensor* acl_src = ggml_cann_create_tensor(src0);

View File

@ -38,6 +38,7 @@
#include <unistd.h> #include <unistd.h>
#include <functional> #include <functional>
#include <optional> #include <optional>
#include <list>
#include "../include/ggml-cann.h" #include "../include/ggml-cann.h"
#include "../include/ggml.h" #include "../include/ggml.h"
@ -106,6 +107,7 @@ int32_t ggml_cann_get_device();
std::optional<std::string> get_env(const std::string& name); std::optional<std::string> get_env(const std::string& name);
bool parse_bool(const std::string& value); bool parse_bool(const std::string& value);
int parse_integer(const std::string& value);
/** /**
* @brief Abstract base class for memory pools used by CANN. * @brief Abstract base class for memory pools used by CANN.
@ -350,7 +352,7 @@ struct ggml_graph_node_properties {
struct ggml_cann_graph { struct ggml_cann_graph {
~ggml_cann_graph() { ~ggml_cann_graph() {
if (graph != nullptr) { if (graph != nullptr) {
aclmdlRIDestroy(graph); ACL_CHECK(aclmdlRIDestroy(graph));
} }
} }
@ -358,6 +360,64 @@ struct ggml_cann_graph {
std::vector<ggml_graph_node_properties> ggml_graph_properties; std::vector<ggml_graph_node_properties> ggml_graph_properties;
}; };
/**
* @brief LRU cache for managing ggml_cann_graph objects.
*
* This class maintains a list of shared_ptr to ggml_cann_graph objects
* and enforces a maximum capacity. It provides methods to push new graphs,
* move existing graphs to the front (most recently used), and clear the cache.
*/
struct ggml_cann_graph_lru_cache {
size_t capacity; /**< Maximum number of graphs in the cache. */
std::list<ggml_cann_graph*> cache_list; /**< List storing cached graphs as raw pointers. */
ggml_cann_graph_lru_cache() {
capacity = parse_integer(get_env("GGML_CANN_GRAPH_CACHE_CAPACITY").value_or("12"));
}
/**
* @brief Push a new graph to the front of the cache.
* If the cache exceeds capacity, the least recently used graph is deleted.
* @param new_node Pointer to the new ggml_cann_graph to cache.
* Ownership is transferred to the cache (cache will delete it).
*/
void push(ggml_cann_graph* new_node) {
if (cache_list.size() >= capacity) {
ggml_cann_graph* old = cache_list.back();
cache_list.pop_back();
delete old; // free the old graph
}
cache_list.push_front(new_node);
}
/**
* @brief Move an existing graph to the front of the cache.
* @param node Pointer to the ggml_cann_graph to move.
*/
void move_to_front(ggml_cann_graph* node) {
cache_list.remove(node);
cache_list.push_front(node);
}
/**
* @brief Clear all graphs from the cache (also frees memory).
*/
void clear() {
for (auto ptr : cache_list) {
delete ptr;
}
cache_list.clear();
}
/**
* @brief Destructor that clears the cache and frees all cached graphs.
*/
~ggml_cann_graph_lru_cache() {
clear();
}
};
#endif // USE_ACL_GRAPH #endif // USE_ACL_GRAPH
struct ggml_cann_rope_cache { struct ggml_cann_rope_cache {
@ -365,12 +425,27 @@ struct ggml_cann_rope_cache {
if(theta_scale_cache != nullptr) { if(theta_scale_cache != nullptr) {
ACL_CHECK(aclrtFree(theta_scale_cache)); ACL_CHECK(aclrtFree(theta_scale_cache));
} }
if(sin_cache != nullptr) {
ACL_CHECK(aclrtFree(sin_cache));
}
if(cos_cache != nullptr) {
ACL_CHECK(aclrtFree(cos_cache));
}
} }
void* theta_scale_cache = nullptr; void* theta_scale_cache = nullptr;
int64_t theta_scale_length = 0; int64_t theta_scale_length = 0;
// sin/cos cache, used only to accelerate first layer on each device
void* sin_cache = nullptr;
void* cos_cache = nullptr;
int64_t position_length = 0;
// Properties to check before reusing the sincos cache
bool cached = false;
float ext_factor = 0.0f;
float theta_scale = 0.0f; float theta_scale = 0.0f;
float freq_scale = 0.0f; float freq_scale = 0.0f;
float attn_factor = 0.0f;
bool is_neox = false;
}; };
struct ggml_cann_tensor_cache { struct ggml_cann_tensor_cache {
@ -394,7 +469,7 @@ struct ggml_backend_cann_context {
aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */ aclrtEvent copy_event = nullptr; /**< Event for managing copy operations. */
#ifdef USE_ACL_GRAPH #ifdef USE_ACL_GRAPH
/// Cached CANN ACL graph used for executing the current ggml computation graph. /// Cached CANN ACL graph used for executing the current ggml computation graph.
std::unique_ptr<ggml_cann_graph> cann_graph; ggml_cann_graph_lru_cache graph_lru_cache;
bool acl_graph_mode = true; bool acl_graph_mode = true;
#endif #endif
cann_task_queue task_queue; cann_task_queue task_queue;

View File

@ -116,6 +116,24 @@ bool parse_bool(const std::string& value) {
return valid_values.find(value) != valid_values.end(); return valid_values.find(value) != valid_values.end();
} }
/**
* @brief Parse a string as an integer, returning 0 if invalid.
*
* This function attempts to convert the input string `value` to an `int`.
* If the string is not a valid integer or is out of the `int` range,
* it returns 0.
*
* @param value The string to parse.
* @return The parsed integer, or 0 if conversion fails.
*/
int parse_integer(const std::string& value) {
try {
return std::stoi(value);
} catch (...) {
return 0;
}
}
/** /**
* @brief Initialize the CANN device information. * @brief Initialize the CANN device information.
* *
@ -2092,16 +2110,17 @@ static bool ggml_backend_cann_cpy_tensor_async(
ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
ACL_MEMCPY_DEVICE_TO_DEVICE, ACL_MEMCPY_DEVICE_TO_DEVICE,
cann_ctx_src->stream())); cann_ctx_src->stream()));
// record event on src stream after the copy // record event on src stream after the copy
if (!cann_ctx_src->copy_event) { // TODO: this event is not effective with acl graph mode, change to use aclrtSynchronizeStream
ACL_CHECK(aclrtCreateEventWithFlag(&cann_ctx_src->copy_event, ACL_EVENT_SYNC)); // if (!cann_ctx_src->copy_event) {
} // ACL_CHECK(aclrtCreateEventWithFlag(&cann_ctx_src->copy_event, ACL_EVENT_SYNC));
ACL_CHECK(aclrtRecordEvent(cann_ctx_src->copy_event, cann_ctx_src->stream())); // }
// ACL_CHECK(aclrtRecordEvent(cann_ctx_src->copy_event, cann_ctx_src->stream()));
// wait on dst stream for the copy to complete // // wait on dst stream for the copy to complete
ggml_cann_set_device(cann_ctx_dst->device); // ggml_cann_set_device(cann_ctx_dst->device);
ACL_CHECK(aclrtStreamWaitEvent(cann_ctx_dst->stream(), cann_ctx_src->copy_event)); // ACL_CHECK(aclrtStreamWaitEvent(cann_ctx_dst->stream(), cann_ctx_src->copy_event));
ACL_CHECK(aclrtSynchronizeStream(cann_ctx_src->stream()));
} else { } else {
// src and dst are on the same backend // src and dst are on the same backend
ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
@ -2130,30 +2149,52 @@ static void ggml_backend_cann_synchronize(ggml_backend_t backend) {
#ifdef USE_ACL_GRAPH #ifdef USE_ACL_GRAPH
/** /**
* @brief Populate the internal CANN graph node properties from the ggml computation graph. * @brief Add a new CANN graph to the LRU cache by populating node properties from the ggml graph.
* *
* This function copies all node attributes (operation type, dimensions, strides, input sources, * This function creates a new ggml_cann_graph object and fills its node properties
* and operation parameters) into the cached CANN graph structure for later reuse or comparison. * (operation type, dimensions, strides, input sources, and operation parameters)
* based on the current ggml computation graph.
* *
* @param cann_ctx The CANN backend context. * Each node in the ggml graph is mapped to a property entry in the new CANN graph:
* @param cgraph The ggml computational graph. * - node address
* - operation type
* - shape (ne) and strides (nb)
* - source tensor addresses
* - operation parameters
*
* After initialization, the new graph is pushed into the LRU cache owned by the
* CANN backend context. The cache takes ownership of the graph and manages its
* lifetime (including deletion upon eviction).
*
* @param cann_ctx The CANN backend context containing the graph cache.
* @param cgraph The current ggml computation graph.
*/ */
static void set_ggml_graph_node_properties(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) { static void add_lru_matched_graph_node_properties(
for (int node_idx = 0; node_idx < cgraph->n_nodes; node_idx++) { ggml_backend_cann_context * cann_ctx,
ggml_tensor * node = cgraph->nodes[node_idx]; ggml_cgraph * cgraph) {
cann_ctx->cann_graph->ggml_graph_properties[node_idx].node_address = node->data; // Create a new ggml_cann_graph object on the heap (its lifetime is managed by the cache).
cann_ctx->cann_graph->ggml_graph_properties[node_idx].node_op = node->op; ggml_cann_graph * new_graph = new ggml_cann_graph();
new_graph->ggml_graph_properties.resize(cgraph->n_nodes);
for (int dim = 0; dim < GGML_MAX_DIMS; dim++) { for (int node_idx = 0; node_idx < cgraph->n_nodes; ++node_idx) {
cann_ctx->cann_graph->ggml_graph_properties[node_idx].ne[dim] = node->ne[dim]; ggml_tensor * node = cgraph->nodes[node_idx];
cann_ctx->cann_graph->ggml_graph_properties[node_idx].nb[dim] = node->nb[dim]; auto & prop = new_graph->ggml_graph_properties[node_idx];
prop.node_address = node->data;
prop.node_op = node->op;
std::copy_n(node->ne, GGML_MAX_DIMS, prop.ne);
std::copy_n(node->nb, GGML_MAX_DIMS, prop.nb);
for (int src = 0; src < GGML_MAX_SRC; ++src) {
prop.src_address[src] = node->src[src] ? node->src[src]->data : nullptr;
} }
for (int src = 0; src < GGML_MAX_SRC; src++) {
cann_ctx->cann_graph->ggml_graph_properties[node_idx].src_address[src] = memcpy(prop.op_params, node->op_params, GGML_MAX_OP_PARAMS);
node->src[src] ? node->src[src]->data : nullptr;
}
memcpy(cann_ctx->cann_graph->ggml_graph_properties[node_idx].op_params, node->op_params, GGML_MAX_OP_PARAMS);
} }
// Insert into the LRU cache (cache takes ownership and will delete it when evicted).
cann_ctx->graph_lru_cache.push(new_graph);
} }
/** /**
@ -2198,30 +2239,45 @@ static bool ggml_graph_node_has_matching_properties(ggml_tensor * node, ggml_gra
} }
/** /**
* @brief Determine if the CANN graph needs to be rebuilt due to graph changes. * @brief Check whether there is a cached CANN graph that matches the current ggml graph.
* *
* This checks whether the number or properties of ggml graph nodes have changed * This function iterates through the cached CANN graphs stored in the LRU cache and
* compared to the last captured CANN graph. If so, the CANN graph must be re-captured. * compares them against the given ggml computation graph. A match requires that the
* number of nodes is the same and that each nodes properties (operation type,
* dimensions, strides, inputs, and operation parameters) are identical.
* *
* @param cann_ctx The CANN backend context. * If a matching graph is found, it is promoted to the front of the LRU cache and the
* function returns true. Otherwise, the function returns false, indicating that a new
* CANN graph needs to be captured.
*
* @param cann_ctx The CANN backend context containing the graph cache.
* @param cgraph The current ggml computation graph. * @param cgraph The current ggml computation graph.
* @return true if an update is required; false otherwise. * @return true if a matching cached graph exists; false otherwise.
*/ */
static bool is_cann_graph_update_required(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) { static bool is_matched_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph) {
// The number of nodes is different, so the graph needs to be reconstructed. ggml_cann_graph_lru_cache &lru_cache = cann_ctx->graph_lru_cache;
if (cann_ctx->cann_graph->ggml_graph_properties.size() != (size_t)cgraph->n_nodes) { for (auto &graph_ptr : lru_cache.cache_list) {
cann_ctx->cann_graph->ggml_graph_properties.resize(cgraph->n_nodes); // Skip graphs with a different number of nodes.
return true; if (graph_ptr->ggml_graph_properties.size() != static_cast<size_t>(cgraph->n_nodes)) {
} continue;
}
// The number of nodes is the same; iterate over each node to check whether they match. // Check if all nodes match.
for (int i = 0; i < cgraph->n_nodes; i++) { bool all_match = true;
bool has_matching_properties = ggml_graph_node_has_matching_properties( for (int i = 0; i < cgraph->n_nodes; ++i) {
cgraph->nodes[i], &cann_ctx->cann_graph->ggml_graph_properties[i]); if (!ggml_graph_node_has_matching_properties(cgraph->nodes[i], &graph_ptr->ggml_graph_properties[i])) {
if(!has_matching_properties) { all_match = false;
break;
}
}
if (all_match) {
// update cache_list && renturn graph_ptr
lru_cache.move_to_front(graph_ptr);
return true; return true;
} }
} }
return false; return false;
} }
#endif // USE_ACL_GRAPH #endif // USE_ACL_GRAPH
@ -2240,17 +2296,13 @@ static bool is_cann_graph_update_required(ggml_backend_cann_context * cann_ctx,
* @param cann_graph_update_required Whether graph capture is needed due to graph changes. * @param cann_graph_update_required Whether graph capture is needed due to graph changes.
*/ */
static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph, static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx, ggml_cgraph * cgraph,
bool & use_cann_graph, bool & cann_graph_update_required) { bool & use_cann_graph, bool & cann_graph_update_required) {
#ifdef USE_ACL_GRAPH #ifdef USE_ACL_GRAPH
ggml_cann_graph* matched_graph = cann_ctx->graph_lru_cache.cache_list.front();
if (use_cann_graph && cann_graph_update_required) { if (use_cann_graph && cann_graph_update_required) {
if (cann_ctx->cann_graph->graph != nullptr) {
ACL_CHECK(aclmdlRIDestroy(cann_ctx->cann_graph->graph));
cann_ctx->cann_graph->graph = nullptr;
}
ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL)); ACL_CHECK(aclmdlRICaptureBegin(cann_ctx->stream(), ACL_MODEL_RI_CAPTURE_MODE_GLOBAL));
} }
#endif // USE_ACL_GRAPH #endif // USE_ACL_GRAPH
// Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph. // Only perform the graph execution if CANN graphs are not enabled, or we are capturing the graph.
// With the use of CANN graphs, the execution will be performed by the graph launch. // With the use of CANN graphs, the execution will be performed by the graph launch.
if (!use_cann_graph || cann_graph_update_required) { if (!use_cann_graph || cann_graph_update_required) {
@ -2271,12 +2323,12 @@ static void evaluate_and_capture_cann_graph(ggml_backend_cann_context * cann_ctx
#ifdef USE_ACL_GRAPH #ifdef USE_ACL_GRAPH
if (use_cann_graph && cann_graph_update_required) { // End CANN graph capture if (use_cann_graph && cann_graph_update_required) { // End CANN graph capture
ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &cann_ctx->cann_graph->graph)); ACL_CHECK(aclmdlRICaptureEnd(cann_ctx->stream(), &matched_graph->graph));
} }
if (use_cann_graph) { if (use_cann_graph) {
// Execute graph // Execute graph
ACL_CHECK(aclmdlRIExecuteAsync(cann_ctx->cann_graph->graph, cann_ctx->stream())); ACL_CHECK(aclmdlRIExecuteAsync(matched_graph->graph, cann_ctx->stream()));
} }
#endif // USE_ACL_GRAPH #endif // USE_ACL_GRAPH
} }
@ -2301,28 +2353,44 @@ static enum ggml_status ggml_backend_cann_graph_compute(
ggml_cann_set_device(cann_ctx->device); ggml_cann_set_device(cann_ctx->device);
g_nz_workspaces[cann_ctx->device].clear(); g_nz_workspaces[cann_ctx->device].clear();
// calculate rope cache for fist layer in current device.
cann_ctx->rope_cache.cached = false;
#ifdef USE_ACL_GRAPH #ifdef USE_ACL_GRAPH
bool use_cann_graph = true; bool use_cann_graph = true;
bool cann_graph_update_required = false; bool cann_graph_update_required = false;
static bool prefill_use_graph = parse_bool(get_env("GGML_CANN_PREFILL_USE_GRAPH").value_or(""));
if (!prefill_use_graph) {
// Do not use acl_graph for prefill.
for (int i = 0; i < cgraph->n_nodes; i++) {
ggml_tensor * node = cgraph->nodes[i];
// TODO: Optimize here. Currently, we can only
// get seq_len by FA's input.
if (node->op == GGML_OP_FLASH_ATTN_EXT) {
// Q -> src[0], shape: [B, S, N, D]
use_cann_graph = (node->src[0]->ne[1] == 1);
break;
}
}
}
if (!cann_ctx->acl_graph_mode) { if (!cann_ctx->acl_graph_mode) {
use_cann_graph = false; use_cann_graph = false;
} }
if (use_cann_graph) { if (use_cann_graph) {
if (cann_ctx->cann_graph == nullptr) { // If no matching graph is found, the graph needs to be recaptured.
cann_ctx->cann_graph.reset(new ggml_cann_graph()); cann_graph_update_required = !is_matched_graph(cann_ctx, cgraph);
cann_graph_update_required = true; if (cann_graph_update_required) {
// If no matching graph is found, add a new ACL graph.
add_lru_matched_graph_node_properties(cann_ctx, cgraph);
} }
cann_graph_update_required = is_cann_graph_update_required(cann_ctx, cgraph);
set_ggml_graph_node_properties(cann_ctx, cgraph);
} }
#else #else
bool use_cann_graph = false; bool use_cann_graph = false;
bool cann_graph_update_required = false; bool cann_graph_update_required = false;
#endif // USE_ACL_GRAPH #endif // USE_ACL_GRAPH
evaluate_and_capture_cann_graph( evaluate_and_capture_cann_graph(
cann_ctx, cann_ctx,
cgraph, cgraph,
@ -2689,6 +2757,7 @@ static const ggml_backend_i ggml_backend_cann_interface = {
/* .graph_compute = */ ggml_backend_cann_graph_compute, /* .graph_compute = */ ggml_backend_cann_graph_compute,
/* .event_record = */ ggml_backend_cann_event_record, /* .event_record = */ ggml_backend_cann_event_record,
/* .event_wait = */ ggml_backend_cann_event_wait, /* .event_wait = */ ggml_backend_cann_event_wait,
/* .optimize_graph = */ NULL,
}; };
/** /**

View File

@ -224,7 +224,13 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME) foreach(feature DOTPROD SVE MATMUL_INT8 FMA FP16_VECTOR_ARITHMETIC SME)
string(FIND "${ARM_FEATURE}" "__ARM_FEATURE_${feature} 1" feature_pos) string(FIND "${ARM_FEATURE}" "__ARM_FEATURE_${feature} 1" feature_pos)
if (NOT ${feature_pos} EQUAL -1) if (NOT ${feature_pos} EQUAL -1)
message(STATUS "ARM feature ${feature} enabled") # Special handling for MATMUL_INT8 when machine doesn't support i8mm
if ("${feature}" STREQUAL "MATMUL_INT8" AND GGML_MACHINE_SUPPORTS_noi8mm)
message(STATUS "ARM feature ${feature} detected but unsetting due to machine not supporting i8mm")
list(APPEND ARCH_FLAGS -U__ARM_FEATURE_MATMUL_INT8)
else()
message(STATUS "ARM feature ${feature} enabled")
endif()
endif() endif()
endforeach() endforeach()
endif() endif()
@ -457,7 +463,6 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
# TODO: Separation to determine activation of VX/VXE/VXE2 # TODO: Separation to determine activation of VX/VXE/VXE2
if (${S390X_M} MATCHES "8561|8562") if (${S390X_M} MATCHES "8561|8562")
set(GGML_NNPA OFF)
message(STATUS "z15 target") message(STATUS "z15 target")
list(APPEND ARCH_FLAGS -march=z15) list(APPEND ARCH_FLAGS -march=z15)
elseif (${S390X_M} MATCHES "3931") elseif (${S390X_M} MATCHES "3931")
@ -479,11 +484,6 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
list(APPEND ARCH_FLAGS -mvx -mzvector) list(APPEND ARCH_FLAGS -mvx -mzvector)
list(APPEND ARCH_DEFINITIONS GGML_VXE) list(APPEND ARCH_DEFINITIONS GGML_VXE)
endif() endif()
if (GGML_NNPA)
message(STATUS "NNPA enabled")
list(APPEND ARCH_DEFINITIONS GGML_NNPA)
endif()
elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "wasm") elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "wasm")
message(STATUS "Wasm detected") message(STATUS "Wasm detected")
list (APPEND GGML_CPU_SOURCES ggml-cpu/arch/wasm/quants.c) list (APPEND GGML_CPU_SOURCES ggml-cpu/arch/wasm/quants.c)

View File

@ -53,9 +53,9 @@ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i
#if defined(__VXE__) || defined(__VXE2__) #if defined(__VXE__) || defined(__VXE2__)
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
__vector float srcv [8]; float32x4_t srcv [8];
__vector float asrcv[8]; float32x4_t asrcv[8];
__vector float amaxv[8]; float32x4_t amaxv[8];
for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j); for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j);
for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]); for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]);
@ -74,8 +74,8 @@ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i
y[i].d = GGML_CPU_FP32_TO_FP16(d); y[i].d = GGML_CPU_FP32_TO_FP16(d);
for (int j = 0; j < 8; j++) { for (int j = 0; j < 8; j++) {
const __vector float v = vec_mul(srcv[j], vec_splats(id)); const float32x4_t v = vec_mul(srcv[j], vec_splats(id));
const __vector int32_t vi = vec_signed(v); const int32x4_t vi = vec_signed(v);
y[i].qs[4*j + 0] = vec_extract(vi, 0); y[i].qs[4*j + 0] = vec_extract(vi, 0);
y[i].qs[4*j + 1] = vec_extract(vi, 1); y[i].qs[4*j + 1] = vec_extract(vi, 1);
@ -98,9 +98,9 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i
#if defined(__VXE__) || defined(__VXE2__) #if defined(__VXE__) || defined(__VXE2__)
for (int i = 0; i < nb; i++) { for (int i = 0; i < nb; i++) {
__vector float srcv [8]; float32x4_t srcv [8];
__vector float asrcv[8]; float32x4_t asrcv[8];
__vector float amaxv[8]; float32x4_t amaxv[8];
for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j); for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j);
for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]); for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]);
@ -118,11 +118,11 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i
y[i].d = GGML_CPU_FP32_TO_FP16(d); y[i].d = GGML_CPU_FP32_TO_FP16(d);
__vector int32_t acc = vec_splats(0); int32x4_t acc = vec_splats(0);
for (int j = 0; j < 8; j++) { for (int j = 0; j < 8; j++) {
const __vector float v = vec_mul(srcv[j], vec_splats(id)); const float32x4_t v = vec_mul(srcv[j], vec_splats(id));
const __vector int32_t vi = vec_signed(v); const int32x4_t vi = vec_signed(v);
y[i].qs[4*j + 0] = vec_extract(vi, 0); y[i].qs[4*j + 0] = vec_extract(vi, 0);
y[i].qs[4*j + 1] = vec_extract(vi, 1); y[i].qs[4*j + 1] = vec_extract(vi, 1);
@ -162,37 +162,36 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
float sumf = 0; float sumf = 0;
#if defined(__VXE__) || defined(__VXE2__) #if defined(__VXE__) || defined(__VXE2__)
__vector float acc = vec_splats(0.0f); float32x4_t acc = vec_splats(0.0f);
const __vector uint8_t v_m = vec_splats((const uint8_t)0x0F); const uint8x16_t v_m = vec_splats((const uint8_t)0x0F);
const __vector int8_t v_s = vec_splats( (const int8_t)0x08); const int8x16_t v_s = vec_splats( (const int8_t)0x08);
for (; ib < nb; ++ib) { for (; ib < nb; ++ib) {
const __vector uint8_t v_x = vec_xl(0, x[ib].qs); const uint8x16_t v_x = vec_xl(0, x[ib].qs);
const __vector int8_t v_xl = (const __vector int8_t)(v_x & v_m); const int8x16_t v_xl = (const int8x16_t)(v_x & v_m);
const __vector int8_t v_xh = (const __vector int8_t)(v_x >> 4); const int8x16_t v_xh = (const int8x16_t)(v_x >> 4);
const __vector int8_t v_xls = vec_sub(v_xl, v_s); const int8x16_t v_xls = vec_sub(v_xl, v_s);
const __vector int8_t v_xhs = vec_sub(v_xh, v_s); const int8x16_t v_xhs = vec_sub(v_xh, v_s);
const __vector int8_t v_yl = vec_xl(0 , y[ib].qs); const int8x16_t v_yl = vec_xl(0 , y[ib].qs);
const __vector int8_t v_yh = vec_xl(QK8_0/2, y[ib].qs); const int8x16_t v_yh = vec_xl(QK8_0/2, y[ib].qs);
const __vector int16_t v_xylso = vec_mulo(v_xls, v_yl); const int16x8_t v_xylso = vec_mulo(v_xls, v_yl);
const __vector int16_t v_xylse = vec_mule(v_xls, v_yl); const int16x8_t v_xylse = vec_mule(v_xls, v_yl);
const __vector int16_t v_xyhso = vec_mulo(v_xhs, v_yh); const int16x8_t v_xyhso = vec_mulo(v_xhs, v_yh);
const __vector int16_t v_xyhse = vec_mule(v_xhs, v_yh); const int16x8_t v_xyhse = vec_mule(v_xhs, v_yh);
__vector int16_t v_xy_ = v_xylso + v_xylse + v_xyhso + v_xyhse; v_xy_ += vec_reve(v_xy_); int16x8_t v_xy_ = v_xylso + v_xylse + v_xyhso + v_xyhse; v_xy_ += vec_reve(v_xy_);
const __vector float v_xy = vec_float(vec_unpackh(v_xy_)); const float32x4_t v_xy = vec_float(vec_unpackh(v_xy_));
const __vector float v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d)); const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d));
acc = vec_madd(v_xy, v_d, acc); acc = vec_madd(v_xy, v_d, acc);
} }
sumf = acc[0] + acc[1] + acc[2] + acc[3]; sumf = vec_hsum_f32x4(acc);
*s = sumf; *s = sumf;
#else #else
UNUSED(nb); UNUSED(nb);
@ -249,8 +248,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
acc = vec_madd(v_xy, v_d, acc); acc = vec_madd(v_xy, v_d, acc);
} }
sumf = acc[0] + acc[1] + acc[2] + acc[3] + summs; sumf = vec_hsum_f32x4(acc) + summs;
*s = sumf; *s = sumf;
#else #else
UNUSED(nb); UNUSED(nb);
@ -351,7 +349,7 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
v_sum1 = vec_madd(v_xy1f, v_d1, v_sum1); v_sum1 = vec_madd(v_xy1f, v_d1, v_sum1);
} }
sumf += vec_hsum(v_sum0) + vec_hsum(v_sum1); sumf += vec_hsum_f32x4(v_sum0) + vec_hsum_f32x4(v_sum1);
#pragma GCC unroll 4 #pragma GCC unroll 4
for (; ib < nb; ++ib) { for (; ib < nb; ++ib) {
@ -390,7 +388,7 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d)); const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d));
const float32x4_t v_acc = vec_madd(v_xyf, v_d, vec_splats(0.0f)); const float32x4_t v_acc = vec_madd(v_xyf, v_d, vec_splats(0.0f));
sumf += vec_hsum(v_acc); sumf += vec_hsum_f32x4(v_acc);
} }
*s = sumf; *s = sumf;
@ -502,7 +500,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
v_sum1 = vec_madd(v_xy1f, v_d1, v_sum1); v_sum1 = vec_madd(v_xy1f, v_d1, v_sum1);
} }
sumf += vec_hsum(v_sum0) + vec_hsum(v_sum1) + summs0 + summs1; sumf += vec_hsum_f32x4(v_sum0) + vec_hsum_f32x4(v_sum1) + summs0 + summs1;
#pragma GCC unroll 4 #pragma GCC unroll 4
for (; ib < nb; ++ib) { for (; ib < nb; ++ib) {
@ -543,7 +541,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d)); const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d));
const float32x4_t v_acc = vec_madd(v_xyf, v_d, v_acc); const float32x4_t v_acc = vec_madd(v_xyf, v_d, v_acc);
sumf += vec_hsum(v_acc) + summs; sumf += vec_hsum_f32x4(v_acc) + summs;
} }
*s = sumf; *s = sumf;
@ -575,7 +573,7 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
float sumf = 0; float sumf = 0;
#if defined(__VXE__) || defined(__VXE2__) #if defined(__VXE__) || defined(__VXE2__)
__vector float acc = vec_splats(0.0f); float32x4_t acc = vec_splats(0.0f);
#pragma GCC unroll 8 #pragma GCC unroll 8
for (; ib < nb; ++ib) { for (; ib < nb; ++ib) {
@ -594,7 +592,7 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
acc = vec_madd(v_xy, v_d, acc); acc = vec_madd(v_xy, v_d, acc);
} }
sumf = acc[0] + acc[1] + acc[2] + acc[3]; sumf = vec_hsum_f32x4(acc);
*s = sumf; *s = sumf;
#else #else
@ -718,10 +716,10 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
isum2 = ggml_vec_dot(v_z, q3bytes[2], q8bytes[6]); isum2 = ggml_vec_dot(v_z, q3bytes[2], q8bytes[6]);
isum3 = ggml_vec_dot(v_z, q3bytes[3], q8bytes[7]); isum3 = ggml_vec_dot(v_z, q3bytes[3], q8bytes[7]);
isum += (isum0[0] + isum0[1] + isum0[2] + isum0[3]) * scale[0]; isum += vec_hsum_i32x4(isum0) * scale[0];
isum += (isum1[0] + isum1[1] + isum1[2] + isum1[3]) * scale[1]; isum += vec_hsum_i32x4(isum1) * scale[1];
isum += (isum2[0] + isum2[1] + isum2[2] + isum2[3]) * scale[2]; isum += vec_hsum_i32x4(isum2) * scale[2];
isum += (isum3[0] + isum3[1] + isum3[2] + isum3[3]) * scale[3]; isum += vec_hsum_i32x4(isum3) * scale[3];
scale += 4; scale += 4;
@ -819,7 +817,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
v_xl[1] = (int8x16_t)vec_and(v_x[1], v_lm); v_xl[1] = (int8x16_t)vec_and(v_x[1], v_lm);
const int32x4_t p1 = ggml_vec_dot(ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]); const int32x4_t p1 = ggml_vec_dot(ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]);
sumi1 += (p1[0] + p1[1] + p1[2] + p1[3]) * scales[2*j+0]; sumi1 += vec_hsum_i32x4(p1) * scales[2*j+0];
v_y[0] = vec_xl(0 , y0); v_y[0] = vec_xl(0 , y0);
v_y[1] = vec_xl(16, y0); v_y[1] = vec_xl(16, y0);
@ -829,7 +827,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
v_xl[1] = (int8x16_t)vec_sr(v_x[1], 4); v_xl[1] = (int8x16_t)vec_sr(v_x[1], 4);
const int32x4_t p2 = ggml_vec_dot(ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]); const int32x4_t p2 = ggml_vec_dot(ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]);
sumi2 += (p2[0] + p2[1] + p2[2] + p2[3]) * scales[2*j+1]; sumi2 += vec_hsum_i32x4(p2) * scales[2*j+1];
} }
sumf += d * (sumi1 + sumi2); sumf += d * (sumi1 + sumi2);
@ -911,7 +909,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
const int32x4_t v_minsho = vec_mulo(v_ysums, v_minsh); const int32x4_t v_minsho = vec_mulo(v_ysums, v_minsh);
const int32x4_t v_minshe = vec_mule(v_ysums, v_minsh); const int32x4_t v_minshe = vec_mule(v_ysums, v_minsh);
const int32x4_t v_mins = vec_add(v_minsho, v_minshe); const int32x4_t v_mins = vec_add(v_minsho, v_minshe);
const int32_t mins = v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3]; const int32_t mins = vec_hsum_i32x4(v_mins);
const uint8_t * scales = (const uint8_t *)utmp; const uint8_t * scales = (const uint8_t *)utmp;
const uint8_t * GGML_RESTRICT x0l = x[i].qs; const uint8_t * GGML_RESTRICT x0l = x[i].qs;
@ -948,8 +946,8 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
int32x4_t sumi0 = ggml_vec_dot(ggml_vec_dot(v_z, q5b[0], v_y[0]), q5b[1], v_y[1]); int32x4_t sumi0 = ggml_vec_dot(ggml_vec_dot(v_z, q5b[0], v_y[0]), q5b[1], v_y[1]);
int32x4_t sumi1 = ggml_vec_dot(ggml_vec_dot(v_z, q5b[2], v_y[2]), q5b[3], v_y[3]); int32x4_t sumi1 = ggml_vec_dot(ggml_vec_dot(v_z, q5b[2], v_y[2]), q5b[3], v_y[3]);
sumi += (sumi0[0] + sumi0[1] + sumi0[2] + sumi0[3]) * *scales++; sumi += vec_hsum_i32x4(sumi0) * *scales++;
sumi += (sumi1[0] + sumi1[1] + sumi1[2] + sumi1[3]) * *scales++; sumi += vec_hsum_i32x4(sumi1) * *scales++;
} }
sumf += d * sumi - dmin * mins; sumf += d * sumi - dmin * mins;
@ -1020,7 +1018,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
const int32x4_t v_minshe = vec_mule(v_ysumsh, v_scaleh); const int32x4_t v_minshe = vec_mule(v_ysumsh, v_scaleh);
const int32x4_t v_mins = v_minslo + v_minsle + v_minsho + v_minshe; const int32x4_t v_mins = v_minslo + v_minsle + v_minsho + v_minshe;
const int32_t mins = v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3]; const int32_t mins = vec_hsum_i32x4(v_mins);
int32_t isum = 0; int32_t isum = 0;
for (int j = 0; j < QK_K/128; ++j) { for (int j = 0; j < QK_K/128; ++j) {
@ -1060,10 +1058,10 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
int32x4_t summs2 = ggml_vec_dot(v_z, q6b[2], v_y[2]); int32x4_t summs2 = ggml_vec_dot(v_z, q6b[2], v_y[2]);
int32x4_t summs3 = ggml_vec_dot(v_z, q6b[3], v_y[3]); int32x4_t summs3 = ggml_vec_dot(v_z, q6b[3], v_y[3]);
isum += (summs0[0] + summs0[1] + summs0[2] + summs0[3]) * scale[0] + isum += vec_hsum_i32x4(summs0) * scale[0] +
(summs1[0] + summs1[1] + summs1[2] + summs1[3]) * scale[1] + vec_hsum_i32x4(summs1) * scale[1] +
(summs2[0] + summs2[1] + summs2[2] + summs2[3]) * scale[2] + vec_hsum_i32x4(summs2) * scale[2] +
(summs3[0] + summs3[1] + summs3[2] + summs3[3]) * scale[3]; vec_hsum_i32x4(summs3) * scale[3];
scale += 4; scale += 4;
@ -1094,10 +1092,10 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
summs2 = ggml_vec_dot(v_z, q6b[2], v_y[2]); summs2 = ggml_vec_dot(v_z, q6b[2], v_y[2]);
summs3 = ggml_vec_dot(v_z, q6b[3], v_y[3]); summs3 = ggml_vec_dot(v_z, q6b[3], v_y[3]);
isum += (summs0[0] + summs0[1] + summs0[2] + summs0[3]) * scale[0] + isum += vec_hsum_i32x4(summs0) * scale[0] +
(summs1[0] + summs1[1] + summs1[2] + summs1[3]) * scale[1] + vec_hsum_i32x4(summs1) * scale[1] +
(summs2[0] + summs2[1] + summs2[2] + summs2[3]) * scale[2] + vec_hsum_i32x4(summs2) * scale[2] +
(summs3[0] + summs3[1] + summs3[2] + summs3[3]) * scale[3]; vec_hsum_i32x4(summs3) * scale[3];
scale += 4; scale += 4;
} }
@ -1285,7 +1283,7 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v
const int8x16_t v_yh = vec_xl(QK8_0/2, y0->qs); const int8x16_t v_yh = vec_xl(QK8_0/2, y0->qs);
const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh); const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);
sumf += GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d) * (v_xy[0] + v_xy[1] + v_xy[2] + v_xy[3]); sumf += GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d) * vec_hsum_i32x4(v_xy);
} }
*s = sumf; *s = sumf;
@ -1354,8 +1352,8 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
h >>= 4; h >>= 4;
sumi1 += (vsumi0[0] + vsumi0[1] + vsumi0[2] + vsumi0[3]) * ls1; sumi1 += vec_hsum_i32x4(vsumi0) * ls1;
sumi2 += (vsumi1[0] + vsumi1[1] + vsumi1[2] + vsumi1[3]) * ls2; sumi2 += vec_hsum_i32x4(vsumi1) * ls2;
} }
sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2); sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2);

View File

@ -68,12 +68,6 @@ struct ggml_compute_params {
#endif // __VXE2__ #endif // __VXE2__
#endif // __s390x__ && __VEC__ #endif // __s390x__ && __VEC__
#if defined(__s390x__) && defined(GGML_NNPA)
#ifndef __NNPA__
#define __NNPA__
#endif // __NNPA__
#endif // __s390x__ && GGML_NNPA
#if defined(__ARM_FEATURE_SVE) #if defined(__ARM_FEATURE_SVE)
#include <sys/prctl.h> #include <sys/prctl.h>
#endif #endif
@ -489,11 +483,16 @@ inline static int16x8_t vec_padd_s16(int16x8_t a, int16x8_t b) {
/** /**
* @see https://github.com/ggml-org/llama.cpp/pull/14037 * @see https://github.com/ggml-org/llama.cpp/pull/14037
*/ */
inline static float vec_hsum(float32x4_t v) { inline static float vec_hsum_f32x4(float32x4_t v) {
float32x4_t v_temp = v + vec_reve(v); float32x4_t v_temp = v + vec_reve(v);
return v_temp[0] + v_temp[1]; return v_temp[0] + v_temp[1];
} }
inline static int32_t vec_hsum_i32x4(int32x4_t v) {
int32x4_t v_temp = v + vec_reve(v);
return v_temp[0] + v_temp[1];
}
inline static int32x4_t ggml_vec_dot(int32x4_t acc, int8x16_t a, int8x16_t b) { inline static int32x4_t ggml_vec_dot(int32x4_t acc, int8x16_t a, int8x16_t b) {
const int16x8_t p = vec_mule(a, b) + vec_mulo(a, b); const int16x8_t p = vec_mule(a, b) + vec_mulo(a, b);
return acc + (vec_unpackh(p) + vec_unpackl(p)); return acc + (vec_unpackh(p) + vec_unpackl(p));

View File

@ -373,6 +373,9 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_K, .vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1, .nrows = 1,
}, },
[GGML_TYPE_I32] = {
.from_float = (ggml_from_float_t) ggml_cpu_fp32_to_i32,
},
}; };
const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) { const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) {
@ -2696,7 +2699,10 @@ struct ggml_cplan ggml_graph_plan(
if (ggml_is_quantized(node->type) || if (ggml_is_quantized(node->type) ||
// F16 -> BF16 and BF16 -> F16 copies go through intermediate F32 // F16 -> BF16 and BF16 -> F16 copies go through intermediate F32
(node->src[0]->type == GGML_TYPE_F16 && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) || (node->src[0]->type == GGML_TYPE_F16 && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) ||
(node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) { (node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16) ||
// conversion between F32 and I32
(node->src[0]->type == GGML_TYPE_F32 && node->src[1] && node->src[1]->type == GGML_TYPE_I32) ||
(node->src[0]->type == GGML_TYPE_I32 && node->src[1] && node->src[1]->type == GGML_TYPE_F32)) {
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks; cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
} }
} break; } break;
@ -3211,21 +3217,6 @@ void ggml_cpu_fp32_to_fp16(const float * x, ggml_fp16_t * y, int64_t n) {
__m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT); __m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
_mm_storel_epi64((__m128i *)(y + i), y_vec); _mm_storel_epi64((__m128i *)(y + i), y_vec);
} }
#elif defined(__NNPA__)
for (; i + 7 < n; i += 8) {
float32x4_t v_xh = vec_xl(0, (const float *)(x + i + 0));
float32x4_t v_xl = vec_xl(0, (const float *)(x + i + 4));
uint16x8_t v_yd = vec_round_from_fp32(v_xh, v_xl, 0);
uint16x8_t v_y = vec_convert_to_fp16(v_yd, 0);
vec_xst(v_y, 0, (ggml_fp16_t *)(y + i));
}
for (; i + 3 < n; i += 4) {
float32x4_t v_x = vec_xl(0, (const float *)(x + i));
float32x4_t v_zero = vec_splats(0.0f);
uint16x8_t v_yd = vec_round_from_fp32(v_x, v_zero, 0);
uint16x8_t v_y = vec_convert_to_fp16(v_yd, 0);
vec_xst(v_y, 0, (ggml_fp16_t *)(y + i));
}
#elif defined(__riscv_zvfh) #elif defined(__riscv_zvfh)
for (int vl; i < n; i += vl) { for (int vl; i < n; i += vl) {
vl = __riscv_vsetvl_e32m2(n - i); vl = __riscv_vsetvl_e32m2(n - i);
@ -3259,21 +3250,6 @@ void ggml_cpu_fp16_to_fp32(const ggml_fp16_t * x, float * y, int64_t n) {
__m128 y_vec = _mm_cvtph_ps(x_vec); __m128 y_vec = _mm_cvtph_ps(x_vec);
_mm_storeu_ps(y + i, y_vec); _mm_storeu_ps(y + i, y_vec);
} }
#elif defined(__NNPA__)
for (; i + 7 < n; i += 8) {
uint16x8_t v_x = vec_xl(0, (const ggml_fp16_t *)(x + i));
uint16x8_t v_yd = vec_convert_from_fp16(v_x, 0);
float32x4_t v_yh = vec_extend_to_fp32_hi(v_yd, 0);
float32x4_t v_yl = vec_extend_to_fp32_lo(v_yd, 0);
vec_xst(v_yh, 0, (float *)(y + i + 0));
vec_xst(v_yl, 0, (float *)(y + i + 4));
}
for (; i + 3 < n; i += 4) {
uint16x8_t v_x = vec_xl(0, (const ggml_fp16_t *)(x + i));
uint16x8_t v_yd = vec_convert_from_fp16(v_x, 0);
float32x4_t v_yh = vec_extend_to_fp32_hi(v_yd, 0);
vec_xst(v_yh, 0, (float *)(y + i));
}
#endif #endif
for (; i < n; ++i) { for (; i < n; ++i) {
@ -3288,6 +3264,13 @@ void ggml_cpu_fp32_to_bf16(const float * x, ggml_bf16_t * y, int64_t n) {
} }
} }
void ggml_cpu_fp32_to_i32(const float * x, int32_t * y, int64_t n) {
int64_t i = 0;
for (; i < n; ++i) {
y[i] = x[i];
}
}
void ggml_cpu_bf16_to_fp32(const ggml_bf16_t * x, float * y, int64_t n) { void ggml_cpu_bf16_to_fp32(const ggml_bf16_t * x, float * y, int64_t n) {
int64_t i = 0; int64_t i = 0;
#if defined(__AVX2__) #if defined(__AVX2__)
@ -3477,14 +3460,6 @@ int ggml_cpu_has_vxe(void) {
#endif #endif
} }
int ggml_cpu_has_nnpa(void) {
#if defined(GGML_NNPA)
return 1;
#else
return 0;
#endif
}
int ggml_cpu_has_neon(void) { int ggml_cpu_has_neon(void) {
#if defined(__ARM_ARCH) && defined(__ARM_NEON) #if defined(__ARM_ARCH) && defined(__ARM_NEON)
return 1; return 1;

View File

@ -190,6 +190,7 @@ static const struct ggml_backend_i ggml_backend_cpu_i = {
/* .graph_compute = */ ggml_backend_cpu_graph_compute, /* .graph_compute = */ ggml_backend_cpu_graph_compute,
/* .event_record = */ NULL, /* .event_record = */ NULL,
/* .event_wait = */ NULL, /* .event_wait = */ NULL,
/* .optimize_graph = */ NULL,
}; };
static ggml_guid_t ggml_backend_cpu_guid(void) { static ggml_guid_t ggml_backend_cpu_guid(void) {
@ -348,8 +349,10 @@ static void ggml_backend_cpu_device_get_memory(ggml_backend_dev_t dev, size_t *
long pages = sysconf(_SC_PHYS_PAGES); long pages = sysconf(_SC_PHYS_PAGES);
long page_size = sysconf(_SC_PAGE_SIZE); long page_size = sysconf(_SC_PAGE_SIZE);
*total = pages * page_size; *total = pages * page_size;
// "free" system memory is ill-defined, for practical purposes assume that all of it is free:
*free = *total; *free = *total;
#endif #endif // _WIN32
GGML_UNUSED(dev); GGML_UNUSED(dev);
} }
@ -576,9 +579,6 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
if (ggml_cpu_has_vxe()) { if (ggml_cpu_has_vxe()) {
features.push_back({ "VXE", "1" }); features.push_back({ "VXE", "1" });
} }
if (ggml_cpu_has_nnpa()) {
features.push_back({ "NNPA", "1" });
}
if (ggml_cpu_has_wasm_simd()) { if (ggml_cpu_has_wasm_simd()) {
features.push_back({ "WASM_SIMD", "1" }); features.push_back({ "WASM_SIMD", "1" });
} }

View File

@ -154,7 +154,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
if (dst->src[0]->type == GGML_TYPE_Q4_0) { if (dst->src[0]->type == GGML_TYPE_Q4_0) {
return compute_forward_q4_0(params, dst); return compute_forward_q4_0(params, dst);
} else if (dst->src[0]->type == GGML_TYPE_F16) { } else if (dst->src[0]->type == GGML_TYPE_F16) {
return compute_forward_kv_cache(params, dst); return compute_forward_fp16(params, dst);
} }
} else if (dst->op == GGML_OP_GET_ROWS) { } else if (dst->op == GGML_OP_GET_ROWS) {
if (dst->src[0]->type == GGML_TYPE_Q4_0) { if (dst->src[0]->type == GGML_TYPE_Q4_0) {
@ -164,7 +164,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
return false; return false;
} }
bool compute_forward_kv_cache(ggml_compute_params * params, struct ggml_tensor * dst) { bool compute_forward_fp16(ggml_compute_params * params, struct ggml_tensor * dst) {
static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT; static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT;
const ggml_tensor * src0 = dst->src[0]; const ggml_tensor * src0 = dst->src[0];
@ -515,9 +515,6 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
op->src[0]->buffer && op->src[0]->buffer &&
(ggml_n_dims(op->src[0]) == 2) && (ggml_n_dims(op->src[0]) == 2) &&
op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) { op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type() && ctx.kernels) {
if (op->op == GGML_OP_GET_ROWS && op->src[1]->ne[0] != 8) {
return false;
}
if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) { if (op->src[1]->buffer && !ggml_backend_buft_is_host(op->src[1]->buffer->buft)) {
return false; return false;
} }
@ -534,13 +531,8 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) { if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
return (ggml::cpu::tensor_traits *) op->src[0]->extra; return (ggml::cpu::tensor_traits *) op->src[0]->extra;
} }
else if (ggml_kleidiai_select_kernels(ctx.features, op) && else if (ggml_kleidiai_select_kernels(ctx.features, op) && op->src[1]->ne[1] > 1) {
op->src[0]->op == GGML_OP_VIEW && if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
(op->src[1]->op == GGML_OP_PERMUTE || op->src[1]->op == GGML_OP_SOFT_MAX) &&
op->src[1]->ne[1] > 1) {
if ((op->src[0]->nb[0] != 2) ||
(op->src[1]->nb[0] != 4) ||
(op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
(op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) { (op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
return nullptr; return nullptr;
} }

View File

@ -776,6 +776,24 @@ static void ggml_compute_forward_dup_f32(
id += ne00 * (ne01 - ir1); id += ne00 * (ne01 - ir1);
} }
} }
} else if (dst->type == GGML_TYPE_I32) {
size_t id = 0;
int32_t * dst_ptr = (int32_t *) dst->data;
for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
id += ne00 * ir0;
for (int i01 = ir0; i01 < ir1; i01++) {
for (int i00 = 0; i00 < ne00; i00++) {
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
dst_ptr[id] = *src0_ptr;
id++;
}
}
id += ne00 * (ne01 - ir1);
}
}
} else { } else {
GGML_ABORT("fatal error"); // TODO: implement GGML_ABORT("fatal error"); // TODO: implement
} }
@ -947,6 +965,144 @@ static void ggml_compute_forward_dup_f32(
} }
} }
} }
} else if (dst->type == GGML_TYPE_I32) {
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
i10 += ne00 * ir0;
while (i10 >= ne0) {
i10 -= ne0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
for (int64_t i01 = ir0; i01 < ir1; i01++) {
for (int64_t i00 = 0; i00 < ne00; i00++) {
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
*(int32_t *) dst_ptr = *(const float *) src0_ptr;
if (++i10 == ne0) {
i10 = 0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
}
}
i10 += ne00 * (ne01 - ir1);
while (i10 >= ne0) {
i10 -= ne0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
}
}
} else {
GGML_ABORT("fatal error"); // TODO: implement
}
}
static void ggml_compute_forward_dup_i32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
GGML_TENSOR_UNARY_OP_LOCALS
const int ith = params->ith; // thread index
const int nth = params->nth; // number of threads
// parallelize by rows
const int nr = ne01;
// number of rows per thread
const int dr = (nr + nth - 1) / nth;
// row range for this thread
const int ir0 = dr * ith;
const int ir1 = MIN(ir0 + dr, nr);
// dst counters
int64_t i10 = 0;
int64_t i11 = 0;
int64_t i12 = 0;
int64_t i13 = 0;
// TODO: not optimal, but works
if (dst->type == GGML_TYPE_F32) {
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
i10 += ne00 * ir0;
while (i10 >= ne0) {
i10 -= ne0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
for (int64_t i01 = ir0; i01 < ir1; i01++) {
for (int64_t i00 = 0; i00 < ne00; i00++) {
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
*(float *) dst_ptr = *(const int32_t *) src0_ptr;
if (++i10 == ne0) {
i10 = 0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
}
}
i10 += ne00 * (ne01 - ir1);
while (i10 >= ne0) {
i10 -= ne0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
}
}
} else { } else {
GGML_ABORT("fatal error"); // TODO: implement GGML_ABORT("fatal error"); // TODO: implement
} }
@ -1177,6 +1333,10 @@ void ggml_compute_forward_dup(
{ {
ggml_compute_forward_dup_f32(params, dst); ggml_compute_forward_dup_f32(params, dst);
} break; } break;
case GGML_TYPE_I32:
{
ggml_compute_forward_dup_i32(params, dst);
} break;
default: default:
{ {
if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) { if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {
@ -8438,6 +8598,7 @@ static void ggml_compute_forward_timestep_embedding_f32(
embed_data[j + half] = sinf(arg); embed_data[j + half] = sinf(arg);
} }
if (dim % 2 != 0 && ith == 0) { if (dim % 2 != 0 && ith == 0) {
embed_data[2 * half] = 0.f;
embed_data[dim] = 0.f; embed_data[dim] = 0.f;
} }
} }

View File

@ -114,26 +114,6 @@ extern "C" {
#define GGML_CPU_COMPUTE_FP32_TO_FP16(x) riscv_compute_fp32_to_fp16(x) #define GGML_CPU_COMPUTE_FP32_TO_FP16(x) riscv_compute_fp32_to_fp16(x)
#define GGML_CPU_FP16_TO_FP32(x) GGML_CPU_COMPUTE_FP16_TO_FP32(x) #define GGML_CPU_FP16_TO_FP32(x) GGML_CPU_COMPUTE_FP16_TO_FP32(x)
#define GGML_CPU_FP32_TO_FP16(x) GGML_CPU_COMPUTE_FP32_TO_FP16(x) #define GGML_CPU_FP32_TO_FP16(x) GGML_CPU_COMPUTE_FP32_TO_FP16(x)
#elif defined(__NNPA__)
#define GGML_CPU_COMPUTE_FP16_TO_FP32(x) nnpa_compute_fp16_to_fp32(x)
#define GGML_CPU_COMPUTE_FP32_TO_FP16(x) nnpa_compute_fp32_to_fp16(x)
#define GGML_CPU_FP16_TO_FP32(x) GGML_CPU_COMPUTE_FP16_TO_FP32(x)
#define GGML_CPU_FP32_TO_FP16(x) GGML_CPU_COMPUTE_FP32_TO_FP16(x)
static inline float nnpa_compute_fp16_to_fp32(ggml_fp16_t h) {
uint16x8_t v_h = vec_splats(h);
uint16x8_t v_hd = vec_convert_from_fp16(v_h, 0);
return vec_extend_to_fp32_hi(v_hd, 0)[0];
}
static inline ggml_fp16_t nnpa_compute_fp32_to_fp16(float f) {
float32x4_t v_f = vec_splats(f);
float32x4_t v_zero = vec_splats(0.0f);
uint16x8_t v_hd = vec_round_from_fp32(v_f, v_zero, 0);
uint16x8_t v_h = vec_convert_to_fp16(v_hd, 0);
return vec_extract(v_h, 0);
}
#endif #endif
// precomputed f32 table for f16 (256 KB) // precomputed f32 table for f16 (256 KB)
@ -1156,11 +1136,6 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
#define GGML_F16_EPR GGML_F32_EPR #define GGML_F16_EPR GGML_F32_EPR
static inline float32x4_t __lzs_f16cx4_load(const ggml_fp16_t * x) { static inline float32x4_t __lzs_f16cx4_load(const ggml_fp16_t * x) {
#if defined(__NNPA__)
uint16x8_t v_x = vec_xl(0, (const ggml_fp16_t *)x);
uint16x8_t v_xd = vec_convert_from_fp16(v_x, 0);
return vec_extend_to_fp32_hi(v_xd, 0);
#else
float tmp[4]; float tmp[4];
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
@ -1170,20 +1145,9 @@ static inline float32x4_t __lzs_f16cx4_load(const ggml_fp16_t * x) {
// note: keep type-cast here to prevent compiler bugs // note: keep type-cast here to prevent compiler bugs
// see: https://github.com/ggml-org/llama.cpp/issues/12846 // see: https://github.com/ggml-org/llama.cpp/issues/12846
return vec_xl(0, (const float *)(tmp)); return vec_xl(0, (const float *)(tmp));
#endif
} }
static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) { static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) {
#if defined(__NNPA__)
float32x4_t v_zero = vec_splats(0.0f);
uint16x8_t v_xd = vec_round_from_fp32(v_y, v_zero, 0);
uint16x8_t v_x = vec_convert_to_fp16(v_xd, 0);
x[0] = vec_extract(v_x, 0);
x[1] = vec_extract(v_x, 1);
x[2] = vec_extract(v_x, 2);
x[3] = vec_extract(v_x, 3);
#else
float arr[4]; float arr[4];
// note: keep type-cast here to prevent compiler bugs // note: keep type-cast here to prevent compiler bugs
@ -1193,7 +1157,6 @@ static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) {
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
x[i] = GGML_CPU_FP32_TO_FP16(arr[i]); x[i] = GGML_CPU_FP32_TO_FP16(arr[i]);
} }
#endif
} }
#define GGML_F16_VEC GGML_F32x4 #define GGML_F16_VEC GGML_F32x4

View File

@ -44,6 +44,8 @@ if (CUDAToolkit_FOUND)
list(APPEND GGML_SOURCES_CUDA ${SRCS}) list(APPEND GGML_SOURCES_CUDA ${SRCS})
file(GLOB SRCS "template-instances/mmq*.cu") file(GLOB SRCS "template-instances/mmq*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS}) list(APPEND GGML_SOURCES_CUDA ${SRCS})
file(GLOB SRCS "template-instances/mmf*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
if (GGML_CUDA_FA_ALL_QUANTS) if (GGML_CUDA_FA_ALL_QUANTS)
file(GLOB SRCS "template-instances/fattn-vec*.cu") file(GLOB SRCS "template-instances/fattn-vec*.cu")

View File

@ -23,28 +23,44 @@ static __device__ __forceinline__ float op_div(const float a, const float b) {
return a / b; return a / b;
} }
template <float (*bin_op)(const float, const float),
typename src0_t,
typename src1_t,
typename dst_t,
typename... src1_ptrs>
static __global__ void k_bin_bcast(const src0_t * src0,
const src1_t * src1,
dst_t * dst,
const int ne0,
const int ne1,
const int ne2,
const uint3 ne3,
const uint3 ne10,
const uint3 ne11,
const uint3 ne12,
const uint3 ne13,
/*int s0, */ const int s1,
const int s2,
const int s3,
/*int s00,*/ const int s01,
const int s02,
const int s03,
/*int s10,*/ const int s11,
const int s12,
const int s13,
src1_ptrs... src1s) {
const uint32_t i0s = blockDim.x * blockIdx.x + threadIdx.x;
const uint32_t i1 = (blockDim.y * blockIdx.y + threadIdx.y);
const uint32_t i2 = fastdiv((blockDim.z * blockIdx.z + threadIdx.z), ne3);
const uint32_t i3 = (blockDim.z * blockIdx.z + threadIdx.z) - (i2 * ne3.z);
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3.z) {
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... src1_ptrs>
static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst_t * dst,
const int ne0, const int ne1, const int ne2, const int ne3,
const int ne10, const int ne11, const int ne12, const int ne13,
/*int s0, */ const int s1, const int s2, const int s3,
/*int s00,*/ const int s01, const int s02, const int s03,
/*int s10,*/ const int s11, const int s12, const int s13,
src1_ptrs... src1s) {
const int i0s = blockDim.x*blockIdx.x + threadIdx.x;
const int i1 = (blockDim.y*blockIdx.y + threadIdx.y);
const int i2 = (blockDim.z*blockIdx.z + threadIdx.z) / ne3;
const int i3 = (blockDim.z*blockIdx.z + threadIdx.z) % ne3;
if (i0s >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
return; return;
} }
const int i11 = i1 % ne11; const uint32_t i11 = fastmodulo(i1, ne11);
const int i12 = i2 % ne12; const uint32_t i12 = fastmodulo(i2, ne12);
const int i13 = i3 % ne13; const uint32_t i13 = fastmodulo(i3, ne13);
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
@ -53,8 +69,8 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr; const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
dst_t * dst_row = dst + i_dst; dst_t * dst_row = dst + i_dst;
for (int i0 = i0s; i0 < ne0; i0 += blockDim.x*gridDim.x) { for (int i0 = i0s; i0 < ne0; i0 += blockDim.x * gridDim.x) {
const int i10 = i0 % ne10; const uint32_t i10 = fastmodulo(i0, ne10);
float result = src0_row ? (float) src0_row[i0] : 0.0f; float result = src0_row ? (float) src0_row[i0] : 0.0f;
if constexpr (sizeof...(src1_ptrs) > 0) { if constexpr (sizeof...(src1_ptrs) > 0) {
@ -67,28 +83,48 @@ static __global__ void k_bin_bcast(const src0_t * src0, const src1_t * src1, dst
} }
} }
template <float (*bin_op)(const float, const float), typename src0_t, typename src1_t, typename dst_t, typename... src1_ptrs> template <float (*bin_op)(const float, const float),
static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t * src1, dst_t * dst, typename src0_t,
const int ne0, const int ne1, const int ne2,const int ne3, typename src1_t,
const int ne10, const int ne11, const int ne12, const int ne13, typename dst_t,
/*int s0, */ const int s1, const int s2, const int s3, typename... src1_ptrs>
/*int s00,*/ const int s01, const int s02, const int s03, static __global__ void k_bin_bcast_unravel(const src0_t * src0,
/*int s10,*/ const int s11, const int s12, const int s13, const src1_t * src1,
src1_ptrs ... src1s) { dst_t * dst,
const uint3 ne0,
const uint3 ne1,
const uint3 ne2,
const uint32_t ne3,
const uint3 prod_012,
const uint3 prod_01,
const uint3 ne10,
const uint3 ne11,
const uint3 ne12,
const uint3 ne13,
/*int s0, */ const int s1,
const int s2,
const int s3,
/*int s00,*/ const int s01,
const int s02,
const int s03,
/*int s10,*/ const int s11,
const int s12,
const int s13,
src1_ptrs... src1s) {
const int i = blockDim.x*blockIdx.x + threadIdx.x; const int i = blockDim.x*blockIdx.x + threadIdx.x;
const int i3 = i/(ne2*ne1*ne0); const uint32_t i3 = fastdiv(i, prod_012);
const int i2 = (i/(ne1*ne0)) % ne2; const uint32_t i2 = fastdiv(i - i3 * prod_012.z, prod_01);
const int i1 = (i/ne0) % ne1; const uint32_t i1 = fastdiv(i - i3 * prod_012.z - i2 * prod_01.z, ne0);
const int i0 = i % ne0; const uint32_t i0 = i - i3 * prod_012.z - i2 * prod_01.z - i1 * ne0.z;
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) { if (i0 >= ne0.z || i1 >= ne1.z || i2 >= ne2.z || i3 >= ne3) {
return; return;
} }
const int i11 = i1 % ne11; const int i11 = fastmodulo(i1, ne11);
const int i12 = i2 % ne12; const int i12 = fastmodulo(i2, ne12);
const int i13 = i3 % ne13; const int i13 = fastmodulo(i3, ne13);
const size_t i_src0 = i3*s03 + i2*s02 + i1*s01; const size_t i_src0 = i3*s03 + i2*s02 + i1*s01;
const size_t i_src1 = i13*s13 + i12*s12 + i11*s11; const size_t i_src1 = i13*s13 + i12*s12 + i11*s11;
@ -97,7 +133,7 @@ static __global__ void k_bin_bcast_unravel(const src0_t * src0, const src1_t *
const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr; const src0_t * src0_row = src0 ? (src0 + i_src0) : nullptr;
dst_t * dst_row = dst + i_dst; dst_t * dst_row = dst + i_dst;
const int i10 = i0 % ne10; const int i10 = fastmodulo(i0, ne10);
float result = src0_row ? (float) src0_row[i0] : 0.0f; float result = src0_row ? (float) src0_row[i0] : 0.0f;
if constexpr (sizeof...(src1_ptrs) > 0) { if constexpr (sizeof...(src1_ptrs) > 0) {
@ -170,11 +206,6 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
//int64_t ne02 = cne0[2]; GGML_UNUSED(ne02); //int64_t ne02 = cne0[2]; GGML_UNUSED(ne02);
//int64_t ne03 = cne0[3]; GGML_UNUSED(ne03); //int64_t ne03 = cne0[3]; GGML_UNUSED(ne03);
int64_t ne10 = cne1[0];
int64_t ne11 = cne1[1];
int64_t ne12 = cne1[2];
int64_t ne13 = cne1[3];
size_t nb0 = cnb[0]; size_t nb0 = cnb[0];
size_t nb1 = cnb[1]; size_t nb1 = cnb[1];
size_t nb2 = cnb[2]; size_t nb2 = cnb[2];
@ -233,48 +264,51 @@ static void launch_bin_bcast_pack(const ggml_tensor * src0, const ggml_tensor *
block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x); block_dims.y = std::min<unsigned int>(ne1, block_size / block_dims.x);
block_dims.z = std::min(std::min<unsigned int>(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U); block_dims.z = std::min(std::min<unsigned int>(ne2 * ne3, block_size / block_dims.x / block_dims.y), 64U);
dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x, dim3 block_nums((hne0 + block_dims.x - 1) / block_dims.x, (ne1 + block_dims.y - 1) / block_dims.y,
(ne1 + block_dims.y - 1) / block_dims.y,
(ne2 * ne3 + block_dims.z - 1) / block_dims.z); (ne2 * ne3 + block_dims.z - 1) / block_dims.z);
const uint3 ne10 = init_fastdiv_values((uint32_t) cne1[0]);
const uint3 ne11 = init_fastdiv_values((uint32_t) cne1[1]);
const uint3 ne12 = init_fastdiv_values((uint32_t) cne1[2]);
const uint3 ne13 = init_fastdiv_values((uint32_t) cne1[3]);
if (block_nums.z > 65535) { if (block_nums.z > 65535) {
int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size; int block_num = (ne0 * ne1 * ne2 * ne3 + block_size - 1) / block_size;
const uint3 prod_012 = init_fastdiv_values((uint32_t) (ne0 * ne1 * ne2));
const uint3 prod_01 = init_fastdiv_values((uint32_t) (ne0 * ne1));
const uint3 ne0_fastdiv = init_fastdiv_values((uint32_t) ne0);
const uint3 ne1_fastdiv = init_fastdiv_values((uint32_t) ne1);
const uint3 ne2_fastdiv = init_fastdiv_values((uint32_t) ne2);
if constexpr (sizeof...(I) > 0) { if constexpr (sizeof...(I) > 0) {
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t> k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t><<<block_num, block_size, 0, stream>>>(
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11,
ne0, ne1, ne2, ne3, ne12, ne13,
ne10, ne11, ne12, ne13, /* s0, */ s1, s2, s3,
/* s0, */ s1, s2, s3, /* s00,*/ s01, s02, s03,
/* s00,*/ s01, s02, s03, /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
/* s10,*/ s11, s12,s13,
(const src1_t *) dst->src[I + 1]->data...);
} else { } else {
k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t> k_bin_bcast_unravel<bin_op, src0_t, src1_t, dst_t>
<<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, <<<block_num, block_size, 0, stream>>>(src0_dd, src1_dd, dst_dd, ne0_fastdiv, ne1_fastdiv,
ne0, ne1, ne2, ne3, ne2_fastdiv, ne3, prod_012, prod_01, ne10, ne11, ne12, ne13,
ne10, ne11, ne12, ne13, /* s0, */ s1, s2, s3,
/* s0, */ s1, s2, s3, /* s00,*/ s01, s02, s03,
/* s00,*/ s01, s02, s03, /* s10,*/ s11, s12, s13);
/* s10,*/ s11, s12,s13);
} }
} else { } else {
const uint3 ne3_fastdiv = init_fastdiv_values((uint32_t) ne3);
if constexpr (sizeof...(I) > 0) { if constexpr (sizeof...(I) > 0) {
k_bin_bcast<bin_op, src0_t, src1_t, dst_t> k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
<<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd, src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
ne0, ne1, ne2, ne3, /* s0, */ s1, s2, s3,
ne10, ne11, ne12, ne13, /* s00,*/ s01, s02, s03,
/* s0, */ s1, s2, s3, /* s10,*/ s11, s12, s13, (const src1_t *) dst->src[I + 1]->data...);
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12,s13,
(const src1_t *) dst->src[I + 1]->data...);
} else { } else {
k_bin_bcast<bin_op, src0_t, src1_t, dst_t> k_bin_bcast<bin_op, src0_t, src1_t, dst_t><<<block_nums, block_dims, 0, stream>>>(
<<<block_nums, block_dims, 0, stream>>>(src0_dd, src1_dd, dst_dd, src0_dd, src1_dd, dst_dd, ne0, ne1, ne2, ne3_fastdiv, ne10, ne11, ne12, ne13,
ne0, ne1, ne2, ne3, /* s0, */ s1, s2, s3,
ne10, ne11, ne12, ne13, /* s00,*/ s01, s02, s03,
/* s0, */ s1, s2, s3, /* s10,*/ s11, s12, s13);
/* s00,*/ s01, s02, s03,
/* s10,*/ s11, s12,s13);
} }
} }
} }

View File

@ -545,6 +545,45 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
#endif // defined(GGML_USE_HIP) #endif // defined(GGML_USE_HIP)
} }
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float v, const float u) {
acc += v*u;
}
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v, const float2 u) {
acc += v.x*u.x;
acc += v.y*u.y;
}
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
#if defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))
asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u));
#else
#ifdef FAST_FP16_AVAILABLE
const float2 tmp = __half22float2(v*u);
acc += tmp.x + tmp.y;
#else
const float2 tmpv = __half22float2(v);
const float2 tmpu = __half22float2(u);
acc += tmpv.x * tmpu.x;
acc += tmpv.y * tmpu.y;
#endif // FAST_FP16_AVAILABLE
#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(GCN5) || defined(CDNA))
}
// Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD.
template <int nbytes>
static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) {
if constexpr (nbytes == 4) {
*(int *) dst = *(const int *) src;
} else if constexpr (nbytes == 8) {
*(int2 *) dst = *(const int2 *) src;
} else if constexpr (nbytes == 16) {
*(int4 *) dst = *(const int4 *) src;
} else {
static_assert(nbytes == 0 && nbytes == -1, "bad nbytes");
}
}
static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) { static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
#if CUDART_VERSION >= 12080 #if CUDART_VERSION >= 12080
const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x); const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x);
@ -570,6 +609,8 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
// //
// n/d = (mulhi(n, mp) + n) >> L; // n/d = (mulhi(n, mp) + n) >> L;
static const uint3 init_fastdiv_values(uint32_t d) { static const uint3 init_fastdiv_values(uint32_t d) {
GGML_ASSERT(d != 0);
// compute L = ceil(log2(d)); // compute L = ceil(log2(d));
uint32_t L = 0; uint32_t L = 0;
while (L < 32 && (uint32_t{ 1 } << L) < d) { while (L < 32 && (uint32_t{ 1 } << L) < d) {

View File

@ -38,6 +38,8 @@ template<typename dst_t, typename src_t>
return __float2bfloat16(float(x)); return __float2bfloat16(float(x));
} else if constexpr(std::is_same_v<src_t, nv_bfloat16>) { } else if constexpr(std::is_same_v<src_t, nv_bfloat16>) {
return __bfloat162float(x); return __bfloat162float(x);
} else if constexpr(std::is_same_v<dst_t, int32_t>) {
return int32_t(x);
} else { } else {
return float(x); return float(x);
} }

View File

@ -374,6 +374,10 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) { } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else { } else {
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__, GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type)); ggml_type_name(src0->type), ggml_type_name(src1->type));

View File

@ -1,371 +0,0 @@
#include "common.cuh"
#include "fattn-common.cuh"
#include "fattn-tile-f16.cuh"
#define FATTN_KQ_STRIDE_TILE_F16 64
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
#if !defined(GGML_USE_HIP)
__launch_bounds__(nwarps*WARP_SIZE, 2)
#endif // !defined(GGML_USE_HIP)
static __global__ void flash_attn_tile_ext_f16(
const char * __restrict__ Q,
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0,
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
const int32_t nb01, const int32_t nb02, const int32_t nb03,
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
const int32_t nb11, const int32_t nb12, const int64_t nb13,
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
// Skip unused kernel variants for faster compilation:
#ifdef FP16_MMA_AVAILABLE
NO_DEVICE_CODE;
return;
#endif // FP16_MMA_AVAILABLE
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
return;
}
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float * sinksf = (const float *) (sinks);
const int stride_KV2 = nb11 / sizeof(half2);
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
__shared__ half KQ[ncols*FATTN_KQ_STRIDE_TILE_F16];
half2 * KQ2 = (half2 *) KQ;
__shared__ half2 KV_tmp[FATTN_KQ_STRIDE_TILE_F16][D/2 + 1]; // Pad D to avoid memory bank conflicts.
half kqmax[ncols/nwarps];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
kqmax[j0/nwarps] = -HALF_MAX_HALF;
}
half2 kqsum[ncols/nwarps] = {{0.0f, 0.0f}};
half2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}};
// Convert Q to half2 and store in registers:
__shared__ half2 Q_h2[ncols][D/2];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f);
Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
}
}
__syncthreads();
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) {
// Calculate KQ tile and keep track of new maximum KQ values:
half kqmax_new[ncols/nwarps];
#pragma unroll
for (int j = 0; j < ncols/nwarps; ++j) {
kqmax_new[j] = kqmax[j];
}
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += nwarps) {
const int i_KQ = i_KQ_0 + threadIdx.y;
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
const int k_KQ = k_KQ_0 + threadIdx.x;
KV_tmp[i_KQ][k_KQ] = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
}
}
__syncthreads();
half2 sum2[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE][ncols/nwarps] = {{{0.0f, 0.0f}}};
#pragma unroll
for (int k_KQ = 0; k_KQ < D/2; ++k_KQ) {
half2 K_k[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE];
half2 Q_k[ncols/nwarps];
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
const int i_KQ = i_KQ_0 + threadIdx.x;
K_k[i_KQ_0/WARP_SIZE] = KV_tmp[i_KQ][k_KQ];
}
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
const int j_KQ = j_KQ_0 + threadIdx.y;
Q_k[j_KQ_0/nwarps] = Q_h2[j_KQ][k_KQ];
}
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += K_k[i_KQ_0/WARP_SIZE]*Q_k[j_KQ_0/nwarps];
}
}
}
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
const int i_KQ = i_KQ_0 + threadIdx.x;
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
const int j_KQ = j_KQ_0 + threadIdx.y;
half sum;
if (use_logit_softcap) {
const float2 tmp = __half22float2(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
sum = logit_softcap * tanhf(tmp.x + tmp.y);
} else {
sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
}
sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum);
KQ[j_KQ*FATTN_KQ_STRIDE_TILE_F16 + i_KQ] = sum;
}
}
__syncthreads();
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]);
const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]));
kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
#pragma unroll
for (int i0 = 0; i0 < FATTN_KQ_STRIDE_TILE_F16/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
const half2 diff = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] - __half2half2(kqmax[j0/nwarps]);
const half2 val = h2exp(diff);
kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + val;
KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] = val;
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale;
}
}
__syncthreads();
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += nwarps) {
const int k = k0 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
KV_tmp[k][i] = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i];
}
}
__syncthreads();
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += 2) {
half2 V_k[(D/2)/WARP_SIZE][2];
half2 KQ_k[ncols/nwarps];
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
V_k[i0/WARP_SIZE][0] = KV_tmp[k0 + 0][i];
V_k[i0/WARP_SIZE][1] = KV_tmp[k0 + 1][i];
}
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
KQ_k[j0/nwarps] = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + k0/2];
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][0]* __low2half2(KQ_k[j0/nwarps]);
VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][1]*__high2half2(KQ_k[j0/nwarps]);
}
}
}
__syncthreads();
}
//Attention sink: adjust running max and sum once per head
if (sinksf && blockIdx.y == 0) {
const half sink = __float2half(sinksf[head]);
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
half kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
kqmax_new_j = warp_reduce_max(kqmax_new_j);
const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new_j));
kqmax[j0/nwarps] = kqmax_new_j;
const half val = hexp(sink - kqmax[j0/nwarps]);
kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
if (threadIdx.x == 0) {
kqsum[j0/nwarps].x = __hadd(__low2half(kqsum[j0/nwarps]), val);
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale;
}
}
}
float2 * dst2 = (float2 *) dst;
#pragma unroll
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
const int j_VKQ = j_VKQ_0 + threadIdx.y;
if (ic0 + j_VKQ >= ne01) {
return;
}
half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
kqsum_j = warp_reduce_sum((float)kqsum_j);
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
#pragma unroll
for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
const int i0 = i00 + threadIdx.x;
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
if (gridDim.y == 1) {
dst_val /= __half2half2(kqsum_j);
}
dst2[j_dst_unrolled*(D/2) + i0] = __half22float2(dst_val);
}
if (gridDim.y != 1 && threadIdx.x == 0) {
dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
}
}
#else
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
max_bias, m0, m1, n_head_log2, logit_softcap,
ne00, ne01, ne02, ne03,
nb01, nb02, nb03,
ne10, ne11, ne12, ne13,
nb11, nb12, nb13,
nb21, nb22, nb23,
ne31, ne32, ne33,
nb31, nb32, nb33);
NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
}
template <int cols_per_block, bool use_logit_softcap>
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
case 64: {
constexpr int D = 64;
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
launch_fattn<D, cols_per_block, 1>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
launch_fattn<D, cols_per_block, 1>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
} break;
default: {
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
} break;
}
}
void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const int32_t precision = KQV->op_params[3];
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
if (Q->ne[1] <= 16) {
constexpr int cols_per_block = 16;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
}
return;
}
constexpr int cols_per_block = 32;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
}
}

View File

@ -1,3 +0,0 @@
#include "common.cuh"
void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -1,379 +0,0 @@
#include "common.cuh"
#include "fattn-common.cuh"
#include "fattn-tile-f32.cuh"
#define FATTN_KQ_STRIDE_TILE_F32 32
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
#if !defined(GGML_USE_HIP)
__launch_bounds__(nwarps*WARP_SIZE, 2)
#endif // !defined(GGML_USE_HIP)
static __global__ void flash_attn_tile_ext_f32(
const char * __restrict__ Q,
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0,
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
const int32_t nb01, const int32_t nb02, const int32_t nb03,
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
const int32_t nb11, const int32_t nb12, const int64_t nb13,
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
#ifdef FLASH_ATTN_AVAILABLE
// Skip unused kernel variants for faster compilation:
#ifdef FP16_MMA_AVAILABLE
NO_DEVICE_CODE;
return;
#endif // FP16_MMA_AVAILABLE
if (use_logit_softcap && !(D == 128 || D == 256)) {
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
max_bias, m0, m1, n_head_log2, logit_softcap,
ne00, ne01, ne02, ne03,
nb01, nb02, nb03,
ne10, ne11, ne12, ne13,
nb11, nb12, nb13,
nb21, nb22, nb23,
ne31, ne32, ne33,
nb31, nb32, nb33);
NO_DEVICE_CODE;
return;
}
// In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float * sinksf = (const float *) (sinks);
const int stride_KV2 = nb11 / sizeof(half2);
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
__shared__ float KQ[ncols*FATTN_KQ_STRIDE_TILE_F32];
__shared__ float KV_tmp[FATTN_KQ_STRIDE_TILE_F32][D + 1]; // Pad D to avoid memory bank conflicts.
float2 * KV_tmp2 = (float2 *) KV_tmp;
float kqmax[ncols/nwarps];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
kqmax[j0/nwarps] = -FLT_MAX/2.0f;
}
float kqsum[ncols/nwarps] = {0.0f};
float2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}};
// Convert Q to half2 and store in registers:
__shared__ float Q_f[ncols][D];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) {
float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x] : make_float2(0.0f, 0.0f);
Q_f[j][i0 + 0*WARP_SIZE + threadIdx.x] = tmp.x * scale;
Q_f[j][i0 + 1*WARP_SIZE + threadIdx.x] = tmp.y * scale;
}
}
__syncthreads();
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) {
// Calculate KQ tile and keep track of new maximum KQ values:
float kqmax_new[ncols/nwarps];
#pragma unroll
for (int j = 0; j < ncols/nwarps; ++j) {
kqmax_new[j] = kqmax[j];
}
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += nwarps) {
const int i_KQ = i_KQ_0 + threadIdx.y;
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) {
const half2 tmp = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp);
KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp);
}
}
__syncthreads();
float sum[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE][ncols/nwarps] = {{0.0f}};
#pragma unroll
for (int k_KQ = 0; k_KQ < D; ++k_KQ) {
float K_k[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE];
float Q_k[ncols/nwarps];
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) {
const int i_KQ = i_KQ_0 + threadIdx.x;
K_k[i_KQ_0/WARP_SIZE] = KV_tmp[i_KQ][k_KQ];
}
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
const int j_KQ = j_KQ_0 + threadIdx.y;
Q_k[j_KQ_0/nwarps] = Q_f[j_KQ][k_KQ];
}
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) {
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += K_k[i_KQ_0/WARP_SIZE] * Q_k[j_KQ_0/nwarps];
}
}
}
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) {
const int i_KQ = i_KQ_0 + threadIdx.x;
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
const int j_KQ = j_KQ_0 + threadIdx.y;
if (use_logit_softcap) {
sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] = logit_softcap * tanhf(sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
}
sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
KQ[j_KQ*FATTN_KQ_STRIDE_TILE_F32 + i_KQ] = sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps];
}
}
__syncthreads();
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]);
const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]);
kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
float kqsum_add = 0.0f;
#pragma unroll
for (int i0 = 0; i0 < FATTN_KQ_STRIDE_TILE_F32; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
const float diff = KQ[j*FATTN_KQ_STRIDE_TILE_F32 + i] - kqmax[j0/nwarps];
const float val = expf(diff);
kqsum_add += val;
KQ[j*FATTN_KQ_STRIDE_TILE_F32 + i] = val;
}
kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add;
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale;
VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale;
}
}
__syncthreads();
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F32; k0 += nwarps) {
const int k = k0 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
const half2 tmp = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i];
KV_tmp2[k*(D/2) + i].x = __low2float(tmp);
KV_tmp2[k*(D/2) + i].y = __high2float(tmp);
}
}
__syncthreads();
#pragma unroll
for (int k = 0; k < FATTN_KQ_STRIDE_TILE_F32; ++k) {
float2 V_k[(D/2)/WARP_SIZE];
float KQ_k[ncols/nwarps];
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
V_k[i0/WARP_SIZE] = KV_tmp2[k*(D/2) + i];
}
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
KQ_k[j0/nwarps] = KQ[j*FATTN_KQ_STRIDE_TILE_F32 + k];
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
VKQ[j0/nwarps][i0/WARP_SIZE].x += V_k[i0/WARP_SIZE].x*KQ_k[j0/nwarps];
VKQ[j0/nwarps][i0/WARP_SIZE].y += V_k[i0/WARP_SIZE].y*KQ_k[j0/nwarps];
}
}
}
__syncthreads();
}
//Attention sink: adjust running max and sum once per head
if (sinksf && blockIdx.y == 0) {
const float sink = sinksf[head];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
float kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
kqmax_new_j = warp_reduce_max(kqmax_new_j);
const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new_j);
kqmax[j0/nwarps] = kqmax_new_j;
const float val = expf(sink - kqmax[j0/nwarps]);
kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
if (threadIdx.x == 0) {
kqsum[j0/nwarps] += val;
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale;
VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale;
}
}
}
float2 * dst2 = (float2 *) dst;
#pragma unroll
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
const int j_VKQ = j_VKQ_0 + threadIdx.y;
if (ic0 + j_VKQ >= ne01) {
return;
}
float kqsum_j = kqsum[j_VKQ_0/nwarps];
kqsum_j = warp_reduce_sum(kqsum_j);
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
#pragma unroll
for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
const int i0 = i00 + threadIdx.x;
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
if (gridDim.y == 1) {
dst_val.x /= kqsum_j;
dst_val.y /= kqsum_j;
}
dst2[j_dst_unrolled*(D/2) + i0] = dst_val;
}
if (gridDim.y != 1 && threadIdx.x == 0) {
dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
}
}
#else
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
max_bias, m0, m1, n_head_log2, logit_softcap,
ne00, ne01, ne02, ne03,
nb01, nb02, nb03,
ne10, ne11, ne12, ne13,
nb11, nb12, nb13,
nb21, nb22, nb23,
ne31, ne32, ne33,
nb31, nb32, nb33);
NO_DEVICE_CODE;
#endif // FLASH_ATTN_AVAILABLE
}
template <int cols_per_block, bool use_logit_softcap>
void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
case 64: {
constexpr int D = 64;
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
launch_fattn<D, cols_per_block, 1>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
launch_fattn<D, cols_per_block, 1>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
} break;
default: {
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
} break;
}
}
void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
if (Q->ne[1] <= 16) {
constexpr int cols_per_block = 16;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
}
return;
}
constexpr int cols_per_block = 32;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
}
}

View File

@ -1,3 +0,0 @@
#include "common.cuh"
void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -0,0 +1,660 @@
#include "common.cuh"
#include "fattn-common.cuh"
#include "fattn-tile.cuh"
#define FATTN_TILE_NTHREADS 256
static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int cc, const int warp_size) {
if (GGML_CUDA_CC_IS_AMD(cc)) {
switch (D) {
case 64:
return 64;
case 128:
case 256:
if (GGML_CUDA_CC_IS_GCN(cc) || GGML_CUDA_CC_IS_CDNA(cc)) {
return ncols <= 16 ? 64 : 32;
} else {
return 64;
}
default:
GGML_ABORT("fatal error");
return -1;
}
}
if (fast_fp16_available(cc)) {
switch (D) {
case 64:
case 128:
return 128;
case 256:
return ncols <= 16 ? 128 : 64;
default:
GGML_ABORT("fatal error");
return -1;
}
}
switch (D) {
case 64:
return ncols <= 16 ? 128 : 64;
case 128:
return ncols <= 16 ? 64 : 32;
case 256:
return 32;
default:
GGML_ABORT("fatal error");
return -1;
}
GGML_UNUSED(warp_size);
}
static constexpr __device__ int fattn_tile_get_kq_stride_device(int D, int ncols, int warp_size) {
#ifdef GGML_USE_HIP
switch (D) {
case 64:
return 64;
case 128:
#if defined(GCN) || defined(CDNA)
return ncols <= 16 ? 64 : 32;
#else
return 64;
#endif // defined(GCN) || defined(CDNA)
case 256:
#if defined(GCN) || defined(CDNA)
return ncols <= 16 ? 64 : 32;
#else
return 64;
#endif // defined(GCN) || defined(CDNA)
default:
return -1;
}
#else
#ifdef FAST_FP16_AVAILABLE
switch (D) {
case 64:
case 128:
return 128;
case 256:
return ncols <= 16 ? 128 : 64;
default:
return -1;
}
#else
switch (D) {
case 64:
return ncols <= 16 ? 128 : 64;
case 128:
return ncols <= 16 ? 64 : 32;
case 256:
return 32;
default:
return -1;
}
#endif // FAST_FP16_AVAILABLE
#endif // GGML_USE_HIP
GGML_UNUSED_VARS(ncols, warp_size);
}
static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols, int warp_size) {
#ifdef GGML_USE_HIP
switch (D) {
case 64:
return 64;
case 128:
#if defined(GCN) || defined(CDNA)
return ncols <= 16 ? 64 : 128;
#else
return 64;
#endif // defined(GCN) || defined(CDNA)
case 256:
#if defined(GCN) || defined(CDNA)
return ncols <= 16 ? 64 : 128;
#else
return ncols <= 16 ? 64 : 256;
#endif // defined(GCN) || defined(CDNA)
default:
return -1;
}
#else
#ifdef FAST_FP16_AVAILABLE
switch (D) {
case 64:
return 64;
case 128:
return ncols <= 16 ? 128 : 64;
case 256:
return ncols <= 16 ? 64 : 128;
default:
return -1;
}
#else
switch (D) {
case 64:
return 64;
case 128:
return 128;
case 256:
return ncols <= 16 ? 128 : 64;
default:
return -1;
}
#endif // FAST_FP16_AVAILABLE
#endif // GGML_USE_HIP
GGML_UNUSED_VARS(ncols, warp_size);
}
template<int D, int ncols, bool use_logit_softcap> // D == head size
#ifdef GGML_USE_HIP
__launch_bounds__(FATTN_TILE_NTHREADS, 1)
#else
__launch_bounds__(FATTN_TILE_NTHREADS, 2)
#endif // GGML_USE_HIP
static __global__ void flash_attn_tile(
const char * __restrict__ Q,
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0,
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
const int32_t nb01, const int32_t nb02, const int32_t nb03,
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
const int32_t nb11, const int32_t nb12, const int64_t nb13,
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
#ifdef FLASH_ATTN_AVAILABLE
// Skip unused kernel variants for faster compilation:
#ifdef FP16_MMA_AVAILABLE
NO_DEVICE_CODE;
return;
#endif // FP16_MMA_AVAILABLE
if (use_logit_softcap && !(D == 128 || D == 256)) {
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
max_bias, m0, m1, n_head_log2, logit_softcap,
ne00, ne01, ne02, ne03,
nb01, nb02, nb03,
ne10, ne11, ne12, ne13,
nb11, nb12, nb13,
nb21, nb22, nb23,
ne31, ne32, ne33,
nb31, nb32, nb33);
NO_DEVICE_CODE;
return;
}
constexpr int warp_size = 32;
constexpr int nwarps = FATTN_TILE_NTHREADS / warp_size;
constexpr int kq_stride = fattn_tile_get_kq_stride_device(D, ncols, warp_size);
static_assert(kq_stride % warp_size == 0, "kq_stride not divisable by warp_size.");
constexpr int kq_nbatch = fattn_tile_get_kq_nbatch_device(D, ncols, warp_size);
static_assert(kq_nbatch % (2*warp_size) == 0, "bad kq_nbatch");
// In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float * sinksf = (const float *) (sinks);
const int stride_KV2 = nb11 / sizeof(half2);
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
#if defined(GGML_USE_HIP)
constexpr int cpy_nb = 16;
#else
constexpr int cpy_nb = 8;
#endif // defined(GGML_USE_HIP) && defined(GCN)
constexpr int cpy_ne = cpy_nb / 4;
__shared__ float KQ[ncols][kq_stride];
#ifdef FAST_FP16_AVAILABLE
__shared__ half2 Q_tmp[ncols][D/2];
__shared__ half2 KV_tmp_h2[kq_stride * (kq_nbatch/2 + cpy_ne)]; // Padded to avoid memory bank conflicts.
half2 VKQ[ncols/nwarps][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
#else
__shared__ float Q_tmp[ncols][D];
__shared__ float KV_tmp_f[kq_stride * (kq_nbatch + cpy_ne)]; // Padded to avoid memory bank conflicts.
float2 * KV_tmp_f2 = (float2 *) KV_tmp_f;
float2 VKQ[ncols/nwarps][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
#endif // FAST_FP16_AVAILABLE
float kqmax[ncols/nwarps];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
kqmax[j0/nwarps] = -FLT_MAX/2.0f;
}
float kqsum[ncols/nwarps] = {0.0f};
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0 + threadIdx.x] : make_float2(0.0f, 0.0f);
#ifdef FAST_FP16_AVAILABLE
Q_tmp[j][i0 + threadIdx.x] = make_half2(tmp.x * scale, tmp.y * scale);
#else
Q_tmp[j][2*i0 + threadIdx.x] = tmp.x * scale;
Q_tmp[j][2*i0 + warp_size + threadIdx.x] = tmp.y * scale;
#endif // FAST_FP16_AVAILABLE
}
}
__syncthreads();
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
for (int k_VKQ_0 = blockIdx.y*kq_stride; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*kq_stride) {
// Calculate KQ tile and keep track of new maximum KQ values:
float kqmax_new[ncols/nwarps];
#pragma unroll
for (int j = 0; j < ncols/nwarps; ++j) {
kqmax_new[j] = kqmax[j];
}
float sum[kq_stride/warp_size][ncols/nwarps] = {{0.0f}};
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += kq_nbatch) {
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += nwarps) {
const int i_KQ = i_KQ_0 + threadIdx.y;
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size) {
const half2 tmp_h2 = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x];
#ifdef FAST_FP16_AVAILABLE
KV_tmp_h2[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1 + threadIdx.x] = tmp_h2;
#else
const float2 tmp_f2 = __half22float2(tmp_h2);
KV_tmp_f[i_KQ*(kq_nbatch + cpy_ne) + 2*k_KQ_1 + threadIdx.x] = tmp_f2.x;
KV_tmp_f[i_KQ*(kq_nbatch + cpy_ne) + 2*k_KQ_1 + warp_size + threadIdx.x] = tmp_f2.y;
#endif // FAST_FP16_AVAILABLE
}
}
__syncthreads();
#ifdef FAST_FP16_AVAILABLE
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += cpy_ne) {
half2 K_k[kq_stride/warp_size][cpy_ne];
half2 Q_k[ncols/nwarps][cpy_ne];
#else
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += cpy_ne) {
float K_k[kq_stride/warp_size][cpy_ne];
float Q_k[ncols/nwarps][cpy_ne];
#endif // FAST_FP16_AVAILABLE
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
const int i_KQ = i_KQ_0 + threadIdx.x;
#ifdef FAST_FP16_AVAILABLE
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp_h2[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1]);
#else
ggml_cuda_memcpy_1<cpy_nb>(&K_k[i_KQ_0/warp_size], &KV_tmp_f [i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1]);
#endif // FAST_FP16_AVAILABLE
}
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
const int j_KQ = j_KQ_0 + threadIdx.y;
#ifdef FAST_FP16_AVAILABLE
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0/nwarps], &Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]);
#else
ggml_cuda_memcpy_1<cpy_nb>(&Q_k[j_KQ_0/nwarps], &Q_tmp[j_KQ][k_KQ_0 + k_KQ_1]);
#endif // FAST_FP16_AVAILABLE
}
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
#pragma unroll
for (int k = 0; k < cpy_ne; ++k) {
ggml_cuda_mad(sum[i_KQ_0/warp_size][j_KQ_0/nwarps], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0/nwarps][k]);
}
}
}
}
if (k_KQ_0 + kq_nbatch < D) {
__syncthreads(); // Sync not needed on last iteration.
}
}
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
const int i_KQ = i_KQ_0 + threadIdx.x;
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
const int j_KQ = j_KQ_0 + threadIdx.y;
if (use_logit_softcap) {
sum[i_KQ_0/warp_size][j_KQ_0/nwarps] = logit_softcap * tanhf(sum[i_KQ_0/warp_size][j_KQ_0/nwarps]);
}
sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/warp_size][j_KQ_0/nwarps]);
KQ[j_KQ][i_KQ] = sum[i_KQ_0/warp_size][j_KQ_0/nwarps];
}
}
__syncthreads();
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
kqmax_new[j0/nwarps] = warp_reduce_max<warp_size>(kqmax_new[j0/nwarps]);
const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]);
kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
float kqsum_add = 0.0f;
if (kq_stride % (4*warp_size) == 0 && cpy_ne % 4 == 0) {
#pragma unroll
for (int i0 = 0; i0 < kq_stride; i0 += 4*warp_size) {
const int i = i0 + 4*threadIdx.x;
float4 val = *(const float4 *) &KQ[j][i];
val.x = expf(val.x - kqmax[j0/nwarps]);
val.y = expf(val.y - kqmax[j0/nwarps]);
val.z = expf(val.z - kqmax[j0/nwarps]);
val.w = expf(val.w - kqmax[j0/nwarps]);
kqsum_add += val.x + val.y + val.z + val.w;
#ifdef FAST_FP16_AVAILABLE
const half2 tmp[2] = {make_half2(val.x, val.y), make_half2(val.z, val.w)};
ggml_cuda_memcpy_1<sizeof(tmp)>(&KQ[j][i/2], &tmp);
#else
ggml_cuda_memcpy_1<sizeof(val)>(&KQ[j][i], &val);
#endif // FAST_FP16_AVAILABLE
}
} else if (kq_stride % (2*warp_size) == 0 && cpy_ne % 2 == 0) {
#pragma unroll
for (int i0 = 0; i0 < kq_stride; i0 += 2*warp_size) {
const int i = i0 + 2*threadIdx.x;
float2 val = *(const float2 *) &KQ[j][i];
val.x = expf(val.x - kqmax[j0/nwarps]);
val.y = expf(val.y - kqmax[j0/nwarps]);
kqsum_add += val.x + val.y;
#ifdef FAST_FP16_AVAILABLE
const half2 tmp = make_half2(val.x, val.y);
ggml_cuda_memcpy_1<sizeof(tmp)>(&KQ[j][i/2], &tmp);
#else
ggml_cuda_memcpy_1<sizeof(val)>(&KQ[j][i], &val);
#endif // FAST_FP16_AVAILABLE
}
} else {
for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
const int i = i0 + threadIdx.x;
const float diff = KQ[j][i] - kqmax[j0/nwarps];
const float val = expf(diff);
kqsum_add += val;
#ifdef FAST_FP16_AVAILABLE
((half *) KQ[j])[i] = val;
#else
KQ[j][i] = val;
#endif // FAST_FP16_AVAILABLE
}
}
kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add;
#ifdef FAST_FP16_AVAILABLE
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
VKQ[j0/nwarps][i0/warp_size] *= KQ_max_scale_h2;
}
#else
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
VKQ[j0/nwarps][i0/warp_size].x *= KQ_max_scale;
VKQ[j0/nwarps][i0/warp_size].y *= KQ_max_scale;
}
#endif // FAST_FP16_AVAILABLE
}
constexpr int V_cols_per_iter = kq_stride*kq_nbatch / D;
static_assert(kq_stride % V_cols_per_iter == 0, "bad V_cols_per_iter");
#pragma unroll
for (int k0 = 0; k0 < kq_stride; k0 += V_cols_per_iter) {
#pragma unroll
for (int k1 = 0; k1 < V_cols_per_iter; k1 += nwarps) {
const int k_tile = k1 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
const int i = i0 + threadIdx.x;
const half2 tmp = V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i];
#ifdef FAST_FP16_AVAILABLE
KV_tmp_h2[k_tile*(D/2) + i] = tmp;
#else
KV_tmp_f2[k_tile*(D/2) + i] = __half22float2(tmp);
#endif // FAST_FP16_AVAILABLE
}
}
__syncthreads();
#pragma unroll
for (int k1 = 0; k1 < V_cols_per_iter; ++k1) {
#ifdef FAST_FP16_AVAILABLE
half2 V_k[(D/2)/warp_size];
half2 KQ_k[ncols/nwarps];
#else
float2 V_k[(D/2)/warp_size];
float KQ_k[ncols/nwarps];
#endif // FAST_FP16_AVAILABLE
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
const int i = i0 + threadIdx.x;
#ifdef FAST_FP16_AVAILABLE
V_k[i0/warp_size] = KV_tmp_h2[k1*(D/2) + i];
#else
V_k[i0/warp_size] = KV_tmp_f2[k1*(D/2) + i];
#endif // FAST_FP16_AVAILABLE
}
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
#ifdef FAST_FP16_AVAILABLE
KQ_k[j0/nwarps] = __half2half2(((const half *)KQ[j])[k0 + k1]);
#else
KQ_k[j0/nwarps] = KQ[j][k0 + k1];
#endif // FAST_FP16_AVAILABLE
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
#ifdef FAST_FP16_AVAILABLE
VKQ[j0/nwarps][i0/warp_size] += V_k[i0/warp_size] *KQ_k[j0/nwarps];
#else
VKQ[j0/nwarps][i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[j0/nwarps];
VKQ[j0/nwarps][i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[j0/nwarps];
#endif // FAST_FP16_AVAILABLE
}
}
}
__syncthreads();
}
}
// Attention sink: adjust running max and sum once per head
if (sinksf && blockIdx.y == 0) {
const float sink = sinksf[head];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
float kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
kqmax_new_j = warp_reduce_max<warp_size>(kqmax_new_j);
const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new_j);
kqmax[j0/nwarps] = kqmax_new_j;
const float val = expf(sink - kqmax[j0/nwarps]);
kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
if (threadIdx.x == 0) {
kqsum[j0/nwarps] += val;
}
#ifdef FAST_FP16_AVAILABLE
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
VKQ[j0/nwarps][i0/warp_size] *= KQ_max_scale_h2;
}
#else
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
VKQ[j0/nwarps][i0/warp_size].x *= KQ_max_scale;
VKQ[j0/nwarps][i0/warp_size].y *= KQ_max_scale;
}
#endif // FAST_FP16_AVAILABLE
}
}
float2 * dst2 = (float2 *) dst;
#pragma unroll
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
const int j_VKQ = j_VKQ_0 + threadIdx.y;
if (ic0 + j_VKQ >= ne01) {
return;
}
float kqsum_j = kqsum[j_VKQ_0/nwarps];
kqsum_j = warp_reduce_sum<warp_size>(kqsum_j);
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
#pragma unroll
for (int i00 = 0; i00 < D/2; i00 += warp_size) {
const int i0 = i00 + threadIdx.x;
#ifdef FAST_FP16_AVAILABLE
float2 dst_val = __half22float2(VKQ[j_VKQ_0/nwarps][i0/warp_size]);
#else
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/warp_size];
#endif // FAST_FP16_AVAILABLE
if (gridDim.y == 1) {
dst_val.x /= kqsum_j;
dst_val.y /= kqsum_j;
}
dst2[j_dst_unrolled*(D/2) + i0] = dst_val;
}
if (gridDim.y != 1 && threadIdx.x == 0) {
dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
}
}
#else
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
max_bias, m0, m1, n_head_log2, logit_softcap,
ne00, ne01, ne02, ne03,
nb01, nb02, nb03,
ne10, ne11, ne12, ne13,
nb11, nb12, nb13,
nb21, nb22, nb23,
ne31, ne32, ne33,
nb31, nb32, nb33);
NO_DEVICE_CODE;
#endif // FLASH_ATTN_AVAILABLE
}
template <int D, bool use_logit_softcap>
static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
const int id = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[id].cc;
const int warp_size = 32;
const int nwarps = FATTN_TILE_NTHREADS / warp_size;
constexpr size_t nbytes_shared = 0;
if (Q->ne[1] > 16) {
constexpr int cols_per_block = 32;
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
launch_fattn<D, cols_per_block, 1>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
return;
}
constexpr int cols_per_block = 16;
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
launch_fattn<D, cols_per_block, 1>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
}
template <bool use_logit_softcap>
static void launch_fattn_tile_switch_head_size(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
case 64: {
launch_fattn_tile_switch_ncols< 64, use_logit_softcap>(ctx, dst);
} break;
case 128: {
launch_fattn_tile_switch_ncols<128, use_logit_softcap>(ctx, dst);
} break;
case 256: {
launch_fattn_tile_switch_ncols<256, use_logit_softcap>(ctx, dst);
} break;
default: {
GGML_ABORT("Unsupported head size");
} break;
}
}
void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_switch_head_size<use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_switch_head_size<use_logit_softcap>(ctx, dst);
}
}

View File

@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -1,8 +1,7 @@
#include "common.cuh" #include "common.cuh"
#include "fattn-common.cuh" #include "fattn-common.cuh"
#include "fattn-mma-f16.cuh" #include "fattn-mma-f16.cuh"
#include "fattn-tile-f16.cuh" #include "fattn-tile.cuh"
#include "fattn-tile-f32.cuh"
#include "fattn-vec-f16.cuh" #include "fattn-vec-f16.cuh"
#include "fattn-vec-f32.cuh" #include "fattn-vec-f32.cuh"
#include "fattn-wmma-f16.cuh" #include "fattn-wmma-f16.cuh"
@ -271,8 +270,7 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
// Best FlashAttention kernel for a specific GPU: // Best FlashAttention kernel for a specific GPU:
enum best_fattn_kernel { enum best_fattn_kernel {
BEST_FATTN_KERNEL_NONE = 0, BEST_FATTN_KERNEL_NONE = 0,
BEST_FATTN_KERNEL_TILE_F32 = 200, BEST_FATTN_KERNEL_TILE = 200,
BEST_FATTN_KERNEL_TILE_F16 = 210,
BEST_FATTN_KERNEL_VEC_F32 = 100, BEST_FATTN_KERNEL_VEC_F32 = 100,
BEST_FATTN_KERNEL_VEC_F16 = 110, BEST_FATTN_KERNEL_VEC_F16 = 110,
BEST_FATTN_KERNEL_WMMA_F16 = 300, BEST_FATTN_KERNEL_WMMA_F16 = 300,
@ -411,10 +409,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
} }
// If there is no suitable kernel for tensor cores or small batch sizes, use the generic kernel for large batch sizes: // If there is no suitable kernel for tensor cores or small batch sizes, use the generic kernel for large batch sizes:
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) { return BEST_FATTN_KERNEL_TILE;
return BEST_FATTN_KERNEL_TILE_F16;
}
return BEST_FATTN_KERNEL_TILE_F32;
} }
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@ -422,11 +417,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) { switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) {
case BEST_FATTN_KERNEL_NONE: case BEST_FATTN_KERNEL_NONE:
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
case BEST_FATTN_KERNEL_TILE_F32: case BEST_FATTN_KERNEL_TILE:
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst); ggml_cuda_flash_attn_ext_tile(ctx, dst);
break;
case BEST_FATTN_KERNEL_TILE_F16:
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
break; break;
case BEST_FATTN_KERNEL_VEC_F32: case BEST_FATTN_KERNEL_VEC_F32:
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst); ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);

View File

@ -2,39 +2,39 @@
#include "dequantize.cuh" #include "dequantize.cuh"
#include "convert.cuh" #include "convert.cuh"
#define MAX_GRIDDIM_Y 65535
template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t> template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static __global__ void k_get_rows( static __global__ void k_get_rows(
const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst, const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/ const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
/*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/ /*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/
/*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) { for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) {
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) {
const int i10 = blockIdx.x; // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
const int i11 = blockIdx.z / ne12; const int i10 = blockIdx.x;
const int i12 = blockIdx.z % ne12; const int i11 = z / ne12; // TODO fastdiv
const int i12 = z % ne12;
const int i01 = src1[i10*s10 + i11*s11 + i12*s12]; const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3; dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03; const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03;
const int ib = i00/qk; // block index const int ib = i00/qk; // block index
const int iqs = (i00%qk)/qr; // quant index const int iqs = (i00%qk)/qr; // quant index
const int iybs = i00 - i00%qk; // dst block start index const int iybs = i00 - i00%qk; // dst block start index
const int y_offset = qr == 1 ? 1 : qk/2; const int y_offset = qr == 1 ? 1 : qk/2;
// dequantize // dequantize
float2 v; float2 v;
dequantize_kernel(src0_row, ib, iqs, v); dequantize_kernel(src0_row, ib, iqs, v);
dst_row[iybs + iqs + 0] = ggml_cuda_cast<dst_t>(v.x); dst_row[iybs + iqs + 0] = ggml_cuda_cast<dst_t>(v.x);
dst_row[iybs + iqs + y_offset] = ggml_cuda_cast<dst_t>(v.y); dst_row[iybs + iqs + y_offset] = ggml_cuda_cast<dst_t>(v.y);
}
} }
} }
@ -42,27 +42,29 @@ template<typename src0_t, typename dst_t>
static __global__ void k_get_rows_float( static __global__ void k_get_rows_float(
const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst, const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/ const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
/*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/ /*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/
/*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3, /*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) { for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) {
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) {
const int i10 = blockIdx.x; // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
const int i11 = blockIdx.z / ne12; const int i10 = blockIdx.x;
const int i12 = blockIdx.z % ne12; const int i11 = z / ne12; // TODO fastdiv
const int i12 = z % ne12;
if (i00 >= ne00) { if (i00 >= ne00) {
return; return;
}
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
dst_row[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
} }
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
dst_row[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
} }
} }
@ -98,7 +100,7 @@ static void get_rows_cuda_q(
cudaStream_t stream) { cudaStream_t stream) {
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE); const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
const dim3 block_nums(ne10, MIN(block_num_y, MAX_GRIDDIM_Y), ne11*ne12); const dim3 block_nums(ne10, MIN(block_num_y, UINT16_MAX), MIN(ne11*ne12, UINT16_MAX));
// strides in elements // strides in elements
// const size_t s0 = nb0 / sizeof(dst_t); // const size_t s0 = nb0 / sizeof(dst_t);
@ -116,7 +118,7 @@ static void get_rows_cuda_q(
k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>( k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
src0_d, src1_d, dst_d, src0_d, src1_d, dst_d,
ne00, /*ne01, ne02, ne03,*/ ne00, /*ne01, ne02, ne03,*/
/*ne10, ne11,*/ ne12, /*ne13,*/ /*ne10,*/ ne11, ne12, /*ne13,*/
/* s0,*/ s1, s2, s3, /* s0,*/ s1, s2, s3,
/* nb00,*/ nb01, nb02, nb03, /* nb00,*/ nb01, nb02, nb03,
s10, s11, s12/*, s13*/); s10, s11, s12/*, s13*/);
@ -131,7 +133,7 @@ static void get_rows_cuda_float(
cudaStream_t stream) { cudaStream_t stream) {
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE; const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
const dim3 block_nums(ne10, MIN(block_num_y, MAX_GRIDDIM_Y), ne11*ne12); const dim3 block_nums(ne10, MIN(block_num_y, UINT16_MAX), MIN(ne11*ne12, UINT16_MAX));
// strides in elements // strides in elements
// const size_t s0 = nb0 / sizeof(dst_t); // const size_t s0 = nb0 / sizeof(dst_t);
@ -147,7 +149,7 @@ static void get_rows_cuda_float(
k_get_rows_float<<<block_nums, block_dims, 0, stream>>>( k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
src0_d, src1_d, dst_d, src0_d, src1_d, dst_d,
ne00, /*ne01, ne02, ne03,*/ ne00, /*ne01, ne02, ne03,*/
/*ne10, ne11,*/ ne12, /*ne13,*/ /*ne10,*/ ne11, ne12, /*ne13,*/
/* s0,*/ s1, s2, s3, /* s0,*/ s1, s2, s3,
/* nb00,*/ nb01, nb02, nb03, /* nb00,*/ nb01, nb02, nb03,
s10, s11, s12/*, s13*/); s10, s11, s12/*, s13*/);

View File

@ -2109,6 +2109,11 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst); ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst);
return; return;
} }
if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2])) {
ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst);
return;
}
} }
cudaStream_t stream = ctx.stream(); cudaStream_t stream = ctx.stream();
@ -3135,6 +3140,7 @@ static const ggml_backend_i ggml_backend_cuda_interface = {
/* .graph_compute = */ ggml_backend_cuda_graph_compute, /* .graph_compute = */ ggml_backend_cuda_graph_compute,
/* .event_record = */ ggml_backend_cuda_event_record, /* .event_record = */ ggml_backend_cuda_event_record,
/* .event_wait = */ ggml_backend_cuda_event_wait, /* .event_wait = */ ggml_backend_cuda_event_wait,
/* .optimize_graph = */ NULL,
}; };
static ggml_guid_t ggml_backend_cuda_guid() { static ggml_guid_t ggml_backend_cuda_guid() {
@ -3204,6 +3210,7 @@ struct ggml_backend_cuda_device_context {
int device; int device;
std::string name; std::string name;
std::string description; std::string description;
std::string pci_bus_id;
}; };
static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) {
@ -3228,9 +3235,12 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend
} }
static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) {
ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context;
props->name = ggml_backend_cuda_device_get_name(dev); props->name = ggml_backend_cuda_device_get_name(dev);
props->description = ggml_backend_cuda_device_get_description(dev); props->description = ggml_backend_cuda_device_get_description(dev);
props->type = ggml_backend_cuda_device_get_type(dev); props->type = ggml_backend_cuda_device_get_type(dev);
props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str();
ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total); ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total);
bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr; bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr;
@ -3461,6 +3471,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) { if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
return true; return true;
} }
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32) {
return true;
}
if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) {
return true;
}
if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) { if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
return true; return true;
} }
@ -3574,9 +3590,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN: case GGML_OP_MEAN:
case GGML_OP_GROUP_NORM: case GGML_OP_GROUP_NORM:
case GGML_OP_PAD:
return ggml_is_contiguous(op->src[0]); return ggml_is_contiguous(op->src[0]);
case GGML_OP_UPSCALE: case GGML_OP_UPSCALE:
case GGML_OP_PAD:
case GGML_OP_PAD_REFLECT_1D: case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_ARANGE: case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_TIMESTEP_EMBEDDING:
@ -3792,6 +3808,10 @@ ggml_backend_reg_t ggml_backend_cuda_reg() {
CUDA_CHECK(cudaGetDeviceProperties(&prop, i)); CUDA_CHECK(cudaGetDeviceProperties(&prop, i));
dev_ctx->description = prop.name; dev_ctx->description = prop.name;
char pci_bus_id[16] = {};
snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID);
dev_ctx->pci_bus_id = pci_bus_id;
ggml_backend_dev_t dev = new ggml_backend_device { ggml_backend_dev_t dev = new ggml_backend_device {
/* .iface = */ ggml_backend_cuda_device_interface, /* .iface = */ ggml_backend_cuda_device_interface,
/* .reg = */ &reg, /* .reg = */ &reg,

View File

@ -1,3 +1,4 @@
#pragma once
// This file contains primitives that expose the tensor core PTX instructions for CUDA code. // This file contains primitives that expose the tensor core PTX instructions for CUDA code.
// The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout. // The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout.
// The documentation for the PTX instructions can be found under: // The documentation for the PTX instructions can be found under:

View File

@ -1,343 +1,12 @@
#include "ggml.h" #include "ggml.h"
#include "common.cuh"
#include "mma.cuh"
#include "mmf.cuh" #include "mmf.cuh"
using namespace ggml_cuda_mma;
#define MMF_ROWS_PER_BLOCK 32
template <typename T, int rows_per_block, int cols_per_block, int nwarps>
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
static __global__ void mul_mat_f(
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
const int ncols, const int nchannels_y, const int stride_row, const int stride_col_y, const int stride_col_dst,
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
typedef tile<16, 8, T> tile_A;
typedef tile< 8, 8, T> tile_B;
typedef tile<16, 8, float> tile_C;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int tile_k_padded = warp_size + 4;
constexpr int ntA = rows_per_block / tile_A::I;
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
const int row0 = blockIdx.x * rows_per_block;
const int channel_dst = blockIdx.y;
const int channel_x = channel_dst / channel_ratio;
const int channel_y = channel_dst;
const int sample_dst = blockIdx.z;
const int sample_x = sample_dst / sample_ratio;
const int sample_y = sample_dst;
x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ;
y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
const float2 * y2 = (const float2 *) y;
extern __shared__ char data_mmv[];
tile_C C[ntA][ntB];
T * tile_xy = (T *) data_mmv + threadIdx.y*(tile_A::I * tile_k_padded);
for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
tile_A A[ntA][warp_size / tile_A::J];
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
#pragma unroll
for (int i = 0; i < tile_A::I; ++i) {
tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
}
#pragma unroll
for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
}
}
#pragma unroll
for (int itB = 0; itB < ntB; ++itB) {
if constexpr (std::is_same_v<T, float>) {
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
const int j = j0 + itB*tile_B::I;
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
}
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
const int j = j0 + itB*tile_B::I;
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
}
} else {
static_assert(std::is_same_v<T, void>, "unsupported type");
}
#pragma unroll
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
tile_B B;
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
mma(C[itA][itB], A[itA][k0/tile_B::J], B);
}
}
}
}
float * buf_iw = (float *) data_mmv;
constexpr int kiw = nwarps*rows_per_block + 4;
if (nwarps > 1) {
__syncthreads();
}
#pragma unroll
for (int itB = 0; itB < ntB; ++itB) {
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
#pragma unroll
for (int l = 0; l < tile_C::ne; ++l) {
const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
const int j = itB*tile_C::J + tile_C::get_j(l);
buf_iw[j*kiw + i] = C[itA][itB].x[l];
}
}
}
if (nwarps > 1) {
__syncthreads();
}
#pragma unroll
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
const int j = j0 + threadIdx.y;
if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
return;
}
float sum = 0.0f;
static_assert(rows_per_block == warp_size, "need loop/check");
#pragma unroll
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
const int i = i0 + threadIdx.x;
sum += buf_iw[j*kiw + i];
}
dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
}
#else
GGML_UNUSED_VARS(x, y, ids, dst,
ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
NO_DEVICE_CODE;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
}
template <typename T, int cols_per_block>
static void mul_mat_f_cuda(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nrows_x,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) {
typedef tile<16, 8, T> tile_A;
typedef tile< 8, 8, T> tile_B;
GGML_ASSERT(!ids && "mul_mat_id not implemented");
GGML_ASSERT(ncols_x % 2 == 0);
GGML_ASSERT(stride_row % 2 == 0);
GGML_ASSERT(stride_col_y % 2 == 0);
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
const int64_t channel_ratio = nchannels_dst / nchannels_x;
const int64_t sample_ratio = nsamples_dst / nsamples_x;
const int device = ggml_cuda_get_device();
const int warp_size = ggml_cuda_info().devices[device].warp_size;
int64_t nwarps_best = 1;
int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2);
int64_t max_block_size = 256;
for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
if (niter < niter_best) {
niter_best = niter;
nwarps_best = nwarps;
}
}
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4;
const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
const dim3 block_nums(nrows_x/rows_per_block, nchannels_dst, nsamples_dst);
const dim3 block_dims(warp_size, nwarps_best, 1);
switch (nwarps_best) {
case 1: {
mul_mat_f<T, rows_per_block, cols_per_block, 1><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 2: {
mul_mat_f<T, rows_per_block, cols_per_block, 2><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 3: {
mul_mat_f<T, rows_per_block, cols_per_block, 3><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 4: {
mul_mat_f<T, rows_per_block, cols_per_block, 4><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 5: {
mul_mat_f<T, rows_per_block, cols_per_block, 5><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 6: {
mul_mat_f<T, rows_per_block, cols_per_block, 6><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 7: {
mul_mat_f<T, rows_per_block, cols_per_block, 7><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 8: {
mul_mat_f<T, rows_per_block, cols_per_block, 8><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
default: {
GGML_ABORT("fatal error");
} break;
}
}
template <typename T>
static void mul_mat_f_switch_cols_per_block(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) {
switch (ncols_dst) {
case 1: {
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 2: {
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 3: {
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 4: {
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 5: {
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 6: {
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 7: {
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 8: {
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 9: {
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 10: {
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 11: {
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 12: {
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 13: {
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 14: {
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 15: {
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 16: {
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
default: {
GGML_ABORT("fatal error");
} break;
}
}
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) { void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
GGML_ASSERT( src1->type == GGML_TYPE_F32); GGML_ASSERT( src1->type == GGML_TYPE_F32);
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32); GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
GGML_ASSERT( dst->type == GGML_TYPE_F32); GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_TENSOR_BINARY_OP_LOCALS; GGML_TENSOR_BINARY_OP_LOCALS;
const size_t ts_src0 = ggml_type_size(src0->type); const size_t ts_src0 = ggml_type_size(src0->type);
@ -365,55 +34,72 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
const int64_t s13 = src1->nb[3] / ts_src1; const int64_t s13 = src1->nb[3] / ts_src1;
const int64_t s3 = dst->nb[3] / ts_dst; const int64_t s3 = dst->nb[3] / ts_dst;
const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0;
const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
// For MUL_MAT_ID the memory layout is different than for MUL_MAT: // For MUL_MAT_ID the memory layout is different than for MUL_MAT:
const int64_t ncols_dst = ids ? ne2 : ne1; const int64_t ncols_dst = ids ? ne2 : ne1;
const int64_t nchannels_y = ids ? ne11 : ne12; const int64_t nchannels_dst = ids ? ne1 : ne2;
const int64_t nchannels_dst = ids ? ne1 : ne2;
const int64_t stride_channel_dst = ids ? s1 : s2;
const int64_t stride_channel_y = ids ? s11 : s12;
GGML_ASSERT(!ids || ncols_dst == 1); const int64_t stride_col_dst = ids ? s2 : s1;
const int64_t stride_col_y = ids ? s12 : s11;
const int64_t stride_channel_dst = ids ? s1 : s2;
int64_t stride_channel_y = ids ? s11 : s12;
int64_t nchannels_y = ids ? ne11 : ne12;
//mul_mat_id: handle broadcast
if (ids && nchannels_y == 1) {
stride_channel_y = 0;
nchannels_y = ids->ne[0];
}
switch (src0->type) { switch (src0->type) {
case GGML_TYPE_F32: { case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0->data; const float * src0_d = (const float *) src0->data;
constexpr int vals_per_T = 1; constexpr int vals_per_T = 1;
mul_mat_f_switch_cols_per_block( mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
} break; } break;
case GGML_TYPE_F16: { case GGML_TYPE_F16: {
const half2 * src0_d = (const half2 *) src0->data; const half2 * src0_d = (const half2 *) src0->data;
constexpr int vals_per_T = 2; constexpr int vals_per_T = 2;
mul_mat_f_switch_cols_per_block( mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
} break; } break;
case GGML_TYPE_BF16: { case GGML_TYPE_BF16: {
const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data; const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
constexpr int vals_per_T = 2; constexpr int vals_per_T = 2;
mul_mat_f_switch_cols_per_block( mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1, src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst, ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream()); ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
} break; } break;
default: default:
GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type)); GGML_ABORT("unsupported type: %s", ggml_type_name(src0->type));
} }
} }
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, int64_t ne11) { bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, const int src1_ncols) {
if (ggml_is_quantized(type)) {
return false;
}
if (src0_ne[0] % (warp_size * (4/ggml_type_size(type))) != 0) { if (src0_ne[0] % (warp_size * (4/ggml_type_size(type))) != 0) {
return false; return false;
} }
if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) { if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
return false; return false;
} }
if (ne11 > 16) { if (src1_ncols > 16) {
return false; return false;
} }
switch (type) { switch (type) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
return ampere_mma_available(cc); return ampere_mma_available(cc);

View File

@ -1,5 +1,473 @@
#pragma once
#include "mma.cuh"
#include "common.cuh" #include "common.cuh"
using namespace ggml_cuda_mma;
#define MMF_ROWS_PER_BLOCK 32
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst); void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, int64_t ne11); bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols);
template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids>
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
static __global__ void mul_mat_f(
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
const int ncols, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
const int stride_col_id, const int stride_row_id,
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
typedef tile<16, 8, T> tile_A;
typedef tile< 8, 8, T> tile_B;
typedef tile<16, 8, float> tile_C;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int tile_k_padded = warp_size + 4;
constexpr int ntA = rows_per_block / tile_A::I;
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
const int row0 = blockIdx.x * rows_per_block;
const int expert_idx = has_ids ? blockIdx.y : 0;
const int channel_dst = has_ids ? 0 : blockIdx.y;
const int channel_x = has_ids ? expert_idx : (channel_dst / channel_ratio);
const int channel_y = channel_dst;
const int sample_dst = blockIdx.z;
const int sample_x = sample_dst / sample_ratio;
const int sample_y = sample_dst;
x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ;
y += int64_t(sample_y) *stride_sample_y + (has_ids ? 0 : channel_y *stride_channel_y);
dst += int64_t(sample_dst)*stride_sample_dst + (has_ids ? 0 : channel_dst*stride_channel_dst);
const float2 * y2 = (const float2 *) y;
extern __shared__ char data_mmv[];
char * shmem_base = data_mmv;
int * slot_map = (int *) shmem_base;
char * compute_base = has_ids ? (shmem_base + GGML_PAD(cols_per_block, 16) * sizeof(int)) : shmem_base;
tile_C C[ntA][ntB];
T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
if constexpr (has_ids) {
__shared__ int has_any;
if (threadIdx.y == 0) {
int local_has_any = 0;
for (int j = threadIdx.x; j < cols_per_block; j += warp_size) {
int slot = -1;
for (int k = 0; k < nchannels_dst; ++k) {
const int idv = ids[j*stride_row_id + k*stride_col_id];
if (idv == expert_idx) {
slot = k;
break;
}
}
if (j < cols_per_block) {
local_has_any |= (slot >= 0);
slot_map[j] = slot;
}
}
has_any = warp_reduce_any(local_has_any);
}
__syncthreads();
if (has_any == 0) {
return;
}
}
for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
tile_A A[ntA][warp_size / tile_A::J];
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
#pragma unroll
for (int i = 0; i < tile_A::I; ++i) {
tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
}
#pragma unroll
for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
}
}
#pragma unroll
for (int itB = 0; itB < ntB; ++itB) {
if constexpr (std::is_same_v<T, float>) {
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
const int j = j0 + itB*tile_B::I;
if constexpr (!has_ids) {
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
} else {
float val = 0.0f;
if (j < cols_per_block) {
const int slot = slot_map[j];
if (slot >= 0) {
val = y[slot*stride_channel_y + j*stride_col_y + col];
}
}
tile_xy[j0*tile_k_padded + threadIdx.x] = val;
}
}
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
const int j = j0 + itB*tile_B::I;
if constexpr (!has_ids) {
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
} else {
float2 tmp = make_float2(0.0f, 0.0f);
if (j < cols_per_block) {
const int slot = slot_map[j];
if (slot >= 0) {
const float2 * y2_slot = (const float2 *)(y + slot*stride_channel_y);
tmp = y2_slot[j*stride_col_y + col];
}
}
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
}
}
} else {
static_assert(std::is_same_v<T, void>, "unsupported type");
}
#pragma unroll
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
tile_B B;
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
mma(C[itA][itB], A[itA][k0/tile_B::J], B);
}
}
}
}
float * buf_iw = (float *) compute_base;
constexpr int kiw = nwarps*rows_per_block + 4;
if (nwarps > 1) {
__syncthreads();
}
#pragma unroll
for (int itB = 0; itB < ntB; ++itB) {
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
#pragma unroll
for (int l = 0; l < tile_C::ne; ++l) {
const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
const int j = itB*tile_C::J + tile_C::get_j(l);
buf_iw[j*kiw + i] = C[itA][itB].x[l];
}
}
}
if (nwarps > 1) {
__syncthreads();
}
#pragma unroll
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
const int j = j0 + threadIdx.y;
if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
return;
}
float sum = 0.0f;
static_assert(rows_per_block == warp_size, "need loop/check");
#pragma unroll
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
const int i = i0 + threadIdx.x;
sum += buf_iw[j*kiw + i];
}
if constexpr (!has_ids) {
dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
} else {
const int slot = (j < cols_per_block) ? slot_map[j] : -1;
if (slot >= 0) {
dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum;
}
}
}
#else
GGML_UNUSED_VARS(x, y, ids, dst,
ncols, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
NO_DEVICE_CODE;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
}
template<typename T, int cols_per_block, int nwarps>
static inline void mul_mat_f_switch_ids(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nchannels_dst,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t stride_col_id, const int64_t stride_row_id,
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
if (ids) {
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} else {
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
}
}
template <typename T, int cols_per_block>
void mul_mat_f_cuda(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t stride_col_id, const int64_t stride_row_id,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) {
typedef tile<16, 8, T> tile_A;
typedef tile< 8, 8, T> tile_B;
GGML_ASSERT(ncols_x % 2 == 0);
GGML_ASSERT(stride_row % 2 == 0);
GGML_ASSERT(stride_col_y % 2 == 0);
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
const int64_t channel_ratio = nchannels_dst / nchannels_x;
const int64_t sample_ratio = nsamples_dst / nsamples_x;
const int device = ggml_cuda_get_device();
const int warp_size = ggml_cuda_info().devices[device].warp_size;
int64_t nwarps_best = 1;
int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2);
int64_t max_block_size = 256;
for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
if (niter < niter_best) {
niter_best = niter;
nwarps_best = nwarps;
}
}
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4;
const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
const int64_t grid_y = ids ? nchannels_x : nchannels_dst; // per expert when ids present
const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst);
const dim3 block_dims(warp_size, nwarps_best, 1);
switch (nwarps_best) {
case 1: {
mul_mat_f_switch_ids<T, cols_per_block, 1>(
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
} break;
case 2: {
mul_mat_f_switch_ids<T, cols_per_block, 2>(
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
} break;
case 3: {
mul_mat_f_switch_ids<T, cols_per_block, 3>(
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
} break;
case 4: {
mul_mat_f_switch_ids<T, cols_per_block, 4>(
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
} break;
case 5: {
mul_mat_f_switch_ids<T, cols_per_block, 5>(
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
} break;
case 6: {
mul_mat_f_switch_ids<T, cols_per_block, 6>(
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
} break;
case 7: {
mul_mat_f_switch_ids<T, cols_per_block, 7>(
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
} break;
case 8: {
mul_mat_f_switch_ids<T, cols_per_block, 8>(
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
} break;
default: {
GGML_ABORT("fatal error");
} break;
}
GGML_UNUSED_VARS(nchannels_y);
}
template <typename T>
static void mul_mat_f_switch_cols_per_block(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t stride_col_id, const int stride_row_id,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) {
switch (ncols_dst) {
case 1: {
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 2: {
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 3: {
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 4: {
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 5: {
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 6: {
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 7: {
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 8: {
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 9: {
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 10: {
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 11: {
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 12: {
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 13: {
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 14: {
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 15: {
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 16: {
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
default: {
GGML_ABORT("fatal error");
} break;
}
}
#define DECL_MMF_CASE_HELPER(T, ncols_dst) \
template void mul_mat_f_cuda<T, ncols_dst>( \
const T * x, const float * y, const int32_t * ids, float * dst, \
const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
const int64_t stride_col_id, const int64_t stride_row_id, \
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \
cudaStream_t stream);
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
#define DECL_MMF_CASE_EXTERN(ncols_dst) \
extern DECL_MMF_CASE_HELPER(float, ncols_dst) \
extern DECL_MMF_CASE_HELPER(half2, ncols_dst) \
extern DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
#define DECL_MMF_CASE(ncols_dst) \
DECL_MMF_CASE_HELPER(float, ncols_dst) \
DECL_MMF_CASE_HELPER(half2, ncols_dst) \
DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
DECL_MMF_CASE_EXTERN(1);
DECL_MMF_CASE_EXTERN(2);
DECL_MMF_CASE_EXTERN(3);
DECL_MMF_CASE_EXTERN(4);
DECL_MMF_CASE_EXTERN(5);
DECL_MMF_CASE_EXTERN(6);
DECL_MMF_CASE_EXTERN(7);
DECL_MMF_CASE_EXTERN(8);
DECL_MMF_CASE_EXTERN(9);
DECL_MMF_CASE_EXTERN(10);
DECL_MMF_CASE_EXTERN(11);
DECL_MMF_CASE_EXTERN(12);
DECL_MMF_CASE_EXTERN(13);
DECL_MMF_CASE_EXTERN(14);
DECL_MMF_CASE_EXTERN(15);
DECL_MMF_CASE_EXTERN(16);
#else
#define DECL_MMF_CASE(ncols_dst)
#endif

View File

@ -141,9 +141,10 @@ template <ggml_type type, int ncols_dst>
__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) __launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mul_mat_vec_q( static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst, const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst,
const int ncols_x, const int nchannels_y, const int stride_row_x, const int stride_col_y, const int stride_col_dst, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) {
constexpr int qk = ggml_cuda_type_traits<type>::qk; constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi; constexpr int qi = ggml_cuda_type_traits<type>::qi;
@ -161,12 +162,12 @@ static __global__ void mul_mat_vec_q(
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi; constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
// The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1. // The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1.
const int channel_dst = blockIdx.y; const uint32_t channel_dst = blockIdx.y;
const int channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : channel_dst / channel_ratio; const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
const int channel_y = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst; const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
const int sample_dst = blockIdx.z; const uint32_t sample_dst = blockIdx.z;
const int sample_x = sample_dst / sample_ratio; const uint32_t sample_x = fastdiv(sample_dst, sample_ratio);
const int sample_y = sample_dst; const uint32_t sample_y = sample_dst;
// partial sum for each thread // partial sum for each thread
float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}}; float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}};
@ -247,8 +248,9 @@ static void mul_mat_vec_q_switch_ncols_dst(
GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE); GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE);
const int channel_ratio = nchannels_dst / nchannels_x; const uint3 nchannels_y_fd = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0);
const int sample_ratio = nsamples_dst / nsamples_x; const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
const int device = ggml_cuda_get_device(); const int device = ggml_cuda_get_device();
const int warp_size = ggml_cuda_info().devices[device].warp_size; const int warp_size = ggml_cuda_info().devices[device].warp_size;
@ -256,86 +258,70 @@ static void mul_mat_vec_q_switch_ncols_dst(
GGML_ASSERT(!ids || ncols_dst == 1); GGML_ASSERT(!ids || ncols_dst == 1);
switch (ncols_dst) { switch (ncols_dst) {
case 1: case 1: {
{
constexpr int c_ncols_dst = 1; constexpr int c_ncols_dst = 1;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
break; } break;
} case 2: {
case 2:
{
constexpr int c_ncols_dst = 2; constexpr int c_ncols_dst = 2;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
break; } break;
} case 3: {
case 3:
{
constexpr int c_ncols_dst = 3; constexpr int c_ncols_dst = 3;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
break; } break;
} case 4: {
case 4:
{
constexpr int c_ncols_dst = 4; constexpr int c_ncols_dst = 4;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
break; } break;
} case 5: {
case 5:
{
constexpr int c_ncols_dst = 5; constexpr int c_ncols_dst = 5;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
break; } break;
} case 6: {
case 6:
{
constexpr int c_ncols_dst = 6; constexpr int c_ncols_dst = 6;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
break; } break;
} case 7: {
case 7:
{
constexpr int c_ncols_dst = 7; constexpr int c_ncols_dst = 7;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
break; } break;
} case 8: {
case 8:
{
constexpr int c_ncols_dst = 8; constexpr int c_ncols_dst = 8;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
break; } break;
}
default: default:
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
break; break;

View File

@ -1,26 +1,27 @@
#include "quantize.cuh" #include "quantize.cuh"
#include <cstdint> #include <cstdint>
__launch_bounds__(CUDA_QUANTIZE_BLOCK_SIZE, 1)
static __global__ void quantize_q8_1( static __global__ void quantize_q8_1(
const float * __restrict__ x, void * __restrict__ vy, const float * __restrict__ x, void * __restrict__ vy,
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t ne0, const int ne1, const int ne2) { const int64_t ne0, const uint32_t ne1, const uint3 ne2) {
const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
if (i0 >= ne0) { if (i0 >= ne0) {
return; return;
} }
const int64_t i3 = fastdiv(blockIdx.z, ne2);
const int64_t i2 = blockIdx.z - i3*ne2.z;
const int64_t i1 = blockIdx.y; const int64_t i1 = blockIdx.y;
const int64_t i2 = blockIdx.z % ne2;
const int64_t i3 = blockIdx.z / ne2;
const int64_t & i00 = i0; const int64_t & i00 = i0;
const int64_t & i01 = i1; const int64_t & i01 = i1;
const int64_t & i02 = i2; const int64_t & i02 = i2;
const int64_t & i03 = i3; const int64_t & i03 = i3;
const int64_t i_cont = ((i3*ne2 + i2) * ne1 + i1) * ne0 + i0; const int64_t i_cont = ((i3*ne2.z + i2) * ne1 + i1) * ne0 + i0;
block_q8_1 * y = (block_q8_1 *) vy; block_q8_1 * y = (block_q8_1 *) vy;
@ -31,10 +32,10 @@ static __global__ void quantize_q8_1(
float amax = fabsf(xi); float amax = fabsf(xi);
float sum = xi; float sum = xi;
amax = warp_reduce_max(amax); amax = warp_reduce_max<QK8_1>(amax);
sum = warp_reduce_sum(sum); sum = warp_reduce_sum<QK8_1>(sum);
const float d = amax / 127; const float d = amax / 127.0f;
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
y[ib].qs[iqs] = q; y[ib].qs[iqs] = q;
@ -43,8 +44,7 @@ static __global__ void quantize_q8_1(
return; return;
} }
reinterpret_cast<half&>(y[ib].ds.x) = d; y[ib].ds = make_half2(d, sum);
reinterpret_cast<half&>(y[ib].ds.y) = sum;
} }
template <mmq_q8_1_ds_layout ds_layout> template <mmq_q8_1_ds_layout ds_layout>
@ -152,10 +152,12 @@ void quantize_row_q8_1_cuda(
GGML_ASSERT(!ids); GGML_ASSERT(!ids);
GGML_ASSERT(ne0 % QK8_1 == 0); GGML_ASSERT(ne0 % QK8_1 == 0);
const uint3 ne2_fastdiv = init_fastdiv_values(ne2);
const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
const dim3 num_blocks(block_num_x, ne1, ne2*ne3); const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2); quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2_fastdiv);
GGML_UNUSED(type_src0); GGML_UNUSED(type_src0);
} }

View File

@ -24,7 +24,7 @@ TYPES_MMQ = [
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0", "GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
"GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K", "GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
"GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S", "GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S",
"GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS" "GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4"
] ]
SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually. SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
@ -34,6 +34,13 @@ SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do
DECL_MMQ_CASE({type}); DECL_MMQ_CASE({type});
""" """
SOURCE_MMF = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE({type});
"""
def get_short_name(long_quant_name): def get_short_name(long_quant_name):
return long_quant_name.replace("GGML_TYPE_", "").lower() return long_quant_name.replace("GGML_TYPE_", "").lower()
@ -76,3 +83,7 @@ for ncols in [8, 16, 32, 64]:
for type in TYPES_MMQ: for type in TYPES_MMQ:
with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f: with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
f.write(SOURCE_MMQ.format(type=type)) f.write(SOURCE_MMQ.format(type=type))
for type in range(1, 17):
with open(f"mmf-instance-ncols_{type}.cu", "w") as f:
f.write(SOURCE_MMF.format(type=type))

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(1);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(10);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(11);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(12);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(13);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(14);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(15);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(16);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(2);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(3);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(4);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(5);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(6);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(7);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(8);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(9);

View File

@ -162,6 +162,14 @@
#define GCN #define GCN
#endif #endif
#if defined(__gfx900__) || defined(__gfx906__)
#define GCN5
#endif
#if defined(__gfx803__)
#define GCN4
#endif
#if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__) #if defined(__gfx908__) || defined(__gfx90a__) || defined(__gfx942__)
#define CDNA // For the entire family #define CDNA // For the entire family
#endif #endif

View File

@ -20,8 +20,8 @@
#define N_R0_Q5_1 4 #define N_R0_Q5_1 4
#define N_SG_Q5_1 2 #define N_SG_Q5_1 2
#define N_R0_Q8_0 4 #define N_R0_Q8_0 2
#define N_SG_Q8_0 2 #define N_SG_Q8_0 4
#define N_R0_MXFP4 2 #define N_R0_MXFP4 2
#define N_SG_MXFP4 2 #define N_SG_MXFP4 2
@ -68,6 +68,11 @@
#define N_R0_IQ4_XS 2 #define N_R0_IQ4_XS 2
#define N_SG_IQ4_XS 2 #define N_SG_IQ4_XS 2
// function constants offsets
#define FC_FLASH_ATTN_EXT 100
#define FC_FLASH_ATTN_EXT_VEC 200
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
// kernel argument structs // kernel argument structs
// //
// - element counters (e.g. ne00) typically use int32_t to reduce register usage // - element counters (e.g. ne00) typically use int32_t to reduce register usage
@ -236,9 +241,11 @@ typedef struct {
int32_t ne11; int32_t ne11;
int32_t ne_12_2; // assume K and V are same shape int32_t ne_12_2; // assume K and V are same shape
int32_t ne_12_3; int32_t ne_12_3;
int32_t ns10;
uint64_t nb11; uint64_t nb11;
uint64_t nb12; uint64_t nb12;
uint64_t nb13; uint64_t nb13;
int32_t ns20;
uint64_t nb21; uint64_t nb21;
uint64_t nb22; uint64_t nb22;
uint64_t nb23; uint64_t nb23;
@ -258,10 +265,43 @@ typedef struct {
float logit_softcap; float logit_softcap;
} ggml_metal_kargs_flash_attn_ext; } ggml_metal_kargs_flash_attn_ext;
typedef struct {
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne11;
int32_t ne_12_2; // assume K and V are same shape
int32_t ne_12_3;
int32_t ns10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ns20;
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
int32_t ne32;
int32_t ne33;
uint64_t nb31;
uint64_t nb32;
uint64_t nb33;
int32_t ne1;
int32_t ne2;
int32_t ne3;
float scale;
float max_bias;
float m0;
float m1;
int32_t n_head_log2;
float logit_softcap;
} ggml_metal_kargs_flash_attn_ext_vec;
typedef struct { typedef struct {
int32_t nrows; int32_t nrows;
int32_t ne20; } ggml_metal_kargs_flash_attn_ext_vec_reduce;
} ggml_metal_kargs_flash_attn_ext_reduce;
typedef struct { typedef struct {
int32_t ne00; int32_t ne00;

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -2838,6 +2838,7 @@ static ggml_backend_i ggml_backend_opencl_i = {
/* .graph_compute = */ ggml_backend_opencl_graph_compute, /* .graph_compute = */ ggml_backend_opencl_graph_compute,
/* .event_record = */ NULL, /* .event_record = */ NULL,
/* .event_wait = */ NULL, /* .event_wait = */ NULL,
/* .optimize_graph = */ NULL,
}; };
ggml_backend_t ggml_backend_opencl_init(void) { ggml_backend_t ggml_backend_opencl_init(void) {

View File

@ -795,6 +795,7 @@ static ggml_backend_i ggml_backend_rpc_interface = {
/* .graph_compute = */ ggml_backend_rpc_graph_compute, /* .graph_compute = */ ggml_backend_rpc_graph_compute,
/* .event_record = */ NULL, /* .event_record = */ NULL,
/* .event_wait = */ NULL, /* .event_wait = */ NULL,
/* .optimize_graph = */ NULL,
}; };
ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) { ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint) {

View File

@ -4063,6 +4063,7 @@ static ggml_backend_i ggml_backend_sycl_interface = {
/* .graph_compute = */ ggml_backend_sycl_graph_compute, /* .graph_compute = */ ggml_backend_sycl_graph_compute,
/* .event_record = */ ggml_backend_sycl_event_record, /* .event_record = */ ggml_backend_sycl_event_record,
/* .event_wait = */ ggml_backend_sycl_event_wait, /* .event_wait = */ ggml_backend_sycl_event_wait,
/* .optimize_graph = */ NULL,
}; };
static ggml_guid_t ggml_backend_sycl_guid() { static ggml_guid_t ggml_backend_sycl_guid() {

View File

@ -506,8 +506,8 @@ struct vk_device_struct {
vk_pipeline pipeline_pad_f32; vk_pipeline pipeline_pad_f32;
vk_pipeline pipeline_roll_f32; vk_pipeline pipeline_roll_f32;
vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32; vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32;
vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16; vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32;
vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16; vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32;
vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT];
vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT];
vk_pipeline pipeline_set_rows[GGML_TYPE_COUNT]; vk_pipeline pipeline_set_rows[GGML_TYPE_COUNT];
@ -554,6 +554,7 @@ struct vk_device_struct {
vk_pipeline pipeline_argmax_f32; vk_pipeline pipeline_argmax_f32;
vk_pipeline pipeline_count_equal_i32; vk_pipeline pipeline_count_equal_i32;
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16;
vk_pipeline pipeline_timestep_embedding_f32; vk_pipeline pipeline_timestep_embedding_f32;
vk_pipeline pipeline_conv_transpose_1d_f32; vk_pipeline pipeline_conv_transpose_1d_f32;
vk_pipeline pipeline_pool2d_f32; vk_pipeline pipeline_pool2d_f32;
@ -582,6 +583,7 @@ struct vk_device_struct {
bool disable_fusion; bool disable_fusion;
bool disable_host_visible_vidmem; bool disable_host_visible_vidmem;
bool allow_sysmem_fallback; bool allow_sysmem_fallback;
bool disable_optimize_graph;
#ifdef GGML_VULKAN_MEMORY_DEBUG #ifdef GGML_VULKAN_MEMORY_DEBUG
std::unique_ptr<vk_memory_logger> memory_logger; std::unique_ptr<vk_memory_logger> memory_logger;
@ -803,6 +805,57 @@ static vk_op_unary_push_constants vk_op_unary_push_constants_init(const ggml_ten
p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize); p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize);
p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize); p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize);
return p; // offsets are initialized later in ggml_vk_op
}
struct vk_op_pad_push_constants {
uint32_t ne;
uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03;
uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13;
uint32_t misalign_offsets;
uint32_t lp0; uint32_t rp0;
uint32_t lp1; uint32_t rp1;
uint32_t lp2; uint32_t rp2;
uint32_t lp3; uint32_t rp3;
};
static vk_op_pad_push_constants vk_op_pad_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst) {
int64_t ne = ggml_nelements(dst);
GGML_ASSERT(ne <= (int64_t)std::numeric_limits<uint32_t>::max());
vk_op_pad_push_constants p{};
p.ne = (uint32_t)ne;
size_t src0_tsize = ggml_type_size(src0->type);
p.ne00 = (uint32_t)src0->ne[0];
p.ne01 = (uint32_t)src0->ne[1];
p.ne02 = (uint32_t)src0->ne[2];
p.ne03 = (uint32_t)src0->ne[3];
p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize);
p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize);
p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize);
p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize);
size_t dst_tsize = ggml_type_size(dst->type);
p.ne10 = (uint32_t)dst->ne[0];
p.ne11 = (uint32_t)dst->ne[1];
p.ne12 = (uint32_t)dst->ne[2];
p.ne13 = (uint32_t)dst->ne[3];
p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize);
p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize);
p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize);
p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize);
p.lp0 = dst->op_params[0];
p.rp0 = dst->op_params[1];
p.lp1 = dst->op_params[2];
p.rp1 = dst->op_params[3];
p.lp2 = dst->op_params[4];
p.rp2 = dst->op_params[5];
p.lp3 = dst->op_params[6];
p.rp3 = dst->op_params[7];
return p; // fastdiv values and offsets are initialized later in ggml_vk_op return p; // fastdiv values and offsets are initialized later in ggml_vk_op
} }
@ -931,6 +984,37 @@ struct vk_op_im2col_push_constants {
int32_t d0; int32_t d1; int32_t d0; int32_t d1;
}; };
struct vk_op_im2col_3d_push_constants {
uint32_t nb10;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t s0;
uint32_t s1;
uint32_t s2;
uint32_t p0;
uint32_t p1;
uint32_t p2;
uint32_t d0;
uint32_t d1;
uint32_t d2;
uint32_t IW;
uint32_t IH;
uint32_t ID;
uint32_t IC;
uint32_t KW;
uint32_t OH;
uint32_t KD_KH_KW;
uint32_t KH_KW;
uint32_t IC_KD_KH_KW;
uint32_t N_OD_OH;
uint32_t OD_OH;
uint32_t OD_OH_OW_IC_KD_KH_KW;
uint32_t OH_OW_IC_KD_KH_KW;
uint32_t OW_IC_KD_KH_KW;
uint32_t misalign_offsets;
};
struct vk_op_timestep_embedding_push_constants { struct vk_op_timestep_embedding_push_constants {
uint32_t nb1; uint32_t nb1;
uint32_t dim; uint32_t dim;
@ -1853,7 +1937,9 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties(); vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties();
for (auto &req_flags : req_flags_list) { for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) {
const auto & req_flags = *it;
uint32_t memory_type_index = find_properties(&mem_props, &mem_req, req_flags); uint32_t memory_type_index = find_properties(&mem_props, &mem_req, req_flags);
if (memory_type_index == UINT32_MAX) { if (memory_type_index == UINT32_MAX) {
@ -1866,6 +1952,11 @@ static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std
break; break;
} catch (const vk::SystemError& e) { } catch (const vk::SystemError& e) {
// loop and retry // loop and retry
// during last attempt throw the exception
if (it + 1 == req_flags_list.end()) {
device->device.destroyBuffer(buf->buffer);
throw e;
}
} }
} }
@ -3143,12 +3234,16 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_i32_f32, "cpy_i32_f32", cpy_i32_f32_len, cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_i32, "cpy_f32_i32", cpy_f32_i32_len, cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_i32_f32, "contig_cpy_i32_f32", contig_cpy_i32_f32_len, contig_cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_i32, "contig_cpy_f32_i32", contig_cpy_f32_i32_len, contig_cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
if (device->float_controls_rte_fp16) { if (device->float_controls_rte_fp16) {
ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1);
@ -3250,7 +3345,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
@ -3299,7 +3394,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1);
ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1);
@ -3329,10 +3424,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1); ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32_len, im2col_f32_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32_len, im2col_3d_f32_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
if (device->float_controls_rte_fp16) { if (device->float_controls_rte_fp16) {
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte_len, im2col_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte_len, im2col_3d_f32_f16_rte_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
} else { } else {
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_len, im2col_f32_f16_data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_len, im2col_3d_f32_f16_data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true);
} }
ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1);
@ -3502,6 +3600,9 @@ static vk_device ggml_vk_get_device(size_t idx) {
const char* GGML_VK_ALLOW_SYSMEM_FALLBACK = getenv("GGML_VK_ALLOW_SYSMEM_FALLBACK"); const char* GGML_VK_ALLOW_SYSMEM_FALLBACK = getenv("GGML_VK_ALLOW_SYSMEM_FALLBACK");
device->allow_sysmem_fallback = GGML_VK_ALLOW_SYSMEM_FALLBACK != nullptr; device->allow_sysmem_fallback = GGML_VK_ALLOW_SYSMEM_FALLBACK != nullptr;
const char* GGML_VK_DISABLE_OPTIMIZE_GRAPH = getenv("GGML_VK_DISABLE_OPTIMIZE_GRAPH");
device->disable_optimize_graph = GGML_VK_DISABLE_OPTIMIZE_GRAPH != nullptr;
bool fp16_storage = false; bool fp16_storage = false;
bool fp16_compute = false; bool fp16_compute = false;
bool maintenance4_support = false; bool maintenance4_support = false;
@ -3642,6 +3743,12 @@ static vk_device ggml_vk_get_device(size_t idx) {
device->subgroup_arithmetic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && device->subgroup_arithmetic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic); (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic);
#ifdef __APPLE__
// Workaround for subgroup arithmetic failing on MoltenVK with AMD GPUs (issue 15846)
if (device->vendor_id == VK_VENDOR_ID_AMD) {
device->subgroup_arithmetic = false;
}
#endif
device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
(vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle); (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle);
device->subgroup_clustered = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && device->subgroup_clustered = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) &&
@ -5607,6 +5714,20 @@ static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_cpy_f32_bf16; return ctx->device->pipeline_cpy_f32_bf16;
} }
} }
if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_I32) {
if (contig) {
return ctx->device->pipeline_contig_cpy_f32_i32;
} else {
return ctx->device->pipeline_cpy_f32_i32;
}
}
if (src->type == GGML_TYPE_I32 && to == GGML_TYPE_F32) {
if (contig) {
return ctx->device->pipeline_contig_cpy_i32_f32;
} else {
return ctx->device->pipeline_cpy_i32_f32;
}
}
if (src->type == GGML_TYPE_F32) { if (src->type == GGML_TYPE_F32) {
switch (to) { switch (to) {
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
@ -7666,6 +7787,14 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_im2col_f32_f16; return ctx->device->pipeline_im2col_f32_f16;
} }
return nullptr; return nullptr;
case GGML_OP_IM2COL_3D:
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_im2col_3d_f32;
}
if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) {
return ctx->device->pipeline_im2col_3d_f32_f16;
}
return nullptr;
case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_TIMESTEP_EMBEDDING:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_timestep_embedding_f32; return ctx->device->pipeline_timestep_embedding_f32;
@ -7781,6 +7910,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
case GGML_OP_RMS_NORM: case GGML_OP_RMS_NORM:
case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_2D_DW:
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
case GGML_OP_IM2COL_3D:
case GGML_OP_SET_ROWS: case GGML_OP_SET_ROWS:
case GGML_OP_SUM: case GGML_OP_SUM:
case GGML_OP_SUM_ROWS: case GGML_OP_SUM_ROWS:
@ -7829,6 +7959,26 @@ template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk
GGML_UNUSED(src2); GGML_UNUSED(src2);
} }
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_pad_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
p.misalign_offsets = (a_offset << 16) | d_offset;
GGML_UNUSED(src1);
GGML_UNUSED(src2);
}
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_im2col_3d_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
p.misalign_offsets = (a_offset << 16) | d_offset;
GGML_UNUSED(src0);
GGML_UNUSED(src2);
}
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type); const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type);
@ -8069,6 +8219,26 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
elements = { OW * KW * KH, OH, batch * IC }; elements = { OW * KW * KH, OH, batch * IC };
} break; } break;
case GGML_OP_IM2COL_3D:
{
const uint32_t IC = ((const uint32_t *)(dst->op_params))[9];
const uint32_t N = ne13 / IC;
const uint32_t KD = ne02;
const uint32_t KH = ne01;
const uint32_t KW = ne00;
const uint32_t OD = ned3 / N;
const uint32_t OH = ned2;
const uint32_t OW = ned1;
const uint32_t IC_KD_KH_KW = IC*KD*KH*KW;
const uint32_t N_OD_OH = N*OD*OH;
elements = { IC_KD_KH_KW, OW, N_OD_OH };
elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]);
} break;
case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_TIMESTEP_EMBEDDING:
{ {
const uint32_t dim = dst->op_params[0]; const uint32_t dim = dst->op_params[0];
@ -8225,7 +8395,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
} }
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_IM2COL) { } else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) {
// im2col uses only src1 and dst buffers // im2col uses only src1 and dst buffers
ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements);
} else if (op == GGML_OP_COUNT_EQUAL) { } else if (op == GGML_OP_COUNT_EQUAL) {
@ -8771,7 +8941,7 @@ static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, con
} }
static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst)); vk_op_pad_push_constants p = vk_op_pad_push_constants_init(src0, dst);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun); ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun);
} }
@ -8982,7 +9152,7 @@ static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx,
static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
float * op_params = (float *)dst->op_params; float * op_params = (float *)dst->op_params;
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], op_params[1] }, dryrun); ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }, dryrun);
} }
static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) { static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) {
@ -9086,6 +9256,66 @@ static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, co
}, dryrun); }, dryrun);
} }
static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) {
GGML_TENSOR_BINARY_OP_LOCALS
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
const int32_t IC = ((const int32_t *)(dst->op_params))[9];
const int64_t N = ne13 / IC;
const int64_t ID = ne12;
const int64_t IH = ne11;
const int64_t IW = ne10;
const int64_t KD = ne02;
const int64_t KH = ne01;
const int64_t KW = ne00;
const int64_t OD = ne3 / N;
const int64_t OH = ne2;
const int64_t OW = ne1;
vk_op_im2col_3d_push_constants pc {};
pc.nb10 = nb10 / ggml_type_size(src1->type);
pc.nb11 = nb11 / ggml_type_size(src1->type);
pc.nb12 = nb12 / ggml_type_size(src1->type);
pc.nb13 = nb13 / ggml_type_size(src1->type);
pc.s0 = s0;
pc.s1 = s1;
pc.s2 = s2;
pc.p0 = p0;
pc.p1 = p1;
pc.p2 = p2;
pc.d0 = d0;
pc.d1 = d1;
pc.d2 = d2;
pc.IW = IW;
pc.IH = IH;
pc.ID = ID;
pc.IC = IC;
pc.KW = KW;
pc.OH = OH;
pc.KD_KH_KW = KD*KH*KW;
pc.KH_KW = KH*KW;
pc.IC_KD_KH_KW = IC*KD*KH*KW;
pc.N_OD_OH = N*OD*OH;
pc.OD_OH = OD*OH;
pc.OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW;
pc.OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW;
pc.OW_IC_KD_KH_KW = OW*IC*KD*KH*KW;
ggml_vk_op_f32<vk_op_im2col_3d_push_constants>(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL_3D, std::move(pc), dryrun);
}
static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) {
const uint32_t dim = dst->op_params[0]; const uint32_t dim = dst->op_params[0];
const uint32_t max_period = dst->op_params[1]; const uint32_t max_period = dst->op_params[1];
@ -10291,6 +10521,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_ARGMAX: case GGML_OP_ARGMAX:
case GGML_OP_COUNT_EQUAL: case GGML_OP_COUNT_EQUAL:
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
case GGML_OP_IM2COL_3D:
case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_CONV_TRANSPOSE_1D:
case GGML_OP_POOL_2D: case GGML_OP_POOL_2D:
@ -10361,6 +10592,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_ARGMAX: case GGML_OP_ARGMAX:
case GGML_OP_COUNT_EQUAL: case GGML_OP_COUNT_EQUAL:
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
case GGML_OP_IM2COL_3D:
case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_CONV_TRANSPOSE_1D:
case GGML_OP_POOL_2D: case GGML_OP_POOL_2D:
@ -10656,6 +10888,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun); ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun);
break;
case GGML_OP_IM2COL_3D:
ggml_vk_im2col_3d(ctx, compute_ctx, src0, src1, node, dryrun);
break; break;
case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_TIMESTEP_EMBEDDING:
ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun); ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun);
@ -10807,6 +11043,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_OP_ARGMAX: case GGML_OP_ARGMAX:
case GGML_OP_COUNT_EQUAL: case GGML_OP_COUNT_EQUAL:
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
case GGML_OP_IM2COL_3D:
case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_CONV_TRANSPOSE_1D: case GGML_OP_CONV_TRANSPOSE_1D:
case GGML_OP_POOL_2D: case GGML_OP_POOL_2D:
@ -11633,6 +11870,131 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
UNUSED(backend); UNUSED(backend);
} }
// Sort the graph for improved parallelism.
static void ggml_vk_optimize_graph(ggml_backend_t backend, struct ggml_cgraph * graph)
{
VK_LOG_DEBUG("ggml_vk_optimize_graph(" << graph->n_nodes << " nodes)");
ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context;
if (ctx->device->disable_optimize_graph) {
return;
}
auto const &is_empty = [](ggml_tensor * node) -> bool {
return node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE;
};
auto const &is_src_of = [](const ggml_tensor *dst, const ggml_tensor *src) -> bool {
for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) {
if (dst->src[s] == src) {
return true;
}
}
// implicit dependency if they view the same tensor
const ggml_tensor *dst2 = dst->view_src ? dst->view_src : dst;
const ggml_tensor *src2 = src->view_src ? src->view_src : src;
if (dst2 == src2) {
return true;
}
return false;
};
// This function tries to reorder the graph to allow nodes to run in parallel.
// This helps with small batches, but for large batches its a slowdown, probably
// due to cache contention. So only reorder if the majority of nodes have few rows.
int num_small_nodes = 0;
int num_counted_nodes = 0;
for (int i = 0; i < graph->n_nodes; ++i) {
if (!is_empty(graph->nodes[i]) &&
graph->nodes[i]->op != GGML_OP_SET_ROWS) {
if (ggml_nrows(graph->nodes[i]) <= 8) {
num_small_nodes++;
}
num_counted_nodes++;
}
}
if (num_small_nodes < num_counted_nodes / 2) {
return;
}
std::vector<ggml_tensor *> new_order;
std::vector<bool> used(graph->n_nodes, false);
int first_unused = 0;
while (first_unused < graph->n_nodes) {
std::vector<int> current_set;
// First, grab the next unused node.
current_set.push_back(first_unused);
// Loop through the next N nodes. Grab any that don't depend on other nodes that
// haven't already been run. Nodes that have already been run have used[i] set
// to true. Allow nodes that depend on the previous node if it's a fusion pattern
// that we support (e.g. RMS_NORM + MUL).
// This first pass only grabs "real" (non-view nodes). Second pass grabs view nodes.
// The goal is to not interleave real and view nodes in a way that breaks fusion.
const int NUM_TO_CHECK = 20;
for (int j = first_unused+1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) {
if (used[j]) {
continue;
}
if (is_empty(graph->nodes[j])) {
continue;
}
bool ok = true;
for (int c = first_unused; c < j; ++c) {
if (!used[c] &&
is_src_of(graph->nodes[j], graph->nodes[c]) &&
!(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL)) {
ok = false;
break;
}
}
if (ok) {
current_set.push_back(j);
}
}
// Second pass grabs view nodes.
// Skip this if it would break a fusion optimization (don't split up add->rms_norm or add->add).
if (graph->nodes[current_set.back()]->op != GGML_OP_ADD) {
for (int j = first_unused+1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) {
if (used[j]) {
continue;
}
if (!is_empty(graph->nodes[j])) {
continue;
}
bool ok = true;
for (int c = first_unused; c < j; ++c) {
bool c_in_current_set = std::find(current_set.begin(), current_set.end(), c) != current_set.end();
// skip views whose srcs haven't been processed.
if (!used[c] &&
is_src_of(graph->nodes[j], graph->nodes[c]) &&
!c_in_current_set) {
ok = false;
break;
}
}
if (ok) {
current_set.push_back(j);
}
}
}
// Push the current set into new_order
for (auto c : current_set) {
new_order.push_back(graph->nodes[c]);
used[c] = true;
}
while (first_unused < graph->n_nodes && used[first_unused]) {
first_unused++;
}
}
// Replace the graph with the new order.
for (int i = 0; i < graph->n_nodes; ++i) {
graph->nodes[i] = new_order[i];
}
}
// TODO: enable async and synchronize // TODO: enable async and synchronize
static ggml_backend_i ggml_backend_vk_interface = { static ggml_backend_i ggml_backend_vk_interface = {
/* .get_name = */ ggml_backend_vk_name, /* .get_name = */ ggml_backend_vk_name,
@ -11648,6 +12010,7 @@ static ggml_backend_i ggml_backend_vk_interface = {
/* .graph_compute = */ ggml_backend_vk_graph_compute, /* .graph_compute = */ ggml_backend_vk_graph_compute,
/* .event_record = */ NULL, /* .event_record = */ NULL,
/* .event_wait = */ NULL, /* .event_wait = */ NULL,
/* .optimize_graph = */ ggml_vk_optimize_graph,
}; };
static ggml_guid_t ggml_backend_vk_guid() { static ggml_guid_t ggml_backend_vk_guid() {
@ -11750,6 +12113,7 @@ static ggml_backend_buffer_type_t ggml_backend_vk_device_get_host_buffer_type(gg
static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_dev_t dev) { static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_dev_t dev) {
UNUSED(dev); UNUSED(dev);
// TODO: return GGML_BACKEND_DEVICE_TYPE_IGPU for integrated GPUs
return GGML_BACKEND_DEVICE_TYPE_GPU; return GGML_BACKEND_DEVICE_TYPE_GPU;
} }
@ -11757,6 +12121,7 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml
props->name = ggml_backend_vk_device_get_name(dev); props->name = ggml_backend_vk_device_get_name(dev);
props->description = ggml_backend_vk_device_get_description(dev); props->description = ggml_backend_vk_device_get_description(dev);
props->type = ggml_backend_vk_device_get_type(dev); props->type = ggml_backend_vk_device_get_type(dev);
// TODO: set props->device_id to PCI bus id
ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total);
props->caps = { props->caps = {
/* .async = */ false, /* .async = */ false,
@ -12022,6 +12387,13 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return true; return true;
} }
if (
src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32 ||
src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32
) {
return true;
}
// We can handle copying from a type to the same type if it's // We can handle copying from a type to the same type if it's
// contiguous (memcpy). We use f16 or f32 shaders to do the copy, // contiguous (memcpy). We use f16 or f32 shaders to do the copy,
// so the type/block size must be a multiple of 4. // so the type/block size must be a multiple of 4.
@ -12076,10 +12448,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_ACC: case GGML_OP_ACC:
case GGML_OP_CONCAT: case GGML_OP_CONCAT:
case GGML_OP_SCALE: case GGML_OP_SCALE:
return true;
case GGML_OP_PAD: case GGML_OP_PAD:
return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
case GGML_OP_ROLL: case GGML_OP_ROLL:
case GGML_OP_DIAG_MASK_INF: case GGML_OP_DIAG_MASK_INF:
case GGML_OP_SOFT_MAX: case GGML_OP_SOFT_MAX:
@ -12092,6 +12461,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_ARGMAX: case GGML_OP_ARGMAX:
case GGML_OP_COUNT_EQUAL: case GGML_OP_COUNT_EQUAL:
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
case GGML_OP_IM2COL_3D:
case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_2D_DW:
case GGML_OP_POOL_2D: case GGML_OP_POOL_2D:
@ -12520,7 +12890,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
const float * params = (const float *)tensor->op_params; const float * params = (const float *)tensor->op_params;
tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]); tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
} else if (tensor->op == GGML_OP_PAD) { } else if (tensor->op == GGML_OP_PAD) {
tensor_clone = ggml_pad(ggml_ctx, src_clone[0], tensor->ne[0] - src_clone[0]->ne[0], tensor->ne[1] - src_clone[0]->ne[1], tensor->ne[2] - src_clone[0]->ne[2], tensor->ne[3] - src_clone[0]->ne[3]); tensor_clone = ggml_pad_ext(ggml_ctx, src_clone[0], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3],
tensor->op_params[4], tensor->op_params[5], tensor->op_params[6], tensor->op_params[7]);
} else if (tensor->op == GGML_OP_REPEAT) { } else if (tensor->op == GGML_OP_REPEAT) {
tensor_clone = ggml_repeat(ggml_ctx, src_clone[0], tensor); tensor_clone = ggml_repeat(ggml_ctx, src_clone[0], tensor);
} else if (tensor->op == GGML_OP_REPEAT_BACK) { } else if (tensor->op == GGML_OP_REPEAT_BACK) {
@ -12666,6 +13037,19 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
const bool is_2D = tensor->op_params[6] == 1; const bool is_2D = tensor->op_params[6] == 1;
tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1, is_2D, tensor->type); tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1, is_2D, tensor->type);
} else if (tensor->op == GGML_OP_IM2COL_3D) {
const int32_t s0 = tensor->op_params[0];
const int32_t s1 = tensor->op_params[1];
const int32_t s1 = tensor->op_params[2];
const int32_t p0 = tensor->op_params[3];
const int32_t p1 = tensor->op_params[4];
const int32_t p1 = tensor->op_params[5];
const int32_t d0 = tensor->op_params[6];
const int32_t d1 = tensor->op_params[7];
const int32_t d1 = tensor->op_params[8];
const int32_t IC = tensor->op_params[9];
tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, tensor->type);
} else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) { } else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) {
const int32_t dim = tensor->op_params[0]; const int32_t dim = tensor->op_params[0];
const int32_t max_period = tensor->op_params[1]; const int32_t max_period = tensor->op_params[1];

View File

@ -0,0 +1,112 @@
#version 450
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_control_flow_attributes : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "rte.comp"
layout (push_constant) uniform parameter
{
uint32_t nb10;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t s0;
uint32_t s1;
uint32_t s2;
uint32_t p0;
uint32_t p1;
uint32_t p2;
uint32_t d0;
uint32_t d1;
uint32_t d2;
uint32_t IW;
uint32_t IH;
uint32_t ID;
uint32_t IC;
uint32_t KW;
uint32_t OH;
uint32_t KD_KH_KW;
uint32_t KH_KW;
uint32_t IC_KD_KH_KW;
uint32_t N_OD_OH;
uint32_t OD_OH;
uint32_t OD_OH_OW_IC_KD_KH_KW;
uint32_t OH_OW_IC_KD_KH_KW;
uint32_t OW_IC_KD_KH_KW;
uint32_t misalign_offsets;
} p;
#include "types.comp"
uint get_aoffset() { return p.misalign_offsets >> 16; }
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint32_t i = gl_GlobalInvocationID.x;
uint32_t nb10 = p.nb10;
uint32_t nb11 = p.nb11;
uint32_t nb12 = p.nb12;
uint32_t nb13 = p.nb13;
uint32_t s0 = p.s0;
uint32_t s1 = p.s1;
uint32_t s2 = p.s2;
uint32_t p0 = p.p0;
uint32_t p1 = p.p1;
uint32_t p2 = p.p2;
uint32_t d0 = p.d0;
uint32_t d1 = p.d1;
uint32_t d2 = p.d2;
uint32_t IW = p.IW;
uint32_t IH = p.IH;
uint32_t ID = p.ID;
uint32_t IC = p.IC;
uint32_t KW = p.KW;
uint32_t OH = p.OH;
uint32_t KD_KH_KW = p.KD_KH_KW;
uint32_t KH_KW = p.KH_KW;
uint32_t IC_KD_KH_KW = p.IC_KD_KH_KW;
uint32_t N_OD_OH = p.N_OD_OH;
uint32_t OD_OH = p.OD_OH;
uint32_t OD_OH_OW_IC_KD_KH_KW = p.OD_OH_OW_IC_KD_KH_KW;
uint32_t OH_OW_IC_KD_KH_KW = p.OH_OW_IC_KD_KH_KW;
uint32_t OW_IC_KD_KH_KW = p.OW_IC_KD_KH_KW;
if (i >= IC_KD_KH_KW) {
return;
}
const uint32_t iic = i / KD_KH_KW;
const uint32_t ikd = (i - iic * KD_KH_KW) / KH_KW;
const uint32_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW;
const uint32_t ikw = i % KW;
const uint32_t iow = gl_GlobalInvocationID.y;
for (uint32_t iz = gl_GlobalInvocationID.z; iz < N_OD_OH; iz += gl_NumWorkGroups.z) {
const uint32_t in_ = iz / OD_OH;
const uint32_t iod = (iz - in_*OD_OH) / OH;
const uint32_t ioh = iz % OH;
const uint32_t iiw = iow * s0 + ikw * d0 - p0;
const uint32_t iih = ioh * s1 + ikh * d1 - p1;
const uint32_t iid = iod * s2 + ikd * d2 - p2;
const uint32_t offset_dst = in_*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
if (iih >= IH || iiw >= IW || iid >= ID) {
data_d[offset_dst + get_doffset()] = D_TYPE(0.0f);
} else {
const uint32_t offset_src = (in_*IC + iic)*nb13 + iid*nb12 + iih*nb11 + iiw*nb10;
data_d[offset_dst + get_doffset()] = D_TYPE(data_a[offset_src + get_aoffset()]);
}
}
}

View File

@ -315,21 +315,23 @@ void main() {
#if LOAD_VEC_A == 8 #if LOAD_VEC_A == 8
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx][0].x); A_TYPE32 aa = A_TYPE32(data_a[idx]);
buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx][0].y); buf_a[buf_idx ] = FLOAT_TYPE(aa[0].x);
buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx][0].z); buf_a[buf_idx + 1] = FLOAT_TYPE(aa[0].y);
buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx][0].w); buf_a[buf_idx + 2] = FLOAT_TYPE(aa[0].z);
buf_a[buf_idx + 4] = FLOAT_TYPE(data_a[idx][1].x); buf_a[buf_idx + 3] = FLOAT_TYPE(aa[0].w);
buf_a[buf_idx + 5] = FLOAT_TYPE(data_a[idx][1].y); buf_a[buf_idx + 4] = FLOAT_TYPE(aa[1].x);
buf_a[buf_idx + 6] = FLOAT_TYPE(data_a[idx][1].z); buf_a[buf_idx + 5] = FLOAT_TYPE(aa[1].y);
buf_a[buf_idx + 7] = FLOAT_TYPE(data_a[idx][1].w); buf_a[buf_idx + 6] = FLOAT_TYPE(aa[1].z);
buf_a[buf_idx + 7] = FLOAT_TYPE(aa[1].w);
#elif LOAD_VEC_A == 4 #elif LOAD_VEC_A == 4
const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a; const uint idx = pos_a + (loadc_a + l) * p.stride_a / LOAD_VEC_A + loadr_a;
const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A; const uint buf_idx = (loadc_a + l) * SHMEM_STRIDE + loadr_a * LOAD_VEC_A;
buf_a[buf_idx ] = FLOAT_TYPE(data_a[idx].x); A_TYPE32 aa = A_TYPE32(data_a[idx]);
buf_a[buf_idx + 1] = FLOAT_TYPE(data_a[idx].y); buf_a[buf_idx ] = FLOAT_TYPE(aa.x);
buf_a[buf_idx + 2] = FLOAT_TYPE(data_a[idx].z); buf_a[buf_idx + 1] = FLOAT_TYPE(aa.y);
buf_a[buf_idx + 3] = FLOAT_TYPE(data_a[idx].w); buf_a[buf_idx + 2] = FLOAT_TYPE(aa.z);
buf_a[buf_idx + 3] = FLOAT_TYPE(aa.w);
#else #else
if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) { if (ir * BM + loadc_a + l < p.M && block + loadr_a < end_k) {
buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]); buf_a[(loadc_a + l) * SHMEM_STRIDE + loadr_a] = FLOAT_TYPE(data_a[pos_a + (loadc_a + l) * p.stride_a + loadr_a]);
@ -808,14 +810,19 @@ void main() {
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
#endif #endif
const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B; const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
buf_b[buf_idx + 0] = FLOAT_TYPE(data_b[idx][0].x); #if defined(DATA_B_BF16)
buf_b[buf_idx + 1] = FLOAT_TYPE(data_b[idx][0].y); B_TYPE32 bb = TO_FLOAT_TYPE(data_b[idx]);
buf_b[buf_idx + 2] = FLOAT_TYPE(data_b[idx][0].z); #else
buf_b[buf_idx + 3] = FLOAT_TYPE(data_b[idx][0].w); B_TYPE32 bb = B_TYPE32(data_b[idx]);
buf_b[buf_idx + 4] = FLOAT_TYPE(data_b[idx][1].x); #endif
buf_b[buf_idx + 5] = FLOAT_TYPE(data_b[idx][1].y); buf_b[buf_idx + 0] = FLOAT_TYPE(bb[0].x);
buf_b[buf_idx + 6] = FLOAT_TYPE(data_b[idx][1].z); buf_b[buf_idx + 1] = FLOAT_TYPE(bb[0].y);
buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w); buf_b[buf_idx + 2] = FLOAT_TYPE(bb[0].z);
buf_b[buf_idx + 3] = FLOAT_TYPE(bb[0].w);
buf_b[buf_idx + 4] = FLOAT_TYPE(bb[1].x);
buf_b[buf_idx + 5] = FLOAT_TYPE(bb[1].y);
buf_b[buf_idx + 6] = FLOAT_TYPE(bb[1].z);
buf_b[buf_idx + 7] = FLOAT_TYPE(bb[1].w);
#elif LOAD_VEC_B == 4 #elif LOAD_VEC_B == 4
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
const u16vec2 row_idx = row_ids[loadc_b + l]; const u16vec2 row_idx = row_ids[loadc_b + l];
@ -824,10 +831,15 @@ void main() {
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
#endif #endif
const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B; const uint buf_idx = (loadc_b + l) * SHMEM_STRIDE + loadr_b * LOAD_VEC_B;
buf_b[buf_idx + 0] = TO_FLOAT_TYPE(data_b[idx].x); #if defined(DATA_B_BF16)
buf_b[buf_idx + 1] = TO_FLOAT_TYPE(data_b[idx].y); B_TYPE32 bb = TO_FLOAT_TYPE(data_b[idx]);
buf_b[buf_idx + 2] = TO_FLOAT_TYPE(data_b[idx].z); #else
buf_b[buf_idx + 3] = TO_FLOAT_TYPE(data_b[idx].w); B_TYPE32 bb = B_TYPE32(data_b[idx]);
#endif
buf_b[buf_idx + 0] = FLOAT_TYPE(bb.x);
buf_b[buf_idx + 1] = FLOAT_TYPE(bb.y);
buf_b[buf_idx + 2] = FLOAT_TYPE(bb.z);
buf_b[buf_idx + 3] = FLOAT_TYPE(bb.w);
#elif !MUL_MAT_ID #elif !MUL_MAT_ID
if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) { if (ic * BN + loadc_b + l < p.N && block + loadr_b < end_k) {
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]); buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + (loadc_b + l) * p.stride_b + loadr_b]);

View File

@ -1,7 +1,25 @@
#version 450 #version 450
#include "types.comp" #include "types.comp"
#include "generic_unary_head.comp"
layout (push_constant) uniform parameter
{
uint ne;
uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03;
uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13;
uint misalign_offsets;
uint lp0; uint rp0;
uint lp1; uint rp1;
uint lp2; uint rp2;
uint lp3; uint rp3;
} p;
uint get_aoffset() { return p.misalign_offsets >> 16; }
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
@ -19,10 +37,13 @@ void main() {
const uint i1 = (idx - i3_offset - i2_offset) / p.ne10; const uint i1 = (idx - i3_offset - i2_offset) / p.ne10;
const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10; const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10;
const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00; const uint src0_idx = (i3 - p.lp3)*p.nb03 + (i2 - p.lp2)*p.nb02 + (i1 - p.lp1)*p.nb01 + (i0 - p.lp0)*p.nb00;
const uint dst_idx = i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0*p.nb10; const uint dst_idx = i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0*p.nb10;
const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03; const bool is_src0 = i0 >= p.lp0 && i0 < p.ne10 - p.rp0 &&
i1 >= p.lp1 && i1 < p.ne11 - p.rp1 &&
i2 >= p.lp2 && i2 < p.ne12 - p.rp2 &&
i3 >= p.lp3 && i3 < p.ne13 - p.rp3;
data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f); data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f);
} }

View File

@ -20,6 +20,10 @@ void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x; const uint tid = gl_LocalInvocationID.x;
if (row >= p.KY) {
return;
}
FLOAT_TYPE scale = p.param1; FLOAT_TYPE scale = p.param1;
// partial sums for thread in warp // partial sums for thread in warp

View File

@ -13,10 +13,13 @@
#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1 #if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
#define A_TYPE float #define A_TYPE float
#define A_TYPE32 float
#elif LOAD_VEC_A == 4 #elif LOAD_VEC_A == 4
#define A_TYPE vec4 #define A_TYPE vec4
#define A_TYPE32 vec4
#elif LOAD_VEC_A == 8 #elif LOAD_VEC_A == 8
#define A_TYPE mat2x4 #define A_TYPE mat2x4
#define A_TYPE32 mat2x4
#endif #endif
#endif #endif
@ -26,10 +29,13 @@
#if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1 #if !defined(LOAD_VEC_A) || LOAD_VEC_A == 1
#define A_TYPE float16_t #define A_TYPE float16_t
#define A_TYPE32 float
#elif LOAD_VEC_A == 4 #elif LOAD_VEC_A == 4
#define A_TYPE f16vec4 #define A_TYPE f16vec4
#define A_TYPE32 vec4
#elif LOAD_VEC_A == 8 #elif LOAD_VEC_A == 8
#define A_TYPE f16mat2x4 #define A_TYPE f16mat2x4
#define A_TYPE32 mat2x4
#endif #endif
#endif #endif
@ -1424,6 +1430,11 @@ float bf16_to_fp32(uint32_t u)
return uintBitsToFloat(u << 16); return uintBitsToFloat(u << 16);
} }
vec4 bf16_to_fp32(uvec4 u)
{
return vec4(bf16_to_fp32(u.x), bf16_to_fp32(u.y), bf16_to_fp32(u.z), bf16_to_fp32(u.w));
}
float e8m0_to_fp32(uint8_t x) { float e8m0_to_fp32(uint8_t x) {
uint32_t bits; uint32_t bits;

View File

@ -364,11 +364,11 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
}; };
// Shaders with f16 B_TYPE // Shaders with f16 B_TYPE
string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE32", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE32", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); string_to_spv(shader_name + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("f16")}, {"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
// bf16 // bf16
{ {
@ -384,8 +384,8 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
if (!(coopmat || coopmat2)) if (!(coopmat || coopmat2))
#endif #endif
{ {
string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"B_TYPE32", "vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}}), fp16, coopmat, coopmat2, f16acc); string_to_spv(shader_name + "_bf16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE("bf16")}, {"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc);
} }
} }
@ -408,13 +408,13 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
// don't generate f32 variants for coopmat2 // don't generate f32 variants for coopmat2
if (!coopmat2) { if (!coopmat2) {
string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"B_TYPE32", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
} }
if (tname != "f16" && tname != "f32") { if (tname != "f16" && tname != "f32") {
string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc);
string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"B_TYPE32", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc);
} }
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
@ -560,10 +560,14 @@ void process_shaders() {
string_to_spv("cpy_f16_f32", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("cpy_f16_f32", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("contig_cpy_f32_i32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}});
string_to_spv("contig_cpy_i32_f32", "contig_copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}});
string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}});
string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
string_to_spv("contig_cpy_f16_f32", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); string_to_spv("contig_cpy_f16_f32", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}});
string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}});
string_to_spv("cpy_f32_i32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}});
string_to_spv("cpy_i32_f32", "copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}});
for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) {
string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
@ -713,6 +717,10 @@ void process_shaders() {
string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}})); string_to_spv("im2col_f32_f16", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
string_to_spv("im2col_f32_f16_rte", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}})); string_to_spv("im2col_f32_f16_rte", "im2col.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}));
string_to_spv("im2col_3d_f32", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("im2col_3d_f32_f16", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}));
string_to_spv("im2col_3d_f32_f16_rte", "im2col_3d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}));
string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}}));
string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}});

View File

@ -815,6 +815,7 @@ static ggml_backend_i ggml_backend_webgpu_i = {
/* .graph_compute = */ ggml_backend_webgpu_graph_compute, /* .graph_compute = */ ggml_backend_webgpu_graph_compute,
/* .event_record = */ NULL, /* .event_record = */ NULL,
/* .event_wait = */ NULL, /* .event_wait = */ NULL,
/* .optimize_graph = */ NULL,
}; };
/* End GGML Backend Interface */ /* End GGML Backend Interface */
@ -1394,17 +1395,15 @@ static ggml_backend_dev_t ggml_backend_webgpu_reg_get_device(ggml_backend_reg_t
webgpu_context ctx = reg_ctx->webgpu_ctx; webgpu_context ctx = reg_ctx->webgpu_ctx;
wgpu::RequestAdapterOptions options = {}; wgpu::RequestAdapterOptions options = {};
auto callback = [](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message,
void * userdata) {
if (status != wgpu::RequestAdapterStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
return;
}
*static_cast<wgpu::Adapter *>(userdata) = std::move(adapter);
};
void * userdata = &ctx->adapter;
ctx->instance.WaitAny( ctx->instance.WaitAny(
ctx->instance.RequestAdapter(&options, wgpu::CallbackMode::AllowSpontaneous, callback, userdata), UINT64_MAX); ctx->instance.RequestAdapter(&options, wgpu::CallbackMode::AllowSpontaneous,
[&ctx](wgpu::RequestAdapterStatus status, wgpu::Adapter adapter, const char * message) {
if (status != wgpu::RequestAdapterStatus::Success) {
GGML_LOG_ERROR("ggml_webgpu: Failed to get an adapter: %s\n", message);
return;
}
ctx->adapter = std::move(adapter);
}), UINT64_MAX);
GGML_ASSERT(ctx->adapter != nullptr); GGML_ASSERT(ctx->adapter != nullptr);
ctx->adapter.GetLimits(&ctx->limits); ctx->adapter.GetLimits(&ctx->limits);

View File

@ -586,6 +586,7 @@ static ggml_backend_i ggml_backend_zdnn_i = {
/* .graph_compute = */ ggml_backend_zdnn_graph_compute, /* .graph_compute = */ ggml_backend_zdnn_graph_compute,
/* .event_record = */ NULL, /* .event_record = */ NULL,
/* .event_wait = */ NULL, /* .event_wait = */ NULL,
/* .optimize_graph = */ NULL,
}; };
static ggml_guid_t ggml_backend_zdnn_guid(void) { static ggml_guid_t ggml_backend_zdnn_guid(void) {

View File

@ -3623,6 +3623,7 @@ struct ggml_tensor * ggml_get_rows(
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b) { struct ggml_tensor * b) {
GGML_ASSERT(a->ne[2] == b->ne[1]); GGML_ASSERT(a->ne[2] == b->ne[1]);
GGML_ASSERT(a->ne[3] == b->ne[2]);
GGML_ASSERT(b->ne[3] == 1); GGML_ASSERT(b->ne[3] == 1);
GGML_ASSERT(b->type == GGML_TYPE_I32); GGML_ASSERT(b->type == GGML_TYPE_I32);

View File

@ -1166,50 +1166,51 @@ void gguf_set_tensor_data(struct gguf_context * ctx, const char * name, const vo
ctx->info[tensor_id].t.data = (void *)(uintptr_t)data; // double cast suppresses warning about casting away const ctx->info[tensor_id].t.data = (void *)(uintptr_t)data; // double cast suppresses warning about casting away const
} }
struct gguf_writer { struct gguf_writer_base {
std::vector<int8_t> & buf; size_t written_bytes {0u};
gguf_writer(std::vector<int8_t> & buf) : buf(buf) {} ~gguf_writer_base(void) {}
// we bet on devirtualization
virtual void write(int8_t val) = 0;
virtual void write(const std::vector<int8_t> & val) = 0;
virtual void write_tensor_data(const struct gguf_tensor_info & info, size_t offset_data, size_t alignment) = 0;
template <typename T> template <typename T>
void write(const T & val) const { void write(const T & val) {
for (size_t i = 0; i < sizeof(val); ++i) { for (size_t i = 0; i < sizeof(val); ++i) {
buf.push_back(reinterpret_cast<const int8_t *>(&val)[i]); write(reinterpret_cast<const int8_t *>(&val)[i]);
} }
} }
void write(const std::vector<int8_t> & val) const { void write(const bool & val) {
buf.insert(buf.end(), val.begin(), val.end());
}
void write(const bool & val) const {
const int8_t val8 = val ? 1 : 0; const int8_t val8 = val ? 1 : 0;
write(val8); write(val8);
} }
void write(const std::string & val) const { void write(const std::string & val) {
{ {
const uint64_t n = val.length(); const uint64_t n = val.length();
write(n); write(n);
} }
for (size_t i = 0; i < val.length(); ++i) { for (size_t i = 0; i < val.length(); ++i) {
buf.push_back(reinterpret_cast<const int8_t *>(val.data())[i]); write((val.data())[i]);
} }
} }
void write(const char * val) const { void write(const char * val) {
write(std::string(val)); write(std::string(val));
} }
void write(const enum ggml_type & val) const { void write(const enum ggml_type & val) {
write(int32_t(val)); write(int32_t(val));
} }
void write(const enum gguf_type & val) const { void write(const enum gguf_type & val) {
write(int32_t(val)); write(int32_t(val));
} }
void write(const struct gguf_kv & kv) const { void write(const struct gguf_kv & kv) {
const uint64_t ne = kv.get_ne(); const uint64_t ne = kv.get_ne();
write(kv.get_key()); write(kv.get_key());
@ -1250,7 +1251,7 @@ struct gguf_writer {
} }
} }
void write_tensor_meta(const struct gguf_tensor_info & info) const { void write_tensor_meta(const struct gguf_tensor_info & info) {
write(info.t.name); write(info.t.name);
const uint32_t n_dims = ggml_n_dims(&info.t); const uint32_t n_dims = ggml_n_dims(&info.t);
@ -1263,14 +1264,33 @@ struct gguf_writer {
write(info.offset); write(info.offset);
} }
void pad(const size_t alignment) const { void pad(const size_t alignment) {
while (buf.size() % alignment != 0) { while (written_bytes % alignment != 0) {
const int8_t zero = 0; const int8_t zero = 0;
write(zero); write(zero);
} }
} }
};
void write_tensor_data(const struct gguf_tensor_info & info, const size_t offset_data, const size_t alignment) const { // vector buffer based writer
struct gguf_writer_buf final : public gguf_writer_base {
std::vector<int8_t> & buf;
gguf_writer_buf(std::vector<int8_t> & buf) : buf(buf) {}
using gguf_writer_base::write;
void write(const int8_t val) override {
buf.push_back(val);
written_bytes++;
}
void write(const std::vector<int8_t> & val) override {
buf.insert(buf.end(), val.begin(), val.end());
written_bytes += val.size();
}
void write_tensor_data(const struct gguf_tensor_info & info, const size_t offset_data, const size_t alignment) override {
GGML_ASSERT(buf.size() - offset_data == info.offset); GGML_ASSERT(buf.size() - offset_data == info.offset);
GGML_ASSERT(ggml_is_contiguous(&info.t)); GGML_ASSERT(ggml_is_contiguous(&info.t));
@ -1284,14 +1304,58 @@ struct gguf_writer {
GGML_ASSERT(info.t.data); GGML_ASSERT(info.t.data);
memcpy(buf.data() + offset, info.t.data, nbytes); memcpy(buf.data() + offset, info.t.data, nbytes);
} }
written_bytes += nbytes;
pad(alignment); pad(alignment);
} }
}; };
void gguf_write_to_buf(const struct gguf_context * ctx, std::vector<int8_t> & buf, bool only_meta) { // file based writer
const struct gguf_writer gw(buf); struct gguf_writer_file final : public gguf_writer_base {
FILE * file;
gguf_writer_file(FILE* file) : file(file) {}
using gguf_writer_base::write;
void write(const int8_t val) override {
const auto real_val = static_cast<uint8_t>(val);
const auto ret = fputc(real_val, file);
written_bytes++;
if (ret != real_val) {
throw std::runtime_error("unexpected fputc result '" + std::to_string(ret) + "' instead of '" + std::to_string((int)real_val) + "'");
}
}
void write(const std::vector<int8_t> & val) override {
const auto ret = fwrite(val.data(), 1, val.size(), file);
written_bytes += val.size();
if (ret != val.size()) {
throw std::runtime_error("unexpected fwrite number of bytes written, '" + std::to_string(ret) + "' instead of '" + std::to_string(val.size()) + "'");
}
}
void write_tensor_data(const struct gguf_tensor_info & info, const size_t offset_data, const size_t alignment) override {
GGML_ASSERT(written_bytes - offset_data == info.offset);
GGML_ASSERT(ggml_is_contiguous(&info.t));
const size_t nbytes = ggml_nbytes(&info.t);
std::vector<int8_t> buf(nbytes);
if (info.t.buffer) {
ggml_backend_tensor_get(&info.t, buf.data(), 0, nbytes);
} else {
GGML_ASSERT(info.t.data);
memcpy(buf.data(), info.t.data, nbytes);
}
write(buf);
pad(alignment);
}
};
template <typename writer_t>
static void gguf_write_out(const struct gguf_context * ctx, writer_t & gw, bool only_meta) {
const int64_t n_kv = gguf_get_n_kv(ctx); const int64_t n_kv = gguf_get_n_kv(ctx);
const int64_t n_tensors = gguf_get_n_tensors(ctx); const int64_t n_tensors = gguf_get_n_tensors(ctx);
@ -1321,7 +1385,7 @@ void gguf_write_to_buf(const struct gguf_context * ctx, std::vector<int8_t> & bu
return; return;
} }
const size_t offset_data = gw.buf.size(); const size_t offset_data = gw.written_bytes;
// write tensor data // write tensor data
for (int64_t i = 0; i < n_tensors; ++i) { for (int64_t i = 0; i < n_tensors; ++i) {
@ -1329,6 +1393,11 @@ void gguf_write_to_buf(const struct gguf_context * ctx, std::vector<int8_t> & bu
} }
} }
void gguf_write_to_buf(const struct gguf_context * ctx, std::vector<int8_t> & buf, bool only_meta) {
gguf_writer_buf gw(buf);
gguf_write_out(ctx, gw, only_meta);
}
bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) { bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, bool only_meta) {
FILE * file = ggml_fopen(fname, "wb"); FILE * file = ggml_fopen(fname, "wb");
@ -1337,11 +1406,17 @@ bool gguf_write_to_file(const struct gguf_context * ctx, const char * fname, boo
return false; return false;
} }
std::vector<int8_t> buf; try {
gguf_write_to_buf(ctx, buf, only_meta); gguf_writer_file gw(file);
const bool ok = fwrite(buf.data(), 1, buf.size(), file) == buf.size(); gguf_write_out(ctx, gw, only_meta);
} catch (const std::runtime_error& ex) {
GGML_LOG_ERROR("%s: failed to write GGUF data into '%s': %s\n", __func__, fname, ex.what());
fclose(file);
return false;
}
fclose(file); fclose(file);
return ok; return true;
} }
size_t gguf_get_meta_size(const struct gguf_context * ctx) { size_t gguf_get_meta_size(const struct gguf_context * ctx) {

View File

@ -109,6 +109,7 @@ class Keys:
POOLING_TYPE = "{arch}.pooling_type" POOLING_TYPE = "{arch}.pooling_type"
LOGIT_SCALE = "{arch}.logit_scale" LOGIT_SCALE = "{arch}.logit_scale"
DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id" DECODER_START_TOKEN_ID = "{arch}.decoder_start_token_id"
DECODER_BLOCK_COUNT = "{arch}.decoder_block_count"
ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping" ATTN_LOGIT_SOFTCAPPING = "{arch}.attn_logit_softcapping"
FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping" FINAL_LOGIT_SOFTCAPPING = "{arch}.final_logit_softcapping"
SWIN_NORM = "{arch}.swin_norm" SWIN_NORM = "{arch}.swin_norm"
@ -231,10 +232,11 @@ class Keys:
MIDDLE_ID = "tokenizer.ggml.middle_token_id" MIDDLE_ID = "tokenizer.ggml.middle_token_id"
class Adapter: class Adapter:
TYPE = "adapter.type" TYPE = "adapter.type"
LORA_ALPHA = "adapter.lora.alpha" LORA_ALPHA = "adapter.lora.alpha"
LORA_TASK_NAME = "adapter.lora.task_name" LORA_TASK_NAME = "adapter.lora.task_name"
LORA_PROMPT_PREFIX = "adapter.lora.prompt_prefix" LORA_PROMPT_PREFIX = "adapter.lora.prompt_prefix"
ALORA_INVOCATION_TOKENS = "adapter.alora.invocation_tokens"
class IMatrix: class IMatrix:
CHUNK_COUNT = "imatrix.chunk_count" CHUNK_COUNT = "imatrix.chunk_count"

View File

@ -676,6 +676,9 @@ class GGUFWriter:
def add_decoder_start_token_id(self, id: int) -> None: def add_decoder_start_token_id(self, id: int) -> None:
self.add_uint32(Keys.LLM.DECODER_START_TOKEN_ID.format(arch=self.arch), id) self.add_uint32(Keys.LLM.DECODER_START_TOKEN_ID.format(arch=self.arch), id)
def add_decoder_block_count(self, value: int) -> None:
self.add_uint32(Keys.LLM.DECODER_BLOCK_COUNT.format(arch=self.arch), value)
def add_embedding_length_per_layer_input(self, value: int) -> None: def add_embedding_length_per_layer_input(self, value: int) -> None:
self.add_uint32(Keys.LLM.EMBD_LENGTH_PER_LAYER_INP.format(arch=self.arch), value) self.add_uint32(Keys.LLM.EMBD_LENGTH_PER_LAYER_INP.format(arch=self.arch), value)

View File

@ -583,6 +583,10 @@ extern "C" {
// Note: loaded adapters will be free when the associated model is deleted // Note: loaded adapters will be free when the associated model is deleted
LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter); LLAMA_API void llama_adapter_lora_free(struct llama_adapter_lora * adapter);
// Get the invocation tokens if the current lora is an alora
LLAMA_API uint64_t llama_adapter_get_alora_n_invocation_tokens(const struct llama_adapter_lora * adapter);
LLAMA_API const llama_token * llama_adapter_get_alora_invocation_tokens (const struct llama_adapter_lora * adapter);
// The following functions operate on a llama_context, hence the naming: llama_verb_... // The following functions operate on a llama_context, hence the naming: llama_verb_...
// Add a loaded LoRA adapter to given context // Add a loaded LoRA adapter to given context

Binary file not shown.

After

Width:  |  Height:  |  Size: 14 KiB

View File

@ -0,0 +1,77 @@
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<svg
id="Layer_1"
version="1.1"
viewBox="0 0 250 250"
sodipodi:docname="llama1-icon-transparent.svg"
width="250"
height="250"
inkscape:version="1.4.2 (ebf0e940d0, 2025-05-08)"
xmlns:inkscape="http://www.inkscape.org/namespaces/inkscape"
xmlns:sodipodi="http://sodipodi.sourceforge.net/DTD/sodipodi-0.dtd"
xmlns="http://www.w3.org/2000/svg"
xmlns:svg="http://www.w3.org/2000/svg">
<sodipodi:namedview
id="namedview7"
pagecolor="#505050"
bordercolor="#ffffff"
borderopacity="1"
inkscape:showpageshadow="0"
inkscape:pageopacity="0"
inkscape:pagecheckerboard="1"
inkscape:deskcolor="#505050"
inkscape:zoom="2.48"
inkscape:cx="49.596774"
inkscape:cy="189.91935"
inkscape:window-width="3440"
inkscape:window-height="1440"
inkscape:window-x="0"
inkscape:window-y="0"
inkscape:window-maximized="1"
inkscape:current-layer="Layer_1" />
<!-- Generator: Adobe Illustrator 29.3.1, SVG Export Plug-In . SVG Version: 2.1.0 Build 151) -->
<defs
id="defs1">
<style
id="style1">
.st0 {
fill: #ff8236;
}
.st1 {
fill: #fff;
}
.st2 {
fill: #1b1f20;
}
</style>
</defs>
<g
id="g7">
<g
id="g6"
transform="translate(-995.51066,-129.70875)">
<path
class="st0"
d="m 1163.3,226.8 -13.5,24 c -17.8,-13.7 -44.2,-15.7 -62,-1 -28.7,23.7 -26.7,78.5 18,78.8 12.5,0 23.1,-5.9 34.5,-9.8 l 6,23.9 c -10.1,4.7 -20.4,9.5 -31.5,11 -101.2,13.8 -95.4,-132.3 -3.9,-139.9 19.2,-1.6 36.1,3.4 52.5,13 z"
id="path4" />
<path
class="st0"
d="m 1093.4,203.8 c -15.4,4.6 -29.7,13.1 -40.5,25 -2,-24.2 3.4,-73.1 30.3,-82.7 4,-1.4 17.7,-4.9 17.3,2.2 -0.4,7.1 -9.9,19.3 -12.2,25.9 -4,11.6 -0.3,19.6 5.2,29.7 z"
id="path5" />
<polygon
class="st0"
points="1131.4,307.8 1116.4,307.8 1116.4,290.8 1099.4,290.8 1099.4,276.8 1114.9,276.8 1116.4,275.3 1116.4,258.8 1131.4,258.8 1131.4,276.8 1147.4,276.8 1147.4,290.8 1131.4,290.8 "
id="polygon5" />
<polygon
class="st0"
points="1186.4,290.8 1186.4,307.8 1171.4,307.8 1171.4,290.8 1155.4,290.8 1155.4,276.8 1171.4,276.8 1171.4,258.8 1186.4,258.8 1186.4,275.3 1187.9,276.8 1203.4,276.8 1203.4,290.8 "
id="polygon6" />
<path
class="st0"
d="m 1142.3,156.9 c 2,3 -9.3,15.9 -11.1,19.2 -5.2,9.8 -1.7,15.4 2.2,24.7 -11.3,-1.7 -21.8,-0.3 -33,1 2.5,-21.5 14.6,-52.8 41.9,-44.9 z"
id="path6" />
</g>
</g>
</svg>

After

Width:  |  Height:  |  Size: 2.6 KiB

BIN
media/llama1-icon.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

Some files were not shown because too many files have changed in this diff Show More