Merge branch 'master' into compilade/convert-prequant

This commit is contained in:
Francis Couture-Harpin 2025-09-09 14:23:06 -04:00
commit 0d5cfed596
173 changed files with 9476 additions and 3841 deletions

View File

@ -22,7 +22,7 @@ AllowShortIfStatementsOnASingleLine: Never
AllowShortLambdasOnASingleLine: Inline
AllowShortLoopsOnASingleLine: false
AlwaysBreakBeforeMultilineStrings: true
BinPackArguments: false
BinPackArguments: true
BinPackParameters: false # OnePerLine
BitFieldColonSpacing: Both
BreakBeforeBraces: Custom # Attach

View File

@ -17,7 +17,7 @@ jobs:
steps:
- uses: actions/stale@v5
with:
exempt-issue-labels: "refactoring,help wanted,good first issue,research,bug,roadmap"
exempt-issue-labels: "refactoring,help wanted,good first issue,research 🔬,bug,roadmap"
days-before-issue-stale: 30
days-before-issue-close: 14
stale-issue-label: "stale"

View File

@ -16,6 +16,9 @@
- Use the following format for the squashed commit title: `<module> : <commit title> (#<issue_number>)`. For example: `utils : fix typo in utils.py (#1234)`
- Optionally pick a `<module>` from here: https://github.com/ggml-org/llama.cpp/wiki/Modules
- Consider adding yourself to [CODEOWNERS](CODEOWNERS)
- Let authors, who are also collaborators, merge their own PRs
- When merging a PR by a contributor, make sure you have a good understanding of the changes
- Be mindful of maintenance: most of the work going into a feature happens after the PR is merged. If the PR author is not committed to contribute long-term, someone else needs to take responsibility (you)
# Coding guidelines

View File

@ -1263,6 +1263,18 @@ static std::string list_builtin_chat_templates() {
return msg.str();
}
static bool is_truthy(const std::string & value) {
return value == "on" || value == "enabled" || value == "1";
}
static bool is_falsey(const std::string & value) {
return value == "off" || value == "disabled" || value == "0";
}
static bool is_autoy(const std::string & value) {
return value == "auto" || value == "-1";
}
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
// load dynamic backends
ggml_backend_load_all();
@ -1544,21 +1556,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.n_chunks = value;
}
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_RETRIEVAL}));
add_opt(common_arg(
{"-fa", "--flash-attn"}, "FA",
string_format("set Flash Attention use ('on', 'off', or 'auto', default: '%s')", llama_flash_attn_type_name(params.flash_attn_type)),
add_opt(common_arg({ "-fa", "--flash-attn" }, "[on|off|auto]",
string_format("set Flash Attention use ('on', 'off', or 'auto', default: '%s')",
llama_flash_attn_type_name(params.flash_attn_type)),
[](common_params & params, const std::string & value) {
if (value == "on" || value == "enabled") {
if (is_truthy(value)) {
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED;
} else if (value == "off" || value == "disabled") {
} else if (is_falsey(value)) {
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
} else if (value == "auto") {
} else if (is_autoy(value)) {
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
} else {
throw std::runtime_error(string_format("error: unkown value for --flash-attn: '%s'\n", value.c_str()));
throw std::runtime_error(
string_format("error: unkown value for --flash-attn: '%s'\n", value.c_str()));
}
}
).set_env("LLAMA_ARG_FLASH_ATTN"));
}).set_env("LLAMA_ARG_FLASH_ATTN"));
add_opt(common_arg(
{"-p", "--prompt"}, "PROMPT",
"prompt to start generation with; for system message, use -sys",
@ -2466,7 +2478,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_N_CPU_MOE_DRAFT"));
add_opt(common_arg(
{"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N",
"number of layers to store in VRAM",
string_format("max. number of layers to store in VRAM (default: %d)", params.n_gpu_layers),
[](common_params & params, int value) {
params.n_gpu_layers = value;
if (!llama_supports_gpu_offload()) {
@ -3134,13 +3146,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
common_log_set_file(common_log_main(), value.c_str());
}
));
add_opt(common_arg(
{"--log-colors"},
"Enable colored logging",
[](common_params &) {
common_log_set_colors(common_log_main(), true);
add_opt(common_arg({ "--log-colors" }, "[on|off|auto]",
"Set colored logging ('on', 'off', or 'auto', default: 'auto')\n"
"'auto' enables colors when output is to a terminal",
[](common_params &, const std::string & value) {
if (is_truthy(value)) {
common_log_set_colors(common_log_main(), LOG_COLORS_ENABLED);
} else if (is_falsey(value)) {
common_log_set_colors(common_log_main(), LOG_COLORS_DISABLED);
} else if (is_autoy(value)) {
common_log_set_colors(common_log_main(), LOG_COLORS_AUTO);
} else {
throw std::invalid_argument(
string_format("error: unkown value for --log-colors: '%s'\n", value.c_str()));
}
).set_env("LLAMA_LOG_COLORS"));
}).set_env("LLAMA_LOG_COLORS"));
add_opt(common_arg(
{"-v", "--verbose", "--log-verbose"},
"Set verbosity level to infinity (i.e. log all messages, useful for debugging)",

View File

@ -163,6 +163,19 @@ common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::strin
throw std::runtime_error("Invalid tool_choice: " + tool_choice);
}
bool common_chat_templates_support_enable_thinking(const common_chat_templates * chat_templates) {
common_chat_templates_inputs dummy_inputs;
common_chat_msg msg;
msg.role = "user";
msg.content = "test";
dummy_inputs.messages = {msg};
dummy_inputs.enable_thinking = false;
const auto rendered_no_thinking = common_chat_templates_apply(chat_templates, dummy_inputs);
dummy_inputs.enable_thinking = true;
const auto rendered_with_thinking = common_chat_templates_apply(chat_templates, dummy_inputs);
return rendered_no_thinking.prompt != rendered_with_thinking.prompt;
}
template <>
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messages) {
std::vector<common_chat_msg> msgs;
@ -618,11 +631,13 @@ const char * common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2: return "FireFunction v2";
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2: return "Functionary v3.2";
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1: return "Functionary v3.1 Llama 3.1";
case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1: return "DeepSeek V3.1";
case COMMON_CHAT_FORMAT_HERMES_2_PRO: return "Hermes 2 Pro";
case COMMON_CHAT_FORMAT_COMMAND_R7B: return "Command R7B";
case COMMON_CHAT_FORMAT_GRANITE: return "Granite";
case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS";
case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2";
default:
throw std::runtime_error("Unknown chat format");
}
@ -684,11 +699,13 @@ static void parse_json_tool_calls(
size_t from = std::string::npos;
auto first = true;
while (true) {
auto start_pos = builder.pos();
auto res = function_regex_start_only && first
? builder.try_consume_regex(*function_regex_start_only)
: function_regex
? builder.try_find_regex(*function_regex, from)
: std::nullopt;
if (res) {
std::string name;
if (get_function_name) {
@ -723,6 +740,8 @@ static void parse_json_tool_calls(
return;
}
throw common_chat_msg_partial_exception("incomplete tool call");
} else {
builder.move_to(start_pos);
}
break;
}
@ -1184,6 +1203,67 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te
});
return data;
}
static common_chat_params common_chat_params_init_nemotron_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
// Generate the prompt using the apply() function with the template
data.prompt = apply(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_NEMOTRON_V2;
// Handle thinking tags appropriately based on inputs.enable_thinking
if (string_ends_with(data.prompt, "<think>\n")) {
if (!inputs.enable_thinking) {
data.prompt += "</think>";
} else {
data.thinking_forced_open = true;
}
}
// When tools are present, build grammar for the <TOOLCALL> format, similar to CommandR, but without tool call ID
if (!inputs.tools.is_null() && inputs.tools.is_array() && !inputs.tools.empty()) {
data.grammar_lazy = true;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
schemas.push_back({
{ "type", "object" },
{ "properties",
{
{ "name",
{
{ "type", "string" },
{ "const", function.at("name") },
} },
{ "arguments", function.at("parameters") },
} },
{ "required", json::array({ "name", "arguments" }) },
});
});
auto schema = json{
{ "type", "array" },
{ "items", schemas.size() == 1 ? schemas[0] : json{ { "anyOf", schemas } } },
{ "minItems", 1 },
};
if (!inputs.parallel_tool_calls) {
schema["maxItems"] = 1;
}
builder.add_rule("root",
std::string(data.thinking_forced_open ? "( \"</think>\" space )? " : "") +
"\"<TOOLCALL>\" " + builder.add_schema("tool_calls", schema) +
" \"</TOOLCALL>\"");
});
data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
// If thinking_forced_open, then we capture the </think> tag in the grammar,
// (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
std::string(data.thinking_forced_open ?
"[\\s\\S]*?(</think>\\s*)" :
"(?:<think>[\\s\\S]*?</think>\\s*)?") +
"(<TOOLCALL>)[\\s\\S]*" });
}
return data;
}
static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
@ -1313,6 +1393,71 @@ static common_chat_params common_chat_params_init_deepseek_r1(const common_chat_
}
return data;
}
static common_chat_params common_chat_params_init_deepseek_v3_1(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
// Pass thinking context for DeepSeek V3.1 template
json additional_context = {
{"thinking", inputs.enable_thinking},
};
auto prompt = apply(tmpl, inputs,
/* messages_override= */ inputs.messages,
/* tools_override= */ std::nullopt,
additional_context);
data.prompt = prompt;
data.format = COMMON_CHAT_FORMAT_DEEPSEEK_V3_1;
if (string_ends_with(data.prompt, "<think>")) {
if (!inputs.enable_thinking) {
data.prompt += "</think>";
} else {
data.thinking_forced_open = true;
}
}
if (inputs.tools.is_array() && !inputs.tools.empty()) {
data.grammar_lazy = inputs.tool_choice != COMMON_CHAT_TOOL_CHOICE_REQUIRED && inputs.json_schema.is_null();
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
std::vector<std::string> tool_rules;
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
std::string name = function.at("name");
auto parameters = function.at("parameters");
builder.resolve_refs(parameters);
tool_rules.push_back(builder.add_rule(name + "-call",
"( \"<tool▁call▁begin>\" )? \"" + name + "<tool▁sep>"
"\" " + builder.add_schema(name + "-args", parameters) + " "
"\"<tool▁call▁end>\""));
});
// Distill Qwen 7B & 32B models seem confused re/ syntax of their tool call opening tag,
// so we accept common variants (then it's all constrained)
builder.add_rule("root",
std::string(data.thinking_forced_open ? "( \"</think>\" space )? " : "") +
"( \"<tool▁calls▁begin>\" | \"<tool_calls_begin>\" | \"<tool calls begin>\" | \"<tool\\\\_calls\\\\_begin>\" | \"<tool▁calls>\" ) "
"(" + string_join(tool_rules, " | ") + ")" + (inputs.parallel_tool_calls ? "*" : "") + " "
"\"<tool▁calls▁end>\""
" space");
data.grammar_triggers.push_back({
COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
// If thinking_forced_open, then we capture the </think> tag in the grammar,
// (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
std::string(data.thinking_forced_open ? "[\\s\\S]*?(</think>\\s*)" : "(?:<think>[\\s\\S]*?</think>\\s*)?") +
"(<tool▁calls▁begin>|<tool_calls_begin>|<tool calls begin>|<tool\\\\_calls\\\\_begin>|<tool▁calls>)[\\s\\S]*"
});
data.preserved_tokens = {
"<think>",
"</think>",
"<tool▁calls▁begin>",
"<tool▁call▁begin>",
"<tool▁sep>",
"<tool▁call▁end>",
"<tool▁calls▁end>",
};
});
}
return data;
}
static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
@ -1334,6 +1479,66 @@ static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
tool_calls_end);
}
static void common_chat_parse_deepseek_v3_1_content(common_chat_msg_parser & builder) {
static const common_regex function_regex("(?:<tool▁call▁begin>)?([^\\n<]+)(?:<toolsep>)");
static const common_regex close_regex("(?:[\\s]*)?<toolcallend>");
static const common_regex tool_calls_begin("(?:<tool▁calls▁begin>|<tool_calls_begin>|<tool calls begin>|<tool\\\\_calls\\\\_begin>|<tool▁calls>)");
static const common_regex tool_calls_end("<tool▁calls▁end>");
if (!builder.syntax().parse_tool_calls) {
LOG_DBG("%s: not parse_tool_calls\n", __func__);
builder.add_content(builder.consume_rest());
return;
}
LOG_DBG("%s: parse_tool_calls\n", __func__);
parse_json_tool_calls(
builder,
/* block_open= */ tool_calls_begin,
/* function_regex_start_only= */ std::nullopt,
function_regex,
close_regex,
tool_calls_end);
}
static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) {
// DeepSeek V3.1 outputs reasoning content between "<think>" and "</think>" tags, followed by regular content
// First try to parse using the standard reasoning parsing method
LOG_DBG("%s: thinking_forced_open: %s\n", __func__, std::to_string(builder.syntax().thinking_forced_open).c_str());
auto start_pos = builder.pos();
auto found_end_think = builder.try_find_literal("</think>");
builder.move_to(start_pos);
if (builder.syntax().thinking_forced_open && !builder.is_partial() && !found_end_think) {
LOG_DBG("%s: no end_think, not partial, adding content\n", __func__);
common_chat_parse_deepseek_v3_1_content(builder);
} else if (builder.try_parse_reasoning("<think>", "</think>")) {
// If reasoning was parsed successfully, the remaining content is regular content
LOG_DBG("%s: parsed reasoning, adding content\n", __func__);
// </think><tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>NAME\n```json\nJSON\n```<tool▁call▁end><tool▁calls▁end>
common_chat_parse_deepseek_v3_1_content(builder);
} else {
if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE) {
LOG_DBG("%s: reasoning_format none, adding content\n", __func__);
common_chat_parse_deepseek_v3_1_content(builder);
return;
}
// If no reasoning tags found, check if we should treat everything as reasoning
if (builder.syntax().thinking_forced_open) {
// If thinking is forced open but no tags found, treat everything as reasoning
LOG_DBG("%s: thinking_forced_open, adding reasoning content\n", __func__);
builder.add_reasoning_content(builder.consume_rest());
} else {
LOG_DBG("%s: no thinking_forced_open, adding content\n", __func__);
// <tool▁call▁begin>NAME<tool▁sep>JSON<tool▁call▁end>
common_chat_parse_deepseek_v3_1_content(builder);
}
}
}
static common_chat_params common_chat_params_init_gpt_oss(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
auto prompt = apply(tmpl, inputs);
@ -1830,7 +2035,7 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
// If thinking_forced_open, then we capture the </think> tag in the grammar,
// (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
std::string(data.thinking_forced_open ? "[\\s\\S]*?(</think>\\s*)" : "(?:<think>[\\s\\S]*?</think>\\s*)?") + (
"(\\s*"
"\\s*("
"(?:<tool_call>"
"|<function"
"|(?:```(?:json|xml)?\n\\s*)?(?:<function_call>|<tools>|<xml><json>|<response>)?"
@ -2060,6 +2265,33 @@ static void common_chat_parse_granite(common_chat_msg_parser & builder) {
}
}
static void common_chat_parse_nemotron_v2(common_chat_msg_parser & builder) {
// Parse thinking tags
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// Look for tool calls
static const common_regex tool_call_regex(regex_escape("<TOOLCALL>"));
if (auto res = builder.try_find_regex(tool_call_regex)) {
builder.move_to(res->groups[0].end);
// Expect JSON array of tool calls
auto tool_calls_data = builder.consume_json();
if (tool_calls_data.json.is_array()) {
if (!builder.try_consume_literal("</TOOLCALL>")) {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
builder.add_tool_calls(tool_calls_data.json);
} else {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
}
builder.add_content(builder.consume_rest());
}
static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
// Parse thinking tags first - this handles the main reasoning content
builder.try_parse_reasoning("<seed:think>", "</seed:think>");
@ -2263,6 +2495,12 @@ static common_chat_params common_chat_templates_apply_jinja(
}
}
// DeepSeek V3.1: detect based on specific patterns in the template
if (src.find("message['prefix'] is defined and message['prefix'] and thinking") != std::string::npos &&
params.json_schema.is_null()) {
return common_chat_params_init_deepseek_v3_1(tmpl, params);
}
// DeepSeek R1: use handler in all cases except json schema (thinking / tools).
if (src.find("<tool▁calls▁begin>") != std::string::npos && params.json_schema.is_null()) {
return common_chat_params_init_deepseek_r1(tmpl, params);
@ -2293,6 +2531,11 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_seed_oss(tmpl, params, inputs);
}
// Nemotron v2
if (src.find("<SPECIAL_10>") != std::string::npos) {
return common_chat_params_init_nemotron_v2(tmpl, params);
}
// Use generic handler when mixing tools + JSON schema.
// TODO: support that mix in handlers below.
if ((params.tools.is_array() && params.json_schema.is_object())) {
@ -2430,6 +2673,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
case COMMON_CHAT_FORMAT_DEEPSEEK_R1:
common_chat_parse_deepseek_r1(builder);
break;
case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1:
common_chat_parse_deepseek_v3_1(builder);
break;
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2:
common_chat_parse_functionary_v3_2(builder);
break;
@ -2454,6 +2700,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
case COMMON_CHAT_FORMAT_SEED_OSS:
common_chat_parse_seed_oss(builder);
break;
case COMMON_CHAT_FORMAT_NEMOTRON_V2:
common_chat_parse_nemotron_v2(builder);
break;
default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
}

View File

@ -107,11 +107,13 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
COMMON_CHAT_FORMAT_DEEPSEEK_V3_1,
COMMON_CHAT_FORMAT_HERMES_2_PRO,
COMMON_CHAT_FORMAT_COMMAND_R7B,
COMMON_CHAT_FORMAT_GRANITE,
COMMON_CHAT_FORMAT_GPT_OSS,
COMMON_CHAT_FORMAT_SEED_OSS,
COMMON_CHAT_FORMAT_NEMOTRON_V2,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
};
@ -198,6 +200,8 @@ common_chat_msg common_chat_parse(const std::string & input, bool is_p
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
bool common_chat_templates_support_enable_thinking(const common_chat_templates * chat_templates);
// Parses a JSON array of messages in OpenAI's chat completion API format.
// T can be std::string containing JSON or nlohmann::ordered_json
template <class T> std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const T & messages);

View File

@ -843,9 +843,10 @@ public:
_build_object_rule(
properties, required, name,
schema.contains("additionalProperties") ? schema["additionalProperties"] : json()));
} else if ((schema_type.is_null() || schema_type == "object") && schema.contains("allOf")) {
} else if ((schema_type.is_null() || schema_type == "object" || schema_type == "string") && schema.contains("allOf")) {
std::unordered_set<std::string> required;
std::vector<std::pair<std::string, json>> properties;
std::map<std::string, size_t> enum_values;
std::string hybrid_name = name;
std::function<void(const json &, bool)> add_component = [&](const json & comp_schema, bool is_required) {
if (comp_schema.contains("$ref")) {
@ -857,6 +858,14 @@ public:
required.insert(prop.key());
}
}
} else if (comp_schema.contains("enum")) {
for (const auto & v : comp_schema["enum"]) {
const auto rule = _generate_constant_rule(v);
if (enum_values.find(rule) == enum_values.end()) {
enum_values[rule] = 0;
}
enum_values[rule] += 1;
}
} else {
// todo warning
}
@ -870,6 +879,17 @@ public:
add_component(t, true);
}
}
if (!enum_values.empty()) {
std::vector<std::string> enum_intersection;
for (const auto & p : enum_values) {
if (p.second == schema["allOf"].size()) {
enum_intersection.push_back(p.first);
}
}
if (!enum_intersection.empty()) {
return _add_rule(rule_name, "(" + string_join(enum_intersection, " | ") + ") space");
}
}
return _add_rule(rule_name, _build_object_rule(properties, required, hybrid_name, json()));
} else if ((schema_type.is_null() || schema_type == "array") && (schema.contains("items") || schema.contains("prefixItems"))) {
json items = schema.contains("items") ? schema["items"] : schema["prefixItems"];

View File

@ -4,17 +4,52 @@
#include <condition_variable>
#include <cstdarg>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <mutex>
#include <sstream>
#include <thread>
#include <vector>
#if defined(_WIN32)
# include <io.h>
# include <windows.h>
# define isatty _isatty
# define fileno _fileno
#else
# include <unistd.h>
#endif // defined(_WIN32)
int common_log_verbosity_thold = LOG_DEFAULT_LLAMA;
void common_log_set_verbosity_thold(int verbosity) {
common_log_verbosity_thold = verbosity;
}
// Auto-detect if colors should be enabled based on terminal and environment
static bool common_log_should_use_colors_auto() {
// Check NO_COLOR environment variable (https://no-color.org/)
if (const char * no_color = std::getenv("NO_COLOR")) {
if (no_color[0] != '\0') {
return false;
}
}
// Check TERM environment variable
if (const char * term = std::getenv("TERM")) {
if (std::strcmp(term, "dumb") == 0) {
return false;
}
}
// Check if stdout and stderr are connected to a terminal
// We check both because log messages can go to either
bool stdout_is_tty = isatty(fileno(stdout));
bool stderr_is_tty = isatty(fileno(stderr));
return stdout_is_tty || stderr_is_tty;
}
static int64_t t_us() {
return std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
}
@ -353,6 +388,11 @@ struct common_log * common_log_init() {
struct common_log * common_log_main() {
static struct common_log log;
static std::once_flag init_flag;
std::call_once(init_flag, [&]() {
// Set default to auto-detect colors
log.set_colors(common_log_should_use_colors_auto());
});
return &log;
}
@ -380,8 +420,19 @@ void common_log_set_file(struct common_log * log, const char * file) {
log->set_file(file);
}
void common_log_set_colors(struct common_log * log, bool colors) {
log->set_colors(colors);
void common_log_set_colors(struct common_log * log, log_colors colors) {
if (colors == LOG_COLORS_AUTO) {
log->set_colors(common_log_should_use_colors_auto());
return;
}
if (colors == LOG_COLORS_DISABLED) {
log->set_colors(false);
return;
}
GGML_ASSERT(colors == LOG_COLORS_ENABLED);
log->set_colors(true);
}
void common_log_set_prefix(struct common_log * log, bool prefix) {

View File

@ -24,6 +24,12 @@
#define LOG_DEFAULT_DEBUG 1
#define LOG_DEFAULT_LLAMA 0
enum log_colors {
LOG_COLORS_AUTO = -1,
LOG_COLORS_DISABLED = 0,
LOG_COLORS_ENABLED = 1,
};
// needed by the LOG_TMPL macro to avoid computing log arguments if the verbosity lower
// set via common_log_set_verbosity()
extern int common_log_verbosity_thold;
@ -66,7 +72,7 @@ void common_log_add(struct common_log * log, enum ggml_log_level level, const ch
//
void common_log_set_file (struct common_log * log, const char * file); // not thread-safe
void common_log_set_colors (struct common_log * log, bool colors); // not thread-safe
void common_log_set_colors (struct common_log * log, log_colors colors); // not thread-safe
void common_log_set_prefix (struct common_log * log, bool prefix); // whether to output prefix to each log
void common_log_set_timestamps(struct common_log * log, bool timestamps); // whether to output timestamps in the prefix

View File

@ -433,10 +433,6 @@ class ModelBase:
# data = data_torch.squeeze().numpy()
data = data_torch.numpy()
# if data ends up empty, it means data_torch was a scalar tensor -> restore
if len(data.shape) == 0:
data = data_torch.numpy()
n_dims = len(data.shape)
data_qtype: gguf.GGMLQuantizationType | bool = self.tensor_force_quant(name, new_name, bid, n_dims)
@ -5236,6 +5232,29 @@ class Gemma3Model(TextModel):
return [(self.map_tensor_name(name), data_torch)]
@ModelBase.register("Gemma3TextModel")
class EmbeddingGemma(Gemma3Model):
model_arch = gguf.MODEL_ARCH.GEMMA_EMBEDDING
def set_gguf_parameters(self):
super().set_gguf_parameters()
# Override the sliding window size as it gets adjusted by the Gemma3TextConfig
# constructor. We want to use the value from the original model's config.json.
# ref: https://github.com/huggingface/transformers/pull/40700
with open(self.dir_model / "config.json", "r", encoding="utf-8") as f:
config = json.load(f)
orig_sliding_window = config.get("sliding_window")
if orig_sliding_window is None:
raise ValueError("sliding_window not found in model config - this is required for the model")
logger.info(f"Using original sliding_window from config: {orig_sliding_window} "
f"instead of {self.hparams['sliding_window']}")
self.gguf_writer.add_sliding_window(orig_sliding_window)
self._try_set_pooling_type()
@ModelBase.register("Gemma3ForConditionalGeneration")
class Gemma3VisionModel(MmprojModel):
def set_gguf_parameters(self):

View File

@ -12,7 +12,7 @@ import json
from math import prod
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast
from transformers import AutoConfig
from transformers import AutoConfig, AutoTokenizer
import torch
@ -26,6 +26,8 @@ import gguf
# reuse model definitions from convert_hf_to_gguf.py
from convert_hf_to_gguf import LazyTorchTensor, ModelBase
from gguf.constants import GGUFValueType
logger = logging.getLogger("lora-to-gguf")
@ -369,7 +371,31 @@ if __name__ == '__main__':
self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
def set_gguf_parameters(self):
logger.debug("GGUF KV: %s = %d", gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha)
self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha)
alora_invocation_tokens = lparams.get("alora_invocation_tokens")
invocation_string = lparams.get("invocation_string")
if invocation_string and not alora_invocation_tokens:
logger.debug("Tokenizing invocation_string -> alora_invocation_tokens")
base_model_path_or_id = hparams.get("_name_or_path")
try:
tokenizer = AutoTokenizer.from_pretrained(base_model_path_or_id)
except ValueError:
logger.error("Unable to load tokenizer from %s", base_model_path_or_id)
raise
# NOTE: There's an off-by-one with the older aLoRAs where
# the invocation string includes the "<|start_of_turn|>"
# token, but the adapters themselves were trained to
# activate _after_ that first token, so we drop it here.
alora_invocation_tokens = tokenizer(invocation_string)["input_ids"][1:]
if alora_invocation_tokens:
logger.debug("GGUF KV: %s = %s", gguf.Keys.Adapter.ALORA_INVOCATION_TOKENS, alora_invocation_tokens)
self.gguf_writer.add_key_value(
gguf.Keys.Adapter.ALORA_INVOCATION_TOKENS,
alora_invocation_tokens,
GGUFValueType.ARRAY,
GGUFValueType.UINT32,
)
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
# Never add extra tensors (e.g. rope_freqs) for LoRA adapters

View File

@ -293,17 +293,14 @@ We would like to thank Tuo Dai, Shanni Li, and all of the project maintainers fr
## Environment variable setup
### GGML_CANN_ASYNC_MODE
Enables asynchronous operator submission. Disabled by default.
### GGML_CANN_MEM_POOL
Specifies the memory pool management strategy:
Specifies the memory pool management strategy, Default is vmm.
- vmm: Utilizes a virtual memory manager pool. If hardware support for VMM is unavailable, falls back to the legacy (leg) memory pool.
- prio: Employs a priority queue-based memory pool management.
- leg: Uses a fixed-size buffer pool.
### GGML_CANN_DISABLE_BUF_POOL_CLEAN
@ -312,5 +309,8 @@ Controls automatic cleanup of the memory pool. This option is only effective whe
### GGML_CANN_WEIGHT_NZ
Converting the matmul weight format from ND to NZ can significantly improve performance on the 310I DUO NPU.
Converting the matmul weight format from ND to NZ to improve performance. Enabled by default.
### GGML_CANN_ACL_GRAPH
Operators are executed using ACL graph execution, rather than in op-by-op (eager) mode. Enabled by default.

View File

@ -42,18 +42,6 @@ cmake --build build --config Release -j $(nproc)
cmake --build build --config Release -j $(nproc)
```
- By default, NNPA is disabled by default. To enable it:
```bash
cmake -S . -B build \
-DCMAKE_BUILD_TYPE=Release \
-DGGML_BLAS=ON \
-DGGML_BLAS_VENDOR=OpenBLAS \
-DGGML_NNPA=ON
cmake --build build --config Release -j $(nproc)
```
- For debug builds:
```bash
@ -164,15 +152,11 @@ All models need to be converted to Big-Endian. You can achieve this in three cas
Only available in IBM z15/LinuxONE 3 or later system with the `-DGGML_VXE=ON` (turned on by default) compile flag. No hardware acceleration is possible with llama.cpp with older systems, such as IBM z14/arch12. In such systems, the APIs can still run but will use a scalar implementation.
### 2. NNPA Vector Intrinsics Acceleration
Only available in IBM z16/LinuxONE 4 or later system with the `-DGGML_NNPA=ON` (turned off by default) compile flag. No hardware acceleration is possible with llama.cpp with older systems, such as IBM z15/arch13. In such systems, the APIs can still run but will use a scalar implementation.
### 3. zDNN Accelerator (WIP)
### 2. zDNN Accelerator (WIP)
Only available in IBM z17/LinuxONE 5 or later system with the `-DGGML_ZDNN=ON` compile flag. No hardware acceleration is possible with llama.cpp with older systems, such as IBM z15/arch13. In such systems, the APIs will default back to CPU routines.
### 4. Spyre Accelerator
### 3. Spyre Accelerator
_Only available with IBM z17 / LinuxONE 5 or later system. No support currently available._
@ -230,10 +214,6 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl
CXXFLAGS="-include cstdint" pip3 install -r requirements.txt
```
5. `-DGGML_NNPA=ON` generates gibberish output
Answer: We are aware of this as detailed in [this issue](https://github.com/ggml-org/llama.cpp/issues/14877). Please either try reducing the number of threads, or disable the compile option using `-DGGML_NNPA=OFF`.
## Getting Help on IBM Z & LinuxONE
1. **Bugs, Feature Requests**
@ -258,38 +238,38 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl
## Appendix B: SIMD Support Matrix
| | VX/VXE/VXE2 | NNPA | zDNN | Spyre |
| ---------- | ----------- | ---- | ---- | ----- |
| FP32 | ✅ | ✅ | ✅ | ❓ |
| FP16 | ✅ | ✅ | ❓ | ❓ |
| BF16 | 🚫 | 🚫 | ❓ | ❓ |
| Q4_0 | ✅ | ✅ | ❓ | ❓ |
| Q4_1 | ✅ | ✅ | ❓ | ❓ |
| MXFP4 | 🚫 | 🚫 | ❓ | ❓ |
| Q5_0 | ✅ | ✅ | ❓ | ❓ |
| Q5_1 | ✅ | ✅ | ❓ | ❓ |
| Q8_0 | ✅ | ✅ | ❓ | ❓ |
| Q2_K | 🚫 | 🚫 | ❓ | ❓ |
| Q3_K | ✅ | ✅ | ❓ | ❓ |
| Q4_K | ✅ | ✅ | ❓ | ❓ |
| Q5_K | ✅ | ✅ | ❓ | ❓ |
| Q6_K | ✅ | ✅ | ❓ | ❓ |
| TQ1_0 | 🚫 | 🚫 | ❓ | ❓ |
| TQ2_0 | 🚫 | 🚫 | ❓ | ❓ |
| IQ2_XXS | 🚫 | 🚫 | ❓ | ❓ |
| IQ2_XS | 🚫 | 🚫 | ❓ | ❓ |
| IQ2_S | 🚫 | 🚫 | ❓ | ❓ |
| IQ3_XXS | 🚫 | 🚫 | ❓ | ❓ |
| IQ3_S | 🚫 | 🚫 | ❓ | ❓ |
| IQ1_S | 🚫 | 🚫 | ❓ | ❓ |
| IQ1_M | 🚫 | 🚫 | ❓ | ❓ |
| IQ4_NL | ✅ | ✅ | ❓ | ❓ |
| IQ4_XS | ✅ | ✅ | ❓ | ❓ |
| FP32->FP16 | 🚫 | ✅ | ❓ | ❓ |
| FP16->FP32 | 🚫 | ✅ | ❓ | ❓ |
| | VX/VXE/VXE2 | zDNN | Spyre |
|------------|-------------|------|-------|
| FP32 | ✅ | ✅ | ❓ |
| FP16 | ✅ | ❓ | ❓ |
| BF16 | 🚫 | ❓ | ❓ |
| Q4_0 | ✅ | ❓ | ❓ |
| Q4_1 | ✅ | ❓ | ❓ |
| MXFP4 | 🚫 | ❓ | ❓ |
| Q5_0 | ✅ | ❓ | ❓ |
| Q5_1 | ✅ | ❓ | ❓ |
| Q8_0 | ✅ | ❓ | ❓ |
| Q2_K | 🚫 | ❓ | ❓ |
| Q3_K | ✅ | ❓ | ❓ |
| Q4_K | ✅ | ❓ | ❓ |
| Q5_K | ✅ | ❓ | ❓ |
| Q6_K | ✅ | ❓ | ❓ |
| TQ1_0 | 🚫 | ❓ | ❓ |
| TQ2_0 | 🚫 | ❓ | ❓ |
| IQ2_XXS | 🚫 | ❓ | ❓ |
| IQ2_XS | 🚫 | ❓ | ❓ |
| IQ2_S | 🚫 | ❓ | ❓ |
| IQ3_XXS | 🚫 | ❓ | ❓ |
| IQ3_S | 🚫 | ❓ | ❓ |
| IQ1_S | 🚫 | ❓ | ❓ |
| IQ1_M | 🚫 | ❓ | ❓ |
| IQ4_NL | ✅ | ❓ | ❓ |
| IQ4_XS | ✅ | ❓ | ❓ |
| FP32->FP16 | 🚫 | ❓ | ❓ |
| FP16->FP32 | 🚫 | ❓ | ❓ |
- ✅ - acceleration available
- 🚫 - acceleration unavailable, will still run using scalar implementation
- ❓ - acceleration unknown, please contribute if you can test it yourself
Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on Aug 22, 2025.
Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on Sep 6, 2025.

View File

@ -333,17 +333,17 @@ static void print_params(struct my_llama_hparams * params) {
}
static void print_tensor_info(const struct ggml_context * ctx) {
for (auto t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
for (auto * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
LOG_INF("%s: Allocating ", __func__);
int64_t total = 1;
int i = 0;
for (; i < ggml_n_dims(t); ++i) {
if (i > 0) LOG("x ");
LOG("[%" PRId64 "] ", t->ne[i]);
if (i > 0) { LOG_INF("x "); }
LOG_INF("[%" PRId64 "] ", t->ne[i]);
total *= t->ne[i];
}
if (i > 1) LOG("= [%" PRId64 "] ", total);
LOG("float space for %s\n", ggml_get_name(t));
if (i > 1) { LOG_INF("= [%" PRId64 "] ", total); }
LOG_INF("float space for %s\n", ggml_get_name(t));
}
}

View File

@ -28,6 +28,15 @@ static std::string ggml_ne_string(const ggml_tensor * t) {
return str;
}
static inline float ggml_compute_bf16_to_fp32(ggml_bf16_t h) {
union {
float f;
uint32_t i;
} u;
u.i = (uint32_t)h.bits << 16;
return u.f;
}
static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t * nb, size_t i0, size_t i1, size_t i2, size_t i3) {
size_t i = i3 * nb[3] + i2 * nb[2] + i1 * nb[1] + i0 * nb[0];
float v;
@ -43,6 +52,8 @@ static float ggml_get_float_value(uint8_t * data, ggml_type type, const size_t *
v = (float) *(int16_t *) &data[i];
} else if (type == GGML_TYPE_I8) {
v = (float) *(int8_t *) &data[i];
} else if (type == GGML_TYPE_BF16) {
v = ggml_compute_bf16_to_fp32(*(ggml_bf16_t *) &data[i]);
} else {
GGML_ABORT("fatal error");
}

View File

@ -586,9 +586,10 @@ class SchemaConverter:
properties = list(schema.get('properties', {}).items())
return self._add_rule(rule_name, self._build_object_rule(properties, required, name, schema.get('additionalProperties')))
elif schema_type in (None, 'object') and 'allOf' in schema:
elif schema_type in (None, 'object', 'string') and 'allOf' in schema:
required = set()
properties = []
enum_sets = []
hybrid_name = name
def add_component(comp_schema, is_required):
if (ref := comp_schema.get('$ref')) is not None:
@ -600,6 +601,9 @@ class SchemaConverter:
if is_required:
required.add(prop_name)
if 'enum' in comp_schema:
enum_sets.append(set(comp_schema['enum']))
for t in schema['allOf']:
if 'anyOf' in t:
for tt in t['anyOf']:
@ -607,6 +611,15 @@ class SchemaConverter:
else:
add_component(t, is_required=True)
if enum_sets:
enum_intersection = enum_sets[0]
for s in enum_sets[1:]:
enum_intersection &= s
if enum_intersection:
rule = '(' + ' | '.join((self._generate_constant_rule(v) for v in sorted(enum_intersection))) + ') space'
return self._add_rule(rule_name, rule)
return self._add_rule(rule_name, self._build_object_rule(properties, required, hybrid_name, additional_properties=None))
elif schema_type in (None, 'array') and ('items' in schema or 'prefixItems' in schema):

View File

@ -63,7 +63,7 @@ causal-verify-logits: causal-run-original-model causal-run-converted-model
@MODEL_PATH="$(MODEL_PATH)" ./scripts/utils/check-nmse.py -m ${MODEL_PATH}
causal-run-original-embeddings:
@./scripts/causal/run-casual-gen-embeddings-org.sh
@./scripts/causal/run-casual-gen-embeddings-org.py
causal-run-converted-embeddings:
@./scripts/causal/run-converted-model-embeddings-logits.sh

View File

@ -1,5 +1,6 @@
--extra-index-url https://download.pytorch.org/whl/cpu
torch~=2.6.0
torchvision~=0.21.0
transformers~=4.55.0
huggingface-hub~=0.34.0
torch
torchvision
transformers
huggingface-hub
accelerate

View File

@ -1,4 +1,4 @@
#/bin/bash
#!/usr/bin/env bash
set -e

View File

@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
set -e

View File

@ -3,11 +3,10 @@
import argparse
import os
import importlib
import sys
import torch
import numpy as np
from transformers import AutoTokenizer, AutoConfig, AutoModel, AutoModelForCausalLM
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
from pathlib import Path
unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME')
@ -43,6 +42,8 @@ if unreleased_model_name:
model = model_class.from_pretrained(model_path)
except (ImportError, AttributeError) as e:
print(f"Failed to import or load model: {e}")
print("Falling back to AutoModelForCausalLM")
model = AutoModelForCausalLM.from_pretrained(model_path)
else:
model = AutoModelForCausalLM.from_pretrained(model_path)
print(f"Model class: {type(model)}")

View File

@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
set -e

View File

@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
set -e

View File

@ -9,15 +9,134 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig
import torch
import numpy as np
unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME')
### If you want to dump RoPE activations, apply this monkey patch to the model
### class from Transformers that you are running (replace apertus.modeling_apertus
### with the proper package and class for your model
### === START ROPE DEBUG ===
# from transformers.models.apertus.modeling_apertus import apply_rotary_pos_emb
parser = argparse.ArgumentParser(description='Process model with specified path')
parser.add_argument('--model-path', '-m', help='Path to the model')
# orig_rope = apply_rotary_pos_emb
# torch.set_printoptions(threshold=float('inf'))
# torch.set_printoptions(precision=6, sci_mode=False)
# def debug_rope(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
# # log inputs
# summarize(q, "RoPE.q_in")
# summarize(k, "RoPE.k_in")
# # call original
# q_out, k_out = orig_rope(q, k, cos, sin, position_ids, unsqueeze_dim)
# # log outputs
# summarize(q_out, "RoPE.q_out")
# summarize(k_out, "RoPE.k_out")
# return q_out, k_out
# # Patch it
# import transformers.models.apertus.modeling_apertus as apertus_mod # noqa: E402
# apertus_mod.apply_rotary_pos_emb = debug_rope
### == END ROPE DEBUG ===
def summarize(tensor: torch.Tensor, name: str, max_seq: int = 3, max_vals: int = 3):
"""
Print a tensor in llama.cpp debug style.
Supports:
- 2D tensors (seq, hidden)
- 3D tensors (batch, seq, hidden)
- 4D tensors (batch, seq, heads, dim_per_head) via flattening heads × dim_per_head
Shows first and last max_vals of each vector per sequence position.
"""
t = tensor.detach().to(torch.float32).cpu()
# Determine dimensions
if t.ndim == 3:
_, s, _ = t.shape
elif t.ndim == 2:
_, s = 1, t.shape[0]
t = t.unsqueeze(0)
elif t.ndim == 4:
_, s, _, _ = t.shape
else:
print(f"Skipping tensor due to unsupported dimensions: {t.ndim}")
return
ten_shape = t.shape
print(f"ggml_debug: {name} = (f32) ... = {{{ten_shape}}}")
print(" [")
print(" [")
# Determine indices for first and last sequences
first_indices = list(range(min(s, max_seq)))
last_indices = list(range(max(0, s - max_seq), s))
# Check if there's an overlap between first and last indices or if we're at the edge case of s = 2 * max_seq
has_overlap = bool(set(first_indices) & set(last_indices)) or (max_seq * 2 == s)
# Combine indices
if has_overlap:
# If there's overlap, just use the combined unique indices
indices = sorted(list(set(first_indices + last_indices)))
separator_index = None
else:
# If no overlap, we'll add a separator between first and last sequences
indices = first_indices + last_indices
separator_index = len(first_indices)
for i, si in enumerate(indices):
# Add separator if needed
if separator_index is not None and i == separator_index:
print(" ...")
# Extract appropriate slice
vec = t[0, si]
if vec.ndim == 2: # 4D case: flatten heads × dim_per_head
flat = vec.flatten().tolist()
else: # 2D or 3D case
flat = vec.tolist()
# First and last slices
first = flat[:max_vals]
last = flat[-max_vals:] if len(flat) >= max_vals else flat
first_str = ", ".join(f"{v:12.4f}" for v in first)
last_str = ", ".join(f"{v:12.4f}" for v in last)
print(f" [{first_str}, ..., {last_str}]")
print(" ],")
print(" ]")
print(f" sum = {t.sum().item():.6f}\n")
def debug_hook(name):
def fn(_m, input, output):
if isinstance(input, torch.Tensor):
summarize(input, name + "_in")
elif isinstance(input, (tuple, list)) and isinstance(input[0], torch.Tensor):
summarize(input[0], name + "_in")
if isinstance(output, torch.Tensor):
summarize(output, name + "_out")
elif isinstance(output, (tuple, list)) and isinstance(output[0], torch.Tensor):
summarize(output[0], name + "_out")
return fn
unreleased_model_name = os.getenv("UNRELEASED_MODEL_NAME")
parser = argparse.ArgumentParser(description="Process model with specified path")
parser.add_argument("--model-path", "-m", help="Path to the model")
args = parser.parse_args()
model_path = os.environ.get('MODEL_PATH', args.model_path)
model_path = os.environ.get("MODEL_PATH", args.model_path)
if model_path is None:
parser.error("Model path must be specified either via --model-path argument or MODEL_PATH environment variable")
parser.error(
"Model path must be specified either via --model-path argument or MODEL_PATH environment variable"
)
config = AutoConfig.from_pretrained(model_path)
@ -34,18 +153,30 @@ config = AutoConfig.from_pretrained(model_path)
if unreleased_model_name:
model_name_lower = unreleased_model_name.lower()
unreleased_module_path = f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
unreleased_module_path = (
f"transformers.models.{model_name_lower}.modular_{model_name_lower}"
)
class_name = f"{unreleased_model_name}ForCausalLM"
print(f"Importing unreleased model module: {unreleased_module_path}")
try:
model_class = getattr(importlib.import_module(unreleased_module_path), class_name)
model = model_class.from_pretrained(model_path) # Note: from_pretrained, not fromPretrained
model_class = getattr(
importlib.import_module(unreleased_module_path), class_name
)
model = model_class.from_pretrained(
model_path
) # Note: from_pretrained, not fromPretrained
except (ImportError, AttributeError) as e:
print(f"Failed to import or load model: {e}")
exit(1)
else:
model = AutoModelForCausalLM.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path, device_map="auto", offload_folder="offload"
)
for name, module in model.named_modules():
if len(list(module.children())) == 0: # only leaf modules
module.register_forward_hook(debug_hook(name))
model_name = os.path.basename(model_path)
# Printing the Model class to allow for easier debugging. This can be useful

View File

@ -1,4 +1,4 @@
#/bin/bash
#!/usr/bin/env bash
set -e

View File

@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
set -e

View File

@ -7,7 +7,7 @@ base_model:
Recommended way to run this model:
```sh
llama-server -hf {namespace}/{model_name}-GGUF
llama-server -hf {namespace}/{model_name}-GGUF --embeddings
```
Then the endpoint can be accessed at http://localhost:8080/embedding, for

View File

@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
set -e

View File

@ -1,4 +1,6 @@
#!/usr/bin/env bash
COLLECTION_SLUG=$(python ./create_collection.py --return-slug)
echo "Created collection: $COLLECTION_SLUG"

View File

@ -0,0 +1,6 @@
#!/usr/bin/env bash
curl --request POST \
--url http://localhost:8080/embedding \
--header "Content-Type: application/json" \
--data '{"input": "Hello world today"}' \
--silent

View File

@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
# First try command line argument, then environment variable, then file
CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"

View File

@ -40,7 +40,7 @@ if os.path.exists(index_path):
file_path = os.path.join(model_path, file_name)
print(f"\n--- From {file_name} ---")
with safe_open(file_path, framework="pt") as f:
with safe_open(file_path, framework="pt") as f: # type: ignore
for tensor_name in sorted(tensor_names):
tensor = f.get_tensor(tensor_name)
print(f"- {tensor_name} : shape = {tensor.shape}, dtype = {tensor.dtype}")
@ -49,7 +49,7 @@ elif os.path.exists(single_file_path):
# Single file model (original behavior)
print("Single-file model detected")
with safe_open(single_file_path, framework="pt") as f:
with safe_open(single_file_path, framework="pt") as f: # type: ignore
keys = f.keys()
print("Tensors in model:")
for key in sorted(keys):

View File

@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
set -e

View File

@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
set -e

View File

@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
set -e

View File

@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
set -e

View File

@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
set -e
#

View File

@ -129,10 +129,11 @@ endif()
option(GGML_LASX "ggml: enable lasx" ON)
option(GGML_LSX "ggml: enable lsx" ON)
option(GGML_RVV "ggml: enable rvv" ON)
option(GGML_RV_ZFH "ggml: enable riscv zfh" OFF)
option(GGML_RV_ZFH "ggml: enable riscv zfh" ON)
option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON)
option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON)
option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF)
option(GGML_VXE "ggml: enable vxe" ON)
option(GGML_NNPA "ggml: enable nnpa" OFF) # temp disabled by default, see: https://github.com/ggml-org/llama.cpp/issues/14877
option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF)
set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM")

View File

@ -101,7 +101,6 @@ extern "C" {
GGML_BACKEND_API int ggml_cpu_has_riscv_v (void);
GGML_BACKEND_API int ggml_cpu_has_vsx (void);
GGML_BACKEND_API int ggml_cpu_has_vxe (void);
GGML_BACKEND_API int ggml_cpu_has_nnpa (void);
GGML_BACKEND_API int ggml_cpu_has_wasm_simd (void);
GGML_BACKEND_API int ggml_cpu_has_llamafile (void);
@ -135,6 +134,7 @@ extern "C" {
GGML_BACKEND_API ggml_backend_reg_t ggml_backend_cpu_reg(void);
GGML_BACKEND_API void ggml_cpu_fp32_to_fp32(const float *, float *, int64_t);
GGML_BACKEND_API void ggml_cpu_fp32_to_i32 (const float *, int32_t *, int64_t);
GGML_BACKEND_API void ggml_cpu_fp32_to_fp16(const float *, ggml_fp16_t *, int64_t);
GGML_BACKEND_API void ggml_cpu_fp16_to_fp32(const ggml_fp16_t *, float *, int64_t);
GGML_BACKEND_API void ggml_cpu_fp32_to_bf16(const float *, ggml_bf16_t *, int64_t);

View File

@ -511,6 +511,7 @@ extern "C" {
GGML_OP_CONV_TRANSPOSE_1D,
GGML_OP_IM2COL,
GGML_OP_IM2COL_BACK,
GGML_OP_IM2COL_3D,
GGML_OP_CONV_2D,
GGML_OP_CONV_3D,
GGML_OP_CONV_2D_DW,
@ -1403,6 +1404,7 @@ extern "C" {
struct ggml_tensor * a,
struct ggml_tensor * b);
// note: casting from f32 to i32 will discard the fractional part
GGML_API struct ggml_tensor * ggml_cast(
struct ggml_context * ctx,
struct ggml_tensor * a,
@ -1527,7 +1529,11 @@ extern "C" {
struct ggml_context * ctx,
struct ggml_tensor * a);
// supports 3D: a->ne[2] == b->ne[1]
// supports 4D a:
// a [n_embd, ne1, ne2, ne3]
// b I32 [n_rows, ne2, ne3, 1]
//
// return [n_embd, n_rows, ne2, ne3]
GGML_API struct ggml_tensor * ggml_get_rows(
struct ggml_context * ctx,
struct ggml_tensor * a, // data
@ -1870,6 +1876,41 @@ extern "C" {
int d0, // dilation dimension 0
int d1); // dilation dimension 1
GGML_API struct ggml_tensor * ggml_im2col_3d(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int64_t IC,
int s0, // stride width
int s1, // stride height
int s2, // stride depth
int p0, // padding width
int p1, // padding height
int p2, // padding depth
int d0, // dilation width
int d1, // dilation height
int d2, // dilation depth
enum ggml_type dst_type);
// a: [OC*IC, KD, KH, KW]
// b: [N*IC, ID, IH, IW]
// result: [N*OC, OD, OH, OW]
GGML_API struct ggml_tensor * ggml_conv_3d(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int64_t IC,
int s0, // stride width
int s1, // stride height
int s2, // stride depth
int p0, // padding width
int p1, // padding height
int p2, // padding depth
int d0, // dilation width
int d1, // dilation height
int d2 // dilation depth
);
// kernel size is a->ne[0] x a->ne[1]
// stride is equal to kernel size
// padding is zero
@ -1941,7 +1982,7 @@ extern "C" {
int d0, // dilation dimension 0
int d1); // dilation dimension 1
GGML_API struct ggml_tensor * ggml_conv_3d(
GGML_API struct ggml_tensor * ggml_conv_3d_direct(
struct ggml_context * ctx,
struct ggml_tensor * a, // kernel [KW, KH, KD, IC * OC]
struct ggml_tensor * b, // input [W, H, D, C * N]
@ -2048,6 +2089,19 @@ extern "C" {
int p2,
int p3);
GGML_API struct ggml_tensor * ggml_pad_ext(
struct ggml_context * ctx,
struct ggml_tensor * a,
int lp0,
int rp0,
int lp1,
int rp1,
int lp2,
int rp2,
int lp3,
int rp3
);
// pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c]
GGML_API struct ggml_tensor * ggml_pad_reflect_1d(
struct ggml_context * ctx,

View File

@ -114,6 +114,9 @@ extern "C" {
void (*event_record)(ggml_backend_t backend, ggml_backend_event_t event);
// wait for an event on on a different stream
void (*event_wait) (ggml_backend_t backend, ggml_backend_event_t event);
// (optional) sort/optimize the nodes in the graph
void (*optimize_graph) (ggml_backend_t backend, struct ggml_cgraph * cgraph);
};
struct ggml_backend {

View File

@ -463,6 +463,13 @@ void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event)
backend->iface.event_wait(backend, event);
}
static void ggml_backend_optimize_graph(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
GGML_ASSERT(backend);
if (backend->iface.optimize_graph != NULL) {
backend->iface.optimize_graph(backend, cgraph);
}
}
// Backend device
const char * ggml_backend_dev_name(ggml_backend_dev_t device) {
@ -651,7 +658,7 @@ static bool ggml_is_view_op(enum ggml_op op) {
#endif
#ifndef GGML_SCHED_MAX_SPLIT_INPUTS
#define GGML_SCHED_MAX_SPLIT_INPUTS GGML_MAX_SRC
#define GGML_SCHED_MAX_SPLIT_INPUTS 30
#endif
#ifndef GGML_SCHED_MAX_COPIES
@ -1298,6 +1305,10 @@ void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgra
struct ggml_backend_sched_split * split = &sched->splits[i];
split->graph = ggml_graph_view(graph, split->i_start, split->i_end);
// Optimize this split of the graph. This needs to happen before we make graph_copy,
// so they are in sync.
ggml_backend_optimize_graph(sched->backends[split->backend_id], &split->graph);
// add inputs to the graph copy so that they are allocated by ggml-alloc at the start of the split
for (int j = 0; j < split->n_inputs; j++) {
assert(graph_copy->size > (graph_copy->n_nodes + 1));

View File

@ -270,6 +270,7 @@ static struct ggml_backend_i blas_backend_i = {
/* .graph_compute = */ ggml_backend_blas_graph_compute,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
/* .optimize_graph = */ NULL,
};
static ggml_guid_t ggml_backend_blas_guid(void) {

View File

@ -70,6 +70,8 @@
#include <aclnnop/aclnn_zero.h>
#include <aclnnop/aclnn_index_copy.h>
#include <aclnnop/aclnn_index_select.h>
#include <aclnnop/aclnn_clamp.h>
#include <aclnnop/aclnn_threshold.h>
#include <float.h>
#include <cmath>
@ -587,9 +589,16 @@ void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
// the position of elements in the array means which dirction to padding,
// each position means: [dim0.front, dim0.behind, dim1.front, dim1.behind,
// dim2.front, dim2.behind, dim3.front, dim3.behind]
int64_t paddings[] = {
0, dst->ne[0] - src->ne[0], 0, dst->ne[1] - src->ne[1],
0, dst->ne[2] - src->ne[2], 0, dst->ne[3] - src->ne[3]};
const int32_t lp0 = ggml_get_op_params_i32(dst, 0);
const int32_t rp0 = ggml_get_op_params_i32(dst, 1);
const int32_t lp1 = ggml_get_op_params_i32(dst, 2);
const int32_t rp1 = ggml_get_op_params_i32(dst, 3);
const int32_t lp2 = ggml_get_op_params_i32(dst, 4);
const int32_t rp2 = ggml_get_op_params_i32(dst, 5);
const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
int64_t paddings[] = {lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3};
aclnn_pad(ctx, acl_src, acl_dst, paddings);
ggml_cann_release_resources(ctx, acl_src, acl_dst);
}
@ -973,18 +982,19 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
);
// build rstd, zero...
size_t acl_rstd_nb[GGML_MAX_DIMS];
int64_t acl_rstd_ne[] = {src->ne[1], src->ne[2], src->ne[3]};
size_t acl_rstd_nb[GGML_MAX_DIMS - 1];
acl_rstd_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
acl_rstd_nb[i] = acl_rstd_nb[i - 1] * src->ne[i - 1];
for (int i = 1; i < GGML_MAX_DIMS - 1; i++) {
acl_rstd_nb[i] = acl_rstd_nb[i - 1] * acl_rstd_ne[i - 1];
}
aclTensor* acl_rstd = get_f32_cache_acl_tensor(
ctx,
&ctx.rms_norm_zero_tensor_cache.cache,
ctx.rms_norm_zero_tensor_cache.size,
src->ne,
acl_rstd_ne,
acl_rstd_nb,
GGML_MAX_DIMS,
GGML_MAX_DIMS - 1,
0.0f // value
);
@ -1423,21 +1433,25 @@ static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx,
* @param start Starting exponent offset.
* @param stop Stopping exponent offset (exclusive).
* @param step Step size for the exponent increment.
* @param dtype Data type for slope tensor.
*/
static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_buffer,
float m, int64_t size, float start, float stop, float step){
int64_t ne[] = {size};
size_t nb[] = {sizeof(uint16_t)};
float m, int64_t size, float start, float stop, float step, ggml_type dtype){
aclDataType acl_type = ggml_cann_type_mapping(dtype);
size_t type_size = ggml_type_size(dtype);
ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * sizeof(uint16_t));
int64_t ne[] = {size};
size_t nb[] = {type_size};
ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * type_size);
void* arange_buffer = arange_allocator.get();
aclTensor* arange_tensor = ggml_cann_create_tensor(
arange_buffer, ACL_FLOAT16, sizeof(uint16_t), ne, nb, 1);
arange_buffer, acl_type, type_size, ne, nb, 1);
aclnn_arange(ctx, arange_tensor, start, stop, step, size);
aclTensor* slope_tensor = ggml_cann_create_tensor(
slope_buffer, ACL_FLOAT16, sizeof(uint16_t), ne, nb, 1);
slope_buffer, acl_type, type_size, ne, nb, 1);
aclScalar* sc = aclCreateScalar(&m, aclDataType::ACL_FLOAT);
@ -1468,10 +1482,11 @@ static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_bu
* @param n_head Total number of attention heads.
* @param slope_buffer Pointer to the output buffer (float array) for storing slopes.
* @param max_bias Maximum bias value for slope computation.
* @param dtype Data type for slope tensor.
*
*/
static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head,
void* slope_buffer, float max_bias) {
void* slope_buffer, float max_bias, ggml_type dtype) {
const int n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
float m0 = powf(2.0f, -(max_bias) / n_head_log2);
@ -1488,7 +1503,7 @@ static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head,
float step = 1;
float count = n_head_log2;
// end needs to be +1 because aclnn uses a left-closed, right-open interval.
aclnn_get_slope_inner(ctx, slope_buffer, m0, count, start, end + 1, step);
aclnn_get_slope_inner(ctx, slope_buffer, m0, count, start, end + 1, step, dtype);
if (n_head_log2 < n_head) {
// arange2
start = 2 * (n_head_log2 - n_head_log2) + 1;
@ -1497,7 +1512,7 @@ static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head,
count = n_head - n_head_log2;
aclnn_get_slope_inner(
ctx, (char *) slope_buffer + n_head_log2 * sizeof(float),
m1, count, start, end + 1, step);
m1, count, start, end + 1, step, dtype);
}
}
@ -1534,7 +1549,7 @@ static void aclnn_add_alibi(ggml_backend_cann_context& ctx, ggml_tensor* mask,
ggml_cann_pool_alloc bias_allocator(
ctx.pool(), ggml_nelements(dst) * ggml_element_size(dst));
bias_buffer = bias_allocator.get();
aclnn_get_slope(ctx, n_heads, slope_buffer, max_bias);
aclnn_get_slope(ctx, n_heads, slope_buffer, max_bias, GGML_TYPE_F32);
}
// broadcast for mask, slop and dst;
@ -1760,10 +1775,10 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
case GGML_TYPE_F16: {
aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
ggml_cann_pool_alloc src_buffer_allocator(
ctx.pool(), ggml_nelements(src0) * sizeof(float_t));
ctx.pool(), ggml_nelements(src0) * sizeof(float));
void* src_trans_buffer = src_buffer_allocator.get();
size_t src_trans_nb[GGML_MAX_DIMS];
src_trans_nb[0] = sizeof(float_t);
src_trans_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
}
@ -1807,14 +1822,14 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
// [3,4,5,64] -> [3,4,5,2,32]
dequant_ne = weight_ne;
dequant_nb[0] = sizeof(float_t);
dequant_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS + 1; i++) {
dequant_nb[i] = dequant_nb[i - 1] * dequant_ne[i - 1];
}
scale_offset = ggml_nelements(src0) * sizeof(int8_t);
ggml_cann_pool_alloc dequant_buffer_allocator(
ctx.pool(), ggml_nelements(src0) * sizeof(float_t));
ctx.pool(), ggml_nelements(src0) * sizeof(float));
aclTensor* acl_weight_tensor = ggml_cann_create_tensor(
src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb,
@ -1823,11 +1838,11 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb,
GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset);
aclTensor* dequant_tensor = ggml_cann_create_tensor(
dequant_buffer_allocator.get(), ACL_FLOAT, sizeof(float_t),
dequant_buffer_allocator.get(), ACL_FLOAT, sizeof(float),
dequant_ne, dequant_nb, GGML_MAX_DIMS + 1);
aclnn_mul(ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor);
dequant_nb[0] = sizeof(float_t);
dequant_nb[0] = sizeof(float);
dequant_ne = src0->ne;
for (int i = 1; i < GGML_MAX_DIMS; i++) {
dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1];
@ -1948,7 +1963,7 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx,
aclTensor* acl_weight_tensor;
// Only check env once.
static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or(""));
static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
if (weight_to_nz && is_matmul_weight(weight)) {
int64_t acl_stride[2] = {1, transpose_ne[1]};
@ -2263,6 +2278,7 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
*/
static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
void* sin_tensor_buffer, void* cos_tensor_buffer,
float* corr_dims, float ext_factor,
float theta_scale, float freq_scale,
float attn_factor, bool is_neox) {
// int sin/cos cache, cache has different repeat method depond on
@ -2274,8 +2290,8 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
int64_t theta_scale_length = src0->ne[0] / 2;
int64_t theta_scale_ne[] = {theta_scale_length, 1, 1, 1};
size_t theta_scale_nb[] = {sizeof(float_t), sizeof(float_t), sizeof(float_t),
theta_scale_length * sizeof(float_t)};
size_t theta_scale_nb[] = {sizeof(float), sizeof(float), sizeof(float),
theta_scale_length * sizeof(float)};
GGML_ASSERT(src1->type == GGML_TYPE_I32);
int64_t position_length = src1->ne[0];
@ -2285,7 +2301,7 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
int64_t theta_ne[] = {theta_scale_length, 1, position_length, 1};
size_t theta_nb[GGML_MAX_DIMS];
theta_nb[0] = sizeof(float_t);
theta_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1];
}
@ -2306,10 +2322,10 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
if (ctx.rope_cache.theta_scale_cache != nullptr) {
ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache));
}
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
acl_theta_scale_tensor =
ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float_t),
ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
float start = 0;
@ -2318,33 +2334,77 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
float n_elements = theta_scale_length;
aclnn_arange(ctx, acl_theta_scale_tensor, start, stop, step, n_elements);
ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool());
aclTensor* acl_yarn_ramp_tensor = nullptr;
if (ext_factor != 0) {
// -rope_yarn_ramp
// const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
// return MIN(1, MAX(0, y)) - 1;
yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
void* yarn_ramp_buffer = yarn_ramp_allocator.get();
acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float_t),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
float zero_value = 0, one_value = 1;
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
aclScalar* low = aclCreateScalar(&corr_dims[0], aclDataType::ACL_FLOAT);
aclScalar* zero = aclCreateScalar(&zero_value, aclDataType::ACL_FLOAT);
aclScalar* one = aclCreateScalar(&one_value, aclDataType::ACL_FLOAT);
aclScalar* denom_safe = aclCreateScalar(&denom_safe_value, aclDataType::ACL_FLOAT);
aclScalar* ext_factor_sc = aclCreateScalar(&ext_factor, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, Subs, acl_theta_scale_tensor, low, one, acl_yarn_ramp_tensor);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDivs, acl_yarn_ramp_tensor, denom_safe);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceThreshold, acl_yarn_ramp_tensor, zero, zero);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceClampMax, acl_yarn_ramp_tensor, one);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor, one, one);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor, ext_factor_sc);
// theta_interp = freq_scale * theta_extrap;
// theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
// theta = freq_scale * theta_extrap * (1 - ramp_mix) + theta_extrap * ramp_mix;
// theta = freq_scale * theta_extrap - freq_scale * theta_extrap * ramp_mix + theta_extrap * ramp_mix;
// theta = theta_extrap * (freq_scale - freq_scale * ramp_mix + ramp_mix);
//
// we cache (freq_scale - freq_scale * ramp_mix + ramp_mix), Considering that the rope_yarn_ramp here is the inverse
// cache freq_scale + (freq_scale - 1) * ramp_mix
float freq_scale_1 = freq_scale - 1;
aclScalar* freq_scale_sc = aclCreateScalar(&freq_scale, aclDataType::ACL_FLOAT);
aclScalar* freq_scale_1_sc = aclCreateScalar(&freq_scale_1, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor, freq_scale_1_sc);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor, freq_scale_sc, one);
ggml_cann_release_resources(ctx, low, zero, one, denom_safe, ext_factor_sc, freq_scale_sc, freq_scale_1_sc);
}
// power
aclScalar* acl_theta_scale = aclCreateScalar(&theta_scale, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale, acl_theta_scale_tensor,
acl_theta_scale_tensor);
// freq_scale
if (freq_scale != 1) {
if (ext_factor != 0) {
aclnn_mul(ctx, acl_theta_scale_tensor, acl_yarn_ramp_tensor);
} else if (freq_scale != 1) {
aclnn_muls(ctx, acl_theta_scale_tensor, freq_scale, nullptr, true);
}
ggml_cann_release_resources(ctx, acl_theta_scale);
ggml_cann_release_resources(ctx, acl_yarn_ramp_tensor, acl_theta_scale);
} else {
// use cache
acl_theta_scale_tensor =
ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float_t),
ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
}
ggml_cann_pool_alloc freq_fac_res_allocator(ctx.pool());
// freq_factors
if (src2) {
freq_fac_res_allocator.alloc(theta_scale_length * sizeof(float_t));
freq_fac_res_allocator.alloc(theta_scale_length * sizeof(float));
void* freq_fac_res_ptr = freq_fac_res_allocator.get();
aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor(
src2->data, ggml_cann_type_mapping(src2->type),
ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
aclTensor* acl_freq_fac_res_tensor = ggml_cann_create_tensor(
freq_fac_res_ptr, ACL_FLOAT, sizeof(float_t),
freq_fac_res_ptr, ACL_FLOAT, sizeof(float),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor, acl_freq_fac_res_tensor);
std::swap(acl_theta_scale_tensor, acl_freq_fac_res_tensor);
@ -2359,32 +2419,36 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
// power * position
int64_t theta_length = theta_scale_length * position_length;
ggml_cann_pool_alloc theta_allocator(ctx.pool(),
theta_length * sizeof(float_t));
theta_length * sizeof(float));
void* theta_buffer = theta_allocator.get();
aclTensor* acl_theta_tensor =
ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float_t),
ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float),
theta_ne, theta_nb, GGML_MAX_DIMS);
aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor,
acl_theta_tensor);
// sin/cos
ggml_cann_pool_alloc sin_allocator(ctx.pool(),
theta_length * sizeof(float_t));
theta_length * sizeof(float));
void* sin_buffer = sin_allocator.get();
aclTensor* acl_sin_tensor = ggml_cann_create_tensor(
sin_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
sin_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb,
GGML_MAX_DIMS, ACL_FORMAT_ND);
aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor);
ggml_cann_pool_alloc cos_allocator(ctx.pool(),
theta_length * sizeof(float_t));
theta_length * sizeof(float));
void* cos_buffer = cos_allocator.get();
aclTensor* acl_cos_tensor = ggml_cann_create_tensor(
cos_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb,
cos_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb,
GGML_MAX_DIMS, ACL_FORMAT_ND);
aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor);
if (ext_factor != 0) {
attn_factor *= 1.0f + 0.1f * logf(1.0f / freq_scale);
}
// attn_factor
if (attn_factor != 1) {
aclnn_muls(ctx, acl_sin_tensor, attn_factor, nullptr, true);
@ -2393,15 +2457,15 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
int64_t sin_reshape_ne[4] = {src0->ne[0], 1, src0->ne[2], 1};
size_t sin_reshape_nb[GGML_MAX_DIMS];
sin_reshape_nb[0] = sizeof(float_t);
sin_reshape_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
}
aclTensor* acl_sin_repeat_tensor =
ggml_cann_create_tensor(sin_tensor_buffer, ACL_FLOAT, sizeof(float_t),
ggml_cann_create_tensor(sin_tensor_buffer, ACL_FLOAT, sizeof(float),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
aclTensor* acl_cos_repeat_tensor =
ggml_cann_create_tensor(cos_tensor_buffer, ACL_FLOAT, sizeof(float_t),
ggml_cann_create_tensor(cos_tensor_buffer, ACL_FLOAT, sizeof(float),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
// repeat
@ -2465,8 +2529,6 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
// TODO: n_dims <= ne0
GGML_ASSERT(n_dims == ne0);
GGML_ASSERT(n_dims % 2 == 0);
// TODO: ext_factor != 0
GGML_ASSERT(ext_factor == 0);
const float theta_scale = powf(freq_base, -2.0f / n_dims);
@ -2484,20 +2546,20 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
void *cos_tensor_buffer = cos_tensor_allocator.get();
// init ctx.rope_cos/rope_sin cache
aclnn_cache_init(ctx, dst, sin_tensor_buffer, cos_tensor_buffer,
aclnn_cache_init(ctx, dst, sin_tensor_buffer, cos_tensor_buffer, corr_dims, ext_factor,
theta_scale, freq_scale, attn_factor, is_neox);
int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
size_t sin_reshape_nb[GGML_MAX_DIMS];
sin_reshape_nb[0] = sizeof(float_t);
sin_reshape_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
}
aclTensor* acl_sin_reshape_tensor =
ggml_cann_create_tensor(sin_tensor_buffer, ACL_FLOAT, sizeof(float_t),
ggml_cann_create_tensor(sin_tensor_buffer, ACL_FLOAT, sizeof(float),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
aclTensor* acl_cos_reshape_tensor =
ggml_cann_create_tensor(cos_tensor_buffer, ACL_FLOAT, sizeof(float_t),
ggml_cann_create_tensor(cos_tensor_buffer, ACL_FLOAT, sizeof(float),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
aclTensor* acl_src = ggml_cann_create_tensor(src0);
@ -2512,7 +2574,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
void* minus_one_scale_buffer = nullptr;
ggml_cann_pool_alloc roll_allocator(ctx.pool(), ggml_nbytes(src0));
ggml_cann_pool_alloc minus_one_scale_allocator(
ctx.pool(), sizeof(float_t) * src0->ne[0]);
ctx.pool(), sizeof(float) * src0->ne[0]);
if (!is_neox) {
// roll input: [q0,q1,q2,q3,...] -> [q1,q0,q3,q2,...]
input_roll_buffer = roll_allocator.get();
@ -2542,13 +2604,13 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1};
size_t minus_one_nb[GGML_MAX_DIMS];
minus_one_nb[0] = sizeof(float_t);
minus_one_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1];
}
acl_minus_one_tensor = aclnn_values(
ctx, minus_one_scale_buffer, sizeof(float_t) * src0->ne[0],
minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), 1);
ctx, minus_one_scale_buffer, sizeof(float) * src0->ne[0],
minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float), 1);
int64_t dim = 3;
int64_t* index = new int64_t[src0->ne[0]];
for (int i = 0; i < src0->ne[0]; i++) {
@ -2576,22 +2638,22 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
minus_one_scale_buffer = minus_one_scale_allocator.get();
int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1};
size_t minus_one_nb[GGML_MAX_DIMS];
minus_one_nb[0] = sizeof(float_t);
minus_one_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1];
}
acl_minus_one_tensor = aclnn_values(
ctx, minus_one_scale_buffer, sizeof(float_t) * src0->ne[0],
minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), 1);
ctx, minus_one_scale_buffer, sizeof(float) * src0->ne[0],
minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float), 1);
// -1 * first half
int64_t first_half_ne[4] = {src0->ne[0] / 2, 1, 1, 1};
size_t first_half_nb[GGML_MAX_DIMS];
first_half_nb[0] = sizeof(float_t);
first_half_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
first_half_nb[i] = first_half_nb[i - 1] * first_half_ne[i - 1];
}
aclTensor* acl_first_half_tensor = ggml_cann_create_tensor(
minus_one_scale_buffer, ACL_FLOAT, sizeof(float_t), first_half_ne,
minus_one_scale_buffer, ACL_FLOAT, sizeof(float), first_half_ne,
first_half_nb, GGML_MAX_DIMS);
bool inplace = true;
float scale = -1;
@ -2631,28 +2693,28 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
// TODO: ne0 != n_dims in mode2
} else if (src0->type == GGML_TYPE_F16) {
size_t input_fp32_nb[GGML_MAX_DIMS];
input_fp32_nb[0] = sizeof(float_t);
input_fp32_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
input_fp32_nb[i] = input_fp32_nb[i - 1] * dst->ne[i - 1];
}
ggml_cann_pool_alloc fp32_allocator1(
ctx.pool(), ggml_nelements(dst) * sizeof(float_t));
ctx.pool(), ggml_nelements(dst) * sizeof(float));
void* input_fp32_buffer1 = fp32_allocator1.get();
aclTensor* input_fp32_tensor1 = ggml_cann_create_tensor(
input_fp32_buffer1, ACL_FLOAT, sizeof(float_t), dst->ne,
input_fp32_buffer1, ACL_FLOAT, sizeof(float), dst->ne,
input_fp32_nb, GGML_MAX_DIMS);
ggml_cann_pool_alloc fp32_allocator2(
ctx.pool(), ggml_nelements(dst) * sizeof(float_t));
ctx.pool(), ggml_nelements(dst) * sizeof(float));
void* input_fp32_buffer2 = fp32_allocator2.get();
aclTensor* input_fp32_tensor2 = ggml_cann_create_tensor(
input_fp32_buffer2, ACL_FLOAT, sizeof(float_t), dst->ne,
input_fp32_buffer2, ACL_FLOAT, sizeof(float), dst->ne,
input_fp32_nb, GGML_MAX_DIMS);
ggml_cann_pool_alloc fp32_allocator(
ctx.pool(), ggml_nelements(dst) * sizeof(float_t));
ctx.pool(), ggml_nelements(dst) * sizeof(float));
output_fp32_buffer = fp32_allocator.get();
aclTensor* output_fp32_tensor = ggml_cann_create_tensor(
output_fp32_buffer, ACL_FLOAT, sizeof(float_t), dst->ne,
output_fp32_buffer, ACL_FLOAT, sizeof(float), dst->ne,
input_fp32_nb, GGML_MAX_DIMS);
aclnn_mul(ctx, acl_src, acl_cos_reshape_tensor, input_fp32_tensor1);
aclnn_mul(ctx, acl_input_roll_mul_scale_tensor, acl_sin_reshape_tensor,
@ -2749,8 +2811,6 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
aclIntArray *padding = aclCreateIntArray(paddingVal, 1);
int64_t dilationVal[] = {1};
aclIntArray *dilation = aclCreateIntArray(dilationVal, 1);
bool transposed = true;
int64_t groups = 1;
int8_t cubeMathType = 0;
#ifdef ASCEND_310P
@ -2758,7 +2818,7 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
#endif
GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input, acl_weight, nullptr, stride,
padding, dilation, transposed, padding, groups, acl_dst, cubeMathType);
padding, dilation, true, padding, 1, acl_dst, cubeMathType);
ggml_cann_release_resources(ctx, acl_weight, acl_dst, stride, padding, dilation);
}
@ -3220,7 +3280,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
const int64_t n_heads = src0->ne[2];
ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(uint16_t));
void* slope_buffer = slope_allocator.get();
aclnn_get_slope(ctx, n_heads, slope_buffer, maxBias);
aclnn_get_slope(ctx, n_heads, slope_buffer, maxBias, GGML_TYPE_F16);
int64_t slope_ne[] = {1, 1, n_heads, 1};
size_t slope_nb[GGML_MAX_DIMS];

View File

@ -395,6 +395,7 @@ struct ggml_backend_cann_context {
#ifdef USE_ACL_GRAPH
/// Cached CANN ACL graph used for executing the current ggml computation graph.
std::unique_ptr<ggml_cann_graph> cann_graph;
bool acl_graph_mode = true;
#endif
cann_task_queue task_queue;
bool async_mode;
@ -404,7 +405,6 @@ struct ggml_backend_cann_context {
ggml_cann_tensor_cache rms_norm_one_tensor_cache;
ggml_cann_tensor_cache rms_norm_zero_tensor_cache;
aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */
/**
@ -419,6 +419,13 @@ struct ggml_backend_cann_context {
async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or(""));
GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__,
device, async_mode ? "ON" : "OFF");
#ifdef USE_ACL_GRAPH
acl_graph_mode = parse_bool(get_env("GGML_CANN_ACL_GRAPH").value_or("on"));
GGML_LOG_INFO("%s: device %d execution mode is %s (%s)\n",
__func__, device,
acl_graph_mode ? "GRAPH" : "EAGER",
acl_graph_mode ? "acl graph enabled" : "acl graph disabled");
#endif
}
/**

View File

@ -1116,30 +1116,65 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor(
return GGML_STATUS_SUCCESS;
}
// ND to NZ Workspace Cache Management. Thread-safety: Not guaranteed
namespace {
void* g_nz_workspace = nullptr;
size_t g_nz_workspace_allocated = 0;
/**
* @brief Workspace for caching NZ buffers per device.
*
* This struct manages a device buffer used in NZ computations. It supports
* allocation, reallocation, and clearing of cached memory. The struct is
* designed to be used with a global array, one per device.
*/
struct ggml_cann_nz_workspace {
void* ptr; // Pointer to allocated device buffer
size_t allocated; // Size of currently allocated buffer in bytes
void release_nz_workspace() {
if (g_nz_workspace) {
aclrtFree(g_nz_workspace);
g_nz_workspace = nullptr;
g_nz_workspace_allocated = 0;
/**
* @brief Constructor. Initializes the workspace with no allocated memory.
*/
ggml_cann_nz_workspace() : ptr(nullptr), allocated(0) {}
/**
* @brief Free cached memory and reset the workspace.
*
* If a buffer has been allocated, this function releases it using
* aclrtFree and resets internal state.
*/
void clear() {
if (ptr) {
ACL_CHECK(aclrtFree(ptr));
ptr = nullptr;
allocated = 0;
}
}
void relloc_nz_workspace(size_t new_size) {
if (new_size > g_nz_workspace_allocated) {
if (g_nz_workspace) {
aclrtFree(g_nz_workspace);
g_nz_workspace = nullptr;
}
ACL_CHECK(aclrtMalloc(&g_nz_workspace, new_size, ACL_MEM_MALLOC_HUGE_FIRST));
g_nz_workspace_allocated = new_size;
/**
* @brief Allocate or reallocate the workspace buffer.
*
* If the requested size is larger than the currently allocated size,
* the old buffer will be freed and a new buffer of the requested size
* will be allocated on the device.
*
* @param new_size Size in bytes to allocate for the workspace.
*/
void realloc(size_t new_size) {
if (new_size > allocated) {
clear();
ACL_CHECK(aclrtMalloc(&ptr, new_size, ACL_MEM_MALLOC_HUGE_FIRST));
allocated = new_size;
}
}
}
/**
* @brief Get the device buffer pointer.
*
* @return Pointer to the allocated buffer, or nullptr if not allocated.
*/
void* get() const { return ptr; }
};
/**
* @brief Global array of NZ workspaces, one per device.
*/
static ggml_cann_nz_workspace g_nz_workspaces[GGML_CANN_MAX_DEVICES];
/**
* @brief Convert tensor weights to NZ format using Ascend CANN API.
@ -1149,13 +1184,13 @@ namespace {
* improve performance on certain hardware.
*
* @param tensor Pointer to the input ggml_tensor containing the weights.
* @param data Pointer to the raw data buffer for the tensor weights.
* @param offset Byte offset within the tensor data buffer where weights start.
* @param device device id.
*
* @note The workspace buffer used in this function is managed globally and reused
* across calls. This reduces overhead from repeated memory allocation and deallocation.
*/
static void weight_format_to_nz(ggml_tensor *tensor, size_t offset) {
static void weight_format_to_nz(ggml_tensor *tensor, size_t offset, int device) {
aclTensor* weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne,
tensor->nb, 2, ACL_FORMAT_ND, offset);
uint64_t workspaceSize = 0;
@ -1165,7 +1200,9 @@ static void weight_format_to_nz(ggml_tensor *tensor, size_t offset) {
ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed,
&workspaceSize, &executor));
// Avoid frequent malloc/free of the workspace.
relloc_nz_workspace(workspaceSize);
g_nz_workspaces[device].realloc(workspaceSize);
void* g_nz_workspace = g_nz_workspaces[device].get();
ACL_CHECK(aclnnTransMatmulWeight(g_nz_workspace, workspaceSize, executor, nullptr));
ACL_CHECK(aclDestroyTensor(weightTransposed));
@ -1196,14 +1233,14 @@ static void ggml_backend_cann_buffer_set_tensor(
// Why aclrtSynchronizeDevice?
// Only check env once.
static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or(""));
static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
if (!need_transform(tensor->type)) {
ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size, data, size,
ACL_MEMCPY_HOST_TO_DEVICE));
if (weight_to_nz && is_matmul_weight((const ggml_tensor*)tensor)) {
GGML_ASSERT(tensor->ne[2] == 1);
GGML_ASSERT(tensor->ne[3] == 1);
weight_format_to_nz(tensor, offset);
weight_format_to_nz(tensor, offset, ctx->device);
}
} else {
void *transform_buffer = malloc(size);
@ -1279,6 +1316,10 @@ static bool ggml_backend_cann_buffer_cpy_tensor(
ACL_MEMCPY_DEVICE_TO_DEVICE));
return true;
} else {
#ifdef ASCEND_310P
// TODO: Support 310p P2P copy
return false;
#endif
// Different device but can access by peer.
int32_t canAccessPeer = 0;
ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, src_ctx->device,
@ -1439,7 +1480,7 @@ static size_t ggml_backend_cann_buffer_type_get_alloc_size(
int64_t ne0 = tensor->ne[0];
// Only check env once.
static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or(""));
static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
// last line must bigger than 32, because every single op deal at
// least 32 bytes.
@ -2000,6 +2041,8 @@ static bool ggml_backend_cann_cpy_tensor_async(
GGML_ASSERT(ggml_backend_is_cann(backend_src) ||
ggml_backend_is_cann(backend_dst));
GGML_ASSERT(!is_matmul_weight((const ggml_tensor*)src));
if (!ggml_backend_buffer_is_cann(src->buffer) ||
!ggml_backend_buffer_is_cann(dst->buffer)) {
return false;
@ -2020,6 +2063,10 @@ static bool ggml_backend_cann_cpy_tensor_async(
return true;
}
if (backend_src != backend_dst) {
#ifdef ASCEND_310P
// TODO: Support 310p P2P copy
return false;
#endif
ggml_backend_cann_buffer_context* buf_ctx_src =
(ggml_backend_cann_buffer_context*)buf_src->context;
ggml_backend_cann_buffer_context* buf_ctx_dst =
@ -2036,7 +2083,6 @@ static bool ggml_backend_cann_cpy_tensor_async(
}
// need open both directions for memcpyasync between devices.
ggml_cann_set_device(cann_ctx_dst->device);
ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_src->device, 0));
ggml_cann_set_device(cann_ctx_src->device);
ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
@ -2046,9 +2092,17 @@ static bool ggml_backend_cann_cpy_tensor_async(
ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
ACL_MEMCPY_DEVICE_TO_DEVICE,
cann_ctx_src->stream()));
// record event on src stream after the copy
// TODO: this event is not effective with acl graph mode, change to use aclrtSynchronizeStream
// if (!cann_ctx_src->copy_event) {
// ACL_CHECK(aclrtCreateEventWithFlag(&cann_ctx_src->copy_event, ACL_EVENT_SYNC));
// }
// ACL_CHECK(aclrtRecordEvent(cann_ctx_src->copy_event, cann_ctx_src->stream()));
//TODO: workaround for Event didn`t work here.
aclrtSynchronizeStream(cann_ctx_src->stream());
// // wait on dst stream for the copy to complete
// ggml_cann_set_device(cann_ctx_dst->device);
// ACL_CHECK(aclrtStreamWaitEvent(cann_ctx_dst->stream(), cann_ctx_src->copy_event));
ACL_CHECK(aclrtSynchronizeStream(cann_ctx_src->stream()));
} else {
// src and dst are on the same backend
ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
@ -2246,12 +2300,16 @@ static enum ggml_status ggml_backend_cann_graph_compute(
ggml_backend_cann_context* cann_ctx =
(ggml_backend_cann_context*)backend->context;
ggml_cann_set_device(cann_ctx->device);
release_nz_workspace();
g_nz_workspaces[cann_ctx->device].clear();
#ifdef USE_ACL_GRAPH
bool use_cann_graph = true;
bool cann_graph_update_required = false;
if (!cann_ctx->acl_graph_mode) {
use_cann_graph = false;
}
if (use_cann_graph) {
if (cann_ctx->cann_graph == nullptr) {
cann_ctx->cann_graph.reset(new ggml_cann_graph());
@ -2401,16 +2459,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
}
case GGML_OP_ROPE: {
// TODO: with ops-test v == 1
float ext_factor = 0.0f;
memcpy(&ext_factor, (const float *) op->op_params + 7, sizeof(float));
// TODO: n_dims <= ne0
if (op->src[0]->ne[0] != op->op_params[1]) {
return false;
}
// TODO: ext_factor != 0
if (ext_factor != 0) {
return false;
}
const int mode = ((const int32_t *) op->op_params)[2];
if (mode & GGML_ROPE_TYPE_MROPE) {
@ -2419,10 +2471,11 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
if (mode & GGML_ROPE_TYPE_VISION) {
return false;
}
#ifdef ASCEND_310P
if(!ggml_is_contiguous(op->src[0])){
return false;
}
#endif
return true;
}
case GGML_OP_UPSCALE: {
@ -2484,12 +2537,14 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
case GGML_OP_ARGMAX:
case GGML_OP_COS:
case GGML_OP_SIN:
case GGML_OP_CONV_TRANSPOSE_1D:
case GGML_OP_LOG:
case GGML_OP_MEAN:
case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_COUNT_EQUAL:
return true;
case GGML_OP_CONV_TRANSPOSE_1D:
// TODO: ((weightL - 1) * dilationW - padLeft)=1336 should not be larger than 255.
return (op->src[0]->ne[0] - 1) <= 255;
case GGML_OP_SCALE:
float bias;
memcpy(&bias, (const float *)(op->op_params) + 1, sizeof(float));
@ -2523,13 +2578,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
// different head sizes of K and V are not supported yet
return false;
}
if (op->src[0]->ne[0] == 192) {
return false;
}
if (op->src[0]->ne[0] == 576) {
// DeepSeek MLA
return false;
}
if (op->src[0]->ne[0] % 16 != 0) {
// TODO: padding to support
return false;
@ -2642,6 +2690,7 @@ static const ggml_backend_i ggml_backend_cann_interface = {
/* .graph_compute = */ ggml_backend_cann_graph_compute,
/* .event_record = */ ggml_backend_cann_event_record,
/* .event_wait = */ ggml_backend_cann_event_wait,
/* .optimize_graph = */ NULL,
};
/**

View File

@ -433,15 +433,22 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
ggml-cpu/arch/riscv/quants.c
ggml-cpu/arch/riscv/repack.cpp
)
if (GGML_RVV)
set(MARCH_STR "rv64gc")
if (GGML_RV_ZFH)
string(APPEND MARCH_STR "_zfh")
endif()
if (GGML_XTHEADVECTOR)
list(APPEND ARCH_FLAGS -march=rv64gc_zfhmin_xtheadvector -mabi=lp64d)
elseif (GGML_RV_ZFH)
list(APPEND ARCH_FLAGS -march=rv64gcv_zfhmin -mabi=lp64d)
else()
list(APPEND ARCH_FLAGS -march=rv64gcv -mabi=lp64d)
string(APPEND MARCH_STR "_xtheadvector")
elseif (GGML_RVV)
string(APPEND MARCH_STR "_v")
if (GGML_RV_ZVFH)
string(APPEND MARCH_STR "_zvfh")
endif()
endif()
if (GGML_RV_ZICBOP)
string(APPEND MARCH_STR "_zicbop")
endif()
list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d)
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
message(STATUS "s390x detected")
list(APPEND GGML_CPU_SOURCES ggml-cpu/arch/s390/quants.c)
@ -450,7 +457,6 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
# TODO: Separation to determine activation of VX/VXE/VXE2
if (${S390X_M} MATCHES "8561|8562")
set(GGML_NNPA OFF)
message(STATUS "z15 target")
list(APPEND ARCH_FLAGS -march=z15)
elseif (${S390X_M} MATCHES "3931")
@ -472,11 +478,6 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
list(APPEND ARCH_FLAGS -mvx -mzvector)
list(APPEND ARCH_DEFINITIONS GGML_VXE)
endif()
if (GGML_NNPA)
message(STATUS "NNPA enabled")
list(APPEND ARCH_DEFINITIONS GGML_NNPA)
endif()
elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "wasm")
message(STATUS "Wasm detected")
list (APPEND GGML_CPU_SOURCES ggml-cpu/arch/wasm/quants.c)

View File

@ -1270,29 +1270,40 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
int tmp, tmp2, sumi;
float ftmp, ft2;
const uint8_t * restrict q40;
const uint8_t * restrict q41;
const uint8_t * restrict q42;
const uint8_t * restrict q43;
const int8_t * restrict q80;
const int8_t * restrict q81;
const int8_t * restrict q82;
const int8_t * restrict q83;
int s0, s1, s2, s3;
__asm__ __volatile__(
"vsetivli zero, 12, e8, m1\n\t"
"vle8.v v1, (%[s6b])\n\t" // {aux[0], aux[1], aux[2]}
"vsetivli zero, 4, e32, m1\n\t"
"li %[s1], 8\n\t"
"vsetivli zero, 4, e32, m1, ta, ma\n\t"
"vle32.v v1, (%[s6b])\n\t"
"vslide1down.vx v1, v1, zero\n\t"
"vmv.v.x v16, zero\n\t"
"vslidedown.vi v2, v1, 2\n\t"
"vmv1r.v v3, v2\n\t"
"vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]}
"vsetivli zero, 2, e32, m1\n\t"
"vsetivli zero, 2, e32, m1, ta, ma\n\t"
"vmv.v.i v4, 4\n\t"
"vand.vx v8, v1, %[kmask1]\n\t"
"vslide1up.vx v5, v4, zero\n\t" // {0, 4}
"vsrl.vi v6, v1, 6\n\t"
"vsrl.vv v7, v2, v5\n\t"
"vsse32.v v8, (%[utmp]), %[s1]\n\t"
"vand.vx v0, v6, %[kmask3]\n\t"
"vand.vx v2, v7, %[kmask2]\n\t"
"vsll.vi v6, v0, 4\n\t"
"li %[t2], 8\n\t"
"addi %[t1], %[utmp], 4\n\t"
"addi %[s0], %[utmp], 4\n\t"
"vor.vv v1, v6, v2\n\t"
"vsse32.v v8, (%[utmp]), %[t2]\n\t"
"vsse32.v v1, (%[t1]), %[t2]\n\t"
"vsetivli zero, 8, e16, m1\n\t"
"vsse32.v v1, (%[s0]), %[s1]\n\t"
"vsetivli zero, 8, e16, m1, ta, ma\n\t"
"vle32.v v2, (%[bsums])\n\t"
"vnsrl.wi v0, v2, 0\n\t"
"vnsrl.wi v1, v2, 16\n\t"
@ -1300,13 +1311,131 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
"vle8.v v3, (%[mins])\n\t"
"vzext.vf2 v4, v3\n\t"
"vwmul.vv v6, v4, v2\n\t"
"vsetivli zero, 4, e32, m1, ta, ma\n\t"
"vredsum.vs v0, v6, v16\n\t"
"vredsum.vs v0, v7, v0\n\t"
"vfcvt.f.x.v v0, v0\n\t"
"vfmv.f.s %[ftmp], v0\n\t"
"vsetivli zero, 16, e8, m1, ta, ma\n\t"
"vle8.v v0, (%[xs])\n\t"
"fnmsub.s %[sumf], %[dmin], %[ftmp], %[sumf]\n\t"
"addi %[q40], %[xs], 64\n\t"
"addi %[q41], %[xs], 16\n\t"
"addi %[q42], %[xs], 32\n\t"
"addi %[q43], %[xs], 48\n\t"
"addi %[q80], %[ys], 64\n\t"
"vle8.v v1, (%[q41])\n\t"
"vle8.v v2, (%[q42])\n\t"
"addi %[q81], %[ys], 16\n\t"
"addi %[q41], %[q41], 64\n\t"
"addi %[q82], %[ys], 32\n\t"
"vle8.v v3, (%[q43])\n\t"
"vle8.v v8, (%[ys])\n\t"
"addi %[q42], %[q42], 64\n\t"
"addi %[q83], %[ys], 48\n\t"
"addi %[q43], %[q43], 64\n\t"
"vsrl.vi v4, v0, 4\n\t"
"vle8.v v9, (%[q81])\n\t"
"vle8.v v10, (%[q82])\n\t"
"vand.vi v0, v0, 0xF\n\t"
"addi %[q81], %[q81], 64\n\t"
"vsrl.vi v5, v1, 4\n\t"
"addi %[q82], %[q82], 64\n\t"
"vle8.v v11, (%[q83])\n\t"
"vle8.v v12, (%[q80])\n\t"
"vand.vi v1, v1, 0xF\n\t"
"addi %[q83], %[q83], 64\n\t"
"vsrl.vi v6, v2, 4\n\t"
"addi %[q80], %[q80], 64\n\t"
"vle8.v v13, (%[q81])\n\t"
"vle8.v v14, (%[q82])\n\t"
"vand.vi v2, v2, 0xF\n\t"
"addi %[q81], %[q81], 64\n\t"
"vsrl.vi v7, v3, 4\n\t"
"addi %[q82], %[q82], 64\n\t"
"vwmul.vv v16, v0, v8\n\t"
"vle8.v v15, (%[q83])\n\t"
"vle8.v v0, (%[q40])\n\t"
"vand.vi v3, v3, 0xF\n\t"
"addi %[q83], %[q83], 64\n\t"
"vwmul.vv v24, v2, v12\n\t"
"vwmul.vv v20, v4, v10\n\t"
"vwmul.vv v28, v6, v14\n\t"
"vwmacc.vv v16, v1, v9\n\t"
"vle8.v v1, (%[q41])\n\t"
"vle8.v v2, (%[q42])\n\t"
"vwmacc.vv v24, v3, v13\n\t"
"vwmacc.vv v20, v5, v11\n\t"
"vwmacc.vv v28, v7, v15\n\t"
"addi %[q40], %[q80], 64\n\t"
"addi %[q41], %[q81], 64\n\t"
"vle8.v v3, (%[q43])\n\t"
"vle8.v v8, (%[q80])\n\t"
"addi %[q42], %[q82], 64\n\t"
"addi %[q43], %[q83], 64\n\t"
"vsrl.vi v4, v0, 4\n\t"
"vle8.v v9, (%[q81])\n\t"
"vle8.v v10, (%[q82])\n\t"
"vand.vi v0, v0, 0xF\n\t"
"vsrl.vi v5, v1, 4\n\t"
"vsrl.vi v7, v3, 4\n\t"
"vand.vi v3, v3, 0xF\n\t"
"vle8.v v11, (%[q83])\n\t"
"vle8.v v12, (%[q40])\n\t"
"vand.vi v1, v1, 0xF\n\t"
"vsrl.vi v6, v2, 4\n\t"
"vand.vi v2, v2, 0xF\n\t"
"vwmul.vv v18, v0, v8\n\t"
"vle8.v v13, (%[q41])\n\t"
"vle8.v v14, (%[q42])\n\t"
"vwmul.vv v26, v2, v12\n\t"
"vwmul.vv v22, v4, v10\n\t"
"vwmul.vv v30, v6, v14\n\t"
"vwmacc.vv v18, v1, v9\n\t"
"vle8.v v15, (%[q43])\n\t"
"vwmacc.vv v26, v3, v13\n\t"
"vwmacc.vv v22, v5, v11\n\t"
"vwmacc.vv v30, v7, v15\n\t"
"vmv.v.x v0, zero\n\t"
"vsetivli zero, 8, e32, m2\n\t"
"vredsum.vs v0, v6, v0\n\t"
"vmv.x.s %[sumi], v0"
: [t1] "=&r" (tmp), [t2] "=&r" (tmp2), [sumi] "=&r" (sumi)
: [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp)
, [s6b] "r" (x[i].scales), [kmask1] "r" (kmask1)
"vsetivli zero, 16, e16, m2, ta, ma\n\t"
"vwredsum.vs v4, v16, v0\n\t"
"lbu %[s0], 0(%[scale])\n\t"
"vwredsum.vs v5, v20, v0\n\t"
"lbu %[s1], 1(%[scale])\n\t"
"vwredsum.vs v6, v24, v0\n\t"
"lbu %[s2], 2(%[scale])\n\t"
"vwredsum.vs v7, v28, v0\n\t"
"lbu %[s3], 3(%[scale])\n\t"
"vwredsum.vs v8, v18, v0\n\t"
"lbu %[q40], 4(%[scale])\n\t"
"vwredsum.vs v9, v22, v0\n\t"
"lbu %[q41], 5(%[scale])\n\t"
"vwredsum.vs v10, v26, v0\n\t"
"lbu %[q42], 6(%[scale])\n\t"
"vwredsum.vs v11, v30, v0\n\t"
"lbu %[q43], 7(%[scale])\n\t"
"vsetivli zero, 4, e32, m1, ta, ma\n\t"
"vmul.vx v0, v4, %[s0]\n\t"
"vmul.vx v1, v8, %[q40]\n\t"
"vmacc.vx v0, %[s1], v5\n\t"
"vmacc.vx v1, %[q41], v9\n\t"
"vmacc.vx v0, %[s2], v6\n\t"
"vmacc.vx v1, %[q42], v10\n\t"
"vmacc.vx v0, %[s3], v7\n\t"
"vmacc.vx v1, %[q43], v11\n\t"
"vfcvt.f.x.v v0, v0\n\t"
"vfcvt.f.x.v v1, v1\n\t"
"vfmv.f.s %[ft2], v0\n\t"
"vfmv.f.s %[ftmp], v1\n\t"
"fadd.s %[ft2], %[ft2], %[ftmp]\n\t"
"fmadd.s %[sumf], %[d], %[ft2], %[sumf]"
: [ftmp] "=&f" (ftmp), [sumf] "+&f" (sumf), [ft2] "=&f" (ft2)
, [s0] "=&r" (s0), [s1] "=&r" (s1), [s2] "=&r" (s2), [s3] "=&r" (s3)
, [q40] "=&r" (q40), [q41] "=&r" (q41), [q42] "=&r" (q42), [q43] "=&r" (q43)
, [q80] "=&r" (q80), [q81] "=&r" (q81), [q82] "=&r" (q82), [q83] "=&r" (q83)
: [d] "f" (d), [ys] "r" (y[i].qs), [xs] "r" (x[i].qs), [scale] "r" (scales)
, [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp)
, [s6b] "r" (&x[i]), [kmask1] "r" (kmask1), [dmin] "f" (dmin)
, [kmask2] "r" (kmask2), [kmask3] "r" (kmask3)
: "memory"
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
@ -1314,59 +1443,6 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
, "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
, "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
);
sumf -= dmin * sumi;
const uint8_t * restrict q4 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
sumi = 0;
const uint8_t * scale = scales;
for (int j = 0; j < QK_K/128; ++j) {
int vl128 = 128, vl64 = 64, vl32 = 32;
__asm__ __volatile__(
"vsetvli zero, %[vl128], e8, m8\n\t"
"vle8.v v8, (%[q8])\n\t"
"vsetvli zero, %[vl64], e8, m4\n\t"
"vle8.v v0, (%[q4])\n\t"
"vsrl.vi v4, v0, 4\n\t"
"vand.vi v0, v0, 0xF\n\t"
"vsetvli zero, %[vl32], e8, m2\n\t"
"vwmul.vv v28, v6, v14\n\t"
"vwmul.vv v20, v4, v10\n\t"
"vwmul.vv v24, v2, v12\n\t"
"vwmul.vv v16, v0, v8\n\t"
"vsetivli zero, 4, e32, m1\n\t"
"vle8.v v2, (%[scale])\n\t"
"vmv.v.x v0, zero\n\t"
"vzext.vf4 v1, v2\n\t"
"vsetvli zero, %[vl32], e16, m4\n\t"
"vwredsum.vs v6, v24, v0\n\t"
"vwredsum.vs v7, v28, v0\n\t"
"vwredsum.vs v4, v16, v0\n\t"
"vwredsum.vs v5, v20, v0\n\t"
"vsetivli zero, 4, e32, m1\n\t"
"vslideup.vi v6, v7, 1\n\t"
"vslideup.vi v4, v5, 1\n\t"
"vslideup.vi v4, v6, 2\n\t"
"vmul.vv v8, v4, v1\n\t"
"vredsum.vs v0, v8, v0\n\t"
"vmv.x.s %[tmp], v0\n\t"
"add %[sumi], %[sumi], %[tmp]"
: [tmp] "=&r" (tmp), [sumi] "+&r" (sumi)
: [vl128] "r" (vl128), [vl64] "r" (vl64), [vl32] "r" (vl32)
, [q4] "r" (q4), [q8] "r" (q8), [scale] "r" (scale)
: "memory"
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
, "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
, "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
, "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
);
q4 += 64; q8 += 128; scale += 4;
}
sumf += d * sumi;
}
break;
default:
@ -1693,6 +1769,8 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
case 128:
for (int i = 0; i < nb; ++i) {
__builtin_prefetch(&x[i + 1].d, 0, 1);
const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
const uint8_t * restrict q6 = x[i].ql;
@ -1701,23 +1779,59 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
const int8_t * restrict scale = x[i].scales;
int sum_t = 0;
int t0;
int q6h;
float ftmp;
for (int j = 0; j < QK_K/128; ++j) {
__asm__ __volatile__(
"addi %[q6h], %[q6], 32\n\t"
"ld t0, 0(%[scale])\n\t"
"addi %[scale], %[scale], 8\n\t"
"slli t6, t0, 1 * 8\n\t"
"lb zero, 0(%[q6])\n\t"
"slli t5, t0, 2 * 8\n\t"
"slli t4, t0, 3 * 8\n\t"
"lb zero, 0(%[q6h])\n\t"
"slli t3, t0, 4 * 8\n\t"
"slli t2, t0, 5 * 8\n\t"
"lb zero, 0(%[qh])\n\t"
"lb zero, 31(%[q6h])\n\t"
"slli t1, t0, 6 * 8\n\t"
"srai a7, t0, 56\n\t"
"vsetvli zero, %[vl32], e8, m2\n\t"
"vle8.v v8, (%[q6])\n\t"
"srai t6, t6, 56\n\t"
"srai t5, t5, 56\n\t"
"srai t4, t4, 56\n\t"
"srai t3, t3, 56\n\t"
"vle8.v v10, (%[q6h])\n\t"
"addi %[q6], %[q6], 64\n\t"
"slli t0, t0, 7 * 8\n\t"
"srai t2, t2, 56\n\t"
"srai t1, t1, 56\n\t"
"srai t0, t0, 56\n\t"
"vle8.v v4, (%[qh])\n\t"
"vsrl.vi v12, v8, 4\n\t"
"vsrl.vi v14, v10, 4\n\t"
"lb zero, 0(%[q8])\n\t"
"vand.vi v8, v8, 0xF\n\t"
"vand.vi v10, v10, 0xF\n\t"
"lb zero, 32(%[q8])\n\t"
"vsll.vi v0, v4, 4\n\t"
"vsll.vi v2, v4, 2\n\t"
"lb zero, 64(%[q8])\n\t"
"vsrl.vi v6, v4, 2\n\t"
"vsetvli zero, %[vl64], e8, m4\n\t"
"vle8.v v8, (%[q6])\n\t"
"vsrl.vi v12, v8, 4\n\t"
"vand.vi v8, v8, 0xF\n\t"
"vsetvli zero, %[vl128], e8, m8\n\t"
"vand.vx v0, v0, %[mask]\n\t"
"lb zero, 96(%[q8])\n\t"
"vand.vx v2, v2, %[mask]\n\t"
"vand.vx v4, v4, %[mask]\n\t"
"vand.vx v6, v6, %[mask]\n\t"
"vor.vv v8, v8, v0\n\t"
"lb zero, 127(%[q8])\n\t"
"vor.vv v10, v10, v2\n\t"
"vor.vv v12, v12, v4\n\t"
"vor.vv v14, v14, v6\n\t"
"vsetvli zero, %[vl128], e8, m8\n\t"
"vle8.v v0, (%[q8])\n\t"
"vsub.vx v8, v8, %[vl32]\n\t"
"vsetvli zero, %[vl64], e8, m4\n\t"
@ -1734,34 +1848,34 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
"vwredsum.vs v13, v28, v0\n\t"
"vwredsum.vs v14, v30, v0\n\t"
"vsetivli zero, 4, e32, m1\n\t"
"vslideup.vi v10, v9, 1\n\t"
"vslideup.vi v8, v7, 1\n\t"
"vslideup.vi v11, v12, 1\n\t"
"vslideup.vi v13, v14, 1\n\t"
"vslideup.vi v10, v8, 2\n\t"
"vslideup.vi v11, v13, 2\n\t"
"vsetivli zero, 8, e32, m2\n\t"
"vle8.v v2, (%[scale])\n\t"
"vsext.vf4 v4, v2\n\t"
"vmul.vv v2, v4, v10\n\t"
"vredsum.vs v0, v2, v0\n\t"
"vmv.x.s %[t0], v0\n\t"
"add %[sumi], %[sumi], %[t0]"
: [sumi] "+&r" (sum_t), [t0] "=&r" (t0)
: [qh] "r" (qh), [q6] "r" (q6), [q8] "r" (q8), [scale] "r" (scale)
"vmul.vx v0, v10, t0\n\t"
"vmul.vx v1, v9, t1\n\t"
"vmacc.vx v0, t2, v8\n\t"
"vmacc.vx v1, t3, v7\n\t"
"vmacc.vx v0, t4, v11\n\t"
"vmacc.vx v1, t5, v12\n\t"
"vmacc.vx v0, t6, v13\n\t"
"vmacc.vx v1, a7, v14\n\t"
"vadd.vv v0, v0, v1\n\t"
"vfcvt.f.x.v v0, v0\n\t"
"vfmv.f.s %[ftmp], v0\n\t"
"fmadd.s %[sumf], %[d], %[ftmp], %[sumf]"
: [q6] "+&r" (q6), [q6h] "=&r" (q6h)
, [scale] "+&r" (scale)
, [sumf] "+&f" (sumf), [ftmp] "=&f" (ftmp)
: [qh] "r" (qh), [q8] "r" (q8)
, [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
, [mask] "r" (0x30)
, [mask] "r" (0x30), [d] "f" (d)
: "memory"
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
, "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
, "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
, "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
, "t0", "t1", "t2", "t3", "t4", "t5", "t6", "a7"
, "a6", "a5", "a4", "a3"
);
q6 += 64; qh += 32; q8 += 128; scale += 8;
qh += 32; q8 += 128;
}
sumf += d * sum_t;
}
break;
default:

View File

@ -53,9 +53,9 @@ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i
#if defined(__VXE__) || defined(__VXE2__)
for (int i = 0; i < nb; i++) {
__vector float srcv [8];
__vector float asrcv[8];
__vector float amaxv[8];
float32x4_t srcv [8];
float32x4_t asrcv[8];
float32x4_t amaxv[8];
for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j);
for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]);
@ -74,8 +74,8 @@ void quantize_row_q8_0(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i
y[i].d = GGML_CPU_FP32_TO_FP16(d);
for (int j = 0; j < 8; j++) {
const __vector float v = vec_mul(srcv[j], vec_splats(id));
const __vector int32_t vi = vec_signed(v);
const float32x4_t v = vec_mul(srcv[j], vec_splats(id));
const int32x4_t vi = vec_signed(v);
y[i].qs[4*j + 0] = vec_extract(vi, 0);
y[i].qs[4*j + 1] = vec_extract(vi, 1);
@ -98,9 +98,9 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i
#if defined(__VXE__) || defined(__VXE2__)
for (int i = 0; i < nb; i++) {
__vector float srcv [8];
__vector float asrcv[8];
__vector float amaxv[8];
float32x4_t srcv [8];
float32x4_t asrcv[8];
float32x4_t amaxv[8];
for (int j = 0; j < 8; j++) srcv[j] = vec_xl(0, x + i*32 + 4*j);
for (int j = 0; j < 8; j++) asrcv[j] = vec_abs(srcv[j]);
@ -118,11 +118,11 @@ void quantize_row_q8_1(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, i
y[i].d = GGML_CPU_FP32_TO_FP16(d);
__vector int32_t acc = vec_splats(0);
int32x4_t acc = vec_splats(0);
for (int j = 0; j < 8; j++) {
const __vector float v = vec_mul(srcv[j], vec_splats(id));
const __vector int32_t vi = vec_signed(v);
const float32x4_t v = vec_mul(srcv[j], vec_splats(id));
const int32x4_t vi = vec_signed(v);
y[i].qs[4*j + 0] = vec_extract(vi, 0);
y[i].qs[4*j + 1] = vec_extract(vi, 1);
@ -162,37 +162,36 @@ void ggml_vec_dot_q4_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
float sumf = 0;
#if defined(__VXE__) || defined(__VXE2__)
__vector float acc = vec_splats(0.0f);
float32x4_t acc = vec_splats(0.0f);
const __vector uint8_t v_m = vec_splats((const uint8_t)0x0F);
const __vector int8_t v_s = vec_splats( (const int8_t)0x08);
const uint8x16_t v_m = vec_splats((const uint8_t)0x0F);
const int8x16_t v_s = vec_splats( (const int8_t)0x08);
for (; ib < nb; ++ib) {
const __vector uint8_t v_x = vec_xl(0, x[ib].qs);
const __vector int8_t v_xl = (const __vector int8_t)(v_x & v_m);
const __vector int8_t v_xh = (const __vector int8_t)(v_x >> 4);
const uint8x16_t v_x = vec_xl(0, x[ib].qs);
const int8x16_t v_xl = (const int8x16_t)(v_x & v_m);
const int8x16_t v_xh = (const int8x16_t)(v_x >> 4);
const __vector int8_t v_xls = vec_sub(v_xl, v_s);
const __vector int8_t v_xhs = vec_sub(v_xh, v_s);
const int8x16_t v_xls = vec_sub(v_xl, v_s);
const int8x16_t v_xhs = vec_sub(v_xh, v_s);
const __vector int8_t v_yl = vec_xl(0 , y[ib].qs);
const __vector int8_t v_yh = vec_xl(QK8_0/2, y[ib].qs);
const int8x16_t v_yl = vec_xl(0 , y[ib].qs);
const int8x16_t v_yh = vec_xl(QK8_0/2, y[ib].qs);
const __vector int16_t v_xylso = vec_mulo(v_xls, v_yl);
const __vector int16_t v_xylse = vec_mule(v_xls, v_yl);
const __vector int16_t v_xyhso = vec_mulo(v_xhs, v_yh);
const __vector int16_t v_xyhse = vec_mule(v_xhs, v_yh);
const int16x8_t v_xylso = vec_mulo(v_xls, v_yl);
const int16x8_t v_xylse = vec_mule(v_xls, v_yl);
const int16x8_t v_xyhso = vec_mulo(v_xhs, v_yh);
const int16x8_t v_xyhse = vec_mule(v_xhs, v_yh);
__vector int16_t v_xy_ = v_xylso + v_xylse + v_xyhso + v_xyhse; v_xy_ += vec_reve(v_xy_);
int16x8_t v_xy_ = v_xylso + v_xylse + v_xyhso + v_xyhse; v_xy_ += vec_reve(v_xy_);
const __vector float v_xy = vec_float(vec_unpackh(v_xy_));
const __vector float v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d));
const float32x4_t v_xy = vec_float(vec_unpackh(v_xy_));
const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x[ib].d) * GGML_CPU_FP16_TO_FP32(y[ib].d));
acc = vec_madd(v_xy, v_d, acc);
}
sumf = acc[0] + acc[1] + acc[2] + acc[3];
sumf = vec_hsum_f32x4(acc);
*s = sumf;
#else
UNUSED(nb);
@ -249,8 +248,7 @@ void ggml_vec_dot_q4_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
acc = vec_madd(v_xy, v_d, acc);
}
sumf = acc[0] + acc[1] + acc[2] + acc[3] + summs;
sumf = vec_hsum_f32x4(acc) + summs;
*s = sumf;
#else
UNUSED(nb);
@ -351,7 +349,7 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
v_sum1 = vec_madd(v_xy1f, v_d1, v_sum1);
}
sumf += vec_hsum(v_sum0) + vec_hsum(v_sum1);
sumf += vec_hsum_f32x4(v_sum0) + vec_hsum_f32x4(v_sum1);
#pragma GCC unroll 4
for (; ib < nb; ++ib) {
@ -390,7 +388,7 @@ void ggml_vec_dot_q5_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d));
const float32x4_t v_acc = vec_madd(v_xyf, v_d, vec_splats(0.0f));
sumf += vec_hsum(v_acc);
sumf += vec_hsum_f32x4(v_acc);
}
*s = sumf;
@ -502,7 +500,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
v_sum1 = vec_madd(v_xy1f, v_d1, v_sum1);
}
sumf += vec_hsum(v_sum0) + vec_hsum(v_sum1) + summs0 + summs1;
sumf += vec_hsum_f32x4(v_sum0) + vec_hsum_f32x4(v_sum1) + summs0 + summs1;
#pragma GCC unroll 4
for (; ib < nb; ++ib) {
@ -543,7 +541,7 @@ void ggml_vec_dot_q5_1_q8_1(int n, float * GGML_RESTRICT s, size_t bs, const voi
const float32x4_t v_d = vec_splats(GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d));
const float32x4_t v_acc = vec_madd(v_xyf, v_d, v_acc);
sumf += vec_hsum(v_acc) + summs;
sumf += vec_hsum_f32x4(v_acc) + summs;
}
*s = sumf;
@ -575,7 +573,7 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
float sumf = 0;
#if defined(__VXE__) || defined(__VXE2__)
__vector float acc = vec_splats(0.0f);
float32x4_t acc = vec_splats(0.0f);
#pragma GCC unroll 8
for (; ib < nb; ++ib) {
@ -594,7 +592,7 @@ void ggml_vec_dot_q8_0_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const voi
acc = vec_madd(v_xy, v_d, acc);
}
sumf = acc[0] + acc[1] + acc[2] + acc[3];
sumf = vec_hsum_f32x4(acc);
*s = sumf;
#else
@ -718,10 +716,10 @@ void ggml_vec_dot_q3_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
isum2 = ggml_vec_dot(v_z, q3bytes[2], q8bytes[6]);
isum3 = ggml_vec_dot(v_z, q3bytes[3], q8bytes[7]);
isum += (isum0[0] + isum0[1] + isum0[2] + isum0[3]) * scale[0];
isum += (isum1[0] + isum1[1] + isum1[2] + isum1[3]) * scale[1];
isum += (isum2[0] + isum2[1] + isum2[2] + isum2[3]) * scale[2];
isum += (isum3[0] + isum3[1] + isum3[2] + isum3[3]) * scale[3];
isum += vec_hsum_i32x4(isum0) * scale[0];
isum += vec_hsum_i32x4(isum1) * scale[1];
isum += vec_hsum_i32x4(isum2) * scale[2];
isum += vec_hsum_i32x4(isum3) * scale[3];
scale += 4;
@ -819,7 +817,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
v_xl[1] = (int8x16_t)vec_and(v_x[1], v_lm);
const int32x4_t p1 = ggml_vec_dot(ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]);
sumi1 += (p1[0] + p1[1] + p1[2] + p1[3]) * scales[2*j+0];
sumi1 += vec_hsum_i32x4(p1) * scales[2*j+0];
v_y[0] = vec_xl(0 , y0);
v_y[1] = vec_xl(16, y0);
@ -829,7 +827,7 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
v_xl[1] = (int8x16_t)vec_sr(v_x[1], 4);
const int32x4_t p2 = ggml_vec_dot(ggml_vec_dot(v_z, v_xl[0], v_y[0]), v_xl[1], v_y[1]);
sumi2 += (p2[0] + p2[1] + p2[2] + p2[3]) * scales[2*j+1];
sumi2 += vec_hsum_i32x4(p2) * scales[2*j+1];
}
sumf += d * (sumi1 + sumi2);
@ -911,7 +909,7 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
const int32x4_t v_minsho = vec_mulo(v_ysums, v_minsh);
const int32x4_t v_minshe = vec_mule(v_ysums, v_minsh);
const int32x4_t v_mins = vec_add(v_minsho, v_minshe);
const int32_t mins = v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3];
const int32_t mins = vec_hsum_i32x4(v_mins);
const uint8_t * scales = (const uint8_t *)utmp;
const uint8_t * GGML_RESTRICT x0l = x[i].qs;
@ -948,8 +946,8 @@ void ggml_vec_dot_q5_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
int32x4_t sumi0 = ggml_vec_dot(ggml_vec_dot(v_z, q5b[0], v_y[0]), q5b[1], v_y[1]);
int32x4_t sumi1 = ggml_vec_dot(ggml_vec_dot(v_z, q5b[2], v_y[2]), q5b[3], v_y[3]);
sumi += (sumi0[0] + sumi0[1] + sumi0[2] + sumi0[3]) * *scales++;
sumi += (sumi1[0] + sumi1[1] + sumi1[2] + sumi1[3]) * *scales++;
sumi += vec_hsum_i32x4(sumi0) * *scales++;
sumi += vec_hsum_i32x4(sumi1) * *scales++;
}
sumf += d * sumi - dmin * mins;
@ -1020,7 +1018,7 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
const int32x4_t v_minshe = vec_mule(v_ysumsh, v_scaleh);
const int32x4_t v_mins = v_minslo + v_minsle + v_minsho + v_minshe;
const int32_t mins = v_mins[0] + v_mins[1] + v_mins[2] + v_mins[3];
const int32_t mins = vec_hsum_i32x4(v_mins);
int32_t isum = 0;
for (int j = 0; j < QK_K/128; ++j) {
@ -1060,10 +1058,10 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
int32x4_t summs2 = ggml_vec_dot(v_z, q6b[2], v_y[2]);
int32x4_t summs3 = ggml_vec_dot(v_z, q6b[3], v_y[3]);
isum += (summs0[0] + summs0[1] + summs0[2] + summs0[3]) * scale[0] +
(summs1[0] + summs1[1] + summs1[2] + summs1[3]) * scale[1] +
(summs2[0] + summs2[1] + summs2[2] + summs2[3]) * scale[2] +
(summs3[0] + summs3[1] + summs3[2] + summs3[3]) * scale[3];
isum += vec_hsum_i32x4(summs0) * scale[0] +
vec_hsum_i32x4(summs1) * scale[1] +
vec_hsum_i32x4(summs2) * scale[2] +
vec_hsum_i32x4(summs3) * scale[3];
scale += 4;
@ -1094,10 +1092,10 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
summs2 = ggml_vec_dot(v_z, q6b[2], v_y[2]);
summs3 = ggml_vec_dot(v_z, q6b[3], v_y[3]);
isum += (summs0[0] + summs0[1] + summs0[2] + summs0[3]) * scale[0] +
(summs1[0] + summs1[1] + summs1[2] + summs1[3]) * scale[1] +
(summs2[0] + summs2[1] + summs2[2] + summs2[3]) * scale[2] +
(summs3[0] + summs3[1] + summs3[2] + summs3[3]) * scale[3];
isum += vec_hsum_i32x4(summs0) * scale[0] +
vec_hsum_i32x4(summs1) * scale[1] +
vec_hsum_i32x4(summs2) * scale[2] +
vec_hsum_i32x4(summs3) * scale[3];
scale += 4;
}
@ -1285,7 +1283,7 @@ void ggml_vec_dot_iq4_nl_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const v
const int8x16_t v_yh = vec_xl(QK8_0/2, y0->qs);
const int32x4_t v_xy = ggml_vec_dot(ggml_vec_dot(vec_splats(0), v_xl, v_yl), v_xh, v_yh);
sumf += GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d) * (v_xy[0] + v_xy[1] + v_xy[2] + v_xy[3]);
sumf += GGML_CPU_FP16_TO_FP32(x0->d) * GGML_CPU_FP16_TO_FP32(y0->d) * vec_hsum_i32x4(v_xy);
}
*s = sumf;
@ -1354,8 +1352,8 @@ void ggml_vec_dot_iq4_xs_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const v
h >>= 4;
sumi1 += (vsumi0[0] + vsumi0[1] + vsumi0[2] + vsumi0[3]) * ls1;
sumi2 += (vsumi1[0] + vsumi1[1] + vsumi1[2] + vsumi1[3]) * ls2;
sumi1 += vec_hsum_i32x4(vsumi0) * ls1;
sumi2 += vec_hsum_i32x4(vsumi1) * ls2;
}
sumf += GGML_CPU_FP16_TO_FP32(x[ibl].d) * y[ibl].d * (sumi1 + sumi2);

View File

@ -68,12 +68,6 @@ struct ggml_compute_params {
#endif // __VXE2__
#endif // __s390x__ && __VEC__
#if defined(__s390x__) && defined(GGML_NNPA)
#ifndef __NNPA__
#define __NNPA__
#endif // __NNPA__
#endif // __s390x__ && GGML_NNPA
#if defined(__ARM_FEATURE_SVE)
#include <sys/prctl.h>
#endif
@ -489,11 +483,16 @@ inline static int16x8_t vec_padd_s16(int16x8_t a, int16x8_t b) {
/**
* @see https://github.com/ggml-org/llama.cpp/pull/14037
*/
inline static float vec_hsum(float32x4_t v) {
inline static float vec_hsum_f32x4(float32x4_t v) {
float32x4_t v_temp = v + vec_reve(v);
return v_temp[0] + v_temp[1];
}
inline static int32_t vec_hsum_i32x4(int32x4_t v) {
int32x4_t v_temp = v + vec_reve(v);
return v_temp[0] + v_temp[1];
}
inline static int32x4_t ggml_vec_dot(int32x4_t acc, int8x16_t a, int8x16_t b) {
const int16x8_t p = vec_mule(a, b) + vec_mulo(a, b);
return acc + (vec_unpackh(p) + vec_unpackl(p));

View File

@ -373,6 +373,9 @@ static const struct ggml_type_traits_cpu type_traits_cpu[GGML_TYPE_COUNT] = {
.vec_dot_type = GGML_TYPE_Q8_K,
.nrows = 1,
},
[GGML_TYPE_I32] = {
.from_float = (ggml_from_float_t) ggml_cpu_fp32_to_i32,
},
};
const struct ggml_type_traits_cpu * ggml_get_type_traits_cpu(enum ggml_type type) {
@ -1876,6 +1879,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{
ggml_compute_forward_im2col_back_f32(params, tensor);
} break;
case GGML_OP_IM2COL_3D:
{
ggml_compute_forward_im2col_3d(params, tensor);
} break;
case GGML_OP_CONV_2D:
{
ggml_compute_forward_conv_2d(params, tensor);
@ -2255,6 +2262,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
} break;
case GGML_OP_IM2COL:
case GGML_OP_IM2COL_BACK:
case GGML_OP_IM2COL_3D:
case GGML_OP_CONV_2D:
case GGML_OP_CONV_3D:
case GGML_OP_CONV_2D_DW:
@ -2691,7 +2699,10 @@ struct ggml_cplan ggml_graph_plan(
if (ggml_is_quantized(node->type) ||
// F16 -> BF16 and BF16 -> F16 copies go through intermediate F32
(node->src[0]->type == GGML_TYPE_F16 && node->src[1] && node->src[1]->type == GGML_TYPE_BF16) ||
(node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16)) {
(node->src[0]->type == GGML_TYPE_BF16 && node->src[1] && node->src[1]->type == GGML_TYPE_F16) ||
// conversion between F32 and I32
(node->src[0]->type == GGML_TYPE_F32 && node->src[1] && node->src[1]->type == GGML_TYPE_I32) ||
(node->src[0]->type == GGML_TYPE_I32 && node->src[1] && node->src[1]->type == GGML_TYPE_F32)) {
cur = ggml_type_size(GGML_TYPE_F32) * node->ne[0] * n_tasks;
}
} break;
@ -3206,20 +3217,12 @@ void ggml_cpu_fp32_to_fp16(const float * x, ggml_fp16_t * y, int64_t n) {
__m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
_mm_storel_epi64((__m128i *)(y + i), y_vec);
}
#elif defined(__NNPA__)
for (; i + 7 < n; i += 8) {
float32x4_t v_xh = vec_xl(0, (const float *)(x + i + 0));
float32x4_t v_xl = vec_xl(0, (const float *)(x + i + 4));
uint16x8_t v_yd = vec_round_from_fp32(v_xh, v_xl, 0);
uint16x8_t v_y = vec_convert_to_fp16(v_yd, 0);
vec_xst(v_y, 0, (ggml_fp16_t *)(y + i));
}
for (; i + 3 < n; i += 4) {
float32x4_t v_x = vec_xl(0, (const float *)(x + i));
float32x4_t v_zero = vec_splats(0.0f);
uint16x8_t v_yd = vec_round_from_fp32(v_x, v_zero, 0);
uint16x8_t v_y = vec_convert_to_fp16(v_yd, 0);
vec_xst(v_y, 0, (ggml_fp16_t *)(y + i));
#elif defined(__riscv_zvfh)
for (int vl; i < n; i += vl) {
vl = __riscv_vsetvl_e32m2(n - i);
vfloat32m2_t vx = __riscv_vle32_v_f32m2(&x[i], vl);
vfloat16m1_t vy = __riscv_vfncvt_f_f_w_f16m1(vx, vl);
__riscv_vse16_v_f16m1((_Float16 *)&y[i], vy, vl);
}
#endif
for (; i < n; ++i) {
@ -3247,21 +3250,6 @@ void ggml_cpu_fp16_to_fp32(const ggml_fp16_t * x, float * y, int64_t n) {
__m128 y_vec = _mm_cvtph_ps(x_vec);
_mm_storeu_ps(y + i, y_vec);
}
#elif defined(__NNPA__)
for (; i + 7 < n; i += 8) {
uint16x8_t v_x = vec_xl(0, (const ggml_fp16_t *)(x + i));
uint16x8_t v_yd = vec_convert_from_fp16(v_x, 0);
float32x4_t v_yh = vec_extend_to_fp32_hi(v_yd, 0);
float32x4_t v_yl = vec_extend_to_fp32_lo(v_yd, 0);
vec_xst(v_yh, 0, (float *)(y + i + 0));
vec_xst(v_yl, 0, (float *)(y + i + 4));
}
for (; i + 3 < n; i += 4) {
uint16x8_t v_x = vec_xl(0, (const ggml_fp16_t *)(x + i));
uint16x8_t v_yd = vec_convert_from_fp16(v_x, 0);
float32x4_t v_yh = vec_extend_to_fp32_hi(v_yd, 0);
vec_xst(v_yh, 0, (float *)(y + i));
}
#endif
for (; i < n; ++i) {
@ -3276,6 +3264,13 @@ void ggml_cpu_fp32_to_bf16(const float * x, ggml_bf16_t * y, int64_t n) {
}
}
void ggml_cpu_fp32_to_i32(const float * x, int32_t * y, int64_t n) {
int64_t i = 0;
for (; i < n; ++i) {
y[i] = x[i];
}
}
void ggml_cpu_bf16_to_fp32(const ggml_bf16_t * x, float * y, int64_t n) {
int64_t i = 0;
#if defined(__AVX2__)
@ -3465,14 +3460,6 @@ int ggml_cpu_has_vxe(void) {
#endif
}
int ggml_cpu_has_nnpa(void) {
#if defined(GGML_NNPA)
return 1;
#else
return 0;
#endif
}
int ggml_cpu_has_neon(void) {
#if defined(__ARM_ARCH) && defined(__ARM_NEON)
return 1;

View File

@ -190,6 +190,7 @@ static const struct ggml_backend_i ggml_backend_cpu_i = {
/* .graph_compute = */ ggml_backend_cpu_graph_compute,
/* .event_record = */ NULL,
/* .event_wait = */ NULL,
/* .optimize_graph = */ NULL,
};
static ggml_guid_t ggml_backend_cpu_guid(void) {
@ -348,8 +349,10 @@ static void ggml_backend_cpu_device_get_memory(ggml_backend_dev_t dev, size_t *
long pages = sysconf(_SC_PHYS_PAGES);
long page_size = sysconf(_SC_PAGE_SIZE);
*total = pages * page_size;
// "free" system memory is ill-defined, for practical purposes assume that all of it is free:
*free = *total;
#endif
#endif // _WIN32
GGML_UNUSED(dev);
}
@ -576,9 +579,6 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
if (ggml_cpu_has_vxe()) {
features.push_back({ "VXE", "1" });
}
if (ggml_cpu_has_nnpa()) {
features.push_back({ "NNPA", "1" });
}
if (ggml_cpu_has_wasm_simd()) {
features.push_back({ "WASM_SIMD", "1" });
}

View File

@ -154,7 +154,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
return compute_forward_q4_0(params, dst);
} else if (dst->src[0]->type == GGML_TYPE_F16) {
return compute_forward_kv_cache(params, dst);
return compute_forward_fp16(params, dst);
}
} else if (dst->op == GGML_OP_GET_ROWS) {
if (dst->src[0]->type == GGML_TYPE_Q4_0) {
@ -164,7 +164,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
return false;
}
bool compute_forward_kv_cache(ggml_compute_params * params, struct ggml_tensor * dst) {
bool compute_forward_fp16(ggml_compute_params * params, struct ggml_tensor * dst) {
static std::atomic_flag first_to_arrive = ATOMIC_FLAG_INIT;
const ggml_tensor * src0 = dst->src[0];
@ -534,13 +534,8 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type {
if (op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_cpu_kleidiai_buffer_type()) {
return (ggml::cpu::tensor_traits *) op->src[0]->extra;
}
else if (ggml_kleidiai_select_kernels(ctx.features, op) &&
op->src[0]->op == GGML_OP_VIEW &&
(op->src[1]->op == GGML_OP_PERMUTE || op->src[1]->op == GGML_OP_SOFT_MAX) &&
op->src[1]->ne[1] > 1) {
if ((op->src[0]->nb[0] != 2) ||
(op->src[1]->nb[0] != 4) ||
(op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
else if (ggml_kleidiai_select_kernels(ctx.features, op) && op->src[1]->ne[1] > 1) {
if ((op->src[0]->nb[1] * op->src[0]->ne[1] != op->src[0]->nb[2]) ||
(op->src[1]->nb[1] * op->src[1]->ne[1] != op->src[1]->nb[2])) {
return nullptr;
}

View File

@ -776,6 +776,24 @@ static void ggml_compute_forward_dup_f32(
id += ne00 * (ne01 - ir1);
}
}
} else if (dst->type == GGML_TYPE_I32) {
size_t id = 0;
int32_t * dst_ptr = (int32_t *) dst->data;
for (int i03 = 0; i03 < ne03; i03++) {
for (int i02 = 0; i02 < ne02; i02++) {
id += ne00 * ir0;
for (int i01 = ir0; i01 < ir1; i01++) {
for (int i00 = 0; i00 < ne00; i00++) {
const float * src0_ptr = (float *) ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
dst_ptr[id] = *src0_ptr;
id++;
}
}
id += ne00 * (ne01 - ir1);
}
}
} else {
GGML_ABORT("fatal error"); // TODO: implement
}
@ -947,6 +965,144 @@ static void ggml_compute_forward_dup_f32(
}
}
}
} else if (dst->type == GGML_TYPE_I32) {
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
i10 += ne00 * ir0;
while (i10 >= ne0) {
i10 -= ne0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
for (int64_t i01 = ir0; i01 < ir1; i01++) {
for (int64_t i00 = 0; i00 < ne00; i00++) {
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
*(int32_t *) dst_ptr = *(const float *) src0_ptr;
if (++i10 == ne0) {
i10 = 0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
}
}
i10 += ne00 * (ne01 - ir1);
while (i10 >= ne0) {
i10 -= ne0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
}
}
} else {
GGML_ABORT("fatal error"); // TODO: implement
}
}
static void ggml_compute_forward_dup_i32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
GGML_ASSERT(ggml_nelements(dst) == ggml_nelements(src0));
GGML_TENSOR_UNARY_OP_LOCALS
const int ith = params->ith; // thread index
const int nth = params->nth; // number of threads
// parallelize by rows
const int nr = ne01;
// number of rows per thread
const int dr = (nr + nth - 1) / nth;
// row range for this thread
const int ir0 = dr * ith;
const int ir1 = MIN(ir0 + dr, nr);
// dst counters
int64_t i10 = 0;
int64_t i11 = 0;
int64_t i12 = 0;
int64_t i13 = 0;
// TODO: not optimal, but works
if (dst->type == GGML_TYPE_F32) {
for (int64_t i03 = 0; i03 < ne03; i03++) {
for (int64_t i02 = 0; i02 < ne02; i02++) {
i10 += ne00 * ir0;
while (i10 >= ne0) {
i10 -= ne0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
for (int64_t i01 = ir0; i01 < ir1; i01++) {
for (int64_t i00 = 0; i00 < ne00; i00++) {
const char * src0_ptr = ((char *) src0->data + i00*nb00 + i01*nb01 + i02*nb02 + i03*nb03);
char * dst_ptr = ((char *) dst->data + i10*nb0 + i11*nb1 + i12*nb2 + i13*nb3);
*(float *) dst_ptr = *(const int32_t *) src0_ptr;
if (++i10 == ne0) {
i10 = 0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
}
}
i10 += ne00 * (ne01 - ir1);
while (i10 >= ne0) {
i10 -= ne0;
if (++i11 == ne1) {
i11 = 0;
if (++i12 == ne2) {
i12 = 0;
if (++i13 == ne3) {
i13 = 0;
}
}
}
}
}
}
} else {
GGML_ABORT("fatal error"); // TODO: implement
}
@ -1177,6 +1333,10 @@ void ggml_compute_forward_dup(
{
ggml_compute_forward_dup_f32(params, dst);
} break;
case GGML_TYPE_I32:
{
ggml_compute_forward_dup_i32(params, dst);
} break;
default:
{
if (ggml_is_quantized(src0->type) && dst->type == GGML_TYPE_F32) {
@ -7027,6 +7187,209 @@ void ggml_compute_forward_im2col_back_f32(
}
}
// ggml_compute_forward_im2col_3d_f16
// src0: kernel [OC*IC, KD, KH, KW]
// src1: image [N*IC, ID, IH, IW]
// dst: result [N*OD, OH, OW, IC * KD * KH * KW]
static void ggml_compute_forward_im2col_3d_f16(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16);
GGML_TENSOR_BINARY_OP_LOCALS;
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
const int32_t IC = ((const int32_t *)(dst->op_params))[9];
const int ith = params->ith;
const int nth = params->nth;
const int64_t N = ne13 / IC;
const int64_t ID = ne12;
const int64_t IH = ne11;
const int64_t IW = ne10;
const int64_t OC = ne03 / IC;
GGML_UNUSED(OC);
const int64_t KD = ne02;
const int64_t KH = ne01;
const int64_t KW = ne00;
const int64_t OD = ne3 / N;
const int64_t OH = ne2;
const int64_t OW = ne1;
const int64_t OH_OW = OH*OW;
const int64_t KD_KH_KW = KD*KH*KW;
const int64_t KH_KW = KH*KW;
const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
GGML_ASSERT(nb10 == sizeof(float));
// im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
{
ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
for (int64_t in = 0; in < N; in++) {
for (int64_t iod = 0; iod < OD; iod++) {
for (int64_t ioh = 0; ioh < OH; ioh++) {
for (int64_t iow = 0; iow < OW; iow++) {
for (int64_t iic = ith; iic < IC; iic += nth) {
// micro kernel
ggml_fp16_t * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
for (int64_t ikd = 0; ikd < KD; ikd++) {
for (int64_t ikh = 0; ikh < KH; ikh++) {
for (int64_t ikw = 0; ikw < KW; ikw++) {
const int64_t iiw = iow*s0 + ikw*d0 - p0;
const int64_t iih = ioh*s1 + ikh*d1 - p1;
const int64_t iid = iod*s2 + ikd*d2 - p2;
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
} else {
const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(*s);
}
}
}
}
}
}
}
}
}
}
}
// ggml_compute_forward_im2col_3d_f32
// src0: kernel [OC*IC, KD, KH, KW]
// src1: image [N*IC, ID, IH, IW]
// dst: result [N*OD, OH, OW, IC * KD * KH * KW]
static void ggml_compute_forward_im2col_3d_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_TENSOR_BINARY_OP_LOCALS;
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
const int32_t IC = ((const int32_t *)(dst->op_params))[9];
const int ith = params->ith;
const int nth = params->nth;
const int64_t N = ne13 / IC;
const int64_t ID = ne12;
const int64_t IH = ne11;
const int64_t IW = ne10;
const int64_t OC = ne03 / IC;
GGML_UNUSED(OC);
const int64_t KD = ne02;
const int64_t KH = ne01;
const int64_t KW = ne00;
const int64_t OD = ne3 / N;
const int64_t OH = ne2;
const int64_t OW = ne1;
const int64_t OH_OW = OH*OW;
const int64_t KD_KH_KW = KD*KH*KW;
const int64_t KH_KW = KH*KW;
const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
GGML_ASSERT(nb10 == sizeof(float));
// im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
{
float * const wdata = (float *) dst->data;
for (int64_t in = 0; in < N; in++) {
for (int64_t iod = 0; iod < OD; iod++) {
for (int64_t ioh = 0; ioh < OH; ioh++) {
for (int64_t iow = 0; iow < OW; iow++) {
for (int64_t iic = ith; iic < IC; iic += nth) {
// micro kernel
float * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
for (int64_t ikd = 0; ikd < KD; ikd++) {
for (int64_t ikh = 0; ikh < KH; ikh++) {
for (int64_t ikw = 0; ikw < KW; ikw++) {
const int64_t iiw = iow*s0 + ikw*d0 - p0;
const int64_t iih = ioh*s1 + ikh*d1 - p1;
const int64_t iid = iod*s2 + ikd*d2 - p2;
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
} else {
const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = *s;
}
}
}
}
}
}
}
}
}
}
}
void ggml_compute_forward_im2col_3d(
const ggml_compute_params * params,
ggml_tensor * dst) {
switch (dst->type) {
case GGML_TYPE_F16:
{
ggml_compute_forward_im2col_3d_f16(params, dst);
} break;
case GGML_TYPE_F32:
{
ggml_compute_forward_im2col_3d_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
void * a, void * b, float * c) {
const ggml_type_traits * traits = ggml_get_type_traits(type);
@ -8014,6 +8377,15 @@ static void ggml_compute_forward_pad_f32(
GGML_TENSOR_UNARY_OP_LOCALS
float * dst_ptr = (float *) dst->data;
const int32_t lp0 = ggml_get_op_params_i32(dst, 0);
const int32_t rp0 = ggml_get_op_params_i32(dst, 1);
const int32_t lp1 = ggml_get_op_params_i32(dst, 2);
const int32_t rp1 = ggml_get_op_params_i32(dst, 3);
const int32_t lp2 = ggml_get_op_params_i32(dst, 4);
const int32_t rp2 = ggml_get_op_params_i32(dst, 5);
const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
// TODO: optimize
@ -8022,10 +8394,12 @@ static void ggml_compute_forward_pad_f32(
for (int64_t i0 = 0; i0 < ne0; ++i0) {
for (int64_t i3 = 0; i3 < ne3; ++i3) {
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00);
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) {
if ((i0 >= lp0 && i0 < ne0 - rp0) \
&& (i1 >= lp1 && i1 < ne1 - rp1) \
&& (i2 >= lp2 && i2 < ne2 - rp2) \
&& (i3 >= lp3 && i3 < ne3 - rp3)) {
const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
const float * src_ptr = (const float *)((char *) src0->data + src_idx);
dst_ptr[dst_idx] = *src_ptr;
} else {
dst_ptr[dst_idx] = 0;

View File

@ -69,6 +69,7 @@ void ggml_compute_forward_clamp(const struct ggml_compute_params * params, struc
void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_im2col_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_conv_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);

View File

@ -114,26 +114,6 @@ extern "C" {
#define GGML_CPU_COMPUTE_FP32_TO_FP16(x) riscv_compute_fp32_to_fp16(x)
#define GGML_CPU_FP16_TO_FP32(x) GGML_CPU_COMPUTE_FP16_TO_FP32(x)
#define GGML_CPU_FP32_TO_FP16(x) GGML_CPU_COMPUTE_FP32_TO_FP16(x)
#elif defined(__NNPA__)
#define GGML_CPU_COMPUTE_FP16_TO_FP32(x) nnpa_compute_fp16_to_fp32(x)
#define GGML_CPU_COMPUTE_FP32_TO_FP16(x) nnpa_compute_fp32_to_fp16(x)
#define GGML_CPU_FP16_TO_FP32(x) GGML_CPU_COMPUTE_FP16_TO_FP32(x)
#define GGML_CPU_FP32_TO_FP16(x) GGML_CPU_COMPUTE_FP32_TO_FP16(x)
static inline float nnpa_compute_fp16_to_fp32(ggml_fp16_t h) {
uint16x8_t v_h = vec_splats(h);
uint16x8_t v_hd = vec_convert_from_fp16(v_h, 0);
return vec_extend_to_fp32_hi(v_hd, 0)[0];
}
static inline ggml_fp16_t nnpa_compute_fp32_to_fp16(float f) {
float32x4_t v_f = vec_splats(f);
float32x4_t v_zero = vec_splats(0.0f);
uint16x8_t v_hd = vec_round_from_fp32(v_f, v_zero, 0);
uint16x8_t v_h = vec_convert_to_fp16(v_hd, 0);
return vec_extract(v_h, 0);
}
#endif
// precomputed f32 table for f16 (256 KB)
@ -215,6 +195,47 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
#define GGML_F32_VEC_MUL GGML_F32xt_MUL
#define GGML_F32_VEC_REDUCE GGML_F32xt_REDUCE
// F16 SVE
#define DEFAULT_PG32 svptrue_b32()
#define DEFAULT_PG16 svptrue_b16()
#define GGML_F32Cxt svfloat16_t
#define GGML_F32Cxt_ZERO svdup_n_f16(0.0f)
#define GGML_F32Cxt_SET1(x) svdup_n_f16(x)
#define GGML_F32Cxt_LOAD(p) svld1_f16(DEFAULT_PG16, (const __fp16 *)(p))
#define GGML_F32Cxt_STORE(dst_ptr, src_vec) svst1_f16(DEFAULT_PG16, (__fp16 *)(dst_ptr), (src_vec))
#define GGML_F32Cxt_FMA_IMPL(pg, a, b, c) svmad_f16_x(pg, b, c, a)
#define GGML_F32Cxt_FMA(...) GGML_F32Cxt_FMA_IMPL(DEFAULT_PG16, __VA_ARGS__)
#define GGML_F32Cxt_ADD_IMPL(pg, a, b) svadd_f16_x(pg, a, b)
#define GGML_F32Cxt_ADD(...) GGML_F32Cxt_ADD_IMPL(DEFAULT_PG16, __VA_ARGS__)
#define GGML_F32Cxt_MUL_IMPL(pg, a, b) svmul_f16_x(pg, a, b)
#define GGML_F32Cxt_MUL(...) GGML_F32Cxt_MUL_IMPL(DEFAULT_PG16, __VA_ARGS__)
#define GGML_F32Cxt_REDUCE GGML_F16xt_REDUCE_MIXED
#define GGML_F16x_VEC GGML_F32Cxt
#define GGML_F16x_VEC_ZERO GGML_F32Cxt_ZERO
#define GGML_F16x_VEC_SET1 GGML_F32Cxt_SET1
#define GGML_F16x_VEC_LOAD(p, i) GGML_F32Cxt_LOAD(p)
#define GGML_F16x_VEC_STORE(p, r, i) GGML_F32Cxt_STORE((__fp16 *)(p), r)
#define GGML_F16x_VEC_FMA GGML_F32Cxt_FMA
#define GGML_F16x_VEC_ADD GGML_F32Cxt_ADD
#define GGML_F16x_VEC_MUL GGML_F32Cxt_MUL
#define GGML_F16x_VEC_REDUCE GGML_F32Cxt_REDUCE
#define GGML_F16xt_REDUCE_ONE_IMPL(pg, a) svaddv_f16(pg, a)
#define GGML_F16xt_REDUCE_ONE(...) GGML_F16xt_REDUCE_ONE_IMPL(DEFAULT_PG16, __VA_ARGS__)
#define GGML_F16xt_REDUCE_MIXED_IMPL(pg16, res, sum1, sum2, sum3, sum4) \
{ \
sum1 = svadd_f16_x(pg16, sum1, sum2); \
sum3 = svadd_f16_x(pg16, sum3, sum4); \
sum1 = svadd_f16_x(pg16, sum1, sum3); \
__fp16 sum_f16 = svaddv_f16(pg16, sum1); \
(res) = (ggml_float) sum_f16; \
}
#define GGML_F16xt_REDUCE_MIXED(...) GGML_F16xt_REDUCE_MIXED_IMPL(DEFAULT_PG16, __VA_ARGS__)
// F16 NEON
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
@ -1115,11 +1136,6 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
#define GGML_F16_EPR GGML_F32_EPR
static inline float32x4_t __lzs_f16cx4_load(const ggml_fp16_t * x) {
#if defined(__NNPA__)
uint16x8_t v_x = vec_xl(0, (const ggml_fp16_t *)x);
uint16x8_t v_xd = vec_convert_from_fp16(v_x, 0);
return vec_extend_to_fp32_hi(v_xd, 0);
#else
float tmp[4];
for (int i = 0; i < 4; i++) {
@ -1129,20 +1145,9 @@ static inline float32x4_t __lzs_f16cx4_load(const ggml_fp16_t * x) {
// note: keep type-cast here to prevent compiler bugs
// see: https://github.com/ggml-org/llama.cpp/issues/12846
return vec_xl(0, (const float *)(tmp));
#endif
}
static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) {
#if defined(__NNPA__)
float32x4_t v_zero = vec_splats(0.0f);
uint16x8_t v_xd = vec_round_from_fp32(v_y, v_zero, 0);
uint16x8_t v_x = vec_convert_to_fp16(v_xd, 0);
x[0] = vec_extract(v_x, 0);
x[1] = vec_extract(v_x, 1);
x[2] = vec_extract(v_x, 2);
x[3] = vec_extract(v_x, 3);
#else
float arr[4];
// note: keep type-cast here to prevent compiler bugs
@ -1152,7 +1157,6 @@ static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) {
for (int i = 0; i < 4; i++) {
x[i] = GGML_CPU_FP32_TO_FP16(arr[i]);
}
#endif
}
#define GGML_F16_VEC GGML_F32x4

View File

@ -85,15 +85,21 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
// reduce sum1,sum2 to sum1
GGML_F32_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8);
#elif defined(__riscv_v_intrinsic)
vfloat32m1_t vsum = __riscv_vfmv_v_f_f32m1(0.0f, 1);
for (int i = 0, avl; i < n; i += avl) {
avl = __riscv_vsetvl_e32m8(n - i);
vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl);
vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl);
vfloat32m8_t prod = __riscv_vfmul_vv_f32m8(ax, ay, avl);
vsum = __riscv_vfredusum_vs_f32m8_f32m1(prod, vsum, avl);
int vl = __riscv_vsetvlmax_e32m8();
vfloat32m1_t vs = __riscv_vfmv_v_f_f32m1(0.0f, 1);
vfloat32m8_t vsum;
vfloat32m8_t ax;
vfloat32m8_t ay;
vsum = __riscv_vfmv_v_f_f32m8_tu(vsum, 0.0f, vl);
for (int i = 0; i < n; i += vl) {
vl = __riscv_vsetvl_e32m8(n - i);
ax = __riscv_vle32_v_f32m8_tu(ax, &x[i], vl);
ay = __riscv_vle32_v_f32m8_tu(ay, &y[i], vl);
vsum = __riscv_vfmacc_vv_f32m8_tu(vsum, ax, ay, vl);
}
sumf += __riscv_vfmv_f_s_f32m1_f32(vsum);
vl = __riscv_vsetvlmax_e32m8();
vs = __riscv_vfredusum_vs_f32m8_f32m1(vsum, vs, vl);
sumf += __riscv_vfmv_f_s_f32m1_f32(vs);
#else
const int np = (n & ~(GGML_F32_STEP - 1));
@ -207,7 +213,94 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G
ggml_float sumf = 0.0;
#if defined(GGML_SIMD) && !defined(__riscv_v_intrinsic)
#if defined(GGML_SIMD)
#if defined(__ARM_FEATURE_SVE)
const int sve_register_length = svcntb() * 8; //get vector length
const int ggml_f16_epr = sve_register_length / 16; // running when 16
const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers
const int np= (n & ~(ggml_f16_step - 1));
svfloat16_t sum1 = svdup_n_f16(0.0f);
svfloat16_t sum2 = svdup_n_f16(0.0f);
svfloat16_t sum3 = svdup_n_f16(0.0f);
svfloat16_t sum4 = svdup_n_f16(0.0f);
svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
for (int i = 0; i < np; i += ggml_f16_step) {
ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
sum1 = GGML_F16x_VEC_FMA(sum1, ax1, ay1);
ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
sum2 = GGML_F16x_VEC_FMA(sum2, ax2, ay2);
ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
sum3 = GGML_F16x_VEC_FMA(sum3, ax3, ay3);
ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
sum4 = GGML_F16x_VEC_FMA(sum4, ax4, ay4);
ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
sum1 = GGML_F16x_VEC_FMA(sum1, ax5, ay5);
ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
sum2 = GGML_F16x_VEC_FMA(sum2, ax6, ay6);
ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
sum3 = GGML_F16x_VEC_FMA(sum3, ax7, ay7);
ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
sum4 = GGML_F16x_VEC_FMA(sum4, ax8, ay8);
}
const int np2 = (n & ~(ggml_f16_epr - 1)); // round down to multiple of 8
for (int k = np; k < np2; k += ggml_f16_epr) {
svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
sum1 = GGML_F16x_VEC_FMA(sum1, rx, ry);
}
if (np2 < n) {
svbool_t pg = svwhilelt_b16(np2, n);
svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
sum1 = svmad_f16_x(pg, hx, hy, sum1);
}
GGML_F16x_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4);
#elif defined(__riscv_v_intrinsic)
#if defined(__riscv_zvfh)
int vl = __riscv_vsetvlmax_e32m2();
vfloat32m1_t vs = __riscv_vfmv_v_f_f32m1(0.0f, 1);
vfloat32m2_t vsum;
vfloat16m1_t ax;
vfloat16m1_t ay;
vsum = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vmv_v_x_u32m2(0, vl));
for (int i = 0; i < n; i += vl) {
vl = __riscv_vsetvl_e16m1(n - i);
ax = __riscv_vle16_v_f16m1_tu(ax, (const _Float16 *)&x[i], vl);
ay = __riscv_vle16_v_f16m1_tu(ay, (const _Float16 *)&y[i], vl);
vsum = __riscv_vfwmacc_vv_f32m2_tu(vsum, ax, ay, vl);
}
vl = __riscv_vsetvlmax_e32m1();
vfloat32m1_t ac0 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(vsum, 0), __riscv_vget_v_f32m2_f32m1(vsum, 1), vl);
vs = __riscv_vfredusum_vs_f32m1_f32m1(ac0, vs, vl);
sumf += __riscv_vfmv_f_s_f32m1_f32(vs);
#else
for (int i = 0; i < n; ++i) {
sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
}
#endif // __riscv_zvfh
#else
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
@ -231,14 +324,14 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G
for (int i = np; i < n; ++i) {
sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
}
// if you hit this, you are likely running outside the FP range
assert(!isnan(sumf) && !isinf(sumf));
#endif
#else
for (int i = 0; i < n; ++i) {
sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
}
#endif
#endif // GGML_SIMD
*s = sumf;
}
@ -257,6 +350,12 @@ void ggml_vec_silu_f32(const int n, float * y, const float * x) {
for (; i + 3 < n; i += 4) {
_mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
}
#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
const int vlen = svcntw();
for (; i < n; i += vlen) {
const svbool_t pg = svwhilelt_b32_s32(i, n);
svst1_f32(pg, y + i, ggml_v_silu(pg, svld1_f32(pg, x + i)));
}
#elif defined(__ARM_NEON) && defined(__aarch64__)
for (; i + 3 < n; i += 4) {
vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
@ -281,10 +380,24 @@ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float *
for (; i + 3 < n; i += 4) {
_mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(g + i)));
}
#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
const int vlen = svcntw();
for (; i < n; i += vlen) {
const svbool_t pg = svwhilelt_b32_s32(i, n);
svst1_f32(pg, y + i, svmul_f32_x(pg, ggml_v_silu(pg, svld1_f32(pg, x + i)), svld1_f32(pg, g + i)));
}
#elif defined(__ARM_NEON) && defined(__aarch64__)
for (; i + 3 < n; i += 4) {
vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(g + i)));
}
#elif defined(__riscv_v_intrinsic)
for (int vl; i < n; i += vl) {
vl = __riscv_vsetvl_e32m2(n - i);
vfloat32m2_t vx = __riscv_vle32_v_f32m2(&x[i], vl);
vfloat32m2_t vg = __riscv_vle32_v_f32m2(&g[i], vl);
vfloat32m2_t vy = __riscv_vfmul_vv_f32m2(ggml_v_silu_m2(vx, vl), vg, vl);
__riscv_vse32_v_f32m2(&y[i], vy, vl);
}
#endif
for (; i < n; ++i) {
y[i] = ggml_silu_f32(x[i]) * g[i];
@ -328,6 +441,15 @@ ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float
#endif
sum += (ggml_float)_mm_cvtss_f32(val);
}
#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
const int vlen = svcntw();
for (; i < n; i += vlen) {
const svbool_t pg = svwhilelt_b32_s32(i, n);
svfloat32_t val = ggml_v_expf(pg, svsub_f32_x(pg, svld1_f32(pg, x + i),
svdup_n_f32_x(pg, max)));
svst1_f32(pg, y + i, val);
sum += (ggml_float)svaddv_f32(pg, val);
}
#elif defined(__ARM_NEON) && defined(__aarch64__)
for (; i + 3 < n; i += 4) {
float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),

View File

@ -119,14 +119,118 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
}
#if defined(GGML_SIMD)
#if defined(__riscv_v_intrinsic)
#if defined(__ARM_FEATURE_SVE)
const int sve_register_length = svcntb() * 8;
const int ggml_f16_epr = sve_register_length / 16; // running when 16
const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers
const int np = (n & ~(ggml_f16_step - 1));
svfloat16_t sum_00 = svdup_n_f16(0.0f);
svfloat16_t sum_01 = svdup_n_f16(0.0f);
svfloat16_t sum_02 = svdup_n_f16(0.0f);
svfloat16_t sum_03 = svdup_n_f16(0.0f);
svfloat16_t sum_10 = svdup_n_f16(0.0f);
svfloat16_t sum_11 = svdup_n_f16(0.0f);
svfloat16_t sum_12 = svdup_n_f16(0.0f);
svfloat16_t sum_13 = svdup_n_f16(0.0f);
svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
for (int i = 0; i < np; i += ggml_f16_step) {
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); // 8 elements
ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elemnst
sum_00 = GGML_F16x_VEC_FMA(sum_00, ax1, ay1); // sum_00 = sum_00+ax1*ay1
ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 0*ggml_f16_epr, 0); // 8 elements
sum_10 = GGML_F16x_VEC_FMA(sum_10, ax1, ay1);
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); // next 8 elements
ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 ekements
sum_01 = GGML_F16x_VEC_FMA(sum_01, ax2, ay2);
ax2 = GGML_F16x_VEC_LOAD(x[1] + i + 1*ggml_f16_epr, 1);
sum_11 = GGML_F16x_VEC_FMA(sum_11, ax2, ay2);
ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
ax3 = GGML_F16x_VEC_LOAD(x[0] + i + 2*ggml_f16_epr, 2);
sum_02 = GGML_F16x_VEC_FMA(sum_02, ax3, ay3);
ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2);
sum_12 = GGML_F16x_VEC_FMA(sum_12, ax3, ay3);
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
ax4 = GGML_F16x_VEC_LOAD(x[0] + i + 3*ggml_f16_epr, 3);
sum_03 = GGML_F16x_VEC_FMA(sum_03, ax4, ay4);
ax4 = GGML_F16x_VEC_LOAD(x[1] + i + 3*ggml_f16_epr, 3);
sum_13 = GGML_F16x_VEC_FMA(sum_13, ax4, ay4);
ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
ax5 = GGML_F16x_VEC_LOAD(x[0] + i + 4*ggml_f16_epr, 4);
sum_00 = GGML_F16x_VEC_FMA(sum_00, ax5, ay5);
ax5 = GGML_F16x_VEC_LOAD(x[1] + i + 4*ggml_f16_epr, 4);
sum_10 = GGML_F16x_VEC_FMA(sum_10, ax5, ay5);
ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
ax6 = GGML_F16x_VEC_LOAD(x[0] + i + 5*ggml_f16_epr, 5);
sum_01 = GGML_F16x_VEC_FMA(sum_01, ax6, ay6);
ax6 = GGML_F16x_VEC_LOAD(x[1] + i + 5*ggml_f16_epr, 5);
sum_11 = GGML_F16x_VEC_FMA(sum_11, ax6, ay6);
ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
ax7 = GGML_F16x_VEC_LOAD(x[0] + i + 6*ggml_f16_epr, 6);
sum_02 = GGML_F16x_VEC_FMA(sum_02, ax7, ay7);
ax7 = GGML_F16x_VEC_LOAD(x[1] + i + 6*ggml_f16_epr, 6);
sum_12 = GGML_F16x_VEC_FMA(sum_12, ax7, ay7);
ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
ax8 = GGML_F16x_VEC_LOAD(x[0] + i + 7*ggml_f16_epr, 7);
sum_03 = GGML_F16x_VEC_FMA(sum_03, ax8, ay8);
ax8 = GGML_F16x_VEC_LOAD(x[1] + i + 7*ggml_f16_epr, 7);
sum_13 = GGML_F16x_VEC_FMA(sum_13, ax8, ay8);
}
const int np2 = (n & ~(ggml_f16_epr - 1));
for (int k = np; k < np2; k += ggml_f16_epr) {
svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
svfloat16_t rx = GGML_F16x_VEC_LOAD(x[0] + k, 0);
sum_00 = GGML_F16x_VEC_FMA(sum_00, rx, ry);
rx = GGML_F16x_VEC_LOAD(x[1] + k, 0);
sum_10 = GGML_F16x_VEC_FMA(sum_10, rx, ry);
}
if (np2 < n) {
svbool_t pg = svwhilelt_b16(np2, n);
svfloat16_t hx_0 = svld1_f16(pg, (const __fp16 *)(x[0] + np2));
svfloat16_t hx_1 = svld1_f16(pg, (const __fp16 *)(x[1] + np2));
svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
sum_00 = svmad_f16_x(pg, hx_0, hy, sum_00);
sum_10 = svmad_f16_x(pg, hx_1, hy, sum_10);
}
GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03);
GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13);
#elif defined(__riscv_v_intrinsic)
// todo: RVV impl
for (int i = 0; i < n; ++i) {
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
}
}
#else
#else
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
@ -157,7 +261,7 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
}
}
#endif
#endif
#else
for (int i = 0; i < n; ++i) {
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
@ -293,13 +397,90 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) {
#if defined(GGML_SIMD)
#if defined(__riscv_v_intrinsic)
#if defined(__ARM_FEATURE_SVE)
const int sve_register_length = svcntb() * 8;
const int ggml_f16_epr = sve_register_length / 16;
const int ggml_f16_step = 8 * ggml_f16_epr;
GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
const int np= (n & ~(ggml_f16_step - 1));
svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
for (int i = 0; i < np; i += ggml_f16_step) {
ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx);
GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0);
ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx);
GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1);
ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx);
GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2);
ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx);
GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3);
ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx);
GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4);
ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx);
GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5);
ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx);
GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6);
ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx);
GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7);
}
const int np2 = (n & ~(ggml_f16_epr - 1));
for (int k = np; k < np2; k += ggml_f16_epr) {
svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
ry = GGML_F16x_VEC_FMA(ry, rx, vx);
GGML_F16x_VEC_STORE(y + k, ry, 0);
}
if (np2 < n) {
svbool_t pg = svwhilelt_b16(np2, n);
svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
hy = svmad_f16_x(pg, hx, vx, hy);
svst1_f16(pg, (__fp16 *)(y + np2), hy);
}
#elif defined(__riscv_v_intrinsic)
// todo: RVV impl
// scalar
for (int i = 0; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
}
#else
#else
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
@ -321,7 +502,7 @@ inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y,
for (int i = np; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
}
#endif
#endif
#else
// scalar
for (int i = 0; i < n; ++i) {
@ -517,13 +698,39 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
#if defined(GGML_SIMD)
#if defined(__riscv_v_intrinsic)
#if defined(__ARM_FEATURE_SVE)
const int sve_register_length = svcntb() * 8;
const int ggml_f16_epr = sve_register_length / 16;
const int ggml_f16_step = 2 * ggml_f16_epr;
GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
const int np = (n & ~(ggml_f16_step - 1));
svfloat16_t ay1, ay2;
for (int i = 0; i < np; i += ggml_f16_step) {
ay1 = GGML_F16x_VEC_LOAD(y + i + 0*ggml_f16_epr, 0);
ay1 = GGML_F16x_VEC_MUL(ay1, vx);
GGML_F16x_VEC_STORE(y + i + 0*ggml_f16_epr, ay1, 0);
ay2 = GGML_F16x_VEC_LOAD(y + i + 1*ggml_f16_epr, 1);
ay2 = GGML_F16x_VEC_MUL(ay2, vx);
GGML_F16x_VEC_STORE(y + i + 1*ggml_f16_epr, ay2, 1);
}
// leftovers
// maximum number of leftover elements will be less that ggmlF_16x_epr. Apply predicated svmad on available elements only
if (np < n) {
svbool_t pg = svwhilelt_b16(np, n);
svfloat16_t hy = svld1_f16(pg, (__fp16 *)(y + np));
svfloat16_t out = svmul_f16_m(pg, hy, vx);
svst1_f16(pg, (__fp16 *)(y + np), out);
}
#elif defined(__riscv_v_intrinsic)
// todo: RVV impl
// scalar
for (int i = 0; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
}
#else
#else
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
@ -543,7 +750,7 @@ inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float
for (int i = np; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
}
#endif
#endif
#else
// scalar
for (int i = 0; i < n; ++i) {
@ -795,7 +1002,39 @@ https://github.com/openvinotoolkit/openvino/blob/master/src/plugins/intel_cpu/sr
}
#endif
#if defined(__ARM_NEON) && defined(__aarch64__)
#if defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
inline static svfloat32_t ggml_v_expf(svbool_t pg, svfloat32_t x) {
const svfloat32_t r = svdup_n_f32_x(pg, 0x1.8p23f);
const svfloat32_t z = svmla_n_f32_x(pg, r, x, 0x1.715476p+0f);
const svfloat32_t n = svsub_f32_x(pg, z, r);
const svfloat32_t b = svmls_n_f32_x(pg, svmls_n_f32_x(pg, x, n, 0x1.62e4p-1f), n, 0x1.7f7d1cp-20f);
const svuint32_t e = svlsl_n_u32_x(pg, svreinterpret_u32_f32(z), 23);
const svfloat32_t k = svreinterpret_f32_u32(svadd_u32_x(pg, e, svreinterpret_u32_f32(svdup_n_f32_x(pg, 1))));
const svbool_t c = svacgt_n_f32(pg, n, 126);
const svfloat32_t u = svmul_f32_x(pg, b, b);
const svfloat32_t j = svmla_f32_x(pg,
svmul_n_f32_x(pg, b, 0x1.ffffecp-1f),
svmla_f32_x(pg, svmla_f32_x(pg, svdup_n_f32_x(pg, 0x1.fffdb6p-2f), svdup_n_f32_x(pg, 0x1.555e66p-3f), b),
svmla_f32_x(pg, svdup_n_f32_x(pg, 0x1.573e2ep-5f), svdup_n_f32_x(pg, 0x1.0e4020p-7f), b), u), u);
const svuint32_t d = svdup_n_u32_z(svcmple_n_f32(pg, n, 0.0), 0x82000000);
const svfloat32_t s1 = svreinterpret_f32_u32(svadd_n_u32_x(pg, d, 0x7f000000));
const svfloat32_t s2 = svreinterpret_f32_u32(svsub_u32_x(pg, e, d));
return svsel_f32(svacgt_f32(pg, n, svdup_n_f32_x(pg, 192)), svmul_f32_x(pg, s1, s1),
svsel_f32(c, svmul_f32_x(pg, svmla_f32_x(pg, s2, s2, j), s1), svmla_f32_x(pg, k, k, j)));
}
// computes silu x/(1+exp(-x)) in single precision vector
inline static svfloat32_t ggml_v_silu(svbool_t pg, svfloat32_t x) {
const svfloat32_t one = svdup_n_f32_x(pg, 1.0f);
const svfloat32_t zero = svdup_n_f32_x(pg, 0.0f);
const svfloat32_t neg_x = svsub_f32_x(pg, zero, x);
const svfloat32_t exp_neg_x = ggml_v_expf(pg, neg_x);
const svfloat32_t one_plus_exp_neg_x = svadd_f32_x(pg, one, exp_neg_x);
return svdiv_f32_x(pg, x, one_plus_exp_neg_x);
}
#elif defined(__ARM_NEON) && defined(__aarch64__)
// adapted from arm limited optimized routine
// the maximum error is 1.45358 plus 0.5 ulps
@ -1030,6 +1269,14 @@ inline static vfloat32m2_t ggml_v_expf_m2(vfloat32m2_t x, int vl) {
vl);
}
// computes silu x/(1+exp(-x)) in single precision vector
inline static vfloat32m2_t ggml_v_silu_m2(vfloat32m2_t x, int vl) {
const vfloat32m2_t neg_x = __riscv_vfneg_v_f32m2(x, vl);
const vfloat32m2_t exp_neg_x = ggml_v_expf_m2(neg_x, vl);
const vfloat32m2_t one_plus_exp_neg_x = __riscv_vfadd_vf_f32m2(exp_neg_x, 1.0f, vl);
return __riscv_vfdiv_vv_f32m2(x, one_plus_exp_neg_x, vl);
}
#endif // __ARM_NEON / __AVX2__ / __SSE2__ / __riscv_v_intrinsic
inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {

View File

@ -44,6 +44,8 @@ if (CUDAToolkit_FOUND)
list(APPEND GGML_SOURCES_CUDA ${SRCS})
file(GLOB SRCS "template-instances/mmq*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
file(GLOB SRCS "template-instances/mmf*.cu")
list(APPEND GGML_SOURCES_CUDA ${SRCS})
if (GGML_CUDA_FA_ALL_QUANTS)
file(GLOB SRCS "template-instances/fattn-vec*.cu")

View File

@ -545,6 +545,31 @@ static __device__ __forceinline__ int ggml_cuda_dp4a(const int a, const int b, i
#endif // defined(GGML_USE_HIP)
}
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float v, const float u) {
acc += v*u;
}
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v, const float2 u) {
acc += v.x*u.x;
acc += v.y*u.y;
}
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
#if defined(GGML_USE_HIP) && defined(GCN)
asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u));
#else
#ifdef FAST_FP16_AVAILABLE
const float2 tmp = __half22float2(v*u);
acc += tmp.x + tmp.y;
#else
const float2 tmpv = __half22float2(v);
const float2 tmpu = __half22float2(u);
acc += tmpv.x * tmpu.x;
acc += tmpv.y * tmpu.y;
#endif // FAST_FP16_AVAILABLE
#endif // defined(GGML_USE_HIP) && defined(GCN)
}
static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
#if CUDART_VERSION >= 12080
const nv_bfloat16 e = __nv_cvt_e8m0_to_bf16raw(x);
@ -563,6 +588,40 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
#endif // CUDART_VERSION >= 12050
}
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
// Precompute mp (m' in the paper) and L such that division
// can be computed using a multiply (high 32b of 64b result)
// and a shift:
//
// n/d = (mulhi(n, mp) + n) >> L;
static const uint3 init_fastdiv_values(uint32_t d) {
GGML_ASSERT(d != 0);
// compute L = ceil(log2(d));
uint32_t L = 0;
while (L < 32 && (uint32_t{ 1 } << L) < d) {
L++;
}
uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1);
// pack divisor as well to reduce error surface
return make_uint3(mp, L, d);
}
static __device__ __forceinline__ uint32_t fastdiv(uint32_t n, const uint3 fastdiv_values) {
// expects fastdiv_values to contain <mp, L, divisor> in <x, y, z>
// fastdiv_values.z is unused and optimized away by the compiler.
// Compute high 32 bits of n * mp
const uint32_t hi = __umulhi(n, fastdiv_values.x);
// add n, apply bit shift
return (hi + n) >> fastdiv_values.y;
}
static __device__ __forceinline__ uint32_t fastmodulo(uint32_t n, const uint3 fastdiv_values) {
// expects fastdiv_values to contain <mp, L, divisor> in <x, y, z> (see init_fastdiv_values)
return n - fastdiv(n, fastdiv_values) * fastdiv_values.z;
}
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v);
static __device__ __forceinline__ float get_alibi_slope(

View File

@ -38,6 +38,8 @@ template<typename dst_t, typename src_t>
return __float2bfloat16(float(x));
} else if constexpr(std::is_same_v<src_t, nv_bfloat16>) {
return __bfloat162float(x);
} else if constexpr(std::is_same_v<dst_t, int32_t>) {
return int32_t(x);
} else {
return float(x);
}

View File

@ -374,6 +374,10 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg
ggml_cpy_flt_cuda<nv_bfloat16, half> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<nv_bfloat16, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_I32) {
ggml_cpy_flt_cuda<float, int32_t> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_F32) {
ggml_cpy_flt_cuda<int32_t, float> (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index);
} else {
GGML_ABORT("%s: unsupported type combination (%s to %s)\n", __func__,
ggml_type_name(src0->type), ggml_type_name(src1->type));

View File

@ -1,371 +0,0 @@
#include "common.cuh"
#include "fattn-common.cuh"
#include "fattn-tile-f16.cuh"
#define FATTN_KQ_STRIDE_TILE_F16 64
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
#if !defined(GGML_USE_HIP)
__launch_bounds__(nwarps*WARP_SIZE, 2)
#endif // !defined(GGML_USE_HIP)
static __global__ void flash_attn_tile_ext_f16(
const char * __restrict__ Q,
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0,
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
const int32_t nb01, const int32_t nb02, const int32_t nb03,
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
const int32_t nb11, const int32_t nb12, const int64_t nb13,
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
#if defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
// Skip unused kernel variants for faster compilation:
#ifdef FP16_MMA_AVAILABLE
NO_DEVICE_CODE;
return;
#endif // FP16_MMA_AVAILABLE
if (use_logit_softcap && !(D == 128 || D == 256)) {
NO_DEVICE_CODE;
return;
}
//In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float * sinksf = (const float *) (sinks);
const int stride_KV2 = nb11 / sizeof(half2);
const float slopef = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
const half slopeh = __float2half(slopef);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
__shared__ half KQ[ncols*FATTN_KQ_STRIDE_TILE_F16];
half2 * KQ2 = (half2 *) KQ;
__shared__ half2 KV_tmp[FATTN_KQ_STRIDE_TILE_F16][D/2 + 1]; // Pad D to avoid memory bank conflicts.
half kqmax[ncols/nwarps];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
kqmax[j0/nwarps] = -HALF_MAX_HALF;
}
half2 kqsum[ncols/nwarps] = {{0.0f, 0.0f}};
half2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}};
// Convert Q to half2 and store in registers:
__shared__ half2 Q_h2[ncols][D/2];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i] : make_float2(0.0f, 0.0f);
Q_h2[j][i] = make_half2(scale, scale) * make_half2(tmp.x, tmp.y);
}
}
__syncthreads();
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F16; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F16) {
// Calculate KQ tile and keep track of new maximum KQ values:
half kqmax_new[ncols/nwarps];
#pragma unroll
for (int j = 0; j < ncols/nwarps; ++j) {
kqmax_new[j] = kqmax[j];
}
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += nwarps) {
const int i_KQ = i_KQ_0 + threadIdx.y;
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D/2; k_KQ_0 += WARP_SIZE) {
const int k_KQ = k_KQ_0 + threadIdx.x;
KV_tmp[i_KQ][k_KQ] = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ];
}
}
__syncthreads();
half2 sum2[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE][ncols/nwarps] = {{{0.0f, 0.0f}}};
#pragma unroll
for (int k_KQ = 0; k_KQ < D/2; ++k_KQ) {
half2 K_k[FATTN_KQ_STRIDE_TILE_F16/WARP_SIZE];
half2 Q_k[ncols/nwarps];
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
const int i_KQ = i_KQ_0 + threadIdx.x;
K_k[i_KQ_0/WARP_SIZE] = KV_tmp[i_KQ][k_KQ];
}
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
const int j_KQ = j_KQ_0 + threadIdx.y;
Q_k[j_KQ_0/nwarps] = Q_h2[j_KQ][k_KQ];
}
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += K_k[i_KQ_0/WARP_SIZE]*Q_k[j_KQ_0/nwarps];
}
}
}
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F16; i_KQ_0 += WARP_SIZE) {
const int i_KQ = i_KQ_0 + threadIdx.x;
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
const int j_KQ = j_KQ_0 + threadIdx.y;
half sum;
if (use_logit_softcap) {
const float2 tmp = __half22float2(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
sum = logit_softcap * tanhf(tmp.x + tmp.y);
} else {
sum = __low2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]) + __high2half(sum2[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
}
sum += mask ? slopeh*maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ] : __float2half(0.0f);
kqmax_new[j_KQ_0/nwarps] = ggml_cuda_hmax(kqmax_new[j_KQ_0/nwarps], sum);
KQ[j_KQ*FATTN_KQ_STRIDE_TILE_F16 + i_KQ] = sum;
}
}
__syncthreads();
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]);
const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]));
kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
#pragma unroll
for (int i0 = 0; i0 < FATTN_KQ_STRIDE_TILE_F16/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
const half2 diff = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] - __half2half2(kqmax[j0/nwarps]);
const half2 val = h2exp(diff);
kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + val;
KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + i] = val;
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale;
}
}
__syncthreads();
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += nwarps) {
const int k = k0 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
KV_tmp[k][i] = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i];
}
}
__syncthreads();
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F16; k0 += 2) {
half2 V_k[(D/2)/WARP_SIZE][2];
half2 KQ_k[ncols/nwarps];
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
V_k[i0/WARP_SIZE][0] = KV_tmp[k0 + 0][i];
V_k[i0/WARP_SIZE][1] = KV_tmp[k0 + 1][i];
}
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
KQ_k[j0/nwarps] = KQ2[j*(FATTN_KQ_STRIDE_TILE_F16/2) + k0/2];
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][0]* __low2half2(KQ_k[j0/nwarps]);
VKQ[j0/nwarps][i0/WARP_SIZE] += V_k[i0/WARP_SIZE][1]*__high2half2(KQ_k[j0/nwarps]);
}
}
}
__syncthreads();
}
//Attention sink: adjust running max and sum once per head
if (sinksf && blockIdx.y == 0) {
const half sink = __float2half(sinksf[head]);
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
half kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
kqmax_new_j = warp_reduce_max(kqmax_new_j);
const half2 KQ_max_scale = __half2half2(hexp(kqmax[j0/nwarps] - kqmax_new_j));
kqmax[j0/nwarps] = kqmax_new_j;
const half val = hexp(sink - kqmax[j0/nwarps]);
kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
if (threadIdx.x == 0) {
kqsum[j0/nwarps].x = __hadd(__low2half(kqsum[j0/nwarps]), val);
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
VKQ[j0/nwarps][i0/WARP_SIZE] *= KQ_max_scale;
}
}
}
float2 * dst2 = (float2 *) dst;
#pragma unroll
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
const int j_VKQ = j_VKQ_0 + threadIdx.y;
if (ic0 + j_VKQ >= ne01) {
return;
}
half kqsum_j = __low2half(kqsum[j_VKQ_0/nwarps]) + __high2half(kqsum[j_VKQ_0/nwarps]);
kqsum_j = warp_reduce_sum((float)kqsum_j);
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
#pragma unroll
for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
const int i0 = i00 + threadIdx.x;
half2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
if (gridDim.y == 1) {
dst_val /= __half2half2(kqsum_j);
}
dst2[j_dst_unrolled*(D/2) + i0] = __half22float2(dst_val);
}
if (gridDim.y != 1 && threadIdx.x == 0) {
dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
}
}
#else
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
max_bias, m0, m1, n_head_log2, logit_softcap,
ne00, ne01, ne02, ne03,
nb01, nb02, nb03,
ne10, ne11, ne12, ne13,
nb11, nb12, nb13,
nb21, nb22, nb23,
ne31, ne32, ne33,
nb31, nb32, nb33);
NO_DEVICE_CODE;
#endif // defined(FLASH_ATTN_AVAILABLE) && defined(FP16_AVAILABLE)
}
template <int cols_per_block, bool use_logit_softcap>
void launch_fattn_tile_f16_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
case 64: {
constexpr int D = 64;
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
launch_fattn<D, cols_per_block, 1>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f16<D, cols_per_block, nwarps, use_logit_softcap>;
launch_fattn<D, cols_per_block, 1>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F16, true, true, false);
} break;
default: {
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
} break;
}
}
void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
const int32_t precision = KQV->op_params[3];
GGML_ASSERT(precision == GGML_PREC_DEFAULT);
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
if (Q->ne[1] <= 16) {
constexpr int cols_per_block = 16;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
}
return;
}
constexpr int cols_per_block = 32;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_f16_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
}
}

View File

@ -1,3 +0,0 @@
#include "common.cuh"
void ggml_cuda_flash_attn_ext_tile_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -1,379 +0,0 @@
#include "common.cuh"
#include "fattn-common.cuh"
#include "fattn-tile-f32.cuh"
#define FATTN_KQ_STRIDE_TILE_F32 32
template<int D, int ncols, int nwarps, bool use_logit_softcap> // D == head size
#if !defined(GGML_USE_HIP)
__launch_bounds__(nwarps*WARP_SIZE, 2)
#endif // !defined(GGML_USE_HIP)
static __global__ void flash_attn_tile_ext_f32(
const char * __restrict__ Q,
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0,
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
const int32_t nb01, const int32_t nb02, const int32_t nb03,
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
const int32_t nb11, const int32_t nb12, const int64_t nb13,
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
#ifdef FLASH_ATTN_AVAILABLE
// Skip unused kernel variants for faster compilation:
#ifdef FP16_MMA_AVAILABLE
NO_DEVICE_CODE;
return;
#endif // FP16_MMA_AVAILABLE
if (use_logit_softcap && !(D == 128 || D == 256)) {
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
max_bias, m0, m1, n_head_log2, logit_softcap,
ne00, ne01, ne02, ne03,
nb01, nb02, nb03,
ne10, ne11, ne12, ne13,
nb11, nb12, nb13,
nb21, nb22, nb23,
ne31, ne32, ne33,
nb31, nb32, nb33);
NO_DEVICE_CODE;
return;
}
// In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float * sinksf = (const float *) (sinks);
const int stride_KV2 = nb11 / sizeof(half2);
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
static_assert(D % (2*WARP_SIZE) == 0, "D not divisible by 2*WARP_SIZE == 64.");
__shared__ float KQ[ncols*FATTN_KQ_STRIDE_TILE_F32];
__shared__ float KV_tmp[FATTN_KQ_STRIDE_TILE_F32][D + 1]; // Pad D to avoid memory bank conflicts.
float2 * KV_tmp2 = (float2 *) KV_tmp;
float kqmax[ncols/nwarps];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
kqmax[j0/nwarps] = -FLT_MAX/2.0f;
}
float kqsum[ncols/nwarps] = {0.0f};
float2 VKQ[ncols/nwarps][(D/2)/WARP_SIZE] = {{{0.0f, 0.0f}}};
// Convert Q to half2 and store in registers:
__shared__ float Q_f[ncols][D];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D; i0 += 2*WARP_SIZE) {
float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0/2 + threadIdx.x] : make_float2(0.0f, 0.0f);
Q_f[j][i0 + 0*WARP_SIZE + threadIdx.x] = tmp.x * scale;
Q_f[j][i0 + 1*WARP_SIZE + threadIdx.x] = tmp.y * scale;
}
}
__syncthreads();
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
for (int k_VKQ_0 = blockIdx.y*FATTN_KQ_STRIDE_TILE_F32; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*FATTN_KQ_STRIDE_TILE_F32) {
// Calculate KQ tile and keep track of new maximum KQ values:
float kqmax_new[ncols/nwarps];
#pragma unroll
for (int j = 0; j < ncols/nwarps; ++j) {
kqmax_new[j] = kqmax[j];
}
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += nwarps) {
const int i_KQ = i_KQ_0 + threadIdx.y;
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += 2*WARP_SIZE) {
const half2 tmp = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + threadIdx.x];
KV_tmp[i_KQ][k_KQ_0 + 0*WARP_SIZE + threadIdx.x] = __low2float(tmp);
KV_tmp[i_KQ][k_KQ_0 + 1*WARP_SIZE + threadIdx.x] = __high2float(tmp);
}
}
__syncthreads();
float sum[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE][ncols/nwarps] = {{0.0f}};
#pragma unroll
for (int k_KQ = 0; k_KQ < D; ++k_KQ) {
float K_k[FATTN_KQ_STRIDE_TILE_F32/WARP_SIZE];
float Q_k[ncols/nwarps];
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) {
const int i_KQ = i_KQ_0 + threadIdx.x;
K_k[i_KQ_0/WARP_SIZE] = KV_tmp[i_KQ][k_KQ];
}
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
const int j_KQ = j_KQ_0 + threadIdx.y;
Q_k[j_KQ_0/nwarps] = Q_f[j_KQ][k_KQ];
}
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) {
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += K_k[i_KQ_0/WARP_SIZE] * Q_k[j_KQ_0/nwarps];
}
}
}
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < FATTN_KQ_STRIDE_TILE_F32; i_KQ_0 += WARP_SIZE) {
const int i_KQ = i_KQ_0 + threadIdx.x;
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
const int j_KQ = j_KQ_0 + threadIdx.y;
if (use_logit_softcap) {
sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] = logit_softcap * tanhf(sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
}
sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps]);
KQ[j_KQ*FATTN_KQ_STRIDE_TILE_F32 + i_KQ] = sum[i_KQ_0/WARP_SIZE][j_KQ_0/nwarps];
}
}
__syncthreads();
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
kqmax_new[j0/nwarps] = warp_reduce_max(kqmax_new[j0/nwarps]);
const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]);
kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
float kqsum_add = 0.0f;
#pragma unroll
for (int i0 = 0; i0 < FATTN_KQ_STRIDE_TILE_F32; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
const float diff = KQ[j*FATTN_KQ_STRIDE_TILE_F32 + i] - kqmax[j0/nwarps];
const float val = expf(diff);
kqsum_add += val;
KQ[j*FATTN_KQ_STRIDE_TILE_F32 + i] = val;
}
kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add;
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale;
VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale;
}
}
__syncthreads();
#pragma unroll
for (int k0 = 0; k0 < FATTN_KQ_STRIDE_TILE_F32; k0 += nwarps) {
const int k = k0 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
const half2 tmp = V_h2[int64_t(k_VKQ_0 + k)*stride_KV2 + i];
KV_tmp2[k*(D/2) + i].x = __low2float(tmp);
KV_tmp2[k*(D/2) + i].y = __high2float(tmp);
}
}
__syncthreads();
#pragma unroll
for (int k = 0; k < FATTN_KQ_STRIDE_TILE_F32; ++k) {
float2 V_k[(D/2)/WARP_SIZE];
float KQ_k[ncols/nwarps];
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
V_k[i0/WARP_SIZE] = KV_tmp2[k*(D/2) + i];
}
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
KQ_k[j0/nwarps] = KQ[j*FATTN_KQ_STRIDE_TILE_F32 + k];
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
VKQ[j0/nwarps][i0/WARP_SIZE].x += V_k[i0/WARP_SIZE].x*KQ_k[j0/nwarps];
VKQ[j0/nwarps][i0/WARP_SIZE].y += V_k[i0/WARP_SIZE].y*KQ_k[j0/nwarps];
}
}
}
__syncthreads();
}
//Attention sink: adjust running max and sum once per head
if (sinksf && blockIdx.y == 0) {
const float sink = sinksf[head];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
float kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
kqmax_new_j = warp_reduce_max(kqmax_new_j);
const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new_j);
kqmax[j0/nwarps] = kqmax_new_j;
const float val = expf(sink - kqmax[j0/nwarps]);
kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
if (threadIdx.x == 0) {
kqsum[j0/nwarps] += val;
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += WARP_SIZE) {
VKQ[j0/nwarps][i0/WARP_SIZE].x *= KQ_max_scale;
VKQ[j0/nwarps][i0/WARP_SIZE].y *= KQ_max_scale;
}
}
}
float2 * dst2 = (float2 *) dst;
#pragma unroll
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
const int j_VKQ = j_VKQ_0 + threadIdx.y;
if (ic0 + j_VKQ >= ne01) {
return;
}
float kqsum_j = kqsum[j_VKQ_0/nwarps];
kqsum_j = warp_reduce_sum(kqsum_j);
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
#pragma unroll
for (int i00 = 0; i00 < D/2; i00 += WARP_SIZE) {
const int i0 = i00 + threadIdx.x;
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/WARP_SIZE];
if (gridDim.y == 1) {
dst_val.x /= kqsum_j;
dst_val.y /= kqsum_j;
}
dst2[j_dst_unrolled*(D/2) + i0] = dst_val;
}
if (gridDim.y != 1 && threadIdx.x == 0) {
dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
}
}
#else
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
max_bias, m0, m1, n_head_log2, logit_softcap,
ne00, ne01, ne02, ne03,
nb01, nb02, nb03,
ne10, ne11, ne12, ne13,
nb11, nb12, nb13,
nb21, nb22, nb23,
ne31, ne32, ne33,
nb31, nb32, nb33);
NO_DEVICE_CODE;
#endif // FLASH_ATTN_AVAILABLE
}
template <int cols_per_block, bool use_logit_softcap>
void launch_fattn_tile_f32_64_128(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
case 64: {
constexpr int D = 64;
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
launch_fattn<D, cols_per_block, 1>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
} break;
case 128: {
constexpr int D = 128;
constexpr int nwarps = 8;
constexpr size_t nbytes_shared = 0;
fattn_kernel_t fattn_kernel = flash_attn_tile_ext_f32<D, cols_per_block, nwarps, use_logit_softcap>;
launch_fattn<D, cols_per_block, 1>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, FATTN_KQ_STRIDE_TILE_F32, true, true, false);
} break;
default: {
GGML_ABORT("FlashAttention without tensor cores only supports head sizes 64 and 128.");
} break;
}
}
void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
const ggml_tensor * Q = dst->src[0];
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
if (Q->ne[1] <= 16) {
constexpr int cols_per_block = 16;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
}
return;
}
constexpr int cols_per_block = 32;
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_f32_64_128<cols_per_block, use_logit_softcap>(ctx, dst);
}
}

View File

@ -1,3 +0,0 @@
#include "common.cuh"
void ggml_cuda_flash_attn_ext_tile_f32(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -0,0 +1,591 @@
#include "common.cuh"
#include "fattn-common.cuh"
#include "fattn-tile.cuh"
#define FATTN_TILE_NTHREADS 256
static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int cc, const int warp_size) {
if (GGML_CUDA_CC_IS_AMD(cc)) {
switch (D) {
case 64:
return ncols <= 16 ? 32 : 64;
case 128:
return ncols <= 16 ? 64 : warp_size;
case 256:
return 64;
default:
GGML_ABORT("fatal error");
return -1;
}
}
if (fast_fp16_available(cc)) {
switch (D) {
case 64:
case 128:
return 128;
case 256:
return ncols <= 16 ? 128 : 64;
default:
GGML_ABORT("fatal error");
return -1;
}
}
switch (D) {
case 64:
return ncols <= 16 ? 128 : 64;
case 128:
return ncols <= 16 ? 64 : 32;
case 256:
return 32;
default:
GGML_ABORT("fatal error");
return -1;
}
}
static constexpr __device__ int fattn_tile_get_kq_stride_device(int D, int ncols, int warp_size) {
#ifdef GGML_USE_HIP
switch (D) {
case 64:
return ncols <= 16 ? 32 : 64;
case 128:
return ncols <= 16 ? 64 : warp_size;
case 256:
return 64;
default:
return -1;
}
#else
#ifdef FAST_FP16_AVAILABLE
switch (D) {
case 64:
case 128:
return 128;
case 256:
return ncols <= 16 ? 128 : 64;
default:
return -1;
}
#else
switch (D) {
case 64:
return ncols <= 16 ? 128 : 64;
case 128:
return ncols <= 16 ? 64 : 32;
case 256:
return 32;
default:
return -1;
}
#endif // FAST_FP16_AVAILABLE
#endif // GGML_USE_HIP
GGML_UNUSED_VARS(ncols, warp_size);
}
static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols, int warp_size) {
#ifdef GGML_USE_HIP
switch (D) {
case 64:
return 64;
case 128:
return ncols <= 16 ? 2*warp_size : 128;
case 256:
return ncols <= 16 ? 128 : 2*warp_size;
default:
return -1;
}
#else
#ifdef FAST_FP16_AVAILABLE
switch (D) {
case 64:
return 64;
case 128:
return ncols <= 16 ? 128 : 64;
case 256:
return ncols <= 16 ? 64 : 128;
default:
return -1;
}
#else
switch (D) {
case 64:
return 64;
case 128:
return 128;
case 256:
return ncols <= 16 ? 128 : 64;
default:
return -1;
}
#endif // FAST_FP16_AVAILABLE
#endif // GGML_USE_HIP
GGML_UNUSED_VARS(ncols, warp_size);
}
template<int D, int ncols, bool use_logit_softcap> // D == head size
#ifdef GGML_USE_HIP
__launch_bounds__(FATTN_TILE_NTHREADS, 1)
#else
__launch_bounds__(FATTN_TILE_NTHREADS, 2)
#endif // GGML_USE_HIP
static __global__ void flash_attn_tile(
const char * __restrict__ Q,
const char * __restrict__ K,
const char * __restrict__ V,
const char * __restrict__ mask,
const char * __restrict__ sinks,
const int * __restrict__ KV_max,
float * __restrict__ dst,
float2 * __restrict__ dst_meta,
const float scale,
const float max_bias,
const float m0,
const float m1,
const uint32_t n_head_log2,
const float logit_softcap,
const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03,
const int32_t nb01, const int32_t nb02, const int32_t nb03,
const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13,
const int32_t nb11, const int32_t nb12, const int64_t nb13,
const int32_t nb21, const int32_t nb22, const int64_t nb23,
const int32_t ne31, const int32_t ne32, const int32_t ne33,
const int32_t nb31, const int32_t nb32, const int64_t nb33) {
#ifdef FLASH_ATTN_AVAILABLE
// Skip unused kernel variants for faster compilation:
#ifdef FP16_MMA_AVAILABLE
NO_DEVICE_CODE;
return;
#endif // FP16_MMA_AVAILABLE
if (use_logit_softcap && !(D == 128 || D == 256)) {
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
max_bias, m0, m1, n_head_log2, logit_softcap,
ne00, ne01, ne02, ne03,
nb01, nb02, nb03,
ne10, ne11, ne12, ne13,
nb11, nb12, nb13,
nb21, nb22, nb23,
ne31, ne32, ne33,
nb31, nb32, nb33);
NO_DEVICE_CODE;
return;
}
constexpr int warp_size = 32;
constexpr int nwarps = FATTN_TILE_NTHREADS / warp_size;
constexpr int kq_stride = fattn_tile_get_kq_stride_device(D, ncols, warp_size);
static_assert(kq_stride % warp_size == 0, "kq_stride not divisable by warp_size.");
constexpr int kq_nbatch = fattn_tile_get_kq_nbatch_device(D, ncols, warp_size);
static_assert(kq_nbatch % (2*warp_size) == 0, "bad kq_nbatch");
// In this kernel Q, K, V are matrices while i, j, k are matrix indices.
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
const int sequence = blockIdx.z / ne02;
const int head = blockIdx.z - sequence*ne02;
const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix.
const float2 * Q_f2 = (const float2 *) (Q + nb03* sequence + nb02* head + nb01*ic0);
const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio));
const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape
const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0);
const float * sinksf = (const float *) (sinks);
const int stride_KV2 = nb11 / sizeof(half2);
const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1);
__shared__ float KQ[ncols][kq_stride];
#ifdef FAST_FP16_AVAILABLE
__shared__ half2 Q_tmp[ncols][D/2];
__shared__ half2 KV_tmp_h2[kq_stride * (kq_nbatch/2 + 1)]; // Padded to avoid memory bank conflicts.
half2 VKQ[ncols/nwarps][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
#else
__shared__ float Q_tmp[ncols][D];
__shared__ float KV_tmp_f[kq_stride * (kq_nbatch + 1)]; // Padded to avoid memory bank conflicts.
float2 * KV_tmp_f2 = (float2 *) KV_tmp_f;
float2 VKQ[ncols/nwarps][D/(2*warp_size)] = {{{0.0f, 0.0f}}};
#endif // FAST_FP16_AVAILABLE
float kqmax[ncols/nwarps];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
kqmax[j0/nwarps] = -FLT_MAX/2.0f;
}
float kqsum[ncols/nwarps] = {0.0f};
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
const float2 tmp = ic0 + j < ne01 ? Q_f2[j*(nb01/sizeof(float2)) + i0 + threadIdx.x] : make_float2(0.0f, 0.0f);
#ifdef FAST_FP16_AVAILABLE
Q_tmp[j][i0 + threadIdx.x] = make_half2(tmp.x * scale, tmp.y * scale);
#else
Q_tmp[j][2*i0 + threadIdx.x] = tmp.x * scale;
Q_tmp[j][2*i0 + warp_size + threadIdx.x] = tmp.y * scale;
#endif // FAST_FP16_AVAILABLE
}
}
__syncthreads();
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
for (int k_VKQ_0 = blockIdx.y*kq_stride; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*kq_stride) {
// Calculate KQ tile and keep track of new maximum KQ values:
float kqmax_new[ncols/nwarps];
#pragma unroll
for (int j = 0; j < ncols/nwarps; ++j) {
kqmax_new[j] = kqmax[j];
}
float sum[kq_stride/warp_size][ncols/nwarps] = {{0.0f}};
#pragma unroll
for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += kq_nbatch) {
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += nwarps) {
const int i_KQ = i_KQ_0 + threadIdx.y;
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size) {
const half2 tmp_h2 = K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x];
#ifdef FAST_FP16_AVAILABLE
KV_tmp_h2[i_KQ*(kq_nbatch/2 + 1) + k_KQ_1 + threadIdx.x] = tmp_h2;
#else
const float2 tmp_f2 = __half22float2(tmp_h2);
KV_tmp_f[i_KQ*(kq_nbatch + 1) + 2*k_KQ_1 + threadIdx.x] = tmp_f2.x;
KV_tmp_f[i_KQ*(kq_nbatch + 1) + 2*k_KQ_1 + warp_size + threadIdx.x] = tmp_f2.y;
#endif // FAST_FP16_AVAILABLE
}
}
__syncthreads();
#ifdef FAST_FP16_AVAILABLE
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; ++k_KQ_1) {
half2 K_k[kq_stride/warp_size];
half2 Q_k[ncols/nwarps];
#else
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; ++k_KQ_1) {
float K_k[kq_stride/warp_size];
float Q_k[ncols/nwarps];
#endif // FAST_FP16_AVAILABLE
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
const int i_KQ = i_KQ_0 + threadIdx.x;
#ifdef FAST_FP16_AVAILABLE
K_k[i_KQ_0/warp_size] = KV_tmp_h2[i_KQ*(kq_nbatch/2 + 1) + k_KQ_1];
#else
K_k[i_KQ_0/warp_size] = KV_tmp_f [i_KQ*(kq_nbatch + 1) + k_KQ_1];
#endif // FAST_FP16_AVAILABLE
}
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
const int j_KQ = j_KQ_0 + threadIdx.y;
#ifdef FAST_FP16_AVAILABLE
Q_k[j_KQ_0/nwarps] = Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1];
#else
Q_k[j_KQ_0/nwarps] = Q_tmp[j_KQ][k_KQ_0 + k_KQ_1];
#endif // FAST_FP16_AVAILABLE
}
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
ggml_cuda_mad(sum[i_KQ_0/warp_size][j_KQ_0/nwarps], K_k[i_KQ_0/warp_size], Q_k[j_KQ_0/nwarps]);
}
}
}
if (k_KQ_0 + kq_nbatch < D) {
__syncthreads(); // Sync not needed on last iteration.
}
}
#pragma unroll
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
const int i_KQ = i_KQ_0 + threadIdx.x;
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
const int j_KQ = j_KQ_0 + threadIdx.y;
if (use_logit_softcap) {
sum[i_KQ_0/warp_size][j_KQ_0/nwarps] = logit_softcap * tanhf(sum[i_KQ_0/warp_size][j_KQ_0/nwarps]);
}
sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f;
kqmax_new[j_KQ_0/nwarps] = fmaxf(kqmax_new[j_KQ_0/nwarps], sum[i_KQ_0/warp_size][j_KQ_0/nwarps]);
KQ[j_KQ][i_KQ] = sum[i_KQ_0/warp_size][j_KQ_0/nwarps];
}
}
__syncthreads();
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
kqmax_new[j0/nwarps] = warp_reduce_max<warp_size>(kqmax_new[j0/nwarps]);
const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new[j0/nwarps]);
kqmax[j0/nwarps] = kqmax_new[j0/nwarps];
float kqsum_add = 0.0f;
#pragma unroll
for (int i0 = 0; i0 < kq_stride; i0 += warp_size) {
const int i = i0 + threadIdx.x;
const float diff = KQ[j][i] - kqmax[j0/nwarps];
const float val = expf(diff);
kqsum_add += val;
KQ[j][i] = val;
}
kqsum[j0/nwarps] = kqsum[j0/nwarps]*KQ_max_scale + kqsum_add;
#ifdef FAST_FP16_AVAILABLE
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
VKQ[j0/nwarps][i0/warp_size] *= KQ_max_scale_h2;
}
#else
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
VKQ[j0/nwarps][i0/warp_size].x *= KQ_max_scale;
VKQ[j0/nwarps][i0/warp_size].y *= KQ_max_scale;
}
#endif // FAST_FP16_AVAILABLE
}
constexpr int V_cols_per_iter = kq_stride*kq_nbatch / D;
static_assert(kq_stride % V_cols_per_iter == 0, "bad V_cols_per_iter");
#pragma unroll
for (int k0 = 0; k0 < kq_stride; k0 += V_cols_per_iter) {
#pragma unroll
for (int k1 = 0; k1 < V_cols_per_iter; k1 += nwarps) {
const int k_tile = k1 + threadIdx.y;
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
const int i = i0 + threadIdx.x;
const half2 tmp = V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i];
#ifdef FAST_FP16_AVAILABLE
KV_tmp_h2[k_tile*(D/2) + i] = tmp;
#else
KV_tmp_f2[k_tile*(D/2) + i] = __half22float2(tmp);
#endif // FAST_FP16_AVAILABLE
}
}
__syncthreads();
#pragma unroll
for (int k1 = 0; k1 < V_cols_per_iter; ++k1) {
#ifdef FAST_FP16_AVAILABLE
half2 V_k[(D/2)/warp_size];
half2 KQ_k[ncols/nwarps];
#else
float2 V_k[(D/2)/warp_size];
float KQ_k[ncols/nwarps];
#endif // FAST_FP16_AVAILABLE
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
const int i = i0 + threadIdx.x;
#ifdef FAST_FP16_AVAILABLE
V_k[i0/warp_size] = KV_tmp_h2[k1*(D/2) + i];
#else
V_k[i0/warp_size] = KV_tmp_f2[k1*(D/2) + i];
#endif // FAST_FP16_AVAILABLE
}
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
const int j = j0 + threadIdx.y;
#ifdef FAST_FP16_AVAILABLE
const float tmp = KQ[j][k0 + k1];
KQ_k[j0/nwarps] = make_half2(tmp, tmp);
#else
KQ_k[j0/nwarps] = KQ[j][k0 + k1];
#endif // FAST_FP16_AVAILABLE
}
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
#ifdef FAST_FP16_AVAILABLE
VKQ[j0/nwarps][i0/warp_size] += V_k[i0/warp_size] *KQ_k[j0/nwarps];
#else
VKQ[j0/nwarps][i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[j0/nwarps];
VKQ[j0/nwarps][i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[j0/nwarps];
#endif // FAST_FP16_AVAILABLE
}
}
}
__syncthreads();
}
}
// Attention sink: adjust running max and sum once per head
if (sinksf && blockIdx.y == 0) {
const float sink = sinksf[head];
#pragma unroll
for (int j0 = 0; j0 < ncols; j0 += nwarps) {
float kqmax_new_j = fmaxf(kqmax[j0/nwarps], sink);
kqmax_new_j = warp_reduce_max<warp_size>(kqmax_new_j);
const float KQ_max_scale = expf(kqmax[j0/nwarps] - kqmax_new_j);
kqmax[j0/nwarps] = kqmax_new_j;
const float val = expf(sink - kqmax[j0/nwarps]);
kqsum[j0/nwarps] = kqsum[j0/nwarps] * KQ_max_scale;
if (threadIdx.x == 0) {
kqsum[j0/nwarps] += val;
}
#ifdef FAST_FP16_AVAILABLE
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
VKQ[j0/nwarps][i0/warp_size] *= KQ_max_scale_h2;
}
#else
#pragma unroll
for (int i0 = 0; i0 < D/2; i0 += warp_size) {
VKQ[j0/nwarps][i0/warp_size].x *= KQ_max_scale;
VKQ[j0/nwarps][i0/warp_size].y *= KQ_max_scale;
}
#endif // FAST_FP16_AVAILABLE
}
}
float2 * dst2 = (float2 *) dst;
#pragma unroll
for (int j_VKQ_0 = 0; j_VKQ_0 < ncols; j_VKQ_0 += nwarps) {
const int j_VKQ = j_VKQ_0 + threadIdx.y;
if (ic0 + j_VKQ >= ne01) {
return;
}
float kqsum_j = kqsum[j_VKQ_0/nwarps];
kqsum_j = warp_reduce_sum<warp_size>(kqsum_j);
const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y;
#pragma unroll
for (int i00 = 0; i00 < D/2; i00 += warp_size) {
const int i0 = i00 + threadIdx.x;
#ifdef FAST_FP16_AVAILABLE
float2 dst_val = __half22float2(VKQ[j_VKQ_0/nwarps][i0/warp_size]);
#else
float2 dst_val = VKQ[j_VKQ_0/nwarps][i0/warp_size];
#endif // FAST_FP16_AVAILABLE
if (gridDim.y == 1) {
dst_val.x /= kqsum_j;
dst_val.y /= kqsum_j;
}
dst2[j_dst_unrolled*(D/2) + i0] = dst_val;
}
if (gridDim.y != 1 && threadIdx.x == 0) {
dst_meta[j_dst_unrolled] = make_float2(kqmax[j_VKQ_0/nwarps], kqsum_j);
}
}
#else
GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale,
max_bias, m0, m1, n_head_log2, logit_softcap,
ne00, ne01, ne02, ne03,
nb01, nb02, nb03,
ne10, ne11, ne12, ne13,
nb11, nb12, nb13,
nb21, nb22, nb23,
ne31, ne32, ne33,
nb31, nb32, nb33);
NO_DEVICE_CODE;
#endif // FLASH_ATTN_AVAILABLE
}
template <int D, bool use_logit_softcap>
static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
const int id = ggml_cuda_get_device();
const int cc = ggml_cuda_info().devices[id].cc;
const int warp_size = 32;
const int nwarps = FATTN_TILE_NTHREADS / warp_size;
constexpr size_t nbytes_shared = 0;
if (Q->ne[1] > 16) {
constexpr int cols_per_block = 32;
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
launch_fattn<D, cols_per_block, 1>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
return;
}
constexpr int cols_per_block = 16;
fattn_kernel_t fattn_kernel = flash_attn_tile<D, cols_per_block, use_logit_softcap>;
const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size);
launch_fattn<D, cols_per_block, 1>
(ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size);
}
template <bool use_logit_softcap>
static void launch_fattn_tile_switch_head_size(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * Q = dst->src[0];
switch (Q->ne[0]) {
case 64: {
launch_fattn_tile_switch_ncols< 64, use_logit_softcap>(ctx, dst);
} break;
case 128: {
launch_fattn_tile_switch_ncols<128, use_logit_softcap>(ctx, dst);
} break;
case 256: {
launch_fattn_tile_switch_ncols<256, use_logit_softcap>(ctx, dst);
} break;
default: {
GGML_ABORT("Unsupported head size");
} break;
}
}
void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * KQV = dst;
float logit_softcap;
memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float));
if (logit_softcap == 0.0f) {
constexpr bool use_logit_softcap = false;
launch_fattn_tile_switch_head_size<use_logit_softcap>(ctx, dst);
} else {
constexpr bool use_logit_softcap = true;
launch_fattn_tile_switch_head_size<use_logit_softcap>(ctx, dst);
}
}

View File

@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -1,8 +1,7 @@
#include "common.cuh"
#include "fattn-common.cuh"
#include "fattn-mma-f16.cuh"
#include "fattn-tile-f16.cuh"
#include "fattn-tile-f32.cuh"
#include "fattn-tile.cuh"
#include "fattn-vec-f16.cuh"
#include "fattn-vec-f32.cuh"
#include "fattn-wmma-f16.cuh"
@ -271,8 +270,7 @@ static void ggml_cuda_flash_attn_ext_vec_f32(ggml_backend_cuda_context & ctx, gg
// Best FlashAttention kernel for a specific GPU:
enum best_fattn_kernel {
BEST_FATTN_KERNEL_NONE = 0,
BEST_FATTN_KERNEL_TILE_F32 = 200,
BEST_FATTN_KERNEL_TILE_F16 = 210,
BEST_FATTN_KERNEL_TILE = 200,
BEST_FATTN_KERNEL_VEC_F32 = 100,
BEST_FATTN_KERNEL_VEC_F16 = 110,
BEST_FATTN_KERNEL_WMMA_F16 = 300,
@ -411,10 +409,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const
}
// If there is no suitable kernel for tensor cores or small batch sizes, use the generic kernel for large batch sizes:
if (prec == GGML_PREC_DEFAULT && fast_fp16_available(cc)) {
return BEST_FATTN_KERNEL_TILE_F16;
}
return BEST_FATTN_KERNEL_TILE_F32;
return BEST_FATTN_KERNEL_TILE;
}
void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@ -422,11 +417,8 @@ void ggml_cuda_flash_attn_ext(ggml_backend_cuda_context & ctx, ggml_tensor * dst
switch (ggml_cuda_get_best_fattn_kernel(ggml_cuda_get_device(), dst)) {
case BEST_FATTN_KERNEL_NONE:
GGML_ABORT("fatal error");
case BEST_FATTN_KERNEL_TILE_F32:
ggml_cuda_flash_attn_ext_tile_f32(ctx, dst);
break;
case BEST_FATTN_KERNEL_TILE_F16:
ggml_cuda_flash_attn_ext_tile_f16(ctx, dst);
case BEST_FATTN_KERNEL_TILE:
ggml_cuda_flash_attn_ext_tile(ctx, dst);
break;
case BEST_FATTN_KERNEL_VEC_F32:
ggml_cuda_flash_attn_ext_vec_f32(ctx, dst);

View File

@ -6,20 +6,17 @@ template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static __global__ void k_get_rows(
const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
/*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
/*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/
/*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) {
for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) {
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
const int i00 = (blockIdx.y * blockDim.x + threadIdx.x)*2;
const int i10 = blockIdx.x;
const int i11 = blockIdx.z / ne12;
const int i12 = blockIdx.z % ne12;
if (i00 >= ne00) {
return;
}
const int i11 = z / ne12; // TODO fastdiv
const int i12 = z % ne12;
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
@ -37,22 +34,25 @@ static __global__ void k_get_rows(
dst_row[iybs + iqs + 0] = ggml_cuda_cast<dst_t>(v.x);
dst_row[iybs + iqs + y_offset] = ggml_cuda_cast<dst_t>(v.y);
}
}
}
template<typename src0_t, typename dst_t>
static __global__ void k_get_rows_float(
const src0_t * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
const int64_t ne00, /*const int64_t ne01, const int64_t ne02, const int64_t ne03,*/
/*const int64_t ne10, const int64_t ne11,*/ const int64_t ne12, /*const int64_t ne13,*/
/*const int64_t ne10,*/ const int64_t ne11, const int64_t ne12, /*const int64_t ne13,*/
/*const size_t s0,*/ const size_t s1, const size_t s2, const size_t s3,
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
for (int64_t z = blockIdx.z; z < ne11*ne12; z += gridDim.z) {
for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) {
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
const int i00 = blockIdx.y * blockDim.x + threadIdx.x;
const int i10 = blockIdx.x;
const int i11 = blockIdx.z / ne12;
const int i12 = blockIdx.z % ne12;
const int i11 = z / ne12; // TODO fastdiv
const int i12 = z % ne12;
if (i00 >= ne00) {
return;
@ -64,6 +64,8 @@ static __global__ void k_get_rows_float(
const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
dst_row[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
}
}
}
template<typename grad_t, typename dst_t>
@ -98,7 +100,7 @@ static void get_rows_cuda_q(
cudaStream_t stream) {
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
const dim3 block_nums(ne10, block_num_y, ne11*ne12);
const dim3 block_nums(ne10, MIN(block_num_y, UINT16_MAX), MIN(ne11*ne12, UINT16_MAX));
// strides in elements
// const size_t s0 = nb0 / sizeof(dst_t);
@ -116,7 +118,7 @@ static void get_rows_cuda_q(
k_get_rows<qk, qr, dq><<<block_nums, block_dims, 0, stream>>>(
src0_d, src1_d, dst_d,
ne00, /*ne01, ne02, ne03,*/
/*ne10, ne11,*/ ne12, /*ne13,*/
/*ne10,*/ ne11, ne12, /*ne13,*/
/* s0,*/ s1, s2, s3,
/* nb00,*/ nb01, nb02, nb03,
s10, s11, s12/*, s13*/);
@ -131,7 +133,7 @@ static void get_rows_cuda_float(
cudaStream_t stream) {
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
const dim3 block_nums(ne10, block_num_y, ne11*ne12);
const dim3 block_nums(ne10, MIN(block_num_y, UINT16_MAX), MIN(ne11*ne12, UINT16_MAX));
// strides in elements
// const size_t s0 = nb0 / sizeof(dst_t);
@ -147,7 +149,7 @@ static void get_rows_cuda_float(
k_get_rows_float<<<block_nums, block_dims, 0, stream>>>(
src0_d, src1_d, dst_d,
ne00, /*ne01, ne02, ne03,*/
/*ne10, ne11,*/ ne12, /*ne13,*/
/*ne10,*/ ne11, ne12, /*ne13,*/
/* s0,*/ s1, s2, s3,
/* nb00,*/ nb01, nb02, nb03,
s10, s11, s12/*, s13*/);

View File

@ -2109,6 +2109,11 @@ static void ggml_cuda_mul_mat_id(ggml_backend_cuda_context & ctx, ggml_tensor *
ggml_cuda_mul_mat_q(ctx, src0, src1, ids, dst);
return;
}
if (ggml_cuda_should_use_mmf(src0->type, cc, WARP_SIZE, src0->ne, src1->ne[2])) {
ggml_cuda_mul_mat_f(ctx, src0, src1, ids, dst);
return;
}
}
cudaStream_t stream = ctx.stream();
@ -2452,6 +2457,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_IM2COL:
ggml_cuda_op_im2col(ctx, dst);
break;
case GGML_OP_IM2COL_3D:
ggml_cuda_op_im2col_3d(ctx, dst);
break;
case GGML_OP_CONV_2D:
ggml_cuda_op_conv2d(ctx, dst);
break;
@ -3132,6 +3140,7 @@ static const ggml_backend_i ggml_backend_cuda_interface = {
/* .graph_compute = */ ggml_backend_cuda_graph_compute,
/* .event_record = */ ggml_backend_cuda_event_record,
/* .event_wait = */ ggml_backend_cuda_event_wait,
/* .optimize_graph = */ NULL,
};
static ggml_guid_t ggml_backend_cuda_guid() {
@ -3458,6 +3467,12 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_IQ4_NL) {
return true;
}
if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32) {
return true;
}
if (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) {
return true;
}
if (src0_type == src1_type && ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op->src[1])) {
return true;
}
@ -3559,6 +3574,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
}
case GGML_OP_IM2COL:
case GGML_OP_IM2COL_3D:
case GGML_OP_CONV_2D:
case GGML_OP_CONV_2D_DW:
case GGML_OP_CONV_TRANSPOSE_2D:
@ -3570,9 +3586,9 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
case GGML_OP_GROUP_NORM:
case GGML_OP_PAD:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_UPSCALE:
case GGML_OP_PAD:
case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_ARANGE:
case GGML_OP_TIMESTEP_EMBEDDING:

View File

@ -112,3 +112,132 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
}
}
// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
template <typename T>
static __global__ void im2col_3d_kernel(
const float * src, T * dst,
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
int64_t OH_OW, int64_t KD_KH_KW, int64_t ID_IH_IW, int64_t KH_KW, int64_t IH_IW, int64_t IC_ID_IH_IW,
int64_t IC_KD_KH_KW, int64_t OW_KD_KH_KW, int64_t OD_OH_OW_IC_KD_KH_KW, int64_t OH_OW_IC_KD_KH_KW,
int64_t OW_IC_KD_KH_KW, int64_t N_OD_OH, int64_t OD_OH,
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2) {
const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= IC_KD_KH_KW) {
return;
}
const int64_t iic = i / KD_KH_KW;
const int64_t ikd = (i - iic * KD_KH_KW) / KH_KW;
const int64_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW;
const int64_t ikw = i % KW;
const int64_t iow = blockIdx.y;
for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz+=MAX_GRIDDIM_Z) {
const int64_t in = iz / OD_OH;
const int64_t iod = (iz - in*OD_OH) / OH;
const int64_t ioh = iz % OH;
const int64_t iiw = iow * s0 + ikw * d0 - p0;
const int64_t iih = ioh * s1 + ikh * d1 - p1;
const int64_t iid = iod * s2 + ikd * d2 - p2;
const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
dst[offset_dst] = 0.0f;
} else {
const int64_t offset_src = in*IC_ID_IH_IW + iic*ID_IH_IW + iid*IH_IW + iih*IW + iiw;
dst[offset_dst] = src[offset_src];
}
}
}
// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
template <typename T>
static void im2col_3d_cuda(const float * src, T* dst,
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
const int64_t OH_OW = OH*OW;
const int64_t KD_KH_KW = KD*KH*KW;
const int64_t ID_IH_IW = ID*IH*IW;
const int64_t KH_KW = KH*KW;
const int64_t IH_IW = IH*IW;
const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
const int64_t OW_KD_KH_KW = OW*KD*KH*KW;
const int64_t N_OD_OH = N*OD*OH;
const int64_t OD_OH = OD*OH;
const int64_t IC_ID_IH_IW = IC*ID*IH*IW;
const int64_t OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW;
const int64_t OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW;
const int64_t OW_IC_KD_KH_KW = OW*IC*KD*KH*KW;
const int64_t num_blocks = (IC_KD_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
dim3 block_nums(num_blocks, OW, MIN(N_OD_OH, MAX_GRIDDIM_Z));
im2col_3d_kernel<<<block_nums, MIN(IC_KD_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW,
IC_KD_KH_KW, OW_KD_KH_KW, OD_OH_OW_IC_KD_KH_KW,
OH_OW_IC_KD_KH_KW, OW_IC_KD_KH_KW, N_OD_OH, OD_OH,
s0, s1, s2, p0, p1, p2, d0, d1, d2);
}
static void im2col_3d_cuda_f16(const float * src, half * dst,
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
im2col_3d_cuda<half>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
}
static void im2col_3d_cuda_f32(const float * src, float * dst,
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
im2col_3d_cuda<float>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
}
void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
GGML_TENSOR_BINARY_OP_LOCALS
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
const int32_t IC = ((const int32_t *)(dst->op_params))[9];
const int64_t N = ne13 / IC;
const int64_t ID = ne12;
const int64_t IH = ne11;
const int64_t IW = ne10;
const int64_t OC = ne03 / IC;
const int64_t KD = ne02;
const int64_t KH = ne01;
const int64_t KW = ne00;
const int64_t OD = ne3 / N;
const int64_t OH = ne2;
const int64_t OW = ne1;
if(dst->type == GGML_TYPE_F16) {
im2col_3d_cuda_f16(src1_d, (half *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
} else {
im2col_3d_cuda_f32(src1_d, (float *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
}
}

View File

@ -3,3 +3,4 @@
#define CUDA_IM2COL_BLOCK_SIZE 256
void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -1,3 +1,4 @@
#pragma once
// This file contains primitives that expose the tensor core PTX instructions for CUDA code.
// The primitives can be used in a similar way as the nvcuda::wmma interface but with a well-defined memory layout.
// The documentation for the PTX instructions can be found under:

View File

@ -1,343 +1,12 @@
#include "ggml.h"
#include "common.cuh"
#include "mma.cuh"
#include "mmf.cuh"
using namespace ggml_cuda_mma;
#define MMF_ROWS_PER_BLOCK 32
template <typename T, int rows_per_block, int cols_per_block, int nwarps>
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
static __global__ void mul_mat_f(
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
const int ncols, const int nchannels_y, const int stride_row, const int stride_col_y, const int stride_col_dst,
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
typedef tile<16, 8, T> tile_A;
typedef tile< 8, 8, T> tile_B;
typedef tile<16, 8, float> tile_C;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int tile_k_padded = warp_size + 4;
constexpr int ntA = rows_per_block / tile_A::I;
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
const int row0 = blockIdx.x * rows_per_block;
const int channel_dst = blockIdx.y;
const int channel_x = channel_dst / channel_ratio;
const int channel_y = channel_dst;
const int sample_dst = blockIdx.z;
const int sample_x = sample_dst / sample_ratio;
const int sample_y = sample_dst;
x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ;
y += int64_t(sample_y) *stride_sample_y + channel_y *stride_channel_y;
dst += int64_t(sample_dst)*stride_sample_dst + channel_dst*stride_channel_dst;
const float2 * y2 = (const float2 *) y;
extern __shared__ char data_mmv[];
tile_C C[ntA][ntB];
T * tile_xy = (T *) data_mmv + threadIdx.y*(tile_A::I * tile_k_padded);
for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
tile_A A[ntA][warp_size / tile_A::J];
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
#pragma unroll
for (int i = 0; i < tile_A::I; ++i) {
tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
}
#pragma unroll
for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
}
}
#pragma unroll
for (int itB = 0; itB < ntB; ++itB) {
if constexpr (std::is_same_v<T, float>) {
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
const int j = j0 + itB*tile_B::I;
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
}
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
const int j = j0 + itB*tile_B::I;
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
}
} else {
static_assert(std::is_same_v<T, void>, "unsupported type");
}
#pragma unroll
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
tile_B B;
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
mma(C[itA][itB], A[itA][k0/tile_B::J], B);
}
}
}
}
float * buf_iw = (float *) data_mmv;
constexpr int kiw = nwarps*rows_per_block + 4;
if (nwarps > 1) {
__syncthreads();
}
#pragma unroll
for (int itB = 0; itB < ntB; ++itB) {
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
#pragma unroll
for (int l = 0; l < tile_C::ne; ++l) {
const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
const int j = itB*tile_C::J + tile_C::get_j(l);
buf_iw[j*kiw + i] = C[itA][itB].x[l];
}
}
}
if (nwarps > 1) {
__syncthreads();
}
#pragma unroll
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
const int j = j0 + threadIdx.y;
if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
return;
}
float sum = 0.0f;
static_assert(rows_per_block == warp_size, "need loop/check");
#pragma unroll
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
const int i = i0 + threadIdx.x;
sum += buf_iw[j*kiw + i];
}
dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
}
#else
GGML_UNUSED_VARS(x, y, ids, dst,
ncols, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
NO_DEVICE_CODE;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
}
template <typename T, int cols_per_block>
static void mul_mat_f_cuda(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nrows_x,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) {
typedef tile<16, 8, T> tile_A;
typedef tile< 8, 8, T> tile_B;
GGML_ASSERT(!ids && "mul_mat_id not implemented");
GGML_ASSERT(ncols_x % 2 == 0);
GGML_ASSERT(stride_row % 2 == 0);
GGML_ASSERT(stride_col_y % 2 == 0);
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
const int64_t channel_ratio = nchannels_dst / nchannels_x;
const int64_t sample_ratio = nsamples_dst / nsamples_x;
const int device = ggml_cuda_get_device();
const int warp_size = ggml_cuda_info().devices[device].warp_size;
int64_t nwarps_best = 1;
int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2);
int64_t max_block_size = 256;
for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
if (niter < niter_best) {
niter_best = niter;
nwarps_best = nwarps;
}
}
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4;
const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
const dim3 block_nums(nrows_x/rows_per_block, nchannels_dst, nsamples_dst);
const dim3 block_dims(warp_size, nwarps_best, 1);
switch (nwarps_best) {
case 1: {
mul_mat_f<T, rows_per_block, cols_per_block, 1><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 2: {
mul_mat_f<T, rows_per_block, cols_per_block, 2><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 3: {
mul_mat_f<T, rows_per_block, cols_per_block, 3><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 4: {
mul_mat_f<T, rows_per_block, cols_per_block, 4><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 5: {
mul_mat_f<T, rows_per_block, cols_per_block, 5><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 6: {
mul_mat_f<T, rows_per_block, cols_per_block, 6><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 7: {
mul_mat_f<T, rows_per_block, cols_per_block, 7><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 8: {
mul_mat_f<T, rows_per_block, cols_per_block, 8><<<block_nums, block_dims, nbytes_shared, stream>>>
(x, y, ids, dst, ncols_x, nchannels_y, stride_row, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
default: {
GGML_ABORT("fatal error");
} break;
}
}
template <typename T>
static void mul_mat_f_switch_cols_per_block(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) {
switch (ncols_dst) {
case 1: {
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 2: {
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 3: {
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 4: {
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 5: {
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 6: {
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 7: {
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 8: {
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 9: {
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 10: {
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 11: {
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 12: {
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 13: {
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 14: {
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 15: {
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 16: {
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
default: {
GGML_ABORT("fatal error");
} break;
}
}
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst) {
GGML_ASSERT( src1->type == GGML_TYPE_F32);
GGML_ASSERT(!ids || ids->type == GGML_TYPE_I32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_TENSOR_BINARY_OP_LOCALS;
const size_t ts_src0 = ggml_type_size(src0->type);
@ -365,38 +34,49 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
const int64_t s13 = src1->nb[3] / ts_src1;
const int64_t s3 = dst->nb[3] / ts_dst;
const int64_t ids_s0 = ids ? ids->nb[0] / ggml_type_size(ids->type) : 0;
const int64_t ids_s1 = ids ? ids->nb[1] / ggml_type_size(ids->type) : 0;
// For MUL_MAT_ID the memory layout is different than for MUL_MAT:
const int64_t ncols_dst = ids ? ne2 : ne1;
const int64_t nchannels_y = ids ? ne11 : ne12;
const int64_t nchannels_dst = ids ? ne1 : ne2;
const int64_t stride_channel_dst = ids ? s1 : s2;
const int64_t stride_channel_y = ids ? s11 : s12;
GGML_ASSERT(!ids || ncols_dst == 1);
const int64_t stride_col_dst = ids ? s2 : s1;
const int64_t stride_col_y = ids ? s12 : s11;
const int64_t stride_channel_dst = ids ? s1 : s2;
int64_t stride_channel_y = ids ? s11 : s12;
int64_t nchannels_y = ids ? ne11 : ne12;
//mul_mat_id: handle broadcast
if (ids && nchannels_y == 1) {
stride_channel_y = 0;
nchannels_y = ids->ne[0];
}
switch (src0->type) {
case GGML_TYPE_F32: {
const float * src0_d = (const float *) src0->data;
constexpr int vals_per_T = 1;
mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
} break;
case GGML_TYPE_F16: {
const half2 * src0_d = (const half2 *) src0->data;
constexpr int vals_per_T = 2;
mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
} break;
case GGML_TYPE_BF16: {
const nv_bfloat162 * src0_d = (const nv_bfloat162 *) src0->data;
constexpr int vals_per_T = 2;
mul_mat_f_switch_cols_per_block(
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, s11/vals_per_T, s1,
ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
src0_d, src1_d, ids_d, dst_d, ne00/vals_per_T, ne01, ncols_dst, s01/vals_per_T, stride_col_y/vals_per_T, stride_col_dst,
ids_s0, ids_s1, ne02, nchannels_y, nchannels_dst, s02/vals_per_T, stride_channel_y, stride_channel_dst,
ne03, ne3, s03/vals_per_T, s13, s3, ctx.stream());
} break;
default:
@ -404,16 +84,22 @@ void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * sr
}
}
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, int64_t ne11) {
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * src0_ne, const int src1_ncols) {
if (ggml_is_quantized(type)) {
return false;
}
if (src0_ne[0] % (warp_size * (4/ggml_type_size(type))) != 0) {
return false;
}
if (src0_ne[1] % MMF_ROWS_PER_BLOCK != 0) {
return false;
}
if (ne11 > 16) {
if (src1_ncols > 16) {
return false;
}
switch (type) {
case GGML_TYPE_F32:
return ampere_mma_available(cc);

View File

@ -1,5 +1,473 @@
#pragma once
#include "mma.cuh"
#include "common.cuh"
using namespace ggml_cuda_mma;
#define MMF_ROWS_PER_BLOCK 32
void ggml_cuda_mul_mat_f(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, int64_t ne11);
bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const int64_t * scr0_ne, const int src1_ncols);
template <typename T, int rows_per_block, int cols_per_block, int nwarps, bool has_ids>
__launch_bounds__(ggml_cuda_get_physical_warp_size()*nwarps, 1)
static __global__ void mul_mat_f(
const T * __restrict__ x, const float * __restrict__ y, const int32_t * __restrict__ ids, float * __restrict__ dst,
const int ncols, const int nchannels_dst, const int stride_row, const int stride_col_y, const int stride_col_dst,
const int stride_col_id, const int stride_row_id,
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
typedef tile<16, 8, T> tile_A;
typedef tile< 8, 8, T> tile_B;
typedef tile<16, 8, float> tile_C;
constexpr int warp_size = ggml_cuda_get_physical_warp_size();
constexpr int tile_k_padded = warp_size + 4;
constexpr int ntA = rows_per_block / tile_A::I;
constexpr int ntB = (cols_per_block + tile_B::I - 1) / tile_B::I;
const int row0 = blockIdx.x * rows_per_block;
const int expert_idx = has_ids ? blockIdx.y : 0;
const int channel_dst = has_ids ? 0 : blockIdx.y;
const int channel_x = has_ids ? expert_idx : (channel_dst / channel_ratio);
const int channel_y = channel_dst;
const int sample_dst = blockIdx.z;
const int sample_x = sample_dst / sample_ratio;
const int sample_y = sample_dst;
x += int64_t(sample_x) *stride_sample_x + channel_x *stride_channel_x + row0*stride_row ;
y += int64_t(sample_y) *stride_sample_y + (has_ids ? 0 : channel_y *stride_channel_y);
dst += int64_t(sample_dst)*stride_sample_dst + (has_ids ? 0 : channel_dst*stride_channel_dst);
const float2 * y2 = (const float2 *) y;
extern __shared__ char data_mmv[];
char * shmem_base = data_mmv;
int * slot_map = (int *) shmem_base;
char * compute_base = has_ids ? (shmem_base + GGML_PAD(cols_per_block, 16) * sizeof(int)) : shmem_base;
tile_C C[ntA][ntB];
T * tile_xy = (T *) compute_base + threadIdx.y*(tile_A::I * tile_k_padded);
if constexpr (has_ids) {
__shared__ int has_any;
if (threadIdx.y == 0) {
int local_has_any = 0;
for (int j = threadIdx.x; j < cols_per_block; j += warp_size) {
int slot = -1;
for (int k = 0; k < nchannels_dst; ++k) {
const int idv = ids[j*stride_row_id + k*stride_col_id];
if (idv == expert_idx) {
slot = k;
break;
}
}
if (j < cols_per_block) {
local_has_any |= (slot >= 0);
slot_map[j] = slot;
}
}
has_any = warp_reduce_any(local_has_any);
}
__syncthreads();
if (has_any == 0) {
return;
}
}
for (int col = threadIdx.y*warp_size + threadIdx.x; col < ncols; col += nwarps*warp_size) {
tile_A A[ntA][warp_size / tile_A::J];
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
#pragma unroll
for (int i = 0; i < tile_A::I; ++i) {
tile_xy[i*tile_k_padded + threadIdx.x] = x[(itA*tile_A::I + i)*stride_row + col];
}
#pragma unroll
for (int k0 = 0; k0 < warp_size; k0 += tile_A::J) {
load_ldmatrix(A[itA][k0/tile_A::J], tile_xy + k0, tile_k_padded);
}
}
#pragma unroll
for (int itB = 0; itB < ntB; ++itB) {
if constexpr (std::is_same_v<T, float>) {
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
const int j = j0 + itB*tile_B::I;
if constexpr (!has_ids) {
tile_xy[j0*tile_k_padded + threadIdx.x] = j < cols_per_block ? y[j*stride_col_y + col] : 0.0f;
} else {
float val = 0.0f;
if (j < cols_per_block) {
const int slot = slot_map[j];
if (slot >= 0) {
val = y[slot*stride_channel_y + j*stride_col_y + col];
}
}
tile_xy[j0*tile_k_padded + threadIdx.x] = val;
}
}
} else if constexpr (std::is_same_v<T, half2> || std::is_same_v<T, nv_bfloat162>) {
#pragma unroll
for (int j0 = 0; j0 < tile_B::I; ++j0) {
const int j = j0 + itB*tile_B::I;
if constexpr (!has_ids) {
const float2 tmp = j < cols_per_block ? y2[j*stride_col_y + col] : make_float2(0.0f, 0.0f);
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
} else {
float2 tmp = make_float2(0.0f, 0.0f);
if (j < cols_per_block) {
const int slot = slot_map[j];
if (slot >= 0) {
const float2 * y2_slot = (const float2 *)(y + slot*stride_channel_y);
tmp = y2_slot[j*stride_col_y + col];
}
}
tile_xy[j0*tile_k_padded + threadIdx.x] = {tmp.x, tmp.y};
}
}
} else {
static_assert(std::is_same_v<T, void>, "unsupported type");
}
#pragma unroll
for (int k0 = 0; k0 < warp_size; k0 += tile_B::J) {
tile_B B;
load_ldmatrix(B, tile_xy + k0, tile_k_padded);
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
mma(C[itA][itB], A[itA][k0/tile_B::J], B);
}
}
}
}
float * buf_iw = (float *) compute_base;
constexpr int kiw = nwarps*rows_per_block + 4;
if (nwarps > 1) {
__syncthreads();
}
#pragma unroll
for (int itB = 0; itB < ntB; ++itB) {
#pragma unroll
for (int itA = 0; itA < ntA; ++itA) {
#pragma unroll
for (int l = 0; l < tile_C::ne; ++l) {
const int i = threadIdx.y*rows_per_block + itA*tile_C::I + tile_C::get_i(l);
const int j = itB*tile_C::J + tile_C::get_j(l);
buf_iw[j*kiw + i] = C[itA][itB].x[l];
}
}
}
if (nwarps > 1) {
__syncthreads();
}
#pragma unroll
for (int j0 = 0; j0 < cols_per_block; j0 += nwarps) {
const int j = j0 + threadIdx.y;
if (j0 + nwarps > cols_per_block && j >= cols_per_block) {
return;
}
float sum = 0.0f;
static_assert(rows_per_block == warp_size, "need loop/check");
#pragma unroll
for (int i0 = 0; i0 < nwarps*rows_per_block; i0 += rows_per_block) {
const int i = i0 + threadIdx.x;
sum += buf_iw[j*kiw + i];
}
if constexpr (!has_ids) {
dst[j*stride_col_dst + row0 + threadIdx.x] = sum;
} else {
const int slot = (j < cols_per_block) ? slot_map[j] : -1;
if (slot >= 0) {
dst[slot*stride_channel_dst + j*stride_col_dst + row0 + threadIdx.x] = sum;
}
}
}
#else
GGML_UNUSED_VARS(x, y, ids, dst,
ncols, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
NO_DEVICE_CODE;
#endif // !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
}
template<typename T, int cols_per_block, int nwarps>
static inline void mul_mat_f_switch_ids(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nchannels_dst,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t stride_col_id, const int64_t stride_row_id,
const int64_t channel_ratio, const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst,
const int64_t sample_ratio, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
const dim3 & block_nums, const dim3 & block_dims, const int nbytes_shared_total, cudaStream_t stream) {
if (ids) {
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, true><<<block_nums, block_dims, nbytes_shared_total, stream>>>
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
} else {
mul_mat_f<T, MMF_ROWS_PER_BLOCK, cols_per_block, nwarps, false><<<block_nums, block_dims, nbytes_shared_total, stream>>>
(x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
}
}
template <typename T, int cols_per_block>
void mul_mat_f_cuda(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t stride_col_id, const int64_t stride_row_id,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) {
typedef tile<16, 8, T> tile_A;
typedef tile< 8, 8, T> tile_B;
GGML_ASSERT(ncols_x % 2 == 0);
GGML_ASSERT(stride_row % 2 == 0);
GGML_ASSERT(stride_col_y % 2 == 0);
GGML_ASSERT(ids || nchannels_dst % nchannels_x == 0);
GGML_ASSERT( nsamples_dst % nsamples_x == 0);
const int64_t channel_ratio = nchannels_dst / nchannels_x;
const int64_t sample_ratio = nsamples_dst / nsamples_x;
const int device = ggml_cuda_get_device();
const int warp_size = ggml_cuda_info().devices[device].warp_size;
int64_t nwarps_best = 1;
int64_t niter_best = (ncols_x + warp_size*2 - 1) / (warp_size*2);
int64_t max_block_size = 256;
for (int64_t nwarps = 2; nwarps <= max_block_size/warp_size; nwarps++) {
const int64_t niter = (ncols_x + nwarps*warp_size*2 - 1) / (nwarps*warp_size*2);
if (niter < niter_best) {
niter_best = niter;
nwarps_best = nwarps;
}
}
constexpr int rows_per_block = MMF_ROWS_PER_BLOCK;
const int nbytes_shared_iter = nwarps_best * tile_A::I * (warp_size + 4) * 4;
const int nbytes_shared_combine = GGML_PAD(cols_per_block, tile_B::I) * (nwarps_best*rows_per_block + 4) * 4;
const int nbytes_shared = std::max(nbytes_shared_iter, nbytes_shared_combine);
const int nbytes_slotmap = ids ? GGML_PAD(cols_per_block, 16) * sizeof(int) : 0;
const int nbytes_shared_total = nbytes_shared + nbytes_slotmap;
const int64_t grid_y = ids ? nchannels_x : nchannels_dst; // per expert when ids present
const dim3 block_nums(nrows_x/rows_per_block, grid_y, nsamples_dst);
const dim3 block_dims(warp_size, nwarps_best, 1);
switch (nwarps_best) {
case 1: {
mul_mat_f_switch_ids<T, cols_per_block, 1>(
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
} break;
case 2: {
mul_mat_f_switch_ids<T, cols_per_block, 2>(
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
} break;
case 3: {
mul_mat_f_switch_ids<T, cols_per_block, 3>(
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
} break;
case 4: {
mul_mat_f_switch_ids<T, cols_per_block, 4>(
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
} break;
case 5: {
mul_mat_f_switch_ids<T, cols_per_block, 5>(
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
} break;
case 6: {
mul_mat_f_switch_ids<T, cols_per_block, 6>(
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
} break;
case 7: {
mul_mat_f_switch_ids<T, cols_per_block, 7>(
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
} break;
case 8: {
mul_mat_f_switch_ids<T, cols_per_block, 8>(
x, y, ids, dst, ncols_x, nchannels_dst, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst, block_nums, block_dims, nbytes_shared_total, stream);
} break;
default: {
GGML_ABORT("fatal error");
} break;
}
GGML_UNUSED_VARS(nchannels_y);
}
template <typename T>
static void mul_mat_f_switch_cols_per_block(
const T * x, const float * y, const int32_t * ids, float * dst,
const int64_t ncols_x, const int64_t nrows_x, const int64_t ncols_dst,
const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst,
const int64_t stride_col_id, const int stride_row_id,
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst,
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst,
cudaStream_t stream) {
switch (ncols_dst) {
case 1: {
mul_mat_f_cuda<T, 1>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 2: {
mul_mat_f_cuda<T, 2>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 3: {
mul_mat_f_cuda<T, 3>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 4: {
mul_mat_f_cuda<T, 4>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 5: {
mul_mat_f_cuda<T, 5>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 6: {
mul_mat_f_cuda<T, 6>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 7: {
mul_mat_f_cuda<T, 7>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 8: {
mul_mat_f_cuda<T, 8>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 9: {
mul_mat_f_cuda<T, 9>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 10: {
mul_mat_f_cuda<T, 10>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 11: {
mul_mat_f_cuda<T, 11>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 12: {
mul_mat_f_cuda<T, 12>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 13: {
mul_mat_f_cuda<T, 13>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 14: {
mul_mat_f_cuda<T, 14>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 15: {
mul_mat_f_cuda<T, 15>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
case 16: {
mul_mat_f_cuda<T, 16>(x, y, ids, dst, ncols_x, nrows_x, stride_row, stride_col_y, stride_col_dst,
stride_col_id, stride_row_id, nchannels_x, nchannels_y, nchannels_dst, stride_channel_x, stride_channel_y, stride_channel_dst,
nsamples_x, nsamples_dst, stride_sample_x, stride_sample_y, stride_sample_dst, stream);
} break;
default: {
GGML_ABORT("fatal error");
} break;
}
}
#define DECL_MMF_CASE_HELPER(T, ncols_dst) \
template void mul_mat_f_cuda<T, ncols_dst>( \
const T * x, const float * y, const int32_t * ids, float * dst, \
const int64_t ncols_x, const int64_t nrows_x, const int64_t stride_row, const int64_t stride_col_y, const int64_t stride_col_dst, \
const int64_t stride_col_id, const int64_t stride_row_id, \
const int64_t nchannels_x, const int64_t nchannels_y, const int64_t nchannels_dst, \
const int64_t stride_channel_x, const int64_t stride_channel_y, const int64_t stride_channel_dst, const int64_t nsamples_x,\
const int64_t nsamples_dst, const int64_t stride_sample_x, const int64_t stride_sample_y, const int64_t stride_sample_dst, \
cudaStream_t stream);
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA)
#define DECL_MMF_CASE_EXTERN(ncols_dst) \
extern DECL_MMF_CASE_HELPER(float, ncols_dst) \
extern DECL_MMF_CASE_HELPER(half2, ncols_dst) \
extern DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
#define DECL_MMF_CASE(ncols_dst) \
DECL_MMF_CASE_HELPER(float, ncols_dst) \
DECL_MMF_CASE_HELPER(half2, ncols_dst) \
DECL_MMF_CASE_HELPER(nv_bfloat162, ncols_dst)
DECL_MMF_CASE_EXTERN(1);
DECL_MMF_CASE_EXTERN(2);
DECL_MMF_CASE_EXTERN(3);
DECL_MMF_CASE_EXTERN(4);
DECL_MMF_CASE_EXTERN(5);
DECL_MMF_CASE_EXTERN(6);
DECL_MMF_CASE_EXTERN(7);
DECL_MMF_CASE_EXTERN(8);
DECL_MMF_CASE_EXTERN(9);
DECL_MMF_CASE_EXTERN(10);
DECL_MMF_CASE_EXTERN(11);
DECL_MMF_CASE_EXTERN(12);
DECL_MMF_CASE_EXTERN(13);
DECL_MMF_CASE_EXTERN(14);
DECL_MMF_CASE_EXTERN(15);
DECL_MMF_CASE_EXTERN(16);
#else
#define DECL_MMF_CASE(ncols_dst)
#endif

View File

@ -141,9 +141,10 @@ template <ggml_type type, int ncols_dst>
__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst,
const int ncols_x, const int nchannels_y, const int stride_row_x, const int stride_col_y, const int stride_col_dst,
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) {
const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) {
constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi;
@ -161,12 +162,12 @@ static __global__ void mul_mat_vec_q(
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
// The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1.
const int channel_dst = blockIdx.y;
const int channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : channel_dst / channel_ratio;
const int channel_y = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst;
const int sample_dst = blockIdx.z;
const int sample_x = sample_dst / sample_ratio;
const int sample_y = sample_dst;
const uint32_t channel_dst = blockIdx.y;
const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
const uint32_t sample_dst = blockIdx.z;
const uint32_t sample_x = fastdiv(sample_dst, sample_ratio);
const uint32_t sample_y = sample_dst;
// partial sum for each thread
float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}};
@ -247,8 +248,9 @@ static void mul_mat_vec_q_switch_ncols_dst(
GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE);
const int channel_ratio = nchannels_dst / nchannels_x;
const int sample_ratio = nsamples_dst / nsamples_x;
const uint3 nchannels_y_fd = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0);
const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
const int device = ggml_cuda_get_device();
const int warp_size = ggml_cuda_info().devices[device].warp_size;
@ -256,86 +258,70 @@ static void mul_mat_vec_q_switch_ncols_dst(
GGML_ASSERT(!ids || ncols_dst == 1);
switch (ncols_dst) {
case 1:
{
case 1: {
constexpr int c_ncols_dst = 1;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
break;
}
case 2:
{
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 2: {
constexpr int c_ncols_dst = 2;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
break;
}
case 3:
{
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 3: {
constexpr int c_ncols_dst = 3;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
break;
}
case 4:
{
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 4: {
constexpr int c_ncols_dst = 4;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
break;
}
case 5:
{
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 5: {
constexpr int c_ncols_dst = 5;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
break;
}
case 6:
{
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 6: {
constexpr int c_ncols_dst = 6;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
break;
}
case 7:
{
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 7: {
constexpr int c_ncols_dst = 7;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
break;
}
case 8:
{
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
case 8: {
constexpr int c_ncols_dst = 8;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst);
break;
}
(vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
} break;
default:
GGML_ABORT("fatal error");
break;

View File

@ -105,7 +105,8 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
}
template <int block_size, bool do_multiply = false, bool do_add = false>
static __global__ void rms_norm_f32(const float * x, float * dst,
static __global__ void rms_norm_f32(const float * x,
float * dst,
const int ncols,
const int64_t stride_row,
const int64_t stride_channel,
@ -115,19 +116,18 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
const int64_t mul_stride_row = 0,
const int64_t mul_stride_channel = 0,
const int64_t mul_stride_sample = 0,
const int mul_ncols = 0,
const int mul_nrows = 0,
const int mul_nchannels = 0,
const int mul_nsamples = 0,
const uint3 mul_ncols_packed = make_uint3(0, 0, 0),
const uint3 mul_nrows_packed = make_uint3(0, 0, 0),
const uint3 mul_nchannels_packed = make_uint3(0, 0, 0),
const uint3 mul_nsamples_packed = make_uint3(0, 0, 0),
const float * add = nullptr,
const int64_t add_stride_row = 0,
const int64_t add_stride_channel = 0,
const int64_t add_stride_sample = 0,
const int add_ncols = 0,
const int add_nrows = 0,
const int add_nchannels = 0,
const int add_nsamples = 0) {
const uint3 add_ncols_packed = make_uint3(0, 0, 0),
const uint3 add_nrows_packed = make_uint3(0, 0, 0),
const uint3 add_nchannels_packed = make_uint3(0, 0, 0),
const uint3 add_nsamples_packed = make_uint3(0, 0, 0)) {
const int nrows = gridDim.x;
const int nchannels = gridDim.y;
@ -142,16 +142,16 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
dst += ((sample*nchannels + channel)*nrows + row)*ncols;
if constexpr (do_multiply) {
const int mul_row = row % mul_nrows;
const int mul_channel = channel % mul_nchannels;
const int mul_sample = sample % mul_nsamples;
mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row;
const uint32_t mul_row = fastmodulo(row, mul_nrows_packed);
const uint32_t mul_channel = fastmodulo(channel, mul_nchannels_packed);
const uint32_t mul_sample = fastmodulo(sample, mul_nsamples_packed);
mul += mul_sample * mul_stride_sample + mul_channel * mul_stride_channel + mul_row * mul_stride_row;
}
if constexpr (do_add) {
const int add_row = row % add_nrows;
const int add_channel = channel % add_nchannels;
const int add_sample = sample % add_nsamples;
const int add_row = fastmodulo(row, add_nrows_packed);
const int add_channel = fastmodulo(channel, add_nchannels_packed);
const int add_sample = fastmodulo(sample, add_nsamples_packed);
add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row;
}
@ -165,15 +165,18 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
// sum up partial sums
tmp = warp_reduce_sum(tmp);
if constexpr (block_size > WARP_SIZE) {
static_assert(block_size == 1024, "unexpected block_size");
static_assert((block_size <= 1024) && (block_size % 32 == 0), "unexpected block_size");
__shared__ float s_sum[32];
const int warp_id = threadIdx.x / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE;
const int warp_id = tid / WARP_SIZE;
const int lane_id = tid % WARP_SIZE;
if (lane_id == 0) {
s_sum[warp_id] = tmp;
}
__syncthreads();
tmp = 0.0f;
if (lane_id < (block_size / WARP_SIZE)) {
tmp = s_sum[lane_id];
}
tmp = warp_reduce_sum(tmp);
}
@ -182,11 +185,11 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
for (int col = tid; col < ncols; col += block_size) {
if constexpr (do_multiply && do_add) {
const int mul_col = col % mul_ncols;
const int add_col = col % add_ncols;
const int mul_col = fastmodulo(col, mul_ncols_packed);
const int add_col = fastmodulo(col, add_ncols_packed);
dst[col] = scale * x[col] * mul[mul_col] + add[add_col];
} else if constexpr (do_multiply) {
const int mul_col = col % mul_ncols;
const int mul_col = fastmodulo(col, mul_ncols_packed);
dst[col] = scale * x[col] * mul[mul_col];
} else {
dst[col] = scale * x[col];
@ -354,8 +357,8 @@ static void rms_norm_f32_cuda(
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32<WARP_SIZE, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
const dim3 block_dims(256, 1, 1);
rms_norm_f32<256, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
@ -376,17 +379,17 @@ static void rms_norm_mul_f32_cuda(const float * x,
const int64_t mul_stride_row,
const int64_t mul_stride_channel,
const int64_t mul_stride_sample,
const int mul_ncols,
const int mul_nrows,
const int mul_nchannels,
const int mul_nsamples,
const uint32_t mul_ncols,
const uint32_t mul_nrows,
const uint32_t mul_nchannels,
const uint32_t mul_nsamples,
const int64_t add_stride_row,
const int64_t add_stride_channel,
const int64_t add_stride_sample,
const int add_ncols,
const int add_nrows,
const int add_nchannels,
const int add_nsamples,
const uint32_t add_ncols,
const uint32_t add_nrows,
const uint32_t add_nchannels,
const uint32_t add_nsamples,
const float eps,
cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples);
@ -395,36 +398,45 @@ static void rms_norm_mul_f32_cuda(const float * x,
return;
}
if (add == nullptr) {
const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols);
const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows);
const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels);
const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32<WARP_SIZE, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
ncols, stride_row, stride_channel, stride_sample, eps,
mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
const dim3 block_dims(256, 1, 1);
rms_norm_f32<256, true><<<blocks_num, block_dims, 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
ncols, stride_row, stride_channel, stride_sample, eps,
mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
}
} else {
const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols);
const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows);
const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels);
const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples);
const uint3 add_ncols_packed = init_fastdiv_values(add_ncols);
const uint3 add_nrows_packed = init_fastdiv_values(add_nrows);
const uint3 add_nchannels_packed = init_fastdiv_values(add_nchannels);
const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples);
if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1);
rms_norm_f32<WARP_SIZE, true, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
ncols, stride_row, stride_channel, stride_sample, eps,
mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
add, add_stride_row, add_stride_channel, add_stride_sample,
add_ncols, add_nrows, add_nchannels, add_nsamples);
const dim3 block_dims(256, 1, 1);
rms_norm_f32<256, true, true><<<blocks_num, block_dims, 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
add_nchannels_packed, add_nsamples_packed);
} else {
const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024, true, true><<<blocks_num, block_dims, 0, stream>>>(x, dst,
ncols, stride_row, stride_channel, stride_sample, eps,
mul, mul_stride_row, mul_stride_channel, mul_stride_sample,
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples,
add, add_stride_row, add_stride_channel, add_stride_sample,
add_ncols, add_nrows, add_nchannels, add_nsamples);
rms_norm_f32<1024, true, true><<<blocks_num, block_dims, 0, stream>>>(
x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
add_nchannels_packed, add_nsamples_packed);
}
}
}

View File

@ -1,36 +1,50 @@
#include "pad.cuh"
static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) {
// blockIdx.z: idx of ne2*ne3, aka ne02*ne03
// blockIdx.y: idx of ne1
// blockIDx.x: idx of ne0 / BLOCK_SIZE
int nidx = threadIdx.x + blockIdx.x * blockDim.x;
if (nidx >= ne0) {
static __global__ void pad_f32(const float * src, float * dst,
const int lp0, const int rp0, const int lp1, const int rp1,
const int lp2, const int rp2, const int lp3, const int rp3,
const int ne0, const int ne1, const int ne2, const int ne3) {
// blockIdx.z: i3*ne2+i2
// blockIdx.y: i1
// blockIDx.x: i0 / CUDA_PAD_BLOCK_SIZE
// gridDim.y: ne1
int i0 = threadIdx.x + blockIdx.x * blockDim.x;
int i1 = blockIdx.y;
int i2 = blockIdx.z % ne2;
int i3 = blockIdx.z / ne2;
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
return;
}
// operation
int offset_dst =
nidx +
blockIdx.y * ne0 +
blockIdx.z * ne0 * gridDim.y;
if (nidx < ne00 && blockIdx.y < (unsigned)ne01 && blockIdx.z < (unsigned)(ne02*ne03)) {
int offset_src =
nidx +
blockIdx.y * ne00 +
blockIdx.z * ne00 * ne01;
dst[offset_dst] = x[offset_src];
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
if ((i0 >= lp0 && i0 < ne0 - rp0) &&
(i1 >= lp1 && i1 < ne1 - rp1) &&
(i2 >= lp2 && i2 < ne2 - rp2) &&
(i3 >= lp3 && i3 < ne3 - rp3)) {
const int64_t i00 = i0 - lp0;
const int64_t i01 = i1 - lp1;
const int64_t i02 = i2 - lp2;
const int64_t i03 = i3 - lp3;
const int64_t ne02 = ne2 - lp2 - rp2;
const int64_t ne01 = ne1 - lp1 - rp1;
const int64_t ne00 = ne0 - lp0 - rp0;
const int64_t src_idx = i03*(ne00*ne01*ne02) + i02*(ne00*ne01) + i01*ne00 + i00;
dst[dst_idx] = src[src_idx];
} else {
dst[offset_dst] = 0.0f;
dst[dst_idx] = 0.0f;
}
}
static void pad_f32_cuda(const float * x, float * dst,
const int ne00, const int ne01, const int ne02, const int ne03,
static void pad_f32_cuda(const float * src, float * dst,
const int lp0, const int rp0, const int lp1, const int rp1,
const int lp2, const int rp2, const int lp3, const int rp3,
const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
dim3 gridDim(num_blocks, ne1, ne2*ne3);
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03);
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1, ne2, ne3);
}
void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@ -41,9 +55,18 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors
GGML_ASSERT(ggml_is_contiguous(src0));
const int32_t lp0 = ((const int32_t*)(dst->op_params))[0];
const int32_t rp0 = ((const int32_t*)(dst->op_params))[1];
const int32_t lp1 = ((const int32_t*)(dst->op_params))[2];
const int32_t rp1 = ((const int32_t*)(dst->op_params))[3];
const int32_t lp2 = ((const int32_t*)(dst->op_params))[4];
const int32_t rp2 = ((const int32_t*)(dst->op_params))[5];
const int32_t lp3 = ((const int32_t*)(dst->op_params))[6];
const int32_t rp3 = ((const int32_t*)(dst->op_params))[7];
pad_f32_cuda(src0_d, dst_d,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3],
lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
}

View File

@ -1,26 +1,27 @@
#include "quantize.cuh"
#include <cstdint>
__launch_bounds__(CUDA_QUANTIZE_BLOCK_SIZE, 1)
static __global__ void quantize_q8_1(
const float * __restrict__ x, void * __restrict__ vy,
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t ne0, const int ne1, const int ne2) {
const int64_t ne0, const uint32_t ne1, const uint3 ne2) {
const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
if (i0 >= ne0) {
return;
}
const int64_t i3 = fastdiv(blockIdx.z, ne2);
const int64_t i2 = blockIdx.z - i3*ne2.z;
const int64_t i1 = blockIdx.y;
const int64_t i2 = blockIdx.z % ne2;
const int64_t i3 = blockIdx.z / ne2;
const int64_t & i00 = i0;
const int64_t & i01 = i1;
const int64_t & i02 = i2;
const int64_t & i03 = i3;
const int64_t i_cont = ((i3*ne2 + i2) * ne1 + i1) * ne0 + i0;
const int64_t i_cont = ((i3*ne2.z + i2) * ne1 + i1) * ne0 + i0;
block_q8_1 * y = (block_q8_1 *) vy;
@ -31,10 +32,10 @@ static __global__ void quantize_q8_1(
float amax = fabsf(xi);
float sum = xi;
amax = warp_reduce_max(amax);
sum = warp_reduce_sum(sum);
amax = warp_reduce_max<QK8_1>(amax);
sum = warp_reduce_sum<QK8_1>(sum);
const float d = amax / 127;
const float d = amax / 127.0f;
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
y[ib].qs[iqs] = q;
@ -43,8 +44,7 @@ static __global__ void quantize_q8_1(
return;
}
reinterpret_cast<half&>(y[ib].ds.x) = d;
reinterpret_cast<half&>(y[ib].ds.y) = sum;
y[ib].ds = make_half2(d, sum);
}
template <mmq_q8_1_ds_layout ds_layout>
@ -152,10 +152,12 @@ void quantize_row_q8_1_cuda(
GGML_ASSERT(!ids);
GGML_ASSERT(ne0 % QK8_1 == 0);
const uint3 ne2_fastdiv = init_fastdiv_values(ne2);
const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2_fastdiv);
GGML_UNUSED(type_src0);
}

View File

@ -1,18 +1,19 @@
#include "scale.cuh"
static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k) {
const int i = blockDim.x*blockIdx.x + threadIdx.x;
#define MAX_GRIDDIM_X 0x7FFFFFFF
if (i >= k) {
return;
}
static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int64_t nelements) {
int64_t tid = (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x;
int64_t stride = (int64_t)blockDim.x * (int64_t)gridDim.x;
for (int64_t i = tid; i < nelements; i += stride) {
dst[i] = scale * x[i] + bias;
}
}
static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int k, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, k);
static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int64_t nelements, cudaStream_t stream) {
const int64_t num_blocks = (nelements + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
scale_f32<<<MIN(MAX_GRIDDIM_X, num_blocks), CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, nelements);
}
void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

View File

@ -24,7 +24,7 @@ TYPES_MMQ = [
"GGML_TYPE_Q4_0", "GGML_TYPE_Q4_1", "GGML_TYPE_Q5_0", "GGML_TYPE_Q5_1", "GGML_TYPE_Q8_0",
"GGML_TYPE_Q2_K", "GGML_TYPE_Q3_K", "GGML_TYPE_Q4_K", "GGML_TYPE_Q5_K", "GGML_TYPE_Q6_K",
"GGML_TYPE_IQ2_XXS", "GGML_TYPE_IQ2_XS", "GGML_TYPE_IQ2_S", "GGML_TYPE_IQ3_XXS", "GGML_TYPE_IQ3_S",
"GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS"
"GGML_TYPE_IQ1_S", "GGML_TYPE_IQ4_NL", "GGML_TYPE_IQ4_XS", "GGML_TYPE_MXFP4"
]
SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
@ -34,6 +34,13 @@ SOURCE_MMQ = """// This file has been autogenerated by generate_cu_files.py, do
DECL_MMQ_CASE({type});
"""
SOURCE_MMF = """// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE({type});
"""
def get_short_name(long_quant_name):
return long_quant_name.replace("GGML_TYPE_", "").lower()
@ -76,3 +83,7 @@ for ncols in [8, 16, 32, 64]:
for type in TYPES_MMQ:
with open(f"mmq-instance-{get_short_name(type)}.cu", "w") as f:
f.write(SOURCE_MMQ.format(type=type))
for type in range(1, 17):
with open(f"mmf-instance-ncols_{type}.cu", "w") as f:
f.write(SOURCE_MMF.format(type=type))

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(1);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(10);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(11);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(12);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(13);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(14);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(15);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(16);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(2);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(3);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(4);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(5);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(6);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(7);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(8);

View File

@ -0,0 +1,5 @@
// This file has been autogenerated by generate_cu_files.py, do not edit manually.
#include "../mmf.cuh"
DECL_MMF_CASE(9);

View File

@ -20,8 +20,8 @@
#define N_R0_Q5_1 4
#define N_SG_Q5_1 2
#define N_R0_Q8_0 4
#define N_SG_Q8_0 2
#define N_R0_Q8_0 2
#define N_SG_Q8_0 4
#define N_R0_MXFP4 2
#define N_SG_MXFP4 2
@ -68,6 +68,11 @@
#define N_R0_IQ4_XS 2
#define N_SG_IQ4_XS 2
// function constants offsets
#define FC_FLASH_ATTN_EXT 100
#define FC_FLASH_ATTN_EXT_VEC 200
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
// kernel argument structs
//
// - element counters (e.g. ne00) typically use int32_t to reduce register usage
@ -236,9 +241,11 @@ typedef struct {
int32_t ne11;
int32_t ne_12_2; // assume K and V are same shape
int32_t ne_12_3;
int32_t ns10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ns20;
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
@ -258,10 +265,43 @@ typedef struct {
float logit_softcap;
} ggml_metal_kargs_flash_attn_ext;
typedef struct {
int32_t ne01;
int32_t ne02;
int32_t ne03;
uint64_t nb01;
uint64_t nb02;
uint64_t nb03;
int32_t ne11;
int32_t ne_12_2; // assume K and V are same shape
int32_t ne_12_3;
int32_t ns10;
uint64_t nb11;
uint64_t nb12;
uint64_t nb13;
int32_t ns20;
uint64_t nb21;
uint64_t nb22;
uint64_t nb23;
int32_t ne32;
int32_t ne33;
uint64_t nb31;
uint64_t nb32;
uint64_t nb33;
int32_t ne1;
int32_t ne2;
int32_t ne3;
float scale;
float max_bias;
float m0;
float m1;
int32_t n_head_log2;
float logit_softcap;
} ggml_metal_kargs_flash_attn_ext_vec;
typedef struct {
int32_t nrows;
int32_t ne20;
} ggml_metal_kargs_flash_attn_ext_reduce;
} ggml_metal_kargs_flash_attn_ext_vec_reduce;
typedef struct {
int32_t ne00;

File diff suppressed because it is too large Load Diff

Some files were not shown because too many files have changed in this diff Show More