Merge branch 'master' into imatrix

This commit is contained in:
Ed Addario 2025-09-06 13:07:56 +01:00
commit 7448bdb393
No known key found for this signature in database
GPG Key ID: E7875815A3230993
139 changed files with 5970 additions and 1777 deletions

View File

@ -22,7 +22,7 @@ AllowShortIfStatementsOnASingleLine: Never
AllowShortLambdasOnASingleLine: Inline AllowShortLambdasOnASingleLine: Inline
AllowShortLoopsOnASingleLine: false AllowShortLoopsOnASingleLine: false
AlwaysBreakBeforeMultilineStrings: true AlwaysBreakBeforeMultilineStrings: true
BinPackArguments: false BinPackArguments: true
BinPackParameters: false # OnePerLine BinPackParameters: false # OnePerLine
BitFieldColonSpacing: Both BitFieldColonSpacing: Both
BreakBeforeBraces: Custom # Attach BreakBeforeBraces: Custom # Attach

View File

@ -17,7 +17,7 @@ jobs:
steps: steps:
- uses: actions/stale@v5 - uses: actions/stale@v5
with: with:
exempt-issue-labels: "refactoring,help wanted,good first issue,research,bug,roadmap" exempt-issue-labels: "refactoring,help wanted,good first issue,research 🔬,bug,roadmap"
days-before-issue-stale: 30 days-before-issue-stale: 30
days-before-issue-close: 14 days-before-issue-close: 14
stale-issue-label: "stale" stale-issue-label: "stale"

View File

@ -137,6 +137,7 @@ Instructions for adding support for new models: [HOWTO-add-model.md](docs/develo
- [X] [Trillion-7B-preview](https://huggingface.co/trillionlabs/Trillion-7B-preview) - [X] [Trillion-7B-preview](https://huggingface.co/trillionlabs/Trillion-7B-preview)
- [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32) - [x] [Ling models](https://huggingface.co/collections/inclusionAI/ling-67c51c85b34a7ea0aba94c32)
- [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38) - [x] [LFM2 models](https://huggingface.co/collections/LiquidAI/lfm2-686d721927015b2ad73eaa38)
- [x] [Hunyuan models](https://huggingface.co/collections/tencent/hunyuan-dense-model-6890632cda26b19119c9c5e7)
#### Multimodal #### Multimodal

View File

@ -386,10 +386,10 @@ function gg_run_open_llama_7b_v2 {
(time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log (time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
function check_ppl { function check_ppl {
qnt="$1" qnt="$1"
@ -520,8 +520,8 @@ function gg_run_pythia_1_4b {
(time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log (time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test_60} -ngl 99 -c 128 -b 128 --chunks 1 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
function check_ppl { function check_ppl {
qnt="$1" qnt="$1"
@ -651,10 +651,10 @@ function gg_run_pythia_2_8b {
(time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log (time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -t 1 -ngl 99 -c 2048 -b 512 --chunks 4 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 0 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 0 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
function check_ppl { function check_ppl {
qnt="$1" qnt="$1"

View File

@ -1263,6 +1263,18 @@ static std::string list_builtin_chat_templates() {
return msg.str(); return msg.str();
} }
static bool is_truthy(const std::string & value) {
return value == "on" || value == "enabled" || value == "1";
}
static bool is_falsey(const std::string & value) {
return value == "off" || value == "disabled" || value == "0";
}
static bool is_autoy(const std::string & value) {
return value == "auto" || value == "-1";
}
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) { common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
// load dynamic backends // load dynamic backends
ggml_backend_load_all(); ggml_backend_load_all();
@ -1544,13 +1556,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.n_chunks = value; params.n_chunks = value;
} }
).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_RETRIEVAL})); ).set_examples({LLAMA_EXAMPLE_IMATRIX, LLAMA_EXAMPLE_PERPLEXITY, LLAMA_EXAMPLE_RETRIEVAL}));
add_opt(common_arg( add_opt(common_arg({ "-fa", "--flash-attn" }, "[on|off|auto]",
{"-fa", "--flash-attn"}, string_format("set Flash Attention use ('on', 'off', or 'auto', default: '%s')",
string_format("enable Flash Attention (default: %s)", params.flash_attn ? "enabled" : "disabled"), llama_flash_attn_type_name(params.flash_attn_type)),
[](common_params & params) { [](common_params & params, const std::string & value) {
params.flash_attn = true; if (is_truthy(value)) {
} params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_ENABLED;
).set_env("LLAMA_ARG_FLASH_ATTN")); } else if (is_falsey(value)) {
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_DISABLED;
} else if (is_autoy(value)) {
params.flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO;
} else {
throw std::runtime_error(
string_format("error: unkown value for --flash-attn: '%s'\n", value.c_str()));
}
}).set_env("LLAMA_ARG_FLASH_ATTN"));
add_opt(common_arg( add_opt(common_arg(
{"-p", "--prompt"}, "PROMPT", {"-p", "--prompt"}, "PROMPT",
"prompt to start generation with; for system message, use -sys", "prompt to start generation with; for system message, use -sys",
@ -2458,7 +2478,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_N_CPU_MOE_DRAFT")); ).set_examples({LLAMA_EXAMPLE_SPECULATIVE, LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_N_CPU_MOE_DRAFT"));
add_opt(common_arg( add_opt(common_arg(
{"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N", {"-ngl", "--gpu-layers", "--n-gpu-layers"}, "N",
"number of layers to store in VRAM", string_format("max. number of layers to store in VRAM (default: %d)", params.n_gpu_layers),
[](common_params & params, int value) { [](common_params & params, int value) {
params.n_gpu_layers = value; params.n_gpu_layers = value;
if (!llama_supports_gpu_offload()) { if (!llama_supports_gpu_offload()) {
@ -2954,13 +2974,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.endpoint_metrics = true; params.endpoint_metrics = true;
} }
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_METRICS")); ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_METRICS"));
add_opt(common_arg(
{"--slots"},
string_format("enable slots monitoring endpoint (default: %s)", params.endpoint_slots ? "enabled" : "disabled"),
[](common_params & params) {
params.endpoint_slots = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_SLOTS"));
add_opt(common_arg( add_opt(common_arg(
{"--props"}, {"--props"},
string_format("enable changing global properties via POST /props (default: %s)", params.endpoint_props ? "enabled" : "disabled"), string_format("enable changing global properties via POST /props (default: %s)", params.endpoint_props ? "enabled" : "disabled"),
@ -2968,6 +2981,13 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.endpoint_props = true; params.endpoint_props = true;
} }
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_PROPS")); ).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_PROPS"));
add_opt(common_arg(
{"--slots"},
string_format("enable slots monitoring endpoint (default: %s)", params.endpoint_slots ? "enabled" : "disabled"),
[](common_params & params) {
params.endpoint_slots = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER}).set_env("LLAMA_ARG_ENDPOINT_SLOTS"));
add_opt(common_arg( add_opt(common_arg(
{"--no-slots"}, {"--no-slots"},
"disables slots monitoring endpoint", "disables slots monitoring endpoint",
@ -3126,13 +3146,21 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
common_log_set_file(common_log_main(), value.c_str()); common_log_set_file(common_log_main(), value.c_str());
} }
)); ));
add_opt(common_arg( add_opt(common_arg({ "--log-colors" }, "[on|off|auto]",
{"--log-colors"}, "Set colored logging ('on', 'off', or 'auto', default: 'auto')\n"
"Enable colored logging", "'auto' enables colors when output is to a terminal",
[](common_params &) { [](common_params &, const std::string & value) {
common_log_set_colors(common_log_main(), true); if (is_truthy(value)) {
} common_log_set_colors(common_log_main(), LOG_COLORS_ENABLED);
).set_env("LLAMA_LOG_COLORS")); } else if (is_falsey(value)) {
common_log_set_colors(common_log_main(), LOG_COLORS_DISABLED);
} else if (is_autoy(value)) {
common_log_set_colors(common_log_main(), LOG_COLORS_AUTO);
} else {
throw std::invalid_argument(
string_format("error: unkown value for --log-colors: '%s'\n", value.c_str()));
}
}).set_env("LLAMA_LOG_COLORS"));
add_opt(common_arg( add_opt(common_arg(
{"-v", "--verbose", "--log-verbose"}, {"-v", "--verbose", "--log-verbose"},
"Set verbosity level to infinity (i.e. log all messages, useful for debugging)", "Set verbosity level to infinity (i.e. log all messages, useful for debugging)",
@ -3459,8 +3487,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-Q8_0-GGUF"; params.model.hf_repo = "ggml-org/Qwen2.5-Coder-1.5B-Q8_0-GGUF";
params.model.hf_file = "qwen2.5-coder-1.5b-q8_0.gguf"; params.model.hf_file = "qwen2.5-coder-1.5b-q8_0.gguf";
params.port = 8012; params.port = 8012;
params.n_gpu_layers = 99;
params.flash_attn = true;
params.n_ubatch = 1024; params.n_ubatch = 1024;
params.n_batch = 1024; params.n_batch = 1024;
params.n_ctx = 0; params.n_ctx = 0;
@ -3475,8 +3501,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-3B-Q8_0-GGUF"; params.model.hf_repo = "ggml-org/Qwen2.5-Coder-3B-Q8_0-GGUF";
params.model.hf_file = "qwen2.5-coder-3b-q8_0.gguf"; params.model.hf_file = "qwen2.5-coder-3b-q8_0.gguf";
params.port = 8012; params.port = 8012;
params.n_gpu_layers = 99;
params.flash_attn = true;
params.n_ubatch = 1024; params.n_ubatch = 1024;
params.n_batch = 1024; params.n_batch = 1024;
params.n_ctx = 0; params.n_ctx = 0;
@ -3491,8 +3515,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.model.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF"; params.model.hf_repo = "ggml-org/Qwen2.5-Coder-7B-Q8_0-GGUF";
params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf"; params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf";
params.port = 8012; params.port = 8012;
params.n_gpu_layers = 99;
params.flash_attn = true;
params.n_ubatch = 1024; params.n_ubatch = 1024;
params.n_batch = 1024; params.n_batch = 1024;
params.n_ctx = 0; params.n_ctx = 0;
@ -3508,10 +3530,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf"; params.model.hf_file = "qwen2.5-coder-7b-q8_0.gguf";
params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF"; params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf"; params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
params.speculative.n_gpu_layers = 99;
params.port = 8012; params.port = 8012;
params.n_gpu_layers = 99;
params.flash_attn = true;
params.n_ubatch = 1024; params.n_ubatch = 1024;
params.n_batch = 1024; params.n_batch = 1024;
params.n_ctx = 0; params.n_ctx = 0;
@ -3527,10 +3546,7 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.model.hf_file = "qwen2.5-coder-14b-q8_0.gguf"; params.model.hf_file = "qwen2.5-coder-14b-q8_0.gguf";
params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF"; params.speculative.model.hf_repo = "ggml-org/Qwen2.5-Coder-0.5B-Q8_0-GGUF";
params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf"; params.speculative.model.hf_file = "qwen2.5-coder-0.5b-q8_0.gguf";
params.speculative.n_gpu_layers = 99;
params.port = 8012; params.port = 8012;
params.n_gpu_layers = 99;
params.flash_attn = true;
params.n_ubatch = 1024; params.n_ubatch = 1024;
params.n_batch = 1024; params.n_batch = 1024;
params.n_ctx = 0; params.n_ctx = 0;
@ -3545,8 +3561,6 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
params.model.hf_repo = "ggml-org/Qwen3-Coder-30B-A3B-Instruct-Q8_0-GGUF"; params.model.hf_repo = "ggml-org/Qwen3-Coder-30B-A3B-Instruct-Q8_0-GGUF";
params.model.hf_file = "qwen3-coder-30b-a3b-instruct-q8_0.gguf"; params.model.hf_file = "qwen3-coder-30b-a3b-instruct-q8_0.gguf";
params.port = 8012; params.port = 8012;
params.n_gpu_layers = 99;
params.flash_attn = true;
params.n_ubatch = 1024; params.n_ubatch = 1024;
params.n_batch = 1024; params.n_batch = 1024;
params.n_ctx = 0; params.n_ctx = 0;

View File

@ -163,6 +163,19 @@ common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::strin
throw std::runtime_error("Invalid tool_choice: " + tool_choice); throw std::runtime_error("Invalid tool_choice: " + tool_choice);
} }
bool common_chat_templates_support_enable_thinking(const common_chat_templates * chat_templates) {
common_chat_templates_inputs dummy_inputs;
common_chat_msg msg;
msg.role = "user";
msg.content = "test";
dummy_inputs.messages = {msg};
dummy_inputs.enable_thinking = false;
const auto rendered_no_thinking = common_chat_templates_apply(chat_templates, dummy_inputs);
dummy_inputs.enable_thinking = true;
const auto rendered_with_thinking = common_chat_templates_apply(chat_templates, dummy_inputs);
return rendered_no_thinking.prompt != rendered_with_thinking.prompt;
}
template <> template <>
std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messages) { std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const json & messages) {
std::vector<common_chat_msg> msgs; std::vector<common_chat_msg> msgs;
@ -623,6 +636,7 @@ const char * common_chat_format_name(common_chat_format format) {
case COMMON_CHAT_FORMAT_GRANITE: return "Granite"; case COMMON_CHAT_FORMAT_GRANITE: return "Granite";
case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS"; case COMMON_CHAT_FORMAT_GPT_OSS: return "GPT-OSS";
case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS"; case COMMON_CHAT_FORMAT_SEED_OSS: return "Seed-OSS";
case COMMON_CHAT_FORMAT_NEMOTRON_V2: return "Nemotron V2";
default: default:
throw std::runtime_error("Unknown chat format"); throw std::runtime_error("Unknown chat format");
} }
@ -1184,6 +1198,67 @@ static common_chat_params common_chat_params_init_llama_3_x(const common_chat_te
}); });
return data; return data;
} }
static common_chat_params common_chat_params_init_nemotron_v2(const common_chat_template & tmpl, const struct templates_params & inputs) {
common_chat_params data;
// Generate the prompt using the apply() function with the template
data.prompt = apply(tmpl, inputs);
data.format = COMMON_CHAT_FORMAT_NEMOTRON_V2;
// Handle thinking tags appropriately based on inputs.enable_thinking
if (string_ends_with(data.prompt, "<think>\n")) {
if (!inputs.enable_thinking) {
data.prompt += "</think>";
} else {
data.thinking_forced_open = true;
}
}
// When tools are present, build grammar for the <TOOLCALL> format, similar to CommandR, but without tool call ID
if (!inputs.tools.is_null() && inputs.tools.is_array() && !inputs.tools.empty()) {
data.grammar_lazy = true;
data.grammar = build_grammar([&](const common_grammar_builder & builder) {
auto schemas = json::array();
foreach_function(inputs.tools, [&](const json & tool) {
const auto & function = tool.at("function");
schemas.push_back({
{ "type", "object" },
{ "properties",
{
{ "name",
{
{ "type", "string" },
{ "const", function.at("name") },
} },
{ "arguments", function.at("parameters") },
} },
{ "required", json::array({ "name", "arguments" }) },
});
});
auto schema = json{
{ "type", "array" },
{ "items", schemas.size() == 1 ? schemas[0] : json{ { "anyOf", schemas } } },
{ "minItems", 1 },
};
if (!inputs.parallel_tool_calls) {
schema["maxItems"] = 1;
}
builder.add_rule("root",
std::string(data.thinking_forced_open ? "( \"</think>\" space )? " : "") +
"\"<TOOLCALL>\" " + builder.add_schema("tool_calls", schema) +
" \"</TOOLCALL>\"");
});
data.grammar_triggers.push_back({ COMMON_GRAMMAR_TRIGGER_TYPE_PATTERN_FULL,
// If thinking_forced_open, then we capture the </think> tag in the grammar,
// (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
std::string(data.thinking_forced_open ?
"[\\s\\S]*?(</think>\\s*)" :
"(?:<think>[\\s\\S]*?</think>\\s*)?") +
"(<TOOLCALL>)[\\s\\S]*" });
}
return data;
}
static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) { static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) {
if (!builder.syntax().parse_tool_calls) { if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest()); builder.add_content(builder.consume_rest());
@ -1830,7 +1905,7 @@ static common_chat_params common_chat_params_init_hermes_2_pro(const common_chat
// If thinking_forced_open, then we capture the </think> tag in the grammar, // If thinking_forced_open, then we capture the </think> tag in the grammar,
// (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar) // (important for required tool choice) and in the trigger's first capture (decides what is sent to the grammar)
std::string(data.thinking_forced_open ? "[\\s\\S]*?(</think>\\s*)" : "(?:<think>[\\s\\S]*?</think>\\s*)?") + ( std::string(data.thinking_forced_open ? "[\\s\\S]*?(</think>\\s*)" : "(?:<think>[\\s\\S]*?</think>\\s*)?") + (
"(\\s*" "\\s*("
"(?:<tool_call>" "(?:<tool_call>"
"|<function" "|<function"
"|(?:```(?:json|xml)?\n\\s*)?(?:<function_call>|<tools>|<xml><json>|<response>)?" "|(?:```(?:json|xml)?\n\\s*)?(?:<function_call>|<tools>|<xml><json>|<response>)?"
@ -2060,6 +2135,33 @@ static void common_chat_parse_granite(common_chat_msg_parser & builder) {
} }
} }
static void common_chat_parse_nemotron_v2(common_chat_msg_parser & builder) {
// Parse thinking tags
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// Look for tool calls
static const common_regex tool_call_regex(regex_escape("<TOOLCALL>"));
if (auto res = builder.try_find_regex(tool_call_regex)) {
builder.move_to(res->groups[0].end);
// Expect JSON array of tool calls
auto tool_calls_data = builder.consume_json();
if (tool_calls_data.json.is_array()) {
if (!builder.try_consume_literal("</TOOLCALL>")) {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
builder.add_tool_calls(tool_calls_data.json);
} else {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
}
builder.add_content(builder.consume_rest());
}
static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) { static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
// Parse thinking tags first - this handles the main reasoning content // Parse thinking tags first - this handles the main reasoning content
builder.try_parse_reasoning("<seed:think>", "</seed:think>"); builder.try_parse_reasoning("<seed:think>", "</seed:think>");
@ -2293,6 +2395,11 @@ static common_chat_params common_chat_templates_apply_jinja(
return common_chat_params_init_seed_oss(tmpl, params, inputs); return common_chat_params_init_seed_oss(tmpl, params, inputs);
} }
// Nemotron v2
if (src.find("<SPECIAL_10>") != std::string::npos) {
return common_chat_params_init_nemotron_v2(tmpl, params);
}
// Use generic handler when mixing tools + JSON schema. // Use generic handler when mixing tools + JSON schema.
// TODO: support that mix in handlers below. // TODO: support that mix in handlers below.
if ((params.tools.is_array() && params.json_schema.is_object())) { if ((params.tools.is_array() && params.json_schema.is_object())) {
@ -2454,6 +2561,9 @@ static void common_chat_parse(common_chat_msg_parser & builder) {
case COMMON_CHAT_FORMAT_SEED_OSS: case COMMON_CHAT_FORMAT_SEED_OSS:
common_chat_parse_seed_oss(builder); common_chat_parse_seed_oss(builder);
break; break;
case COMMON_CHAT_FORMAT_NEMOTRON_V2:
common_chat_parse_nemotron_v2(builder);
break;
default: default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format)); throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
} }

View File

@ -112,6 +112,7 @@ enum common_chat_format {
COMMON_CHAT_FORMAT_GRANITE, COMMON_CHAT_FORMAT_GRANITE,
COMMON_CHAT_FORMAT_GPT_OSS, COMMON_CHAT_FORMAT_GPT_OSS,
COMMON_CHAT_FORMAT_SEED_OSS, COMMON_CHAT_FORMAT_SEED_OSS,
COMMON_CHAT_FORMAT_NEMOTRON_V2,
COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
}; };
@ -198,6 +199,8 @@ common_chat_msg common_chat_parse(const std::string & input, bool is_p
common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice); common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
bool common_chat_templates_support_enable_thinking(const common_chat_templates * chat_templates);
// Parses a JSON array of messages in OpenAI's chat completion API format. // Parses a JSON array of messages in OpenAI's chat completion API format.
// T can be std::string containing JSON or nlohmann::ordered_json // T can be std::string containing JSON or nlohmann::ordered_json
template <class T> std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const T & messages); template <class T> std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const T & messages);

View File

@ -901,7 +901,8 @@ struct common_init_result common_init_from_params(common_params & params) {
llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams); llama_model * model = llama_model_load_from_file(params.model.path.c_str(), mparams);
if (model == NULL) { if (model == NULL) {
LOG_ERR("%s: failed to load model '%s'\n", __func__, params.model.path.c_str()); LOG_ERR("%s: failed to load model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
__func__, params.model.path.c_str());
return iparams; return iparams;
} }
@ -911,7 +912,8 @@ struct common_init_result common_init_from_params(common_params & params) {
llama_context * lctx = llama_init_from_model(model, cparams); llama_context * lctx = llama_init_from_model(model, cparams);
if (lctx == NULL) { if (lctx == NULL) {
LOG_ERR("%s: failed to create context with model '%s'\n", __func__, params.model.path.c_str()); LOG_ERR("%s: failed to create context with model '%s', try reducing --n-gpu-layers if you're running out of VRAM\n",
__func__, params.model.path.c_str());
llama_model_free(model); llama_model_free(model);
return iparams; return iparams;
} }
@ -1157,10 +1159,10 @@ struct llama_context_params common_context_params_to_llama(const common_params &
cparams.yarn_orig_ctx = params.yarn_orig_ctx; cparams.yarn_orig_ctx = params.yarn_orig_ctx;
cparams.pooling_type = params.pooling_type; cparams.pooling_type = params.pooling_type;
cparams.attention_type = params.attention_type; cparams.attention_type = params.attention_type;
cparams.flash_attn_type = params.flash_attn_type;
cparams.cb_eval = params.cb_eval; cparams.cb_eval = params.cb_eval;
cparams.cb_eval_user_data = params.cb_eval_user_data; cparams.cb_eval_user_data = params.cb_eval_user_data;
cparams.offload_kqv = !params.no_kv_offload; cparams.offload_kqv = !params.no_kv_offload;
cparams.flash_attn = params.flash_attn;
cparams.no_perf = params.no_perf; cparams.no_perf = params.no_perf;
cparams.op_offload = !params.no_op_offload; cparams.op_offload = !params.no_op_offload;
cparams.swa_full = params.swa_full; cparams.swa_full = params.swa_full;

View File

@ -312,6 +312,7 @@ struct common_params {
enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type for embeddings
enum llama_flash_attn_type flash_attn_type = LLAMA_FLASH_ATTN_TYPE_AUTO; // whether to use Flash Attention
struct common_params_sampling sampling; struct common_params_sampling sampling;
struct common_params_speculative speculative; struct common_params_speculative speculative;
@ -375,7 +376,6 @@ struct common_params {
bool multiline_input = false; // reverse the usage of `\` bool multiline_input = false; // reverse the usage of `\`
bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool simple_io = false; // improves compatibility with subprocesses and limited consoles
bool cont_batching = true; // insert new sequences for decoding on-the-fly bool cont_batching = true; // insert new sequences for decoding on-the-fly
bool flash_attn = false; // flash attention
bool no_perf = false; // disable performance metrics bool no_perf = false; // disable performance metrics
bool ctx_shift = false; // context shift on infinite text generation bool ctx_shift = false; // context shift on infinite text generation
bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)
@ -444,7 +444,7 @@ struct common_params {
// "advanced" endpoints are disabled by default for better security // "advanced" endpoints are disabled by default for better security
bool webui = true; bool webui = true;
bool endpoint_slots = false; bool endpoint_slots = true;
bool endpoint_props = false; // only control POST requests, not GET bool endpoint_props = false; // only control POST requests, not GET
bool endpoint_metrics = false; bool endpoint_metrics = false;

View File

@ -4,17 +4,52 @@
#include <condition_variable> #include <condition_variable>
#include <cstdarg> #include <cstdarg>
#include <cstdio> #include <cstdio>
#include <cstdlib>
#include <cstring>
#include <mutex> #include <mutex>
#include <sstream> #include <sstream>
#include <thread> #include <thread>
#include <vector> #include <vector>
#if defined(_WIN32)
# include <io.h>
# include <windows.h>
# define isatty _isatty
# define fileno _fileno
#else
# include <unistd.h>
#endif // defined(_WIN32)
int common_log_verbosity_thold = LOG_DEFAULT_LLAMA; int common_log_verbosity_thold = LOG_DEFAULT_LLAMA;
void common_log_set_verbosity_thold(int verbosity) { void common_log_set_verbosity_thold(int verbosity) {
common_log_verbosity_thold = verbosity; common_log_verbosity_thold = verbosity;
} }
// Auto-detect if colors should be enabled based on terminal and environment
static bool common_log_should_use_colors_auto() {
// Check NO_COLOR environment variable (https://no-color.org/)
if (const char * no_color = std::getenv("NO_COLOR")) {
if (no_color[0] != '\0') {
return false;
}
}
// Check TERM environment variable
if (const char * term = std::getenv("TERM")) {
if (std::strcmp(term, "dumb") == 0) {
return false;
}
}
// Check if stdout and stderr are connected to a terminal
// We check both because log messages can go to either
bool stdout_is_tty = isatty(fileno(stdout));
bool stderr_is_tty = isatty(fileno(stderr));
return stdout_is_tty || stderr_is_tty;
}
static int64_t t_us() { static int64_t t_us() {
return std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now().time_since_epoch()).count(); return std::chrono::duration_cast<std::chrono::microseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
} }
@ -353,6 +388,11 @@ struct common_log * common_log_init() {
struct common_log * common_log_main() { struct common_log * common_log_main() {
static struct common_log log; static struct common_log log;
static std::once_flag init_flag;
std::call_once(init_flag, [&]() {
// Set default to auto-detect colors
log.set_colors(common_log_should_use_colors_auto());
});
return &log; return &log;
} }
@ -380,8 +420,19 @@ void common_log_set_file(struct common_log * log, const char * file) {
log->set_file(file); log->set_file(file);
} }
void common_log_set_colors(struct common_log * log, bool colors) { void common_log_set_colors(struct common_log * log, log_colors colors) {
log->set_colors(colors); if (colors == LOG_COLORS_AUTO) {
log->set_colors(common_log_should_use_colors_auto());
return;
}
if (colors == LOG_COLORS_DISABLED) {
log->set_colors(false);
return;
}
GGML_ASSERT(colors == LOG_COLORS_ENABLED);
log->set_colors(true);
} }
void common_log_set_prefix(struct common_log * log, bool prefix) { void common_log_set_prefix(struct common_log * log, bool prefix) {

View File

@ -24,6 +24,12 @@
#define LOG_DEFAULT_DEBUG 1 #define LOG_DEFAULT_DEBUG 1
#define LOG_DEFAULT_LLAMA 0 #define LOG_DEFAULT_LLAMA 0
enum log_colors {
LOG_COLORS_AUTO = -1,
LOG_COLORS_DISABLED = 0,
LOG_COLORS_ENABLED = 1,
};
// needed by the LOG_TMPL macro to avoid computing log arguments if the verbosity lower // needed by the LOG_TMPL macro to avoid computing log arguments if the verbosity lower
// set via common_log_set_verbosity() // set via common_log_set_verbosity()
extern int common_log_verbosity_thold; extern int common_log_verbosity_thold;
@ -65,10 +71,10 @@ void common_log_add(struct common_log * log, enum ggml_log_level level, const ch
// D - debug (stderr, V = LOG_DEFAULT_DEBUG) // D - debug (stderr, V = LOG_DEFAULT_DEBUG)
// //
void common_log_set_file (struct common_log * log, const char * file); // not thread-safe void common_log_set_file (struct common_log * log, const char * file); // not thread-safe
void common_log_set_colors (struct common_log * log, bool colors); // not thread-safe void common_log_set_colors (struct common_log * log, log_colors colors); // not thread-safe
void common_log_set_prefix (struct common_log * log, bool prefix); // whether to output prefix to each log void common_log_set_prefix (struct common_log * log, bool prefix); // whether to output prefix to each log
void common_log_set_timestamps(struct common_log * log, bool timestamps); // whether to output timestamps in the prefix void common_log_set_timestamps(struct common_log * log, bool timestamps); // whether to output timestamps in the prefix
// helper macros for logging // helper macros for logging
// use these to avoid computing log arguments if the verbosity of the log is higher than the threshold // use these to avoid computing log arguments if the verbosity of the log is higher than the threshold

View File

@ -426,8 +426,29 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl) {
// helpers // helpers
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl) { llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort) {
return &gsmpl->cur_p; auto * res = &gsmpl->cur_p;
if (do_sort && !res->sorted) {
// remember the selected token before sorting
const llama_token id = res->data[res->selected].id;
std::sort(res->data, res->data + res->size, [](const llama_token_data & a, const llama_token_data & b) {
return a.p > b.p;
});
// restore the selected token after sorting
for (size_t i = 0; i < res->size; ++i) {
if (res->data[i].id == id) {
res->selected = i;
break;
}
}
res->sorted = true;
}
return res;
} }
llama_token common_sampler_last(const struct common_sampler * gsmpl) { llama_token common_sampler_last(const struct common_sampler * gsmpl) {

View File

@ -86,7 +86,9 @@ uint32_t common_sampler_get_seed(const struct common_sampler * gsmpl);
// helpers // helpers
// access the internal list of current candidate tokens // access the internal list of current candidate tokens
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl); // if do_sort == true, the candidates are guaranteed to be sorted afterwards (in descending order of probability)
// the .sorted flag of the result indicates whether the returned candidates are sorted
llama_token_data_array * common_sampler_get_candidates(struct common_sampler * gsmpl, bool do_sort);
// get the last accepted token // get the last accepted token
llama_token common_sampler_last(const struct common_sampler * gsmpl); llama_token common_sampler_last(const struct common_sampler * gsmpl);

View File

@ -317,7 +317,7 @@ llama_tokens common_speculative_gen_draft(
common_sampler_sample(smpl, ctx_dft, 0, true); common_sampler_sample(smpl, ctx_dft, 0, true);
const auto * cur_p = common_sampler_get_candidates(smpl); const auto * cur_p = common_sampler_get_candidates(smpl, true);
for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) { for (int k = 0; k < std::min(3, (int) cur_p->size); ++k) {
LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n", LOG_DBG(" - draft candidate %3d, pos %3d: %6d (%8.3f) '%s'\n",

View File

@ -302,10 +302,6 @@ class ModelBase:
# data = data_torch.squeeze().numpy() # data = data_torch.squeeze().numpy()
data = data_torch.numpy() data = data_torch.numpy()
# if data ends up empty, it means data_torch was a scalar tensor -> restore
if len(data.shape) == 0:
data = data_torch.numpy()
n_dims = len(data.shape) n_dims = len(data.shape)
data_qtype: gguf.GGMLQuantizationType | bool = self.tensor_force_quant(name, new_name, bid, n_dims) data_qtype: gguf.GGMLQuantizationType | bool = self.tensor_force_quant(name, new_name, bid, n_dims)
@ -5126,6 +5122,15 @@ class Gemma3Model(TextModel):
return [(self.map_tensor_name(name), data_torch)] return [(self.map_tensor_name(name), data_torch)]
@ModelBase.register("Gemma3TextModel")
class EmbeddingGemma(Gemma3Model):
model_arch = gguf.MODEL_ARCH.GEMMA_EMBEDDING
def set_gguf_parameters(self):
super().set_gguf_parameters()
self._try_set_pooling_type()
@ModelBase.register("Gemma3ForConditionalGeneration") @ModelBase.register("Gemma3ForConditionalGeneration")
class Gemma3VisionModel(MmprojModel): class Gemma3VisionModel(MmprojModel):
def set_gguf_parameters(self): def set_gguf_parameters(self):

View File

@ -12,7 +12,7 @@ import json
from math import prod from math import prod
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast from typing import TYPE_CHECKING, Any, Callable, Iterable, Iterator, Sequence, SupportsIndex, cast
from transformers import AutoConfig from transformers import AutoConfig, AutoTokenizer
import torch import torch
@ -26,6 +26,8 @@ import gguf
# reuse model definitions from convert_hf_to_gguf.py # reuse model definitions from convert_hf_to_gguf.py
from convert_hf_to_gguf import LazyTorchTensor, ModelBase from convert_hf_to_gguf import LazyTorchTensor, ModelBase
from gguf.constants import GGUFValueType
logger = logging.getLogger("lora-to-gguf") logger = logging.getLogger("lora-to-gguf")
@ -369,7 +371,31 @@ if __name__ == '__main__':
self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora") self.gguf_writer.add_string(gguf.Keys.Adapter.TYPE, "lora")
def set_gguf_parameters(self): def set_gguf_parameters(self):
logger.debug("GGUF KV: %s = %d", gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha)
self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha) self.gguf_writer.add_float32(gguf.Keys.Adapter.LORA_ALPHA, self.lora_alpha)
alora_invocation_tokens = lparams.get("alora_invocation_tokens")
invocation_string = lparams.get("invocation_string")
if invocation_string and not alora_invocation_tokens:
logger.debug("Tokenizing invocation_string -> alora_invocation_tokens")
base_model_path_or_id = hparams.get("_name_or_path")
try:
tokenizer = AutoTokenizer.from_pretrained(base_model_path_or_id)
except ValueError:
logger.error("Unable to load tokenizer from %s", base_model_path_or_id)
raise
# NOTE: There's an off-by-one with the older aLoRAs where
# the invocation string includes the "<|start_of_turn|>"
# token, but the adapters themselves were trained to
# activate _after_ that first token, so we drop it here.
alora_invocation_tokens = tokenizer(invocation_string)["input_ids"][1:]
if alora_invocation_tokens:
logger.debug("GGUF KV: %s = %s", gguf.Keys.Adapter.ALORA_INVOCATION_TOKENS, alora_invocation_tokens)
self.gguf_writer.add_key_value(
gguf.Keys.Adapter.ALORA_INVOCATION_TOKENS,
alora_invocation_tokens,
GGUFValueType.ARRAY,
GGUFValueType.UINT32,
)
def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]: def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
# Never add extra tensors (e.g. rope_freqs) for LoRA adapters # Never add extra tensors (e.g. rope_freqs) for LoRA adapters

View File

@ -293,17 +293,14 @@ We would like to thank Tuo Dai, Shanni Li, and all of the project maintainers fr
## Environment variable setup ## Environment variable setup
### GGML_CANN_ASYNC_MODE
Enables asynchronous operator submission. Disabled by default.
### GGML_CANN_MEM_POOL ### GGML_CANN_MEM_POOL
Specifies the memory pool management strategy: Specifies the memory pool management strategy, Default is vmm.
- vmm: Utilizes a virtual memory manager pool. If hardware support for VMM is unavailable, falls back to the legacy (leg) memory pool. - vmm: Utilizes a virtual memory manager pool. If hardware support for VMM is unavailable, falls back to the legacy (leg) memory pool.
- prio: Employs a priority queue-based memory pool management. - prio: Employs a priority queue-based memory pool management.
- leg: Uses a fixed-size buffer pool. - leg: Uses a fixed-size buffer pool.
### GGML_CANN_DISABLE_BUF_POOL_CLEAN ### GGML_CANN_DISABLE_BUF_POOL_CLEAN
@ -312,5 +309,8 @@ Controls automatic cleanup of the memory pool. This option is only effective whe
### GGML_CANN_WEIGHT_NZ ### GGML_CANN_WEIGHT_NZ
Converting the matmul weight format from ND to NZ can significantly improve performance on the 310I DUO NPU. Converting the matmul weight format from ND to NZ to improve performance. Enabled by default.
### GGML_CANN_ACL_GRAPH
Operators are executed using ACL graph execution, rather than in op-by-op (eager) mode. Enabled by default.

View File

@ -42,18 +42,6 @@ cmake --build build --config Release -j $(nproc)
cmake --build build --config Release -j $(nproc) cmake --build build --config Release -j $(nproc)
``` ```
- By default, NNPA is disabled by default. To enable it:
```bash
cmake -S . -B build \
-DCMAKE_BUILD_TYPE=Release \
-DGGML_BLAS=ON \
-DGGML_BLAS_VENDOR=OpenBLAS \
-DGGML_NNPA=ON
cmake --build build --config Release -j $(nproc)
```
- For debug builds: - For debug builds:
```bash ```bash
@ -164,15 +152,11 @@ All models need to be converted to Big-Endian. You can achieve this in three cas
Only available in IBM z15/LinuxONE 3 or later system with the `-DGGML_VXE=ON` (turned on by default) compile flag. No hardware acceleration is possible with llama.cpp with older systems, such as IBM z14/arch12. In such systems, the APIs can still run but will use a scalar implementation. Only available in IBM z15/LinuxONE 3 or later system with the `-DGGML_VXE=ON` (turned on by default) compile flag. No hardware acceleration is possible with llama.cpp with older systems, such as IBM z14/arch12. In such systems, the APIs can still run but will use a scalar implementation.
### 2. NNPA Vector Intrinsics Acceleration ### 2. zDNN Accelerator (WIP)
Only available in IBM z16/LinuxONE 4 or later system with the `-DGGML_NNPA=ON` (turned off by default) compile flag. No hardware acceleration is possible with llama.cpp with older systems, such as IBM z15/arch13. In such systems, the APIs can still run but will use a scalar implementation.
### 3. zDNN Accelerator (WIP)
Only available in IBM z17/LinuxONE 5 or later system with the `-DGGML_ZDNN=ON` compile flag. No hardware acceleration is possible with llama.cpp with older systems, such as IBM z15/arch13. In such systems, the APIs will default back to CPU routines. Only available in IBM z17/LinuxONE 5 or later system with the `-DGGML_ZDNN=ON` compile flag. No hardware acceleration is possible with llama.cpp with older systems, such as IBM z15/arch13. In such systems, the APIs will default back to CPU routines.
### 4. Spyre Accelerator ### 3. Spyre Accelerator
_Only available with IBM z17 / LinuxONE 5 or later system. No support currently available._ _Only available with IBM z17 / LinuxONE 5 or later system. No support currently available._
@ -230,10 +214,6 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl
CXXFLAGS="-include cstdint" pip3 install -r requirements.txt CXXFLAGS="-include cstdint" pip3 install -r requirements.txt
``` ```
5. `-DGGML_NNPA=ON` generates gibberish output
Answer: We are aware of this as detailed in [this issue](https://github.com/ggml-org/llama.cpp/issues/14877). Please either try reducing the number of threads, or disable the compile option using `-DGGML_NNPA=OFF`.
## Getting Help on IBM Z & LinuxONE ## Getting Help on IBM Z & LinuxONE
1. **Bugs, Feature Requests** 1. **Bugs, Feature Requests**
@ -258,38 +238,38 @@ IBM VXE/VXE2 SIMD acceleration depends on the BLAS implementation. It is strongl
## Appendix B: SIMD Support Matrix ## Appendix B: SIMD Support Matrix
| | VX/VXE/VXE2 | NNPA | zDNN | Spyre | | | VX/VXE/VXE2 | zDNN | Spyre |
| ---------- | ----------- | ---- | ---- | ----- | |------------|-------------|------|-------|
| FP32 | ✅ | ✅ | ✅ | ❓ | | FP32 | ✅ | ✅ | ❓ |
| FP16 | ✅ | ✅ | ❓ | ❓ | | FP16 | ✅ | ❓ | ❓ |
| BF16 | 🚫 | 🚫 | ❓ | ❓ | | BF16 | 🚫 | ❓ | ❓ |
| Q4_0 | ✅ | ✅ | ❓ | ❓ | | Q4_0 | ✅ | ❓ | ❓ |
| Q4_1 | ✅ | ✅ | ❓ | ❓ | | Q4_1 | ✅ | ❓ | ❓ |
| MXFP4 | 🚫 | 🚫 | ❓ | ❓ | | MXFP4 | 🚫 | ❓ | ❓ |
| Q5_0 | ✅ | ✅ | ❓ | ❓ | | Q5_0 | ✅ | ❓ | ❓ |
| Q5_1 | ✅ | ✅ | ❓ | ❓ | | Q5_1 | ✅ | ❓ | ❓ |
| Q8_0 | ✅ | ✅ | ❓ | ❓ | | Q8_0 | ✅ | ❓ | ❓ |
| Q2_K | 🚫 | 🚫 | ❓ | ❓ | | Q2_K | 🚫 | ❓ | ❓ |
| Q3_K | ✅ | ✅ | ❓ | ❓ | | Q3_K | ✅ | ❓ | ❓ |
| Q4_K | ✅ | ✅ | ❓ | ❓ | | Q4_K | ✅ | ❓ | ❓ |
| Q5_K | ✅ | ✅ | ❓ | ❓ | | Q5_K | ✅ | ❓ | ❓ |
| Q6_K | ✅ | ✅ | ❓ | ❓ | | Q6_K | ✅ | ❓ | ❓ |
| TQ1_0 | 🚫 | 🚫 | ❓ | ❓ | | TQ1_0 | 🚫 | ❓ | ❓ |
| TQ2_0 | 🚫 | 🚫 | ❓ | ❓ | | TQ2_0 | 🚫 | ❓ | ❓ |
| IQ2_XXS | 🚫 | 🚫 | ❓ | ❓ | | IQ2_XXS | 🚫 | ❓ | ❓ |
| IQ2_XS | 🚫 | 🚫 | ❓ | ❓ | | IQ2_XS | 🚫 | ❓ | ❓ |
| IQ2_S | 🚫 | 🚫 | ❓ | ❓ | | IQ2_S | 🚫 | ❓ | ❓ |
| IQ3_XXS | 🚫 | 🚫 | ❓ | ❓ | | IQ3_XXS | 🚫 | ❓ | ❓ |
| IQ3_S | 🚫 | 🚫 | ❓ | ❓ | | IQ3_S | 🚫 | ❓ | ❓ |
| IQ1_S | 🚫 | 🚫 | ❓ | ❓ | | IQ1_S | 🚫 | ❓ | ❓ |
| IQ1_M | 🚫 | 🚫 | ❓ | ❓ | | IQ1_M | 🚫 | ❓ | ❓ |
| IQ4_NL | ✅ | ✅ | ❓ | ❓ | | IQ4_NL | ✅ | ❓ | ❓ |
| IQ4_XS | ✅ | ✅ | ❓ | ❓ | | IQ4_XS | ✅ | ❓ | ❓ |
| FP32->FP16 | 🚫 | ✅ | ❓ | ❓ | | FP32->FP16 | 🚫 | ❓ | ❓ |
| FP16->FP32 | 🚫 | ✅ | ❓ | ❓ | | FP16->FP32 | 🚫 | ❓ | ❓ |
- ✅ - acceleration available - ✅ - acceleration available
- 🚫 - acceleration unavailable, will still run using scalar implementation - 🚫 - acceleration unavailable, will still run using scalar implementation
- ❓ - acceleration unknown, please contribute if you can test it yourself - ❓ - acceleration unknown, please contribute if you can test it yourself
Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on Aug 22, 2025. Last Updated by **Aaron Teo (aaron.teo1@ibm.com)** on Sep 6, 2025.

View File

@ -59,8 +59,6 @@ cmake --build build --config Release
cmake --preset arm64-windows-llvm-release -D GGML_OPENMP=OFF cmake --preset arm64-windows-llvm-release -D GGML_OPENMP=OFF
cmake --build build-arm64-windows-llvm-release cmake --build build-arm64-windows-llvm-release
``` ```
Building for arm64 can also be done with the MSVC compiler with the build-arm64-windows-MSVC preset, or the standard CMake build instructions. However, note that the MSVC compiler does not support inline ARM assembly code, used e.g. for the accelerated Q4_0_N_M CPU kernels.
For building with ninja generator and clang compiler as default: For building with ninja generator and clang compiler as default:
-set path:set LIB=C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\um\x64;C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Tools\MSVC\14.41.34120\lib\x64\uwp;C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\ucrt\x64 -set path:set LIB=C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\um\x64;C:\Program Files\Microsoft Visual Studio\2022\Community\VC\Tools\MSVC\14.41.34120\lib\x64\uwp;C:\Program Files (x86)\Windows Kits\10\Lib\10.0.22621.0\ucrt\x64
```bash ```bash

View File

@ -333,17 +333,17 @@ static void print_params(struct my_llama_hparams * params) {
} }
static void print_tensor_info(const struct ggml_context * ctx) { static void print_tensor_info(const struct ggml_context * ctx) {
for (auto t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { for (auto * t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) {
LOG_INF("%s: Allocating ", __func__); LOG_INF("%s: Allocating ", __func__);
int64_t total = 1; int64_t total = 1;
int i = 0; int i = 0;
for (; i < ggml_n_dims(t); ++i) { for (; i < ggml_n_dims(t); ++i) {
if (i > 0) LOG("x "); if (i > 0) { LOG_INF("x "); }
LOG("[%" PRId64 "] ", t->ne[i]); LOG_INF("[%" PRId64 "] ", t->ne[i]);
total *= t->ne[i]; total *= t->ne[i];
} }
if (i > 1) LOG("= [%" PRId64 "] ", total); if (i > 1) { LOG_INF("= [%" PRId64 "] ", total); }
LOG("float space for %s\n", ggml_get_name(t)); LOG_INF("float space for %s\n", ggml_get_name(t));
} }
} }

View File

@ -564,7 +564,7 @@ int main(int argc, char ** argv) {
ctx_params.n_ctx = params.n_ctx; ctx_params.n_ctx = params.n_ctx;
ctx_params.n_batch = params.n_batch; ctx_params.n_batch = params.n_batch;
ctx_params.n_ubatch = params.n_ubatch; ctx_params.n_ubatch = params.n_ubatch;
ctx_params.flash_attn = params.flash_attn; ctx_params.flash_attn_type = params.flash_attn_type;
ctx_params.no_perf = params.no_perf; ctx_params.no_perf = params.no_perf;
ctx_params.type_k = params.cache_type_k; ctx_params.type_k = params.cache_type_k;
ctx_params.type_v = params.cache_type_v; ctx_params.type_v = params.cache_type_v;

View File

@ -63,7 +63,7 @@ causal-verify-logits: causal-run-original-model causal-run-converted-model
@MODEL_PATH="$(MODEL_PATH)" ./scripts/utils/check-nmse.py -m ${MODEL_PATH} @MODEL_PATH="$(MODEL_PATH)" ./scripts/utils/check-nmse.py -m ${MODEL_PATH}
causal-run-original-embeddings: causal-run-original-embeddings:
@./scripts/causal/run-casual-gen-embeddings-org.sh @./scripts/causal/run-casual-gen-embeddings-org.py
causal-run-converted-embeddings: causal-run-converted-embeddings:
@./scripts/causal/run-converted-model-embeddings-logits.sh @./scripts/causal/run-converted-model-embeddings-logits.sh

View File

@ -1,4 +1,4 @@
#/bin/bash #!/usr/bin/env bash
set -e set -e

View File

@ -1,4 +1,4 @@
#!/bin/bash #!/usr/bin/env bash
set -e set -e

View File

@ -3,11 +3,10 @@
import argparse import argparse
import os import os
import importlib import importlib
import sys
import torch import torch
import numpy as np import numpy as np
from transformers import AutoTokenizer, AutoConfig, AutoModel, AutoModelForCausalLM from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM
from pathlib import Path from pathlib import Path
unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME') unreleased_model_name = os.getenv('UNRELEASED_MODEL_NAME')
@ -43,6 +42,8 @@ if unreleased_model_name:
model = model_class.from_pretrained(model_path) model = model_class.from_pretrained(model_path)
except (ImportError, AttributeError) as e: except (ImportError, AttributeError) as e:
print(f"Failed to import or load model: {e}") print(f"Failed to import or load model: {e}")
print("Falling back to AutoModelForCausalLM")
model = AutoModelForCausalLM.from_pretrained(model_path)
else: else:
model = AutoModelForCausalLM.from_pretrained(model_path) model = AutoModelForCausalLM.from_pretrained(model_path)
print(f"Model class: {type(model)}") print(f"Model class: {type(model)}")

View File

@ -1,4 +1,4 @@
#!/bin/bash #!/usr/bin/env bash
set -e set -e

View File

@ -1,4 +1,4 @@
#!/bin/bash #!/usr/bin/env bash
set -e set -e

View File

@ -1,4 +1,4 @@
#/bin/bash #!/usr/bin/env bash
set -e set -e

View File

@ -1,4 +1,4 @@
#!/bin/bash #!/usr/bin/env bash
set -e set -e

View File

@ -7,7 +7,7 @@ base_model:
Recommended way to run this model: Recommended way to run this model:
```sh ```sh
llama-server -hf {namespace}/{model_name}-GGUF llama-server -hf {namespace}/{model_name}-GGUF --embeddings
``` ```
Then the endpoint can be accessed at http://localhost:8080/embedding, for Then the endpoint can be accessed at http://localhost:8080/embedding, for

View File

@ -1,4 +1,4 @@
#!/bin/bash #!/usr/bin/env bash
set -e set -e

View File

@ -1,4 +1,6 @@
#!/usr/bin/env bash
COLLECTION_SLUG=$(python ./create_collection.py --return-slug) COLLECTION_SLUG=$(python ./create_collection.py --return-slug)
echo "Created collection: $COLLECTION_SLUG" echo "Created collection: $COLLECTION_SLUG"

View File

@ -0,0 +1,6 @@
#!/usr/bin/env bash
curl --request POST \
--url http://localhost:8080/embedding \
--header "Content-Type: application/json" \
--data '{"input": "Hello world today"}' \
--silent

View File

@ -1,4 +1,4 @@
#!/bin/bash #!/usr/bin/env bash
# First try command line argument, then environment variable, then file # First try command line argument, then environment variable, then file
CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}" CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"

View File

@ -40,7 +40,7 @@ if os.path.exists(index_path):
file_path = os.path.join(model_path, file_name) file_path = os.path.join(model_path, file_name)
print(f"\n--- From {file_name} ---") print(f"\n--- From {file_name} ---")
with safe_open(file_path, framework="pt") as f: with safe_open(file_path, framework="pt") as f: # type: ignore
for tensor_name in sorted(tensor_names): for tensor_name in sorted(tensor_names):
tensor = f.get_tensor(tensor_name) tensor = f.get_tensor(tensor_name)
print(f"- {tensor_name} : shape = {tensor.shape}, dtype = {tensor.dtype}") print(f"- {tensor_name} : shape = {tensor.shape}, dtype = {tensor.dtype}")
@ -49,7 +49,7 @@ elif os.path.exists(single_file_path):
# Single file model (original behavior) # Single file model (original behavior)
print("Single-file model detected") print("Single-file model detected")
with safe_open(single_file_path, framework="pt") as f: with safe_open(single_file_path, framework="pt") as f: # type: ignore
keys = f.keys() keys = f.keys()
print("Tensors in model:") print("Tensors in model:")
for key in sorted(keys): for key in sorted(keys):

View File

@ -1,4 +1,4 @@
#!/bin/bash #!/usr/bin/env bash
set -e set -e

View File

@ -1,4 +1,4 @@
#!/bin/bash #!/usr/bin/env bash
set -e set -e

View File

@ -1,4 +1,4 @@
#!/bin/bash #!/usr/bin/env bash
set -e set -e

View File

@ -1,4 +1,4 @@
#!/bin/bash #!/usr/bin/env bash
set -e set -e

View File

@ -1,4 +1,4 @@
#!/bin/bash #!/usr/bin/env bash
set -e set -e
# #

View File

@ -244,7 +244,7 @@ int main(int argc, char ** argv) {
// stochastic verification // stochastic verification
common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true); common_sampler_sample(smpl, ctx_tgt, drafts[s_keep].i_batch_tgt[i_dft], true);
auto & dist_tgt = *common_sampler_get_candidates(smpl); auto & dist_tgt = *common_sampler_get_candidates(smpl, true);
float p_tgt = 0.0f; float p_tgt = 0.0f;
float p_dft = 0.0f; float p_dft = 0.0f;
@ -493,7 +493,7 @@ int main(int argc, char ** argv) {
common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true); common_sampler_sample(drafts[s].smpl, ctx_dft, drafts[s].i_batch_dft, true);
const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl); const auto * cur_p = common_sampler_get_candidates(drafts[s].smpl, true);
for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p->size); ++k) { for (int k = 0; k < std::min(n_seq_dft + 3, (int) cur_p->size); ++k) {
LOG_DBG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n", LOG_DBG(" - draft candidate %3d for seq %3d, pos %3d: %6d (%8.3f) '%s'\n",

View File

@ -1,5 +1,5 @@
cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories. cmake_minimum_required(VERSION 3.14) # for add_link_options and implicit target directories.
project("ggml" C CXX) project("ggml" C CXX ASM)
include(CheckIncludeFileCXX) include(CheckIncludeFileCXX)
set(CMAKE_EXPORT_COMPILE_COMMANDS ON) set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
@ -129,10 +129,11 @@ endif()
option(GGML_LASX "ggml: enable lasx" ON) option(GGML_LASX "ggml: enable lasx" ON)
option(GGML_LSX "ggml: enable lsx" ON) option(GGML_LSX "ggml: enable lsx" ON)
option(GGML_RVV "ggml: enable rvv" ON) option(GGML_RVV "ggml: enable rvv" ON)
option(GGML_RV_ZFH "ggml: enable riscv zfh" OFF) option(GGML_RV_ZFH "ggml: enable riscv zfh" ON)
option(GGML_RV_ZVFH "ggml: enable riscv zvfh" ON)
option(GGML_RV_ZICBOP "ggml: enable riscv zicbop" ON)
option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF) option(GGML_XTHEADVECTOR "ggml: enable xtheadvector" OFF)
option(GGML_VXE "ggml: enable vxe" ON) option(GGML_VXE "ggml: enable vxe" ON)
option(GGML_NNPA "ggml: enable nnpa" OFF) # temp disabled by default, see: https://github.com/ggml-org/llama.cpp/issues/14877
option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF) option(GGML_CPU_ALL_VARIANTS "ggml: build all variants of the CPU backend (requires GGML_BACKEND_DL)" OFF)
set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM") set(GGML_CPU_ARM_ARCH "" CACHE STRING "ggml: CPU architecture for ARM")

View File

@ -307,6 +307,9 @@ extern "C" {
GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend); GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend);
GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node); GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node);
// Split graph without allocating it
GGML_API void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph);
// Allocate and compute graph on the backend scheduler // Allocate and compute graph on the backend scheduler
GGML_API bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); // returns success GGML_API bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph); // returns success
GGML_API enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph); GGML_API enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, struct ggml_cgraph * graph);

View File

@ -101,7 +101,6 @@ extern "C" {
GGML_BACKEND_API int ggml_cpu_has_riscv_v (void); GGML_BACKEND_API int ggml_cpu_has_riscv_v (void);
GGML_BACKEND_API int ggml_cpu_has_vsx (void); GGML_BACKEND_API int ggml_cpu_has_vsx (void);
GGML_BACKEND_API int ggml_cpu_has_vxe (void); GGML_BACKEND_API int ggml_cpu_has_vxe (void);
GGML_BACKEND_API int ggml_cpu_has_nnpa (void);
GGML_BACKEND_API int ggml_cpu_has_wasm_simd (void); GGML_BACKEND_API int ggml_cpu_has_wasm_simd (void);
GGML_BACKEND_API int ggml_cpu_has_llamafile (void); GGML_BACKEND_API int ggml_cpu_has_llamafile (void);

View File

@ -511,6 +511,7 @@ extern "C" {
GGML_OP_CONV_TRANSPOSE_1D, GGML_OP_CONV_TRANSPOSE_1D,
GGML_OP_IM2COL, GGML_OP_IM2COL,
GGML_OP_IM2COL_BACK, GGML_OP_IM2COL_BACK,
GGML_OP_IM2COL_3D,
GGML_OP_CONV_2D, GGML_OP_CONV_2D,
GGML_OP_CONV_3D, GGML_OP_CONV_3D,
GGML_OP_CONV_2D_DW, GGML_OP_CONV_2D_DW,
@ -1870,6 +1871,41 @@ extern "C" {
int d0, // dilation dimension 0 int d0, // dilation dimension 0
int d1); // dilation dimension 1 int d1); // dilation dimension 1
GGML_API struct ggml_tensor * ggml_im2col_3d(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int64_t IC,
int s0, // stride width
int s1, // stride height
int s2, // stride depth
int p0, // padding width
int p1, // padding height
int p2, // padding depth
int d0, // dilation width
int d1, // dilation height
int d2, // dilation depth
enum ggml_type dst_type);
// a: [OC*IC, KD, KH, KW]
// b: [N*IC, ID, IH, IW]
// result: [N*OC, OD, OH, OW]
GGML_API struct ggml_tensor * ggml_conv_3d(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int64_t IC,
int s0, // stride width
int s1, // stride height
int s2, // stride depth
int p0, // padding width
int p1, // padding height
int p2, // padding depth
int d0, // dilation width
int d1, // dilation height
int d2 // dilation depth
);
// kernel size is a->ne[0] x a->ne[1] // kernel size is a->ne[0] x a->ne[1]
// stride is equal to kernel size // stride is equal to kernel size
// padding is zero // padding is zero
@ -1941,7 +1977,7 @@ extern "C" {
int d0, // dilation dimension 0 int d0, // dilation dimension 0
int d1); // dilation dimension 1 int d1); // dilation dimension 1
GGML_API struct ggml_tensor * ggml_conv_3d( GGML_API struct ggml_tensor * ggml_conv_3d_direct(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, // kernel [KW, KH, KD, IC * OC] struct ggml_tensor * a, // kernel [KW, KH, KD, IC * OC]
struct ggml_tensor * b, // input [W, H, D, C * N] struct ggml_tensor * b, // input [W, H, D, C * N]
@ -2048,6 +2084,19 @@ extern "C" {
int p2, int p2,
int p3); int p3);
GGML_API struct ggml_tensor * ggml_pad_ext(
struct ggml_context * ctx,
struct ggml_tensor * a,
int lp0,
int rp0,
int lp1,
int rp1,
int lp2,
int rp2,
int lp3,
int rp3
);
// pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c] // pad each dimension with reflection: [a, b, c, d] -> [b, a, b, c, d, c]
GGML_API struct ggml_tensor * ggml_pad_reflect_1d( GGML_API struct ggml_tensor * ggml_pad_reflect_1d(
struct ggml_context * ctx, struct ggml_context * ctx,

View File

@ -31,6 +31,7 @@
// backend buffer type // backend buffer type
const char * ggml_backend_buft_name(ggml_backend_buffer_type_t buft) { const char * ggml_backend_buft_name(ggml_backend_buffer_type_t buft) {
GGML_ASSERT(buft);
return buft->iface.get_name(buft); return buft->iface.get_name(buft);
} }
@ -40,14 +41,17 @@ ggml_backend_buffer_t ggml_backend_buft_alloc_buffer(ggml_backend_buffer_type_t
return ggml_backend_buffer_init(buft, {}, NULL, 0); return ggml_backend_buffer_init(buft, {}, NULL, 0);
} }
GGML_ASSERT(buft);
return buft->iface.alloc_buffer(buft, size); return buft->iface.alloc_buffer(buft, size);
} }
size_t ggml_backend_buft_get_alignment(ggml_backend_buffer_type_t buft) { size_t ggml_backend_buft_get_alignment(ggml_backend_buffer_type_t buft) {
GGML_ASSERT(buft);
return buft->iface.get_alignment(buft); return buft->iface.get_alignment(buft);
} }
size_t ggml_backend_buft_get_max_size(ggml_backend_buffer_type_t buft) { size_t ggml_backend_buft_get_max_size(ggml_backend_buffer_type_t buft) {
GGML_ASSERT(buft);
// get_max_size is optional, defaults to SIZE_MAX // get_max_size is optional, defaults to SIZE_MAX
if (buft->iface.get_max_size) { if (buft->iface.get_max_size) {
return buft->iface.get_max_size(buft); return buft->iface.get_max_size(buft);
@ -56,6 +60,7 @@ size_t ggml_backend_buft_get_max_size(ggml_backend_buffer_type_t buft) {
} }
size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) { size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const struct ggml_tensor * tensor) {
GGML_ASSERT(buft);
// get_alloc_size is optional, defaults to ggml_nbytes // get_alloc_size is optional, defaults to ggml_nbytes
if (buft->iface.get_alloc_size) { if (buft->iface.get_alloc_size) {
size_t size = buft->iface.get_alloc_size(buft, tensor); size_t size = buft->iface.get_alloc_size(buft, tensor);
@ -66,6 +71,7 @@ size_t ggml_backend_buft_get_alloc_size(ggml_backend_buffer_type_t buft, const s
} }
bool ggml_backend_buft_is_host(ggml_backend_buffer_type_t buft) { bool ggml_backend_buft_is_host(ggml_backend_buffer_type_t buft) {
GGML_ASSERT(buft);
if (buft->iface.is_host) { if (buft->iface.is_host) {
return buft->iface.is_host(buft); return buft->iface.is_host(buft);
} }
@ -73,6 +79,7 @@ bool ggml_backend_buft_is_host(ggml_backend_buffer_type_t buft) {
} }
ggml_backend_dev_t ggml_backend_buft_get_device(ggml_backend_buffer_type_t buft) { ggml_backend_dev_t ggml_backend_buft_get_device(ggml_backend_buffer_type_t buft) {
GGML_ASSERT(buft);
return buft->device; return buft->device;
} }
@ -110,10 +117,12 @@ void ggml_backend_buffer_free(ggml_backend_buffer_t buffer) {
} }
size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) { size_t ggml_backend_buffer_get_size(ggml_backend_buffer_t buffer) {
GGML_ASSERT(buffer);
return buffer->size; return buffer->size;
} }
void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) { void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) {
GGML_ASSERT(buffer);
// get_base is optional if the buffer is zero-sized // get_base is optional if the buffer is zero-sized
if (buffer->size == 0) { if (buffer->size == 0) {
return NULL; return NULL;
@ -127,6 +136,7 @@ void * ggml_backend_buffer_get_base(ggml_backend_buffer_t buffer) {
} }
enum ggml_status ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) { enum ggml_status ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor) {
GGML_ASSERT(buffer);
// init_tensor is optional // init_tensor is optional
if (buffer->iface.init_tensor) { if (buffer->iface.init_tensor) {
return buffer->iface.init_tensor(buffer, tensor); return buffer->iface.init_tensor(buffer, tensor);
@ -135,6 +145,7 @@ enum ggml_status ggml_backend_buffer_init_tensor(ggml_backend_buffer_t buffer, s
} }
void ggml_backend_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { void ggml_backend_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
GGML_ASSERT(buffer);
// clear is optional if the buffer is zero-sized // clear is optional if the buffer is zero-sized
if (buffer->size == 0) { if (buffer->size == 0) {
return; return;
@ -160,6 +171,7 @@ bool ggml_backend_buffer_is_host(ggml_backend_buffer_t buffer) {
} }
void ggml_backend_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) { void ggml_backend_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) {
GGML_ASSERT(buffer);
buffer->usage = usage; buffer->usage = usage;
// FIXME: add a generic callback to the buffer interface // FIXME: add a generic callback to the buffer interface
@ -169,14 +181,17 @@ void ggml_backend_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backe
} }
enum ggml_backend_buffer_usage ggml_backend_buffer_get_usage(ggml_backend_buffer_t buffer) { enum ggml_backend_buffer_usage ggml_backend_buffer_get_usage(ggml_backend_buffer_t buffer) {
GGML_ASSERT(buffer);
return buffer->usage; return buffer->usage;
} }
ggml_backend_buffer_type_t ggml_backend_buffer_get_type(ggml_backend_buffer_t buffer) { ggml_backend_buffer_type_t ggml_backend_buffer_get_type(ggml_backend_buffer_t buffer) {
GGML_ASSERT(buffer);
return buffer->buft; return buffer->buft;
} }
void ggml_backend_buffer_reset(ggml_backend_buffer_t buffer) { void ggml_backend_buffer_reset(ggml_backend_buffer_t buffer) {
GGML_ASSERT(buffer);
if (buffer->iface.reset) { if (buffer->iface.reset) {
buffer->iface.reset(buffer); buffer->iface.reset(buffer);
} }
@ -215,6 +230,7 @@ void ggml_backend_free(ggml_backend_t backend) {
} }
ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend) { ggml_backend_buffer_type_t ggml_backend_get_default_buffer_type(ggml_backend_t backend) {
GGML_ASSERT(backend);
return ggml_backend_dev_buffer_type(backend->device); return ggml_backend_dev_buffer_type(backend->device);
} }
@ -231,6 +247,8 @@ size_t ggml_backend_get_max_size(ggml_backend_t backend) {
} }
void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
GGML_ASSERT(backend);
GGML_ASSERT(tensor);
GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds"); GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor write out of bounds");
@ -242,6 +260,8 @@ void ggml_backend_tensor_set_async(ggml_backend_t backend, struct ggml_tensor *
} }
void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { void ggml_backend_tensor_get_async(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
GGML_ASSERT(backend);
GGML_ASSERT(tensor);
GGML_ASSERT(tensor->data != NULL && "tensor not allocated"); GGML_ASSERT(tensor->data != NULL && "tensor not allocated");
GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds"); GGML_ASSERT(offset + size <= ggml_nbytes(tensor) && "tensor read out of bounds");
@ -283,6 +303,7 @@ void ggml_backend_tensor_get(const struct ggml_tensor * tensor, void * data, siz
} }
void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
GGML_ASSERT(tensor);
ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer; ggml_backend_buffer_t buf = tensor->view_src ? tensor->view_src->buffer : tensor->buffer;
if (size == 0) { if (size == 0) {
@ -298,6 +319,7 @@ void ggml_backend_tensor_memset(struct ggml_tensor * tensor, uint8_t value, size
} }
void ggml_backend_synchronize(ggml_backend_t backend) { void ggml_backend_synchronize(ggml_backend_t backend) {
GGML_ASSERT(backend);
if (backend->iface.synchronize == NULL) { if (backend->iface.synchronize == NULL) {
return; return;
} }
@ -306,18 +328,21 @@ void ggml_backend_synchronize(ggml_backend_t backend) {
} }
ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) { ggml_backend_graph_plan_t ggml_backend_graph_plan_create(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
GGML_ASSERT(backend);
GGML_ASSERT(backend->iface.graph_plan_create != NULL); GGML_ASSERT(backend->iface.graph_plan_create != NULL);
return backend->iface.graph_plan_create(backend, cgraph); return backend->iface.graph_plan_create(backend, cgraph);
} }
void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { void ggml_backend_graph_plan_free(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
GGML_ASSERT(backend);
GGML_ASSERT(backend->iface.graph_plan_free != NULL); GGML_ASSERT(backend->iface.graph_plan_free != NULL);
backend->iface.graph_plan_free(backend, plan); backend->iface.graph_plan_free(backend, plan);
} }
enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) { enum ggml_status ggml_backend_graph_plan_compute(ggml_backend_t backend, ggml_backend_graph_plan_t plan) {
GGML_ASSERT(backend);
GGML_ASSERT(backend->iface.graph_plan_compute != NULL); GGML_ASSERT(backend->iface.graph_plan_compute != NULL);
return backend->iface.graph_plan_compute(backend, plan); return backend->iface.graph_plan_compute(backend, plan);
@ -330,22 +355,27 @@ enum ggml_status ggml_backend_graph_compute(ggml_backend_t backend, struct ggml_
} }
enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph) { enum ggml_status ggml_backend_graph_compute_async(ggml_backend_t backend, struct ggml_cgraph * cgraph) {
GGML_ASSERT(backend);
return backend->iface.graph_compute(backend, cgraph); return backend->iface.graph_compute(backend, cgraph);
} }
bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) { bool ggml_backend_supports_op(ggml_backend_t backend, const struct ggml_tensor * op) {
GGML_ASSERT(backend);
return ggml_backend_dev_supports_op(backend->device, op); return ggml_backend_dev_supports_op(backend->device, op);
} }
bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) { bool ggml_backend_supports_buft(ggml_backend_t backend, ggml_backend_buffer_type_t buft) {
GGML_ASSERT(backend);
return ggml_backend_dev_supports_buft(backend->device, buft); return ggml_backend_dev_supports_buft(backend->device, buft);
} }
bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op) { bool ggml_backend_offload_op(ggml_backend_t backend, const struct ggml_tensor * op) {
GGML_ASSERT(backend);
return ggml_backend_dev_offload_op(backend->device, op); return ggml_backend_dev_offload_op(backend->device, op);
} }
ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend) { ggml_backend_dev_t ggml_backend_get_device(ggml_backend_t backend) {
GGML_ASSERT(backend);
return backend->device; return backend->device;
} }
@ -381,6 +411,7 @@ void ggml_backend_tensor_copy_async(ggml_backend_t backend_src, ggml_backend_t b
return; return;
} }
GGML_ASSERT(backend_dst);
if (backend_dst->iface.cpy_tensor_async != NULL) { if (backend_dst->iface.cpy_tensor_async != NULL) {
if (backend_dst->iface.cpy_tensor_async(backend_src, backend_dst, src, dst)) { if (backend_dst->iface.cpy_tensor_async(backend_src, backend_dst, src, dst)) {
return; return;
@ -412,18 +443,21 @@ void ggml_backend_event_free(ggml_backend_event_t event) {
} }
void ggml_backend_event_record(ggml_backend_event_t event, ggml_backend_t backend) { void ggml_backend_event_record(ggml_backend_event_t event, ggml_backend_t backend) {
GGML_ASSERT(backend);
GGML_ASSERT(backend->iface.event_record != NULL); GGML_ASSERT(backend->iface.event_record != NULL);
backend->iface.event_record(backend, event); backend->iface.event_record(backend, event);
} }
void ggml_backend_event_synchronize(ggml_backend_event_t event) { void ggml_backend_event_synchronize(ggml_backend_event_t event) {
GGML_ASSERT(event);
GGML_ASSERT(event->device->iface.event_synchronize); GGML_ASSERT(event->device->iface.event_synchronize);
event->device->iface.event_synchronize(event->device, event); event->device->iface.event_synchronize(event->device, event);
} }
void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event) { void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event) {
GGML_ASSERT(backend);
GGML_ASSERT(backend->iface.event_wait != NULL); GGML_ASSERT(backend->iface.event_wait != NULL);
backend->iface.event_wait(backend, event); backend->iface.event_wait(backend, event);
@ -432,18 +466,22 @@ void ggml_backend_event_wait(ggml_backend_t backend, ggml_backend_event_t event)
// Backend device // Backend device
const char * ggml_backend_dev_name(ggml_backend_dev_t device) { const char * ggml_backend_dev_name(ggml_backend_dev_t device) {
GGML_ASSERT(device);
return device->iface.get_name(device); return device->iface.get_name(device);
} }
const char * ggml_backend_dev_description(ggml_backend_dev_t device) { const char * ggml_backend_dev_description(ggml_backend_dev_t device) {
GGML_ASSERT(device);
return device->iface.get_description(device); return device->iface.get_description(device);
} }
void ggml_backend_dev_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { void ggml_backend_dev_memory(ggml_backend_dev_t device, size_t * free, size_t * total) {
GGML_ASSERT(device);
device->iface.get_memory(device, free, total); device->iface.get_memory(device, free, total);
} }
enum ggml_backend_dev_type ggml_backend_dev_type(ggml_backend_dev_t device) { enum ggml_backend_dev_type ggml_backend_dev_type(ggml_backend_dev_t device) {
GGML_ASSERT(device);
return device->iface.get_type(device); return device->iface.get_type(device);
} }
@ -453,18 +491,22 @@ void ggml_backend_dev_get_props(ggml_backend_dev_t device, struct ggml_backend_d
} }
ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device) { ggml_backend_reg_t ggml_backend_dev_backend_reg(ggml_backend_dev_t device) {
GGML_ASSERT(device);
return device->reg; return device->reg;
} }
ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params) { ggml_backend_t ggml_backend_dev_init(ggml_backend_dev_t device, const char * params) {
GGML_ASSERT(device);
return device->iface.init_backend(device, params); return device->iface.init_backend(device, params);
} }
ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device) { ggml_backend_buffer_type_t ggml_backend_dev_buffer_type(ggml_backend_dev_t device) {
GGML_ASSERT(device);
return device->iface.get_buffer_type(device); return device->iface.get_buffer_type(device);
} }
ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device) { ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device) {
GGML_ASSERT(device);
if (device->iface.get_host_buffer_type == NULL) { if (device->iface.get_host_buffer_type == NULL) {
return NULL; return NULL;
} }
@ -473,18 +515,22 @@ ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t
} }
ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size) { ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size) {
GGML_ASSERT(device);
return device->iface.buffer_from_host_ptr(device, ptr, size, max_tensor_size); return device->iface.buffer_from_host_ptr(device, ptr, size, max_tensor_size);
} }
bool ggml_backend_dev_supports_op(ggml_backend_dev_t device, const struct ggml_tensor * op) { bool ggml_backend_dev_supports_op(ggml_backend_dev_t device, const struct ggml_tensor * op) {
GGML_ASSERT(device);
return device->iface.supports_op(device, op); return device->iface.supports_op(device, op);
} }
bool ggml_backend_dev_supports_buft(ggml_backend_dev_t device, ggml_backend_buffer_type_t buft) { bool ggml_backend_dev_supports_buft(ggml_backend_dev_t device, ggml_backend_buffer_type_t buft) {
GGML_ASSERT(device);
return device->iface.supports_buft(device, buft); return device->iface.supports_buft(device, buft);
} }
bool ggml_backend_dev_offload_op(ggml_backend_dev_t device, const struct ggml_tensor * op) { bool ggml_backend_dev_offload_op(ggml_backend_dev_t device, const struct ggml_tensor * op) {
GGML_ASSERT(device);
if (device->iface.offload_op != NULL) { if (device->iface.offload_op != NULL) {
return device->iface.offload_op(device, op); return device->iface.offload_op(device, op);
} }
@ -495,18 +541,22 @@ bool ggml_backend_dev_offload_op(ggml_backend_dev_t device, const struct ggml_te
// Backend (reg) // Backend (reg)
const char * ggml_backend_reg_name(ggml_backend_reg_t reg) { const char * ggml_backend_reg_name(ggml_backend_reg_t reg) {
GGML_ASSERT(reg);
return reg->iface.get_name(reg); return reg->iface.get_name(reg);
} }
size_t ggml_backend_reg_dev_count(ggml_backend_reg_t reg) { size_t ggml_backend_reg_dev_count(ggml_backend_reg_t reg) {
GGML_ASSERT(reg);
return reg->iface.get_device_count(reg); return reg->iface.get_device_count(reg);
} }
ggml_backend_dev_t ggml_backend_reg_dev_get(ggml_backend_reg_t reg, size_t index) { ggml_backend_dev_t ggml_backend_reg_dev_get(ggml_backend_reg_t reg, size_t index) {
GGML_ASSERT(reg);
return reg->iface.get_device(reg, index); return reg->iface.get_device(reg, index);
} }
void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) { void * ggml_backend_reg_get_proc_address(ggml_backend_reg_t reg, const char * name) {
GGML_ASSERT(reg);
if (!reg->iface.get_proc_address) { if (!reg->iface.get_proc_address) {
return NULL; return NULL;
} }
@ -521,6 +571,7 @@ struct ggml_backend_multi_buffer_context {
}; };
static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer) { static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer) {
GGML_ASSERT(buffer);
ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context; ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context;
for (size_t i = 0; i < ctx->n_buffers; i++) { for (size_t i = 0; i < ctx->n_buffers; i++) {
ggml_backend_buffer_free(ctx->buffers[i]); ggml_backend_buffer_free(ctx->buffers[i]);
@ -531,6 +582,7 @@ static void ggml_backend_multi_buffer_free_buffer(ggml_backend_buffer_t buffer)
} }
static void ggml_backend_multi_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { static void ggml_backend_multi_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
GGML_ASSERT(buffer);
ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context; ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context;
for (size_t i = 0; i < ctx->n_buffers; i++) { for (size_t i = 0; i < ctx->n_buffers; i++) {
ggml_backend_buffer_clear(ctx->buffers[i], value); ggml_backend_buffer_clear(ctx->buffers[i], value);
@ -566,10 +618,12 @@ ggml_backend_buffer_t ggml_backend_multi_buffer_alloc_buffer(ggml_backend_buffer
} }
bool ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer) { bool ggml_backend_buffer_is_multi_buffer(ggml_backend_buffer_t buffer) {
GGML_ASSERT(buffer);
return buffer->iface.free_buffer == ggml_backend_multi_buffer_free_buffer; return buffer->iface.free_buffer == ggml_backend_multi_buffer_free_buffer;
} }
void ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) { void ggml_backend_multi_buffer_set_usage(ggml_backend_buffer_t buffer, enum ggml_backend_buffer_usage usage) {
GGML_ASSERT(buffer);
GGML_ASSERT(ggml_backend_buffer_is_multi_buffer(buffer)); GGML_ASSERT(ggml_backend_buffer_is_multi_buffer(buffer));
ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context; ggml_backend_multi_buffer_context * ctx = (ggml_backend_multi_buffer_context *) buffer->context;
for (size_t i = 0; i < ctx->n_buffers; i++) { for (size_t i = 0; i < ctx->n_buffers; i++) {
@ -597,7 +651,7 @@ static bool ggml_is_view_op(enum ggml_op op) {
#endif #endif
#ifndef GGML_SCHED_MAX_SPLIT_INPUTS #ifndef GGML_SCHED_MAX_SPLIT_INPUTS
#define GGML_SCHED_MAX_SPLIT_INPUTS GGML_MAX_SRC #define GGML_SCHED_MAX_SPLIT_INPUTS 30
#endif #endif
#ifndef GGML_SCHED_MAX_COPIES #ifndef GGML_SCHED_MAX_COPIES
@ -848,7 +902,7 @@ static void ggml_backend_sched_set_if_supported(ggml_backend_sched_t sched, stru
} }
// assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend // assigns backends to ops and splits the graph into subgraphs that can be computed on the same backend
static void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { void ggml_backend_sched_split_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
// reset splits // reset splits
sched->n_splits = 0; sched->n_splits = 0;
sched->n_graph_inputs = 0; sched->n_graph_inputs = 0;
@ -1349,6 +1403,7 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
} }
static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) { static enum ggml_status ggml_backend_sched_compute_splits(ggml_backend_sched_t sched) {
GGML_ASSERT(sched);
struct ggml_backend_sched_split * splits = sched->splits; struct ggml_backend_sched_split * splits = sched->splits;
ggml_tensor * prev_ids_tensor = nullptr; ggml_tensor * prev_ids_tensor = nullptr;
@ -1617,6 +1672,7 @@ void ggml_backend_sched_free(ggml_backend_sched_t sched) {
} }
void ggml_backend_sched_reset(ggml_backend_sched_t sched) { void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
GGML_ASSERT(sched);
// reset state for the next run // reset state for the next run
if (!sched->is_reset) { if (!sched->is_reset) {
ggml_hash_set_reset(&sched->hash_set); ggml_hash_set_reset(&sched->hash_set);
@ -1628,8 +1684,11 @@ void ggml_backend_sched_reset(ggml_backend_sched_t sched) {
} }
bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) { bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph * measure_graph) {
GGML_ASSERT(sched);
GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs); GGML_ASSERT((int)sched->hash_set.size >= measure_graph->n_nodes + measure_graph->n_leafs);
ggml_backend_sched_reset(sched);
ggml_backend_sched_synchronize(sched); ggml_backend_sched_synchronize(sched);
ggml_backend_sched_split_graph(sched, measure_graph); ggml_backend_sched_split_graph(sched, measure_graph);
@ -1644,6 +1703,7 @@ bool ggml_backend_sched_reserve(ggml_backend_sched_t sched, struct ggml_cgraph *
} }
bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { bool ggml_backend_sched_alloc_graph(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
GGML_ASSERT(sched);
GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + graph->n_leafs); GGML_ASSERT((int)sched->hash_set.size >= graph->n_nodes + graph->n_leafs);
GGML_ASSERT(!sched->is_alloc); GGML_ASSERT(!sched->is_alloc);
@ -1668,6 +1728,7 @@ enum ggml_status ggml_backend_sched_graph_compute(ggml_backend_sched_t sched, st
} }
enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph) { enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sched, struct ggml_cgraph * graph) {
GGML_ASSERT(sched);
if (!sched->is_reset && !sched->is_alloc) { if (!sched->is_reset && !sched->is_alloc) {
ggml_backend_sched_reset(sched); ggml_backend_sched_reset(sched);
} }
@ -1682,6 +1743,7 @@ enum ggml_status ggml_backend_sched_graph_compute_async(ggml_backend_sched_t sch
} }
void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) { void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) {
GGML_ASSERT(sched);
for (int i = 0; i < sched->n_backends; i++) { for (int i = 0; i < sched->n_backends; i++) {
ggml_backend_synchronize(sched->backends[i]); ggml_backend_synchronize(sched->backends[i]);
} }
@ -1694,28 +1756,34 @@ void ggml_backend_sched_synchronize(ggml_backend_sched_t sched) {
} }
void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) { void ggml_backend_sched_set_eval_callback(ggml_backend_sched_t sched, ggml_backend_sched_eval_callback callback, void * user_data) {
GGML_ASSERT(sched);
sched->callback_eval = callback; sched->callback_eval = callback;
sched->callback_eval_user_data = user_data; sched->callback_eval_user_data = user_data;
} }
int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched) { int ggml_backend_sched_get_n_splits(ggml_backend_sched_t sched) {
GGML_ASSERT(sched);
return sched->n_splits; return sched->n_splits;
} }
int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched) { int ggml_backend_sched_get_n_copies(ggml_backend_sched_t sched) {
GGML_ASSERT(sched);
return sched->n_copies; return sched->n_copies;
} }
int ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched) { int ggml_backend_sched_get_n_backends(ggml_backend_sched_t sched) {
GGML_ASSERT(sched);
return sched->n_backends; return sched->n_backends;
} }
ggml_backend_t ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i) { ggml_backend_t ggml_backend_sched_get_backend(ggml_backend_sched_t sched, int i) {
GGML_ASSERT(sched);
GGML_ASSERT(i >= 0 && i < sched->n_backends); GGML_ASSERT(i >= 0 && i < sched->n_backends);
return sched->backends[i]; return sched->backends[i];
} }
size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) { size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend) {
GGML_ASSERT(sched);
int backend_index = ggml_backend_sched_backend_id(sched, backend); int backend_index = ggml_backend_sched_backend_id(sched, backend);
GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
@ -1723,6 +1791,7 @@ size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backe
} }
void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) { void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) {
GGML_ASSERT(sched);
int backend_index = ggml_backend_sched_backend_id(sched, backend); int backend_index = ggml_backend_sched_backend_id(sched, backend);
GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends); GGML_ASSERT(backend_index >= 0 && backend_index < sched->n_backends);
tensor_backend_id(node) = backend_index; tensor_backend_id(node) = backend_index;
@ -1731,6 +1800,7 @@ void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct gg
} }
ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node) { ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node) {
GGML_ASSERT(sched);
int backend_index = tensor_backend_id(node); int backend_index = tensor_backend_id(node);
if (backend_index == -1) { if (backend_index == -1) {
return NULL; return NULL;
@ -1741,6 +1811,7 @@ ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched,
// utils // utils
enum ggml_status ggml_backend_view_init(struct ggml_tensor * tensor) { enum ggml_status ggml_backend_view_init(struct ggml_tensor * tensor) {
GGML_ASSERT(tensor);
GGML_ASSERT(tensor->buffer == NULL); GGML_ASSERT(tensor->buffer == NULL);
GGML_ASSERT(tensor->view_src != NULL); GGML_ASSERT(tensor->view_src != NULL);
GGML_ASSERT(tensor->view_src->buffer != NULL); GGML_ASSERT(tensor->view_src->buffer != NULL);
@ -1752,6 +1823,7 @@ enum ggml_status ggml_backend_view_init(struct ggml_tensor * tensor) {
} }
enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr) { enum ggml_status ggml_backend_tensor_alloc(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, void * addr) {
GGML_ASSERT(tensor);
GGML_ASSERT(tensor->buffer == NULL); GGML_ASSERT(tensor->buffer == NULL);
GGML_ASSERT(tensor->data == NULL); GGML_ASSERT(tensor->data == NULL);
GGML_ASSERT(tensor->view_src == NULL); GGML_ASSERT(tensor->view_src == NULL);
@ -1825,6 +1897,7 @@ static void graph_copy_init_tensor(struct ggml_hash_set * hash_set, struct ggml_
} }
struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph) { struct ggml_backend_graph_copy ggml_backend_graph_copy(ggml_backend_t backend, struct ggml_cgraph * graph) {
GGML_ASSERT(graph);
struct ggml_hash_set hash_set = ggml_hash_set_new(graph->visited_hash_set.size); struct ggml_hash_set hash_set = ggml_hash_set_new(graph->visited_hash_set.size);
struct ggml_tensor ** node_copies = (ggml_tensor **) calloc(hash_set.size, sizeof(node_copies[0])); // NOLINT struct ggml_tensor ** node_copies = (ggml_tensor **) calloc(hash_set.size, sizeof(node_copies[0])); // NOLINT
bool * node_init = (bool *) calloc(hash_set.size, sizeof(node_init[0])); bool * node_init = (bool *) calloc(hash_set.size, sizeof(node_init[0]));
@ -1969,6 +2042,7 @@ bool ggml_backend_compare_graph_backend(ggml_backend_t backend1, ggml_backend_t
// CPU backend - buffer // CPU backend - buffer
static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) { static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
GGML_ASSERT(buffer);
uintptr_t data = (uintptr_t)buffer->context; uintptr_t data = (uintptr_t)buffer->context;
// align the buffer // align the buffer
@ -1980,28 +2054,33 @@ static void * ggml_backend_cpu_buffer_get_base(ggml_backend_buffer_t buffer) {
} }
static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) { static void ggml_backend_cpu_buffer_free_buffer(ggml_backend_buffer_t buffer) {
GGML_ASSERT(buffer);
ggml_aligned_free(buffer->context, buffer->size); ggml_aligned_free(buffer->context, buffer->size);
} }
static void ggml_backend_cpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { static void ggml_backend_cpu_buffer_memset_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) {
GGML_ASSERT(tensor);
memset((char *)tensor->data + offset, value, size); memset((char *)tensor->data + offset, value, size);
GGML_UNUSED(buffer); GGML_UNUSED(buffer);
} }
static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) { static void ggml_backend_cpu_buffer_set_tensor(ggml_backend_buffer_t buffer, struct ggml_tensor * tensor, const void * data, size_t offset, size_t size) {
GGML_ASSERT(tensor);
memcpy((char *)tensor->data + offset, data, size); memcpy((char *)tensor->data + offset, data, size);
GGML_UNUSED(buffer); GGML_UNUSED(buffer);
} }
static void ggml_backend_cpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { static void ggml_backend_cpu_buffer_get_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) {
GGML_ASSERT(tensor);
memcpy(data, (const char *)tensor->data + offset, size); memcpy(data, (const char *)tensor->data + offset, size);
GGML_UNUSED(buffer); GGML_UNUSED(buffer);
} }
static bool ggml_backend_cpu_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) { static bool ggml_backend_cpu_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const struct ggml_tensor * src, struct ggml_tensor * dst) {
GGML_ASSERT(src);
if (ggml_backend_buffer_is_host(src->buffer)) { if (ggml_backend_buffer_is_host(src->buffer)) {
memcpy(dst->data, src->data, ggml_nbytes(src)); memcpy(dst->data, src->data, ggml_nbytes(src));
return true; return true;
@ -2012,6 +2091,7 @@ static bool ggml_backend_cpu_buffer_cpy_tensor(ggml_backend_buffer_t buffer, con
} }
static void ggml_backend_cpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { static void ggml_backend_cpu_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) {
GGML_ASSERT(buffer);
memset(buffer->context, value, buffer->size); memset(buffer->context, value, buffer->size);
} }

View File

@ -70,6 +70,8 @@
#include <aclnnop/aclnn_zero.h> #include <aclnnop/aclnn_zero.h>
#include <aclnnop/aclnn_index_copy.h> #include <aclnnop/aclnn_index_copy.h>
#include <aclnnop/aclnn_index_select.h> #include <aclnnop/aclnn_index_select.h>
#include <aclnnop/aclnn_clamp.h>
#include <aclnnop/aclnn_threshold.h>
#include <float.h> #include <float.h>
#include <cmath> #include <cmath>
@ -587,9 +589,16 @@ void ggml_cann_pad(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
// the position of elements in the array means which dirction to padding, // the position of elements in the array means which dirction to padding,
// each position means: [dim0.front, dim0.behind, dim1.front, dim1.behind, // each position means: [dim0.front, dim0.behind, dim1.front, dim1.behind,
// dim2.front, dim2.behind, dim3.front, dim3.behind] // dim2.front, dim2.behind, dim3.front, dim3.behind]
int64_t paddings[] = { const int32_t lp0 = ggml_get_op_params_i32(dst, 0);
0, dst->ne[0] - src->ne[0], 0, dst->ne[1] - src->ne[1], const int32_t rp0 = ggml_get_op_params_i32(dst, 1);
0, dst->ne[2] - src->ne[2], 0, dst->ne[3] - src->ne[3]}; const int32_t lp1 = ggml_get_op_params_i32(dst, 2);
const int32_t rp1 = ggml_get_op_params_i32(dst, 3);
const int32_t lp2 = ggml_get_op_params_i32(dst, 4);
const int32_t rp2 = ggml_get_op_params_i32(dst, 5);
const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
int64_t paddings[] = {lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3};
aclnn_pad(ctx, acl_src, acl_dst, paddings); aclnn_pad(ctx, acl_src, acl_dst, paddings);
ggml_cann_release_resources(ctx, acl_src, acl_dst); ggml_cann_release_resources(ctx, acl_src, acl_dst);
} }
@ -964,8 +973,8 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
} }
aclTensor* acl_gamma = get_f32_cache_acl_tensor( aclTensor* acl_gamma = get_f32_cache_acl_tensor(
ctx, ctx,
&ctx.f32_one_cache, &ctx.rms_norm_one_tensor_cache.cache,
ctx.f32_one_cache_element, ctx.rms_norm_one_tensor_cache.size,
src->ne, src->ne,
acl_gamma_nb, acl_gamma_nb,
1, // dims 1, // dims
@ -973,18 +982,19 @@ void ggml_cann_rms_norm(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
); );
// build rstd, zero... // build rstd, zero...
size_t acl_rstd_nb[GGML_MAX_DIMS]; int64_t acl_rstd_ne[] = {src->ne[1], src->ne[2], src->ne[3]};
size_t acl_rstd_nb[GGML_MAX_DIMS - 1];
acl_rstd_nb[0] = sizeof(float); acl_rstd_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) { for (int i = 1; i < GGML_MAX_DIMS - 1; i++) {
acl_rstd_nb[i] = acl_rstd_nb[i - 1] * src->ne[i - 1]; acl_rstd_nb[i] = acl_rstd_nb[i - 1] * acl_rstd_ne[i - 1];
} }
aclTensor* acl_rstd = get_f32_cache_acl_tensor( aclTensor* acl_rstd = get_f32_cache_acl_tensor(
ctx, ctx,
&ctx.f32_zero_cache, &ctx.rms_norm_zero_tensor_cache.cache,
ctx.f32_zero_cache_element, ctx.rms_norm_zero_tensor_cache.size,
src->ne, acl_rstd_ne,
acl_rstd_nb, acl_rstd_nb,
GGML_MAX_DIMS, GGML_MAX_DIMS - 1,
0.0f // value 0.0f // value
); );
@ -1423,21 +1433,25 @@ static void aclnn_pow_tensor_tensor(ggml_backend_cann_context& ctx,
* @param start Starting exponent offset. * @param start Starting exponent offset.
* @param stop Stopping exponent offset (exclusive). * @param stop Stopping exponent offset (exclusive).
* @param step Step size for the exponent increment. * @param step Step size for the exponent increment.
* @param dtype Data type for slope tensor.
*/ */
static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_buffer, static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_buffer,
float m, int64_t size, float start, float stop, float step){ float m, int64_t size, float start, float stop, float step, ggml_type dtype){
int64_t ne[] = {size}; aclDataType acl_type = ggml_cann_type_mapping(dtype);
size_t nb[] = {sizeof(uint16_t)}; size_t type_size = ggml_type_size(dtype);
ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * sizeof(uint16_t)); int64_t ne[] = {size};
size_t nb[] = {type_size};
ggml_cann_pool_alloc arange_allocator(ctx.pool(), size * type_size);
void* arange_buffer = arange_allocator.get(); void* arange_buffer = arange_allocator.get();
aclTensor* arange_tensor = ggml_cann_create_tensor( aclTensor* arange_tensor = ggml_cann_create_tensor(
arange_buffer, ACL_FLOAT16, sizeof(uint16_t), ne, nb, 1); arange_buffer, acl_type, type_size, ne, nb, 1);
aclnn_arange(ctx, arange_tensor, start, stop, step, size); aclnn_arange(ctx, arange_tensor, start, stop, step, size);
aclTensor* slope_tensor = ggml_cann_create_tensor( aclTensor* slope_tensor = ggml_cann_create_tensor(
slope_buffer, ACL_FLOAT16, sizeof(uint16_t), ne, nb, 1); slope_buffer, acl_type, type_size, ne, nb, 1);
aclScalar* sc = aclCreateScalar(&m, aclDataType::ACL_FLOAT); aclScalar* sc = aclCreateScalar(&m, aclDataType::ACL_FLOAT);
@ -1468,10 +1482,11 @@ static void aclnn_get_slope_inner(ggml_backend_cann_context& ctx, void* slope_bu
* @param n_head Total number of attention heads. * @param n_head Total number of attention heads.
* @param slope_buffer Pointer to the output buffer (float array) for storing slopes. * @param slope_buffer Pointer to the output buffer (float array) for storing slopes.
* @param max_bias Maximum bias value for slope computation. * @param max_bias Maximum bias value for slope computation.
* @param dtype Data type for slope tensor.
* *
*/ */
static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head, static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head,
void* slope_buffer, float max_bias) { void* slope_buffer, float max_bias, ggml_type dtype) {
const int n_head_log2 = 1u << (uint32_t) floor(log2(n_head)); const int n_head_log2 = 1u << (uint32_t) floor(log2(n_head));
float m0 = powf(2.0f, -(max_bias) / n_head_log2); float m0 = powf(2.0f, -(max_bias) / n_head_log2);
@ -1488,7 +1503,7 @@ static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head,
float step = 1; float step = 1;
float count = n_head_log2; float count = n_head_log2;
// end needs to be +1 because aclnn uses a left-closed, right-open interval. // end needs to be +1 because aclnn uses a left-closed, right-open interval.
aclnn_get_slope_inner(ctx, slope_buffer, m0, count, start, end + 1, step); aclnn_get_slope_inner(ctx, slope_buffer, m0, count, start, end + 1, step, dtype);
if (n_head_log2 < n_head) { if (n_head_log2 < n_head) {
// arange2 // arange2
start = 2 * (n_head_log2 - n_head_log2) + 1; start = 2 * (n_head_log2 - n_head_log2) + 1;
@ -1497,7 +1512,7 @@ static void aclnn_get_slope(ggml_backend_cann_context & ctx, int64_t n_head,
count = n_head - n_head_log2; count = n_head - n_head_log2;
aclnn_get_slope_inner( aclnn_get_slope_inner(
ctx, (char *) slope_buffer + n_head_log2 * sizeof(float), ctx, (char *) slope_buffer + n_head_log2 * sizeof(float),
m1, count, start, end + 1, step); m1, count, start, end + 1, step, dtype);
} }
} }
@ -1534,7 +1549,7 @@ static void aclnn_add_alibi(ggml_backend_cann_context& ctx, ggml_tensor* mask,
ggml_cann_pool_alloc bias_allocator( ggml_cann_pool_alloc bias_allocator(
ctx.pool(), ggml_nelements(dst) * ggml_element_size(dst)); ctx.pool(), ggml_nelements(dst) * ggml_element_size(dst));
bias_buffer = bias_allocator.get(); bias_buffer = bias_allocator.get();
aclnn_get_slope(ctx, n_heads, slope_buffer, max_bias); aclnn_get_slope(ctx, n_heads, slope_buffer, max_bias, GGML_TYPE_F32);
} }
// broadcast for mask, slop and dst; // broadcast for mask, slop and dst;
@ -1760,10 +1775,10 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
case GGML_TYPE_F16: { case GGML_TYPE_F16: {
aclTensor* acl_src0 = ggml_cann_create_tensor(src0); aclTensor* acl_src0 = ggml_cann_create_tensor(src0);
ggml_cann_pool_alloc src_buffer_allocator( ggml_cann_pool_alloc src_buffer_allocator(
ctx.pool(), ggml_nelements(src0) * sizeof(float_t)); ctx.pool(), ggml_nelements(src0) * sizeof(float));
void* src_trans_buffer = src_buffer_allocator.get(); void* src_trans_buffer = src_buffer_allocator.get();
size_t src_trans_nb[GGML_MAX_DIMS]; size_t src_trans_nb[GGML_MAX_DIMS];
src_trans_nb[0] = sizeof(float_t); src_trans_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) { for (int i = 1; i < GGML_MAX_DIMS; i++) {
src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1]; src_trans_nb[i] = src_trans_nb[i - 1] * src0->ne[i - 1];
} }
@ -1807,14 +1822,14 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
// [3,4,5,64] -> [3,4,5,2,32] // [3,4,5,64] -> [3,4,5,2,32]
dequant_ne = weight_ne; dequant_ne = weight_ne;
dequant_nb[0] = sizeof(float_t); dequant_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS + 1; i++) { for (int i = 1; i < GGML_MAX_DIMS + 1; i++) {
dequant_nb[i] = dequant_nb[i - 1] * dequant_ne[i - 1]; dequant_nb[i] = dequant_nb[i - 1] * dequant_ne[i - 1];
} }
scale_offset = ggml_nelements(src0) * sizeof(int8_t); scale_offset = ggml_nelements(src0) * sizeof(int8_t);
ggml_cann_pool_alloc dequant_buffer_allocator( ggml_cann_pool_alloc dequant_buffer_allocator(
ctx.pool(), ggml_nelements(src0) * sizeof(float_t)); ctx.pool(), ggml_nelements(src0) * sizeof(float));
aclTensor* acl_weight_tensor = ggml_cann_create_tensor( aclTensor* acl_weight_tensor = ggml_cann_create_tensor(
src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb, src0->data, ACL_INT8, sizeof(int8_t), weight_ne, weight_nb,
@ -1823,11 +1838,11 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb, src0->data, ACL_FLOAT16, sizeof(uint16_t), scale_ne, scale_nb,
GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset); GGML_MAX_DIMS + 1, ACL_FORMAT_ND, scale_offset);
aclTensor* dequant_tensor = ggml_cann_create_tensor( aclTensor* dequant_tensor = ggml_cann_create_tensor(
dequant_buffer_allocator.get(), ACL_FLOAT, sizeof(float_t), dequant_buffer_allocator.get(), ACL_FLOAT, sizeof(float),
dequant_ne, dequant_nb, GGML_MAX_DIMS + 1); dequant_ne, dequant_nb, GGML_MAX_DIMS + 1);
aclnn_mul(ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor); aclnn_mul(ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor);
dequant_nb[0] = sizeof(float_t); dequant_nb[0] = sizeof(float);
dequant_ne = src0->ne; dequant_ne = src0->ne;
for (int i = 1; i < GGML_MAX_DIMS; i++) { for (int i = 1; i < GGML_MAX_DIMS; i++) {
dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1]; dequant_nb[i] = dequant_nb[i - 1] * src0->ne[i - 1];
@ -1948,7 +1963,7 @@ static void ggml_cann_mat_mul_fp(ggml_backend_cann_context& ctx,
aclTensor* acl_weight_tensor; aclTensor* acl_weight_tensor;
// Only check env once. // Only check env once.
static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("")); static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
if (weight_to_nz && is_matmul_weight(weight)) { if (weight_to_nz && is_matmul_weight(weight)) {
int64_t acl_stride[2] = {1, transpose_ne[1]}; int64_t acl_stride[2] = {1, transpose_ne[1]};
@ -2248,46 +2263,35 @@ static void aclnn_index_fill_tensor(ggml_backend_cann_context& ctx,
* 5. Compute sin(θ), cos(θ) and optionally scale by attn_factor. * 5. Compute sin(θ), cos(θ) and optionally scale by attn_factor.
* 6. Expand sin/cos values by repeat or repeat_interleave depending * 6. Expand sin/cos values by repeat or repeat_interleave depending
* on whether @param is_neox is enabled. * on whether @param is_neox is enabled.
* 7. Store the computed values into persistent buffers
* (ctx.rope_sin_ptr / ctx.rope_cos_ptr).
* *
* @param ctx The CANN backend context, holding memory pool, * @param ctx The CANN backend context, holding memory pool,
* stream, and persistent buffers for rope init/cache. * stream, and persistent buffers for rope init/cache.
* @param dst The destination ggml_tensor whose computation * @param dst The destination ggml_tensor whose computation
* depends on the cached RoPE values (usually Qcur/Kcur). * depends on the RoPE values (usually Qcur/Kcur).
* @param theta_scale Scalar exponent base for computing theta scale values. * @param sin_tensor_buffer Pre-allocated buffer for storing repeated sin values.
* @param freq_scale Frequency scaling factor, applied to theta scale. * @param cos_tensor_buffer Pre-allocated buffer for storing repeated cos values.
* @param attn_factor Attention scaling factor, applied to sin/cos. * @param theta_scale Scalar exponent base for computing theta scale values.
* @param is_neox Whether to use Neox-style repeat strategy * @param freq_scale Frequency scaling factor, applied to theta scale.
* (dim expansion vs repeat_interleave). * @param attn_factor Attention scaling factor, applied to sin/cos.
* @param is_neox Whether to use Neox-style repeat strategy
* (dim expansion vs repeat_interleave).
*/ */
static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst, static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
void* sin_tensor_buffer, void* cos_tensor_buffer,
float* corr_dims, float ext_factor,
float theta_scale, float freq_scale, float theta_scale, float freq_scale,
float attn_factor, bool is_neox) { float attn_factor, bool is_neox) {
// int sin/cos cache, cache has different repeat method depond on // int sin/cos cache, cache has different repeat method depond on
// @param.is_neox // @param.is_neox
bool is_q = (std::strncmp(dst->name, "Qcur-", 5) == 0);
bool is_k = (std::strncmp(dst->name, "Kcur-", 5) == 0);
// used for accuracy testing
bool is_attention = is_q || is_k;
// just compute in first layer in attention
bool is_fisrt_layer = (std::strncmp(dst->name, "Qcur-0", GGML_MAX_NAME) == 0);
if(is_attention && !is_fisrt_layer) {
return;
}
ggml_tensor* src0 = dst->src[0]; // input ggml_tensor* src0 = dst->src[0]; // input
ggml_tensor* src1 = dst->src[1]; // position ggml_tensor* src1 = dst->src[1]; // position
ggml_tensor* src2 = dst->src[2]; // freq_factors ggml_tensor* src2 = dst->src[2]; // freq_factors
GGML_TENSOR_BINARY_OP_LOCALS int64_t theta_scale_length = src0->ne[0] / 2;
int64_t theta_scale_length = ne00 / 2;
int64_t theta_scale_ne[] = {theta_scale_length, 1, 1, 1}; int64_t theta_scale_ne[] = {theta_scale_length, 1, 1, 1};
size_t theta_scale_nb[] = {sizeof(float_t), sizeof(float_t), sizeof(float_t), size_t theta_scale_nb[] = {sizeof(float), sizeof(float), sizeof(float),
theta_scale_length * sizeof(float_t)}; theta_scale_length * sizeof(float)};
GGML_ASSERT(src1->type == GGML_TYPE_I32); GGML_ASSERT(src1->type == GGML_TYPE_I32);
int64_t position_length = src1->ne[0]; int64_t position_length = src1->ne[0];
@ -2297,65 +2301,115 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
int64_t theta_ne[] = {theta_scale_length, 1, position_length, 1}; int64_t theta_ne[] = {theta_scale_length, 1, position_length, 1};
size_t theta_nb[GGML_MAX_DIMS]; size_t theta_nb[GGML_MAX_DIMS];
theta_nb[0] = sizeof(float_t); theta_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) { for (int i = 1; i < GGML_MAX_DIMS; i++) {
theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1]; theta_nb[i] = theta_nb[i - 1] * theta_ne[i - 1];
} }
// init theta scale, just one time // theta_scale arange, [0,1,...,ne00/2 - 1]
if(ctx.rope_init_ptr == nullptr || !is_attention) { aclTensor* acl_theta_scale_tensor = nullptr;
// theta_scale arange, [0,1,...,ne00/2 - 1] // cache theta scale
if(ctx.rope_init_ptr != nullptr){ if (ctx.rope_cache.theta_scale_length != theta_scale_length ||
ACL_CHECK(aclrtFree(ctx.rope_init_ptr)); // theta_scale and freq_scale should not change during the current token inference process,
} // so we can directly use == here instead of comparing the absolute difference.
ACL_CHECK(aclrtMalloc(&ctx.rope_init_ptr, theta_scale_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST)); ctx.rope_cache.theta_scale != theta_scale ||
ctx.rope_cache.freq_scale != freq_scale) {
aclTensor* acl_theta_scale_tensor = ctx.rope_cache.theta_scale_length = theta_scale_length;
ggml_cann_create_tensor(ctx.rope_init_ptr, ACL_FLOAT, sizeof(float_t), ctx.rope_cache.theta_scale = theta_scale;
ctx.rope_cache.freq_scale = freq_scale;
if (ctx.rope_cache.theta_scale_cache != nullptr) {
ACL_CHECK(aclrtFree(ctx.rope_cache.theta_scale_cache));
}
ACL_CHECK(aclrtMalloc(&ctx.rope_cache.theta_scale_cache, theta_scale_length * sizeof(float), ACL_MEM_MALLOC_HUGE_FIRST));
acl_theta_scale_tensor =
ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
float start = 0; float start = 0;
float step = 1; float step = 1;
float stop = ne00 / 2; float stop = theta_scale_length;
float n_elements = ne00 / 2; float n_elements = theta_scale_length;
aclnn_arange(ctx, acl_theta_scale_tensor, start, stop, step, n_elements); aclnn_arange(ctx, acl_theta_scale_tensor, start, stop, step, n_elements);
ggml_cann_pool_alloc yarn_ramp_allocator(ctx.pool());
aclTensor* acl_yarn_ramp_tensor = nullptr;
if (ext_factor != 0) {
// -rope_yarn_ramp
// const float y = (i0 / 2 - low) / MAX(0.001f, high - low);
// return MIN(1, MAX(0, y)) - 1;
yarn_ramp_allocator.alloc(theta_scale_length * sizeof(float));
void* yarn_ramp_buffer = yarn_ramp_allocator.get();
acl_yarn_ramp_tensor = ggml_cann_create_tensor(yarn_ramp_buffer, ACL_FLOAT, sizeof(float_t),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
float zero_value = 0, one_value = 1;
float denom_safe_value = MAX(0.001f, corr_dims[1] - corr_dims[0]);
aclScalar* low = aclCreateScalar(&corr_dims[0], aclDataType::ACL_FLOAT);
aclScalar* zero = aclCreateScalar(&zero_value, aclDataType::ACL_FLOAT);
aclScalar* one = aclCreateScalar(&one_value, aclDataType::ACL_FLOAT);
aclScalar* denom_safe = aclCreateScalar(&denom_safe_value, aclDataType::ACL_FLOAT);
aclScalar* ext_factor_sc = aclCreateScalar(&ext_factor, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, Subs, acl_theta_scale_tensor, low, one, acl_yarn_ramp_tensor);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceDivs, acl_yarn_ramp_tensor, denom_safe);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceThreshold, acl_yarn_ramp_tensor, zero, zero);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceClampMax, acl_yarn_ramp_tensor, one);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceSubs, acl_yarn_ramp_tensor, one, one);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor, ext_factor_sc);
// theta_interp = freq_scale * theta_extrap;
// theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix;
// theta = freq_scale * theta_extrap * (1 - ramp_mix) + theta_extrap * ramp_mix;
// theta = freq_scale * theta_extrap - freq_scale * theta_extrap * ramp_mix + theta_extrap * ramp_mix;
// theta = theta_extrap * (freq_scale - freq_scale * ramp_mix + ramp_mix);
//
// we cache (freq_scale - freq_scale * ramp_mix + ramp_mix), Considering that the rope_yarn_ramp here is the inverse
// cache freq_scale + (freq_scale - 1) * ramp_mix
float freq_scale_1 = freq_scale - 1;
aclScalar* freq_scale_sc = aclCreateScalar(&freq_scale, aclDataType::ACL_FLOAT);
aclScalar* freq_scale_1_sc = aclCreateScalar(&freq_scale_1, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMuls, acl_yarn_ramp_tensor, freq_scale_1_sc);
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceAdds, acl_yarn_ramp_tensor, freq_scale_sc, one);
ggml_cann_release_resources(ctx, low, zero, one, denom_safe, ext_factor_sc, freq_scale_sc, freq_scale_1_sc);
}
// power // power
aclScalar* acl_theta_scale = aclCreateScalar(&theta_scale, aclDataType::ACL_FLOAT); aclScalar* acl_theta_scale = aclCreateScalar(&theta_scale, aclDataType::ACL_FLOAT);
GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale, acl_theta_scale_tensor, GGML_CANN_CALL_ACLNN_OP(ctx, PowScalarTensor, acl_theta_scale, acl_theta_scale_tensor,
acl_theta_scale_tensor); acl_theta_scale_tensor);
// freq_scale if (ext_factor != 0) {
if (freq_scale != 1) { aclnn_mul(ctx, acl_theta_scale_tensor, acl_yarn_ramp_tensor);
} else if (freq_scale != 1) {
aclnn_muls(ctx, acl_theta_scale_tensor, freq_scale, nullptr, true); aclnn_muls(ctx, acl_theta_scale_tensor, freq_scale, nullptr, true);
} }
// freq_factors ggml_cann_release_resources(ctx, acl_yarn_ramp_tensor, acl_theta_scale);
if (src2) { } else {
aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor( // use cache
src2->data, ggml_cann_type_mapping(src2->type), acl_theta_scale_tensor =
ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); ggml_cann_create_tensor(ctx.rope_cache.theta_scale_cache, ACL_FLOAT, sizeof(float),
aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor);
ggml_cann_release_resources(ctx, acl_freq_factors_tensor);
}
// release
ggml_cann_release_resources(ctx, acl_theta_scale_tensor,acl_theta_scale);
}
// init sin_repeat && cos_repeat, one token just init in 0 layer
if(position_length > ctx.max_prompt_length) {
ctx.max_prompt_length = position_length;
int64_t repeat_theta_length = theta_scale_length * ctx.max_prompt_length * 2;
if(ctx.rope_sin_ptr != nullptr) {
ACL_CHECK(aclrtFree(ctx.rope_sin_ptr));
ACL_CHECK(aclrtFree(ctx.rope_cos_ptr));
}
ACL_CHECK(aclrtMalloc(&ctx.rope_sin_ptr, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
ACL_CHECK(aclrtMalloc(&ctx.rope_cos_ptr, repeat_theta_length * sizeof(float_t), ACL_MEM_MALLOC_HUGE_FIRST));
}
aclTensor* acl_theta_scale_tensor =
ggml_cann_create_tensor(ctx.rope_init_ptr, ACL_FLOAT, sizeof(float_t),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS); theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
}
ggml_cann_pool_alloc freq_fac_res_allocator(ctx.pool());
// freq_factors
if (src2) {
freq_fac_res_allocator.alloc(theta_scale_length * sizeof(float));
void* freq_fac_res_ptr = freq_fac_res_allocator.get();
aclTensor* acl_freq_factors_tensor = ggml_cann_create_tensor(
src2->data, ggml_cann_type_mapping(src2->type),
ggml_type_size(src2->type), theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
aclTensor* acl_freq_fac_res_tensor = ggml_cann_create_tensor(
freq_fac_res_ptr, ACL_FLOAT, sizeof(float),
theta_scale_ne, theta_scale_nb, GGML_MAX_DIMS);
aclnn_div(ctx, acl_theta_scale_tensor, acl_freq_factors_tensor, acl_freq_fac_res_tensor);
std::swap(acl_theta_scale_tensor, acl_freq_fac_res_tensor);
ggml_cann_release_resources(ctx, acl_freq_factors_tensor, acl_freq_fac_res_tensor);
}
// position // position
aclTensor* acl_position_tensor = ggml_cann_create_tensor( aclTensor* acl_position_tensor = ggml_cann_create_tensor(
@ -2365,49 +2419,53 @@ static void aclnn_cache_init(ggml_backend_cann_context& ctx, ggml_tensor* dst,
// power * position // power * position
int64_t theta_length = theta_scale_length * position_length; int64_t theta_length = theta_scale_length * position_length;
ggml_cann_pool_alloc theta_allocator(ctx.pool(), ggml_cann_pool_alloc theta_allocator(ctx.pool(),
theta_length * sizeof(float_t)); theta_length * sizeof(float));
void* theta_buffer = theta_allocator.get(); void* theta_buffer = theta_allocator.get();
aclTensor* acl_theta_tensor = aclTensor* acl_theta_tensor =
ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float_t), ggml_cann_create_tensor(theta_buffer, ACL_FLOAT, sizeof(float),
theta_ne, theta_nb, GGML_MAX_DIMS); theta_ne, theta_nb, GGML_MAX_DIMS);
aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor, aclnn_mul(ctx, acl_position_tensor, acl_theta_scale_tensor,
acl_theta_tensor); acl_theta_tensor);
// sin/cos // sin/cos
ggml_cann_pool_alloc sin_allocator(ctx.pool(), ggml_cann_pool_alloc sin_allocator(ctx.pool(),
theta_length * sizeof(float_t)); theta_length * sizeof(float));
void* sin_buffer = sin_allocator.get(); void* sin_buffer = sin_allocator.get();
aclTensor* acl_sin_tensor = ggml_cann_create_tensor( aclTensor* acl_sin_tensor = ggml_cann_create_tensor(
sin_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb, sin_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb,
GGML_MAX_DIMS, ACL_FORMAT_ND); GGML_MAX_DIMS, ACL_FORMAT_ND);
aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor); aclnn_sin(ctx, acl_theta_tensor, acl_sin_tensor);
ggml_cann_pool_alloc cos_allocator(ctx.pool(), ggml_cann_pool_alloc cos_allocator(ctx.pool(),
theta_length * sizeof(float_t)); theta_length * sizeof(float));
void* cos_buffer = cos_allocator.get(); void* cos_buffer = cos_allocator.get();
aclTensor* acl_cos_tensor = ggml_cann_create_tensor( aclTensor* acl_cos_tensor = ggml_cann_create_tensor(
cos_buffer, ACL_FLOAT, sizeof(float_t), theta_ne, theta_nb, cos_buffer, ACL_FLOAT, sizeof(float), theta_ne, theta_nb,
GGML_MAX_DIMS, ACL_FORMAT_ND); GGML_MAX_DIMS, ACL_FORMAT_ND);
aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor); aclnn_cos(ctx, acl_theta_tensor, acl_cos_tensor);
if (ext_factor != 0) {
attn_factor *= 1.0f + 0.1f * logf(1.0f / freq_scale);
}
// attn_factor // attn_factor
if (attn_factor != 1) { if (attn_factor != 1) {
aclnn_muls(ctx, acl_sin_tensor, attn_factor, nullptr, true); aclnn_muls(ctx, acl_sin_tensor, attn_factor, nullptr, true);
aclnn_muls(ctx, acl_cos_tensor, attn_factor, nullptr, true); aclnn_muls(ctx, acl_cos_tensor, attn_factor, nullptr, true);
} }
int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1}; int64_t sin_reshape_ne[4] = {src0->ne[0], 1, src0->ne[2], 1};
size_t sin_reshape_nb[GGML_MAX_DIMS]; size_t sin_reshape_nb[GGML_MAX_DIMS];
sin_reshape_nb[0] = sizeof(float_t); sin_reshape_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) { for (int i = 1; i < GGML_MAX_DIMS; i++) {
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1]; sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
} }
aclTensor* acl_sin_repeat_tensor = aclTensor* acl_sin_repeat_tensor =
ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t), ggml_cann_create_tensor(sin_tensor_buffer, ACL_FLOAT, sizeof(float),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
aclTensor* acl_cos_repeat_tensor = aclTensor* acl_cos_repeat_tensor =
ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t), ggml_cann_create_tensor(cos_tensor_buffer, ACL_FLOAT, sizeof(float),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
// repeat // repeat
@ -2449,6 +2507,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
// TODO: use ascendc // TODO: use ascendc
// Only test with LLAMA model. // Only test with LLAMA model.
ggml_tensor* src0 = dst->src[0]; // input ggml_tensor* src0 = dst->src[0]; // input
ggml_tensor* src1 = dst->src[1];
// param // param
float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow; float freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow;
@ -2470,8 +2529,6 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
// TODO: n_dims <= ne0 // TODO: n_dims <= ne0
GGML_ASSERT(n_dims == ne0); GGML_ASSERT(n_dims == ne0);
GGML_ASSERT(n_dims % 2 == 0); GGML_ASSERT(n_dims % 2 == 0);
// TODO: ext_factor != 0
GGML_ASSERT(ext_factor == 0);
const float theta_scale = powf(freq_base, -2.0f / n_dims); const float theta_scale = powf(freq_base, -2.0f / n_dims);
@ -2481,20 +2538,28 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
// sin/cos tensor length.
int64_t repeat_theta_length = src0->ne[0] * src1->ne[0];
ggml_cann_pool_alloc sin_tensor_allocator(ctx.pool(), repeat_theta_length * sizeof(float));
ggml_cann_pool_alloc cos_tensor_allocator(ctx.pool(), repeat_theta_length * sizeof(float));
void *sin_tensor_buffer = sin_tensor_allocator.get();
void *cos_tensor_buffer = cos_tensor_allocator.get();
// init ctx.rope_cos/rope_sin cache // init ctx.rope_cos/rope_sin cache
aclnn_cache_init(ctx, dst, theta_scale, freq_scale, attn_factor, is_neox); aclnn_cache_init(ctx, dst, sin_tensor_buffer, cos_tensor_buffer, corr_dims, ext_factor,
theta_scale, freq_scale, attn_factor, is_neox);
int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1}; int64_t sin_reshape_ne[4] = {ne00, 1, ne02, 1};
size_t sin_reshape_nb[GGML_MAX_DIMS]; size_t sin_reshape_nb[GGML_MAX_DIMS];
sin_reshape_nb[0] = sizeof(float_t); sin_reshape_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) { for (int i = 1; i < GGML_MAX_DIMS; i++) {
sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1]; sin_reshape_nb[i] = sin_reshape_nb[i - 1] * sin_reshape_ne[i - 1];
} }
aclTensor* acl_sin_reshape_tensor = aclTensor* acl_sin_reshape_tensor =
ggml_cann_create_tensor(ctx.rope_sin_ptr, ACL_FLOAT, sizeof(float_t), ggml_cann_create_tensor(sin_tensor_buffer, ACL_FLOAT, sizeof(float),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
aclTensor* acl_cos_reshape_tensor = aclTensor* acl_cos_reshape_tensor =
ggml_cann_create_tensor(ctx.rope_cos_ptr, ACL_FLOAT, sizeof(float_t), ggml_cann_create_tensor(cos_tensor_buffer, ACL_FLOAT, sizeof(float),
sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS); sin_reshape_ne, sin_reshape_nb, GGML_MAX_DIMS);
aclTensor* acl_src = ggml_cann_create_tensor(src0); aclTensor* acl_src = ggml_cann_create_tensor(src0);
@ -2509,7 +2574,7 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
void* minus_one_scale_buffer = nullptr; void* minus_one_scale_buffer = nullptr;
ggml_cann_pool_alloc roll_allocator(ctx.pool(), ggml_nbytes(src0)); ggml_cann_pool_alloc roll_allocator(ctx.pool(), ggml_nbytes(src0));
ggml_cann_pool_alloc minus_one_scale_allocator( ggml_cann_pool_alloc minus_one_scale_allocator(
ctx.pool(), sizeof(float_t) * src0->ne[0]); ctx.pool(), sizeof(float) * src0->ne[0]);
if (!is_neox) { if (!is_neox) {
// roll input: [q0,q1,q2,q3,...] -> [q1,q0,q3,q2,...] // roll input: [q0,q1,q2,q3,...] -> [q1,q0,q3,q2,...]
input_roll_buffer = roll_allocator.get(); input_roll_buffer = roll_allocator.get();
@ -2539,13 +2604,13 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1}; int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1};
size_t minus_one_nb[GGML_MAX_DIMS]; size_t minus_one_nb[GGML_MAX_DIMS];
minus_one_nb[0] = sizeof(float_t); minus_one_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) { for (int i = 1; i < GGML_MAX_DIMS; i++) {
minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1]; minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1];
} }
acl_minus_one_tensor = aclnn_values( acl_minus_one_tensor = aclnn_values(
ctx, minus_one_scale_buffer, sizeof(float_t) * src0->ne[0], ctx, minus_one_scale_buffer, sizeof(float) * src0->ne[0],
minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), 1); minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float), 1);
int64_t dim = 3; int64_t dim = 3;
int64_t* index = new int64_t[src0->ne[0]]; int64_t* index = new int64_t[src0->ne[0]];
for (int i = 0; i < src0->ne[0]; i++) { for (int i = 0; i < src0->ne[0]; i++) {
@ -2573,22 +2638,22 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
minus_one_scale_buffer = minus_one_scale_allocator.get(); minus_one_scale_buffer = minus_one_scale_allocator.get();
int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1}; int64_t minus_one_ne[4] = {src0->ne[0], 1, 1, 1};
size_t minus_one_nb[GGML_MAX_DIMS]; size_t minus_one_nb[GGML_MAX_DIMS];
minus_one_nb[0] = sizeof(float_t); minus_one_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) { for (int i = 1; i < GGML_MAX_DIMS; i++) {
minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1]; minus_one_nb[i] = minus_one_nb[i - 1] * minus_one_ne[i - 1];
} }
acl_minus_one_tensor = aclnn_values( acl_minus_one_tensor = aclnn_values(
ctx, minus_one_scale_buffer, sizeof(float_t) * src0->ne[0], ctx, minus_one_scale_buffer, sizeof(float) * src0->ne[0],
minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float_t), 1); minus_one_ne, GGML_MAX_DIMS, ACL_FLOAT, sizeof(float), 1);
// -1 * first half // -1 * first half
int64_t first_half_ne[4] = {src0->ne[0] / 2, 1, 1, 1}; int64_t first_half_ne[4] = {src0->ne[0] / 2, 1, 1, 1};
size_t first_half_nb[GGML_MAX_DIMS]; size_t first_half_nb[GGML_MAX_DIMS];
first_half_nb[0] = sizeof(float_t); first_half_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) { for (int i = 1; i < GGML_MAX_DIMS; i++) {
first_half_nb[i] = first_half_nb[i - 1] * first_half_ne[i - 1]; first_half_nb[i] = first_half_nb[i - 1] * first_half_ne[i - 1];
} }
aclTensor* acl_first_half_tensor = ggml_cann_create_tensor( aclTensor* acl_first_half_tensor = ggml_cann_create_tensor(
minus_one_scale_buffer, ACL_FLOAT, sizeof(float_t), first_half_ne, minus_one_scale_buffer, ACL_FLOAT, sizeof(float), first_half_ne,
first_half_nb, GGML_MAX_DIMS); first_half_nb, GGML_MAX_DIMS);
bool inplace = true; bool inplace = true;
float scale = -1; float scale = -1;
@ -2628,28 +2693,28 @@ void ggml_cann_rope(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
// TODO: ne0 != n_dims in mode2 // TODO: ne0 != n_dims in mode2
} else if (src0->type == GGML_TYPE_F16) { } else if (src0->type == GGML_TYPE_F16) {
size_t input_fp32_nb[GGML_MAX_DIMS]; size_t input_fp32_nb[GGML_MAX_DIMS];
input_fp32_nb[0] = sizeof(float_t); input_fp32_nb[0] = sizeof(float);
for (int i = 1; i < GGML_MAX_DIMS; i++) { for (int i = 1; i < GGML_MAX_DIMS; i++) {
input_fp32_nb[i] = input_fp32_nb[i - 1] * dst->ne[i - 1]; input_fp32_nb[i] = input_fp32_nb[i - 1] * dst->ne[i - 1];
} }
ggml_cann_pool_alloc fp32_allocator1( ggml_cann_pool_alloc fp32_allocator1(
ctx.pool(), ggml_nelements(dst) * sizeof(float_t)); ctx.pool(), ggml_nelements(dst) * sizeof(float));
void* input_fp32_buffer1 = fp32_allocator1.get(); void* input_fp32_buffer1 = fp32_allocator1.get();
aclTensor* input_fp32_tensor1 = ggml_cann_create_tensor( aclTensor* input_fp32_tensor1 = ggml_cann_create_tensor(
input_fp32_buffer1, ACL_FLOAT, sizeof(float_t), dst->ne, input_fp32_buffer1, ACL_FLOAT, sizeof(float), dst->ne,
input_fp32_nb, GGML_MAX_DIMS); input_fp32_nb, GGML_MAX_DIMS);
ggml_cann_pool_alloc fp32_allocator2( ggml_cann_pool_alloc fp32_allocator2(
ctx.pool(), ggml_nelements(dst) * sizeof(float_t)); ctx.pool(), ggml_nelements(dst) * sizeof(float));
void* input_fp32_buffer2 = fp32_allocator2.get(); void* input_fp32_buffer2 = fp32_allocator2.get();
aclTensor* input_fp32_tensor2 = ggml_cann_create_tensor( aclTensor* input_fp32_tensor2 = ggml_cann_create_tensor(
input_fp32_buffer2, ACL_FLOAT, sizeof(float_t), dst->ne, input_fp32_buffer2, ACL_FLOAT, sizeof(float), dst->ne,
input_fp32_nb, GGML_MAX_DIMS); input_fp32_nb, GGML_MAX_DIMS);
ggml_cann_pool_alloc fp32_allocator( ggml_cann_pool_alloc fp32_allocator(
ctx.pool(), ggml_nelements(dst) * sizeof(float_t)); ctx.pool(), ggml_nelements(dst) * sizeof(float));
output_fp32_buffer = fp32_allocator.get(); output_fp32_buffer = fp32_allocator.get();
aclTensor* output_fp32_tensor = ggml_cann_create_tensor( aclTensor* output_fp32_tensor = ggml_cann_create_tensor(
output_fp32_buffer, ACL_FLOAT, sizeof(float_t), dst->ne, output_fp32_buffer, ACL_FLOAT, sizeof(float), dst->ne,
input_fp32_nb, GGML_MAX_DIMS); input_fp32_nb, GGML_MAX_DIMS);
aclnn_mul(ctx, acl_src, acl_cos_reshape_tensor, input_fp32_tensor1); aclnn_mul(ctx, acl_src, acl_cos_reshape_tensor, input_fp32_tensor1);
aclnn_mul(ctx, acl_input_roll_mul_scale_tensor, acl_sin_reshape_tensor, aclnn_mul(ctx, acl_input_roll_mul_scale_tensor, acl_sin_reshape_tensor,
@ -2746,8 +2811,6 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
aclIntArray *padding = aclCreateIntArray(paddingVal, 1); aclIntArray *padding = aclCreateIntArray(paddingVal, 1);
int64_t dilationVal[] = {1}; int64_t dilationVal[] = {1};
aclIntArray *dilation = aclCreateIntArray(dilationVal, 1); aclIntArray *dilation = aclCreateIntArray(dilationVal, 1);
bool transposed = true;
int64_t groups = 1;
int8_t cubeMathType = 0; int8_t cubeMathType = 0;
#ifdef ASCEND_310P #ifdef ASCEND_310P
@ -2755,7 +2818,7 @@ void ggml_cann_conv_transpose_1d(ggml_backend_cann_context& ctx, ggml_tensor* ds
#endif #endif
GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input, acl_weight, nullptr, stride, GGML_CANN_CALL_ACLNN_OP(ctx, Convolution, acl_input, acl_weight, nullptr, stride,
padding, dilation, transposed, padding, groups, acl_dst, cubeMathType); padding, dilation, true, padding, 1, acl_dst, cubeMathType);
ggml_cann_release_resources(ctx, acl_weight, acl_dst, stride, padding, dilation); ggml_cann_release_resources(ctx, acl_weight, acl_dst, stride, padding, dilation);
} }
@ -2864,174 +2927,49 @@ void ggml_cann_step(ggml_backend_cann_context& ctx, ggml_tensor* dst){
*/ */
static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor* dst) { static void ggml_cann_mul_mat_id_fp(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
//dst [M, K, N, 1] //dst [M, K, N, 1]
ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1] ggml_tensor * src0 = dst->src[0]; //src0 [D, M, A, 1] -> [D, M, K, 1]
ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1 ggml_tensor * src1 = dst->src[1]; //src1 [D, B, N, 1], B = K or B = 1 -> [D, 1, K, 1]
ggml_tensor * ids = dst->src[2]; //ids [K, N] ggml_tensor * ids = dst->src[2]; //ids [K, N]
GGML_TENSOR_BINARY_OP_LOCALS GGML_ASSERT(src0->ne[3] == 1);
GGML_ASSERT(src1->ne[3] == 1);
GGML_ASSERT(dst->ne[3] == 1);
// copy index from npu to cpu int64_t batch = src1->ne[2];
int64_t n_as = ne02; // A GGML_ASSERT(batch == ids->ne[1]);
int64_t n_ids = ids->ne[0]; // K
std::vector<char> ids_host(ggml_nbytes(ids)); ggml_cann_pool_alloc export_allocator(ctx.pool(), src0->ne[0] * src0->ne[1] * ids->ne[0] * ggml_element_size(src0));
ggml_cann_async_memcpy(ctx, ids_host.data(), ids->data, ggml_nbytes(ids), void* export_ptr = export_allocator.get();
ACL_MEMCPY_DEVICE_TO_HOST); for (int64_t i = 0; i < batch; i++) {
ACL_CHECK(aclrtSynchronizeStream(ctx.stream())); aclTensor *select_index = ggml_cann_create_tensor(ids, ids->ne, ids->nb, 1, ACL_FORMAT_ND, i * ids->nb[1]);
aclTensor *export_weight = ggml_cann_create_tensor(src0, src0->ne, src0->nb, 3);
char * src0_original = (char *) src0->data; int64_t select_export_ne[] = {src0->ne[0], src0->ne[1], ids->ne[0]};
char * src1_original = (char *) src1->data; size_t select_export_nb[3];
char * dst_original = (char *) dst->data; select_export_nb[0] = src0->nb[0];
size_t ori_src0_nb[4] = {nb00, nb01, nb02, nb03}; for (int k = 1;k < 3; k++) {
select_export_nb[k] = select_export_nb[k-1] * select_export_ne[k-1];
// src0 is F16, src1 is F32, dst is F32
ggml_cann_pool_alloc src0_cast_allocator;
if (src0->type == GGML_TYPE_F16) {
src0_cast_allocator.alloc(ctx.pool(), sizeof(float) * ggml_nelements(src0));
void* src0_cast_buf = src0_cast_allocator.get();
size_t cast_nb[GGML_MAX_DIMS];
cast_nb[0] = sizeof(float_t);
for (int i = 1; i < GGML_MAX_DIMS; i++) {
cast_nb[i] = cast_nb[i - 1] * src0->ne[i - 1];
} }
aclTensor* acl_src0_f16 = ggml_cann_create_tensor(src0); aclTensor *select_export = ggml_cann_create_tensor(export_ptr, ggml_cann_type_mapping(src0->type), ggml_element_size(src0), select_export_ne, select_export_nb, 3);
aclTensor* acl_cast = ggml_cann_create_tensor(src0_cast_buf, GGML_CANN_CALL_ACLNN_OP(ctx, IndexSelect, export_weight, 0, select_index, select_export);
ACL_FLOAT, sizeof(float), src0->ne, cast_nb, 4);
GGML_CANN_CALL_ACLNN_OP(ctx, Cast, acl_src0_f16, ACL_FLOAT, acl_cast);
ggml_cann_release_resources(ctx, acl_cast, acl_src0_f16);
src0_original = (char *) src0_cast_buf; int64_t select_transpose_ne[] = {select_export_ne[1], select_export_ne[0], select_export_ne[2]};
memcpy(ori_src0_nb, cast_nb, sizeof(ori_src0_nb)); size_t select_transpose_nb[] = {select_export_nb[1], select_export_nb[0], select_export_nb[2]};
aclTensor *select_export_transpose = ggml_cann_create_tensor(export_ptr, ggml_cann_type_mapping(src0->type), ggml_element_size(src0), select_transpose_ne, select_transpose_nb, 3);
int64_t active_tensor_ne[] = {src1->ne[0], 1, src1->ne[1]};
size_t active_tensor_nb[] = {src1->nb[0], src1->nb[1], src1->nb[1]};
aclTensor *active_tensor = ggml_cann_create_tensor(src1, active_tensor_ne, active_tensor_nb, 3, ACL_FORMAT_ND, i * src1->nb[2]);
int64_t dst_ne[] = {dst->ne[0], 1, dst->ne[1]};
size_t dst_nb[] = {dst->nb[0], dst->nb[1], dst->nb[1]};
aclTensor *acl_dst = ggml_cann_create_tensor(dst, dst_ne,dst_nb, 3, ACL_FORMAT_ND, i * dst->nb[2]);
GGML_CANN_CALL_ACLNN_OP(ctx, BatchMatMul, active_tensor, select_export_transpose, acl_dst, 2);
ggml_cann_release_resources(ctx, select_index, export_weight, select_export, active_tensor, acl_dst, select_export_transpose);
} }
#ifdef ASCEND_310P
ggml_tensor src0_row = *src0;
ggml_tensor src1_row = *src1;
ggml_tensor dst_row = *dst;
if (src0->type == GGML_TYPE_F16) {
src0_row.type = GGML_TYPE_F32;
}
// src0_row [D, M, 1, 1] weight without permute
src0_row.ne[2] = 1;
src0_row.ne[3] = 1;
src0_row.nb[0] = ori_src0_nb[0];
src0_row.nb[1] = ori_src0_nb[1];
src0_row.nb[2] = ori_src0_nb[1];
src0_row.nb[3] = ori_src0_nb[1];
// src1_row [D, 1, 1, 1] -> input
src1_row.ne[1] = 1;
src1_row.ne[2] = 1;
src1_row.ne[3] = 1;
src1_row.nb[2] = nb11;
src1_row.nb[3] = nb11;
// dst_row [M, 1, 1, 1] -> out
dst_row.ne[1] = 1;
dst_row.ne[2] = 1;
dst_row.ne[3] = 1;
dst_row.nb[2] = nb1;
dst_row.nb[3] = nb1;
//create weight for one row
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
for (int64_t id = 0; id < n_ids; id++) {
// expert index
int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
GGML_ASSERT(i02 >= 0 && i02 < n_as);
// If B = 1 (broadcast), always use 0; otherwise, use id.
int64_t i11 = (ne11 == 1 ? 0 : id);
int64_t i12 = iid1;
int64_t i1 = id;
int64_t i2 = i12;
void* src0_tmp_ptr = src0_original + i02*ori_src0_nb[2];
void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12;
void* dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2;
src0_row.data = src0_tmp_ptr;
src1_row.data = src1_tmp_ptr;
dst_row.data = dst_tmp_ptr;
dst_row.src[0] = &src0_row;
dst_row.src[1] = &src1_row;
ggml_cann_mul_mat(ctx, &dst_row);
}
}
return;
#endif
std::vector<aclTensor*> src0_tensor_vec;
std::vector<aclTensor*> src1_tensor_vec;
std::vector<aclTensor*> dst_tensor_vec;
for (int64_t iid1 = 0; iid1 < ids->ne[1]; iid1++) {
for (int64_t id = 0; id < n_ids; id++) {
// src0_row [M, D] -> weight && permute
int64_t src0_ne[2] = {ne01, ne00};
size_t src0_nb[2] = {ori_src0_nb[1], ori_src0_nb[0]};
// src1_row [D, 1] -> input
int64_t src1_ne[2] = {ne10, 1};
size_t src1_nb[2] = {nb10, nb11};
// dst_row [M, 1] -> out
int64_t dst_ne[2] = {ne0, 1};
size_t dst_nb[2] = {nb0, nb1};
// expert index
int32_t i02 = *(int32_t *) (ids_host.data() + iid1*ids->nb[1] + id*ids->nb[0]);
GGML_ASSERT(i02 >= 0 && i02 < n_as);
// If B = 1 (broadcast), always use 0; otherwise, use id.
int64_t i11 = (ne11 == 1 ? 0 : id);
int64_t i12 = iid1;
int64_t i1 = id;
int64_t i2 = i12;
void* src0_tmp_ptr = src0_original + i02*ori_src0_nb[2];
void* src1_tmp_ptr = src1_original + i11*nb11 + i12*nb12;
void* dst_tmp_ptr = dst_original + i1*nb1 + i2*nb2;
aclTensor* acl_src0 = ggml_cann_create_tensor(src0_tmp_ptr,
ACL_FLOAT, sizeof(float),
src0_ne, src0_nb, 2);
aclTensor* acl_src1 = ggml_cann_create_tensor(src1_tmp_ptr,
ACL_FLOAT, sizeof(float),
src1_ne, src1_nb, 2);
aclTensor* acl_dst = ggml_cann_create_tensor(dst_tmp_ptr,
ACL_FLOAT, sizeof(float),
dst_ne, dst_nb, 2);
src0_tensor_vec.push_back(acl_src0);
src1_tensor_vec.push_back(acl_src1);
dst_tensor_vec.push_back(acl_dst);
}
}
size_t GROUP_SIZE = 128;
// GroupedMatmulV3 required tensor_list.size < 128
for (size_t i = 0; i < src0_tensor_vec.size(); i += GROUP_SIZE) {
// split and call GroupedMatmulV3
size_t end = std::min(i + GROUP_SIZE, src0_tensor_vec.size());
std::vector<aclTensor*> src0_tensor_vec_split(src0_tensor_vec.begin() + i, src0_tensor_vec.begin() + end);
std::vector<aclTensor*> src1_tensor_vec_split(src1_tensor_vec.begin() + i, src1_tensor_vec.begin() + end);
std::vector<aclTensor*> dst_tensor_vec_split(dst_tensor_vec.begin() + i, dst_tensor_vec.begin() + end);
aclTensorList* src0_tensor_list = aclCreateTensorList(src0_tensor_vec_split.data(), src0_tensor_vec_split.size());
aclTensorList* src1_tensor_list = aclCreateTensorList(src1_tensor_vec_split.data(), src1_tensor_vec_split.size());
aclTensorList* dst_tensor_list = aclCreateTensorList(dst_tensor_vec_split.data(), dst_tensor_vec_split.size());
GGML_CANN_CALL_ACLNN_OP(ctx, GroupedMatmulV3, src1_tensor_list, src0_tensor_list,
nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, 0, -1, dst_tensor_list);
ggml_cann_release_resources(ctx, src0_tensor_list, src1_tensor_list, dst_tensor_list);
}
return;
} }
/** /**
@ -3342,7 +3280,7 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
const int64_t n_heads = src0->ne[2]; const int64_t n_heads = src0->ne[2];
ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(uint16_t)); ggml_cann_pool_alloc slope_allocator(ctx.pool(), n_heads * sizeof(uint16_t));
void* slope_buffer = slope_allocator.get(); void* slope_buffer = slope_allocator.get();
aclnn_get_slope(ctx, n_heads, slope_buffer, maxBias); aclnn_get_slope(ctx, n_heads, slope_buffer, maxBias, GGML_TYPE_F16);
int64_t slope_ne[] = {1, 1, n_heads, 1}; int64_t slope_ne[] = {1, 1, n_heads, 1};
size_t slope_nb[GGML_MAX_DIMS]; size_t slope_nb[GGML_MAX_DIMS];

View File

@ -360,6 +360,30 @@ struct ggml_cann_graph {
}; };
#endif // USE_ACL_GRAPH #endif // USE_ACL_GRAPH
struct ggml_cann_rope_cache {
~ggml_cann_rope_cache() {
if(theta_scale_cache != nullptr) {
ACL_CHECK(aclrtFree(theta_scale_cache));
}
}
void* theta_scale_cache = nullptr;
int64_t theta_scale_length = 0;
float theta_scale = 0.0f;
float freq_scale = 0.0f;
};
struct ggml_cann_tensor_cache {
~ggml_cann_tensor_cache() {
if(cache != nullptr) {
ACL_CHECK(aclrtFree(cache));
}
}
void* cache = nullptr;
int64_t size = 0;
};
/** /**
* @brief Context for managing CANN backend operations. * @brief Context for managing CANN backend operations.
*/ */
@ -371,19 +395,15 @@ struct ggml_backend_cann_context {
#ifdef USE_ACL_GRAPH #ifdef USE_ACL_GRAPH
/// Cached CANN ACL graph used for executing the current ggml computation graph. /// Cached CANN ACL graph used for executing the current ggml computation graph.
std::unique_ptr<ggml_cann_graph> cann_graph; std::unique_ptr<ggml_cann_graph> cann_graph;
bool acl_graph_mode = true;
#endif #endif
cann_task_queue task_queue; cann_task_queue task_queue;
bool async_mode; bool async_mode;
// Rope Cache // Rope Cache
void* rope_init_ptr = nullptr; ggml_cann_rope_cache rope_cache;
void* rope_sin_ptr = nullptr;
void* rope_cos_ptr = nullptr;
int64_t max_prompt_length = 0;
// Constant Pool // Constant Pool
void* f32_zero_cache = nullptr; ggml_cann_tensor_cache rms_norm_one_tensor_cache;
void* f32_one_cache = nullptr; ggml_cann_tensor_cache rms_norm_zero_tensor_cache;
int64_t f32_zero_cache_element = 0;
int64_t f32_one_cache_element = 0;
aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */ aclrtStream streams[GGML_CANN_MAX_STREAMS] = {nullptr}; /**< Array of streams for the device. */
@ -399,6 +419,13 @@ struct ggml_backend_cann_context {
async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or("")); async_mode = parse_bool(get_env("GGML_CANN_ASYNC_MODE").value_or(""));
GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__, GGML_LOG_INFO("%s: device %d async operator submission is %s\n", __func__,
device, async_mode ? "ON" : "OFF"); device, async_mode ? "ON" : "OFF");
#ifdef USE_ACL_GRAPH
acl_graph_mode = parse_bool(get_env("GGML_CANN_ACL_GRAPH").value_or("on"));
GGML_LOG_INFO("%s: device %d execution mode is %s (%s)\n",
__func__, device,
acl_graph_mode ? "GRAPH" : "EAGER",
acl_graph_mode ? "acl graph enabled" : "acl graph disabled");
#endif
} }
/** /**
@ -415,21 +442,6 @@ struct ggml_backend_cann_context {
ACL_CHECK(aclrtDestroyStream(streams[i])); ACL_CHECK(aclrtDestroyStream(streams[i]));
} }
} }
if(rope_init_ptr != nullptr) {
ACL_CHECK(aclrtFree(rope_init_ptr));
}
if(rope_sin_ptr != nullptr) {
ACL_CHECK(aclrtFree(rope_sin_ptr));
}
if(rope_cos_ptr != nullptr) {
ACL_CHECK(aclrtFree(rope_cos_ptr));
}
if(f32_zero_cache != nullptr) {
ACL_CHECK(aclrtFree(f32_zero_cache));
}
if(f32_one_cache != nullptr) {
ACL_CHECK(aclrtFree(f32_one_cache));
}
} }
/** /**

View File

@ -1116,30 +1116,65 @@ static enum ggml_status ggml_backend_cann_buffer_init_tensor(
return GGML_STATUS_SUCCESS; return GGML_STATUS_SUCCESS;
} }
// ND to NZ Workspace Cache Management. Thread-safety: Not guaranteed /**
namespace { * @brief Workspace for caching NZ buffers per device.
void* g_nz_workspace = nullptr; *
size_t g_nz_workspace_allocated = 0; * This struct manages a device buffer used in NZ computations. It supports
* allocation, reallocation, and clearing of cached memory. The struct is
* designed to be used with a global array, one per device.
*/
struct ggml_cann_nz_workspace {
void* ptr; // Pointer to allocated device buffer
size_t allocated; // Size of currently allocated buffer in bytes
void release_nz_workspace() { /**
if (g_nz_workspace) { * @brief Constructor. Initializes the workspace with no allocated memory.
aclrtFree(g_nz_workspace); */
g_nz_workspace = nullptr; ggml_cann_nz_workspace() : ptr(nullptr), allocated(0) {}
g_nz_workspace_allocated = 0;
/**
* @brief Free cached memory and reset the workspace.
*
* If a buffer has been allocated, this function releases it using
* aclrtFree and resets internal state.
*/
void clear() {
if (ptr) {
ACL_CHECK(aclrtFree(ptr));
ptr = nullptr;
allocated = 0;
} }
} }
void relloc_nz_workspace(size_t new_size) { /**
if (new_size > g_nz_workspace_allocated) { * @brief Allocate or reallocate the workspace buffer.
if (g_nz_workspace) { *
aclrtFree(g_nz_workspace); * If the requested size is larger than the currently allocated size,
g_nz_workspace = nullptr; * the old buffer will be freed and a new buffer of the requested size
* will be allocated on the device.
*
* @param new_size Size in bytes to allocate for the workspace.
*/
void realloc(size_t new_size) {
if (new_size > allocated) {
clear();
ACL_CHECK(aclrtMalloc(&ptr, new_size, ACL_MEM_MALLOC_HUGE_FIRST));
allocated = new_size;
} }
ACL_CHECK(aclrtMalloc(&g_nz_workspace, new_size, ACL_MEM_MALLOC_HUGE_FIRST));
g_nz_workspace_allocated = new_size;
} }
}
} /**
* @brief Get the device buffer pointer.
*
* @return Pointer to the allocated buffer, or nullptr if not allocated.
*/
void* get() const { return ptr; }
};
/**
* @brief Global array of NZ workspaces, one per device.
*/
static ggml_cann_nz_workspace g_nz_workspaces[GGML_CANN_MAX_DEVICES];
/** /**
* @brief Convert tensor weights to NZ format using Ascend CANN API. * @brief Convert tensor weights to NZ format using Ascend CANN API.
@ -1149,13 +1184,13 @@ namespace {
* improve performance on certain hardware. * improve performance on certain hardware.
* *
* @param tensor Pointer to the input ggml_tensor containing the weights. * @param tensor Pointer to the input ggml_tensor containing the weights.
* @param data Pointer to the raw data buffer for the tensor weights.
* @param offset Byte offset within the tensor data buffer where weights start. * @param offset Byte offset within the tensor data buffer where weights start.
* @param device device id.
* *
* @note The workspace buffer used in this function is managed globally and reused * @note The workspace buffer used in this function is managed globally and reused
* across calls. This reduces overhead from repeated memory allocation and deallocation. * across calls. This reduces overhead from repeated memory allocation and deallocation.
*/ */
static void weight_format_to_nz(ggml_tensor *tensor, size_t offset) { static void weight_format_to_nz(ggml_tensor *tensor, size_t offset, int device) {
aclTensor* weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne, aclTensor* weightTransposed = ggml_cann_create_tensor(tensor, tensor->ne,
tensor->nb, 2, ACL_FORMAT_ND, offset); tensor->nb, 2, ACL_FORMAT_ND, offset);
uint64_t workspaceSize = 0; uint64_t workspaceSize = 0;
@ -1165,7 +1200,9 @@ static void weight_format_to_nz(ggml_tensor *tensor, size_t offset) {
ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed, ACL_CHECK(aclnnTransMatmulWeightGetWorkspaceSize(weightTransposed,
&workspaceSize, &executor)); &workspaceSize, &executor));
// Avoid frequent malloc/free of the workspace. // Avoid frequent malloc/free of the workspace.
relloc_nz_workspace(workspaceSize); g_nz_workspaces[device].realloc(workspaceSize);
void* g_nz_workspace = g_nz_workspaces[device].get();
ACL_CHECK(aclnnTransMatmulWeight(g_nz_workspace, workspaceSize, executor, nullptr)); ACL_CHECK(aclnnTransMatmulWeight(g_nz_workspace, workspaceSize, executor, nullptr));
ACL_CHECK(aclDestroyTensor(weightTransposed)); ACL_CHECK(aclDestroyTensor(weightTransposed));
@ -1196,14 +1233,14 @@ static void ggml_backend_cann_buffer_set_tensor(
// Why aclrtSynchronizeDevice? // Why aclrtSynchronizeDevice?
// Only check env once. // Only check env once.
static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("")); static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
if (!need_transform(tensor->type)) { if (!need_transform(tensor->type)) {
ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size, data, size, ACL_CHECK(aclrtMemcpy((char *)tensor->data + offset, size, data, size,
ACL_MEMCPY_HOST_TO_DEVICE)); ACL_MEMCPY_HOST_TO_DEVICE));
if (weight_to_nz && is_matmul_weight((const ggml_tensor*)tensor)) { if (weight_to_nz && is_matmul_weight((const ggml_tensor*)tensor)) {
GGML_ASSERT(tensor->ne[2] == 1); GGML_ASSERT(tensor->ne[2] == 1);
GGML_ASSERT(tensor->ne[3] == 1); GGML_ASSERT(tensor->ne[3] == 1);
weight_format_to_nz(tensor, offset); weight_format_to_nz(tensor, offset, ctx->device);
} }
} else { } else {
void *transform_buffer = malloc(size); void *transform_buffer = malloc(size);
@ -1279,6 +1316,10 @@ static bool ggml_backend_cann_buffer_cpy_tensor(
ACL_MEMCPY_DEVICE_TO_DEVICE)); ACL_MEMCPY_DEVICE_TO_DEVICE));
return true; return true;
} else { } else {
#ifdef ASCEND_310P
// TODO: Support 310p P2P copy
return false;
#endif
// Different device but can access by peer. // Different device but can access by peer.
int32_t canAccessPeer = 0; int32_t canAccessPeer = 0;
ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, src_ctx->device, ACL_CHECK(aclrtDeviceCanAccessPeer(&canAccessPeer, src_ctx->device,
@ -1439,7 +1480,7 @@ static size_t ggml_backend_cann_buffer_type_get_alloc_size(
int64_t ne0 = tensor->ne[0]; int64_t ne0 = tensor->ne[0];
// Only check env once. // Only check env once.
static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("")); static bool weight_to_nz = parse_bool(get_env("GGML_CANN_WEIGHT_NZ").value_or("on"));
// last line must bigger than 32, because every single op deal at // last line must bigger than 32, because every single op deal at
// least 32 bytes. // least 32 bytes.
@ -2000,6 +2041,8 @@ static bool ggml_backend_cann_cpy_tensor_async(
GGML_ASSERT(ggml_backend_is_cann(backend_src) || GGML_ASSERT(ggml_backend_is_cann(backend_src) ||
ggml_backend_is_cann(backend_dst)); ggml_backend_is_cann(backend_dst));
GGML_ASSERT(!is_matmul_weight((const ggml_tensor*)src));
if (!ggml_backend_buffer_is_cann(src->buffer) || if (!ggml_backend_buffer_is_cann(src->buffer) ||
!ggml_backend_buffer_is_cann(dst->buffer)) { !ggml_backend_buffer_is_cann(dst->buffer)) {
return false; return false;
@ -2020,6 +2063,10 @@ static bool ggml_backend_cann_cpy_tensor_async(
return true; return true;
} }
if (backend_src != backend_dst) { if (backend_src != backend_dst) {
#ifdef ASCEND_310P
// TODO: Support 310p P2P copy
return false;
#endif
ggml_backend_cann_buffer_context* buf_ctx_src = ggml_backend_cann_buffer_context* buf_ctx_src =
(ggml_backend_cann_buffer_context*)buf_src->context; (ggml_backend_cann_buffer_context*)buf_src->context;
ggml_backend_cann_buffer_context* buf_ctx_dst = ggml_backend_cann_buffer_context* buf_ctx_dst =
@ -2036,7 +2083,6 @@ static bool ggml_backend_cann_cpy_tensor_async(
} }
// need open both directions for memcpyasync between devices. // need open both directions for memcpyasync between devices.
ggml_cann_set_device(cann_ctx_dst->device);
ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_src->device, 0)); ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_src->device, 0));
ggml_cann_set_device(cann_ctx_src->device); ggml_cann_set_device(cann_ctx_src->device);
ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0)); ACL_CHECK(aclrtDeviceEnablePeerAccess(cann_ctx_dst->device, 0));
@ -2047,8 +2093,15 @@ static bool ggml_backend_cann_cpy_tensor_async(
ACL_MEMCPY_DEVICE_TO_DEVICE, ACL_MEMCPY_DEVICE_TO_DEVICE,
cann_ctx_src->stream())); cann_ctx_src->stream()));
//TODO: workaround for Event didn`t work here. // record event on src stream after the copy
aclrtSynchronizeStream(cann_ctx_src->stream()); if (!cann_ctx_src->copy_event) {
ACL_CHECK(aclrtCreateEventWithFlag(&cann_ctx_src->copy_event, ACL_EVENT_SYNC));
}
ACL_CHECK(aclrtRecordEvent(cann_ctx_src->copy_event, cann_ctx_src->stream()));
// wait on dst stream for the copy to complete
ggml_cann_set_device(cann_ctx_dst->device);
ACL_CHECK(aclrtStreamWaitEvent(cann_ctx_dst->stream(), cann_ctx_src->copy_event));
} else { } else {
// src and dst are on the same backend // src and dst are on the same backend
ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size, ACL_CHECK(aclrtMemcpyAsync(dst->data, copy_size, src->data, copy_size,
@ -2246,11 +2299,16 @@ static enum ggml_status ggml_backend_cann_graph_compute(
ggml_backend_cann_context* cann_ctx = ggml_backend_cann_context* cann_ctx =
(ggml_backend_cann_context*)backend->context; (ggml_backend_cann_context*)backend->context;
ggml_cann_set_device(cann_ctx->device); ggml_cann_set_device(cann_ctx->device);
release_nz_workspace(); g_nz_workspaces[cann_ctx->device].clear();
#ifdef USE_ACL_GRAPH #ifdef USE_ACL_GRAPH
bool use_cann_graph = true; bool use_cann_graph = true;
bool cann_graph_update_required = false; bool cann_graph_update_required = false;
if (!cann_ctx->acl_graph_mode) {
use_cann_graph = false;
}
if (use_cann_graph) { if (use_cann_graph) {
if (cann_ctx->cann_graph == nullptr) { if (cann_ctx->cann_graph == nullptr) {
cann_ctx->cann_graph.reset(new ggml_cann_graph()); cann_ctx->cann_graph.reset(new ggml_cann_graph());
@ -2400,16 +2458,10 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
} }
case GGML_OP_ROPE: { case GGML_OP_ROPE: {
// TODO: with ops-test v == 1 // TODO: with ops-test v == 1
float ext_factor = 0.0f;
memcpy(&ext_factor, (const float *) op->op_params + 7, sizeof(float));
// TODO: n_dims <= ne0 // TODO: n_dims <= ne0
if (op->src[0]->ne[0] != op->op_params[1]) { if (op->src[0]->ne[0] != op->op_params[1]) {
return false; return false;
} }
// TODO: ext_factor != 0
if (ext_factor != 0) {
return false;
}
const int mode = ((const int32_t *) op->op_params)[2]; const int mode = ((const int32_t *) op->op_params)[2];
if (mode & GGML_ROPE_TYPE_MROPE) { if (mode & GGML_ROPE_TYPE_MROPE) {
@ -2418,10 +2470,11 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
if (mode & GGML_ROPE_TYPE_VISION) { if (mode & GGML_ROPE_TYPE_VISION) {
return false; return false;
} }
#ifdef ASCEND_310P
if(!ggml_is_contiguous(op->src[0])){ if(!ggml_is_contiguous(op->src[0])){
return false; return false;
} }
#endif
return true; return true;
} }
case GGML_OP_UPSCALE: { case GGML_OP_UPSCALE: {
@ -2483,12 +2536,14 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
case GGML_OP_ARGMAX: case GGML_OP_ARGMAX:
case GGML_OP_COS: case GGML_OP_COS:
case GGML_OP_SIN: case GGML_OP_SIN:
case GGML_OP_CONV_TRANSPOSE_1D:
case GGML_OP_LOG: case GGML_OP_LOG:
case GGML_OP_MEAN: case GGML_OP_MEAN:
case GGML_OP_PAD_REFLECT_1D: case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_COUNT_EQUAL: case GGML_OP_COUNT_EQUAL:
return true; return true;
case GGML_OP_CONV_TRANSPOSE_1D:
// TODO: ((weightL - 1) * dilationW - padLeft)=1336 should not be larger than 255.
return (op->src[0]->ne[0] - 1) <= 255;
case GGML_OP_SCALE: case GGML_OP_SCALE:
float bias; float bias;
memcpy(&bias, (const float *)(op->op_params) + 1, sizeof(float)); memcpy(&bias, (const float *)(op->op_params) + 1, sizeof(float));
@ -2522,13 +2577,6 @@ static bool ggml_backend_cann_supports_op(ggml_backend_dev_t dev,
// different head sizes of K and V are not supported yet // different head sizes of K and V are not supported yet
return false; return false;
} }
if (op->src[0]->ne[0] == 192) {
return false;
}
if (op->src[0]->ne[0] == 576) {
// DeepSeek MLA
return false;
}
if (op->src[0]->ne[0] % 16 != 0) { if (op->src[0]->ne[0] % 16 != 0) {
// TODO: padding to support // TODO: padding to support
return false; return false;

View File

@ -433,15 +433,22 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
ggml-cpu/arch/riscv/quants.c ggml-cpu/arch/riscv/quants.c
ggml-cpu/arch/riscv/repack.cpp ggml-cpu/arch/riscv/repack.cpp
) )
if (GGML_RVV) set(MARCH_STR "rv64gc")
if (GGML_XTHEADVECTOR) if (GGML_RV_ZFH)
list(APPEND ARCH_FLAGS -march=rv64gc_zfhmin_xtheadvector -mabi=lp64d) string(APPEND MARCH_STR "_zfh")
elseif (GGML_RV_ZFH) endif()
list(APPEND ARCH_FLAGS -march=rv64gcv_zfhmin -mabi=lp64d) if (GGML_XTHEADVECTOR)
else() string(APPEND MARCH_STR "_xtheadvector")
list(APPEND ARCH_FLAGS -march=rv64gcv -mabi=lp64d) elseif (GGML_RVV)
string(APPEND MARCH_STR "_v")
if (GGML_RV_ZVFH)
string(APPEND MARCH_STR "_zvfh")
endif() endif()
endif() endif()
if (GGML_RV_ZICBOP)
string(APPEND MARCH_STR "_zicbop")
endif()
list(APPEND ARCH_FLAGS "-march=${MARCH_STR}" -mabi=lp64d)
elseif (GGML_SYSTEM_ARCH STREQUAL "s390x") elseif (GGML_SYSTEM_ARCH STREQUAL "s390x")
message(STATUS "s390x detected") message(STATUS "s390x detected")
list(APPEND GGML_CPU_SOURCES ggml-cpu/arch/s390/quants.c) list(APPEND GGML_CPU_SOURCES ggml-cpu/arch/s390/quants.c)
@ -450,7 +457,6 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
# TODO: Separation to determine activation of VX/VXE/VXE2 # TODO: Separation to determine activation of VX/VXE/VXE2
if (${S390X_M} MATCHES "8561|8562") if (${S390X_M} MATCHES "8561|8562")
set(GGML_NNPA OFF)
message(STATUS "z15 target") message(STATUS "z15 target")
list(APPEND ARCH_FLAGS -march=z15) list(APPEND ARCH_FLAGS -march=z15)
elseif (${S390X_M} MATCHES "3931") elseif (${S390X_M} MATCHES "3931")
@ -472,11 +478,6 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
list(APPEND ARCH_FLAGS -mvx -mzvector) list(APPEND ARCH_FLAGS -mvx -mzvector)
list(APPEND ARCH_DEFINITIONS GGML_VXE) list(APPEND ARCH_DEFINITIONS GGML_VXE)
endif() endif()
if (GGML_NNPA)
message(STATUS "NNPA enabled")
list(APPEND ARCH_DEFINITIONS GGML_NNPA)
endif()
elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "wasm") elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "wasm")
message(STATUS "Wasm detected") message(STATUS "Wasm detected")
list (APPEND GGML_CPU_SOURCES ggml-cpu/arch/wasm/quants.c) list (APPEND GGML_CPU_SOURCES ggml-cpu/arch/wasm/quants.c)
@ -497,9 +498,9 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
# Fetch KleidiAI sources: # Fetch KleidiAI sources:
include(FetchContent) include(FetchContent)
set(KLEIDIAI_COMMIT_TAG "v1.11.0") set(KLEIDIAI_COMMIT_TAG "v1.13.0")
set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz") set(KLEIDIAI_DOWNLOAD_URL "https://github.com/ARM-software/kleidiai/archive/refs/tags/${KLEIDIAI_COMMIT_TAG}.tar.gz")
set(KLEIDIAI_ARCHIVE_MD5 "3fe9e5ab964c375c53839296eb71eaa2") set(KLEIDIAI_ARCHIVE_MD5 "d82a8de939d9814621a5ba23907bdac1")
if (POLICY CMP0135) if (POLICY CMP0135)
cmake_policy(SET CMP0135 NEW) cmake_policy(SET CMP0135 NEW)
@ -555,6 +556,7 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
list(APPEND GGML_KLEIDIAI_SOURCES list(APPEND GGML_KLEIDIAI_SOURCES
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qsi8d32p_f32_neon.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c) ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0.c)
@ -576,7 +578,8 @@ function(ggml_add_cpu_backend_variant_impl tag_name)
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_pack_bf16p2vlx2_f32_sme.c
${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c) ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.c
${KLEIDIAI_SRC}/kai/kai_common_sme_asm.S)
set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2") set(PRIVATE_ARCH_FLAGS "-fno-tree-vectorize;${PRIVATE_ARCH_FLAGS}+sve+sve2")
endif() endif()

View File

@ -1270,29 +1270,40 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d); const float d = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].d);
const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin); const float dmin = y[i].d * GGML_CPU_FP16_TO_FP32(x[i].dmin);
int tmp, tmp2, sumi; float ftmp, ft2;
const uint8_t * restrict q40;
const uint8_t * restrict q41;
const uint8_t * restrict q42;
const uint8_t * restrict q43;
const int8_t * restrict q80;
const int8_t * restrict q81;
const int8_t * restrict q82;
const int8_t * restrict q83;
int s0, s1, s2, s3;
__asm__ __volatile__( __asm__ __volatile__(
"vsetivli zero, 12, e8, m1\n\t" "li %[s1], 8\n\t"
"vle8.v v1, (%[s6b])\n\t" // {aux[0], aux[1], aux[2]} "vsetivli zero, 4, e32, m1, ta, ma\n\t"
"vsetivli zero, 4, e32, m1\n\t" "vle32.v v1, (%[s6b])\n\t"
"vslide1down.vx v1, v1, zero\n\t"
"vmv.v.x v16, zero\n\t"
"vslidedown.vi v2, v1, 2\n\t" "vslidedown.vi v2, v1, 2\n\t"
"vmv1r.v v3, v2\n\t" "vmv1r.v v3, v2\n\t"
"vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]} "vslideup.vi v2, v3, 1\n\t" // {aux[2], aux[2]}
"vsetivli zero, 2, e32, m1\n\t" "vsetivli zero, 2, e32, m1, ta, ma\n\t"
"vmv.v.i v4, 4\n\t" "vmv.v.i v4, 4\n\t"
"vand.vx v8, v1, %[kmask1]\n\t" "vand.vx v8, v1, %[kmask1]\n\t"
"vslide1up.vx v5, v4, zero\n\t" // {0, 4} "vslide1up.vx v5, v4, zero\n\t" // {0, 4}
"vsrl.vi v6, v1, 6\n\t" "vsrl.vi v6, v1, 6\n\t"
"vsrl.vv v7, v2, v5\n\t" "vsrl.vv v7, v2, v5\n\t"
"vsse32.v v8, (%[utmp]), %[s1]\n\t"
"vand.vx v0, v6, %[kmask3]\n\t" "vand.vx v0, v6, %[kmask3]\n\t"
"vand.vx v2, v7, %[kmask2]\n\t" "vand.vx v2, v7, %[kmask2]\n\t"
"vsll.vi v6, v0, 4\n\t" "vsll.vi v6, v0, 4\n\t"
"li %[t2], 8\n\t" "addi %[s0], %[utmp], 4\n\t"
"addi %[t1], %[utmp], 4\n\t"
"vor.vv v1, v6, v2\n\t" "vor.vv v1, v6, v2\n\t"
"vsse32.v v8, (%[utmp]), %[t2]\n\t" "vsse32.v v1, (%[s0]), %[s1]\n\t"
"vsse32.v v1, (%[t1]), %[t2]\n\t" "vsetivli zero, 8, e16, m1, ta, ma\n\t"
"vsetivli zero, 8, e16, m1\n\t"
"vle32.v v2, (%[bsums])\n\t" "vle32.v v2, (%[bsums])\n\t"
"vnsrl.wi v0, v2, 0\n\t" "vnsrl.wi v0, v2, 0\n\t"
"vnsrl.wi v1, v2, 16\n\t" "vnsrl.wi v1, v2, 16\n\t"
@ -1300,13 +1311,131 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
"vle8.v v3, (%[mins])\n\t" "vle8.v v3, (%[mins])\n\t"
"vzext.vf2 v4, v3\n\t" "vzext.vf2 v4, v3\n\t"
"vwmul.vv v6, v4, v2\n\t" "vwmul.vv v6, v4, v2\n\t"
"vsetivli zero, 4, e32, m1, ta, ma\n\t"
"vredsum.vs v0, v6, v16\n\t"
"vredsum.vs v0, v7, v0\n\t"
"vfcvt.f.x.v v0, v0\n\t"
"vfmv.f.s %[ftmp], v0\n\t"
"vsetivli zero, 16, e8, m1, ta, ma\n\t"
"vle8.v v0, (%[xs])\n\t"
"fnmsub.s %[sumf], %[dmin], %[ftmp], %[sumf]\n\t"
"addi %[q40], %[xs], 64\n\t"
"addi %[q41], %[xs], 16\n\t"
"addi %[q42], %[xs], 32\n\t"
"addi %[q43], %[xs], 48\n\t"
"addi %[q80], %[ys], 64\n\t"
"vle8.v v1, (%[q41])\n\t"
"vle8.v v2, (%[q42])\n\t"
"addi %[q81], %[ys], 16\n\t"
"addi %[q41], %[q41], 64\n\t"
"addi %[q82], %[ys], 32\n\t"
"vle8.v v3, (%[q43])\n\t"
"vle8.v v8, (%[ys])\n\t"
"addi %[q42], %[q42], 64\n\t"
"addi %[q83], %[ys], 48\n\t"
"addi %[q43], %[q43], 64\n\t"
"vsrl.vi v4, v0, 4\n\t"
"vle8.v v9, (%[q81])\n\t"
"vle8.v v10, (%[q82])\n\t"
"vand.vi v0, v0, 0xF\n\t"
"addi %[q81], %[q81], 64\n\t"
"vsrl.vi v5, v1, 4\n\t"
"addi %[q82], %[q82], 64\n\t"
"vle8.v v11, (%[q83])\n\t"
"vle8.v v12, (%[q80])\n\t"
"vand.vi v1, v1, 0xF\n\t"
"addi %[q83], %[q83], 64\n\t"
"vsrl.vi v6, v2, 4\n\t"
"addi %[q80], %[q80], 64\n\t"
"vle8.v v13, (%[q81])\n\t"
"vle8.v v14, (%[q82])\n\t"
"vand.vi v2, v2, 0xF\n\t"
"addi %[q81], %[q81], 64\n\t"
"vsrl.vi v7, v3, 4\n\t"
"addi %[q82], %[q82], 64\n\t"
"vwmul.vv v16, v0, v8\n\t"
"vle8.v v15, (%[q83])\n\t"
"vle8.v v0, (%[q40])\n\t"
"vand.vi v3, v3, 0xF\n\t"
"addi %[q83], %[q83], 64\n\t"
"vwmul.vv v24, v2, v12\n\t"
"vwmul.vv v20, v4, v10\n\t"
"vwmul.vv v28, v6, v14\n\t"
"vwmacc.vv v16, v1, v9\n\t"
"vle8.v v1, (%[q41])\n\t"
"vle8.v v2, (%[q42])\n\t"
"vwmacc.vv v24, v3, v13\n\t"
"vwmacc.vv v20, v5, v11\n\t"
"vwmacc.vv v28, v7, v15\n\t"
"addi %[q40], %[q80], 64\n\t"
"addi %[q41], %[q81], 64\n\t"
"vle8.v v3, (%[q43])\n\t"
"vle8.v v8, (%[q80])\n\t"
"addi %[q42], %[q82], 64\n\t"
"addi %[q43], %[q83], 64\n\t"
"vsrl.vi v4, v0, 4\n\t"
"vle8.v v9, (%[q81])\n\t"
"vle8.v v10, (%[q82])\n\t"
"vand.vi v0, v0, 0xF\n\t"
"vsrl.vi v5, v1, 4\n\t"
"vsrl.vi v7, v3, 4\n\t"
"vand.vi v3, v3, 0xF\n\t"
"vle8.v v11, (%[q83])\n\t"
"vle8.v v12, (%[q40])\n\t"
"vand.vi v1, v1, 0xF\n\t"
"vsrl.vi v6, v2, 4\n\t"
"vand.vi v2, v2, 0xF\n\t"
"vwmul.vv v18, v0, v8\n\t"
"vle8.v v13, (%[q41])\n\t"
"vle8.v v14, (%[q42])\n\t"
"vwmul.vv v26, v2, v12\n\t"
"vwmul.vv v22, v4, v10\n\t"
"vwmul.vv v30, v6, v14\n\t"
"vwmacc.vv v18, v1, v9\n\t"
"vle8.v v15, (%[q43])\n\t"
"vwmacc.vv v26, v3, v13\n\t"
"vwmacc.vv v22, v5, v11\n\t"
"vwmacc.vv v30, v7, v15\n\t"
"vmv.v.x v0, zero\n\t" "vmv.v.x v0, zero\n\t"
"vsetivli zero, 8, e32, m2\n\t" "vsetivli zero, 16, e16, m2, ta, ma\n\t"
"vredsum.vs v0, v6, v0\n\t" "vwredsum.vs v4, v16, v0\n\t"
"vmv.x.s %[sumi], v0" "lbu %[s0], 0(%[scale])\n\t"
: [t1] "=&r" (tmp), [t2] "=&r" (tmp2), [sumi] "=&r" (sumi) "vwredsum.vs v5, v20, v0\n\t"
: [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp) "lbu %[s1], 1(%[scale])\n\t"
, [s6b] "r" (x[i].scales), [kmask1] "r" (kmask1) "vwredsum.vs v6, v24, v0\n\t"
"lbu %[s2], 2(%[scale])\n\t"
"vwredsum.vs v7, v28, v0\n\t"
"lbu %[s3], 3(%[scale])\n\t"
"vwredsum.vs v8, v18, v0\n\t"
"lbu %[q40], 4(%[scale])\n\t"
"vwredsum.vs v9, v22, v0\n\t"
"lbu %[q41], 5(%[scale])\n\t"
"vwredsum.vs v10, v26, v0\n\t"
"lbu %[q42], 6(%[scale])\n\t"
"vwredsum.vs v11, v30, v0\n\t"
"lbu %[q43], 7(%[scale])\n\t"
"vsetivli zero, 4, e32, m1, ta, ma\n\t"
"vmul.vx v0, v4, %[s0]\n\t"
"vmul.vx v1, v8, %[q40]\n\t"
"vmacc.vx v0, %[s1], v5\n\t"
"vmacc.vx v1, %[q41], v9\n\t"
"vmacc.vx v0, %[s2], v6\n\t"
"vmacc.vx v1, %[q42], v10\n\t"
"vmacc.vx v0, %[s3], v7\n\t"
"vmacc.vx v1, %[q43], v11\n\t"
"vfcvt.f.x.v v0, v0\n\t"
"vfcvt.f.x.v v1, v1\n\t"
"vfmv.f.s %[ft2], v0\n\t"
"vfmv.f.s %[ftmp], v1\n\t"
"fadd.s %[ft2], %[ft2], %[ftmp]\n\t"
"fmadd.s %[sumf], %[d], %[ft2], %[sumf]"
: [ftmp] "=&f" (ftmp), [sumf] "+&f" (sumf), [ft2] "=&f" (ft2)
, [s0] "=&r" (s0), [s1] "=&r" (s1), [s2] "=&r" (s2), [s3] "=&r" (s3)
, [q40] "=&r" (q40), [q41] "=&r" (q41), [q42] "=&r" (q42), [q43] "=&r" (q43)
, [q80] "=&r" (q80), [q81] "=&r" (q81), [q82] "=&r" (q82), [q83] "=&r" (q83)
: [d] "f" (d), [ys] "r" (y[i].qs), [xs] "r" (x[i].qs), [scale] "r" (scales)
, [bsums] "r" (y[i].bsums), [mins] "r" (mins), [utmp] "r" (utmp)
, [s6b] "r" (&x[i]), [kmask1] "r" (kmask1), [dmin] "f" (dmin)
, [kmask2] "r" (kmask2), [kmask3] "r" (kmask3) , [kmask2] "r" (kmask2), [kmask3] "r" (kmask3)
: "memory" : "memory"
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
@ -1314,59 +1443,6 @@ void ggml_vec_dot_q4_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
, "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
, "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
); );
sumf -= dmin * sumi;
const uint8_t * restrict q4 = x[i].qs;
const int8_t * restrict q8 = y[i].qs;
sumi = 0;
const uint8_t * scale = scales;
for (int j = 0; j < QK_K/128; ++j) {
int vl128 = 128, vl64 = 64, vl32 = 32;
__asm__ __volatile__(
"vsetvli zero, %[vl128], e8, m8\n\t"
"vle8.v v8, (%[q8])\n\t"
"vsetvli zero, %[vl64], e8, m4\n\t"
"vle8.v v0, (%[q4])\n\t"
"vsrl.vi v4, v0, 4\n\t"
"vand.vi v0, v0, 0xF\n\t"
"vsetvli zero, %[vl32], e8, m2\n\t"
"vwmul.vv v28, v6, v14\n\t"
"vwmul.vv v20, v4, v10\n\t"
"vwmul.vv v24, v2, v12\n\t"
"vwmul.vv v16, v0, v8\n\t"
"vsetivli zero, 4, e32, m1\n\t"
"vle8.v v2, (%[scale])\n\t"
"vmv.v.x v0, zero\n\t"
"vzext.vf4 v1, v2\n\t"
"vsetvli zero, %[vl32], e16, m4\n\t"
"vwredsum.vs v6, v24, v0\n\t"
"vwredsum.vs v7, v28, v0\n\t"
"vwredsum.vs v4, v16, v0\n\t"
"vwredsum.vs v5, v20, v0\n\t"
"vsetivli zero, 4, e32, m1\n\t"
"vslideup.vi v6, v7, 1\n\t"
"vslideup.vi v4, v5, 1\n\t"
"vslideup.vi v4, v6, 2\n\t"
"vmul.vv v8, v4, v1\n\t"
"vredsum.vs v0, v8, v0\n\t"
"vmv.x.s %[tmp], v0\n\t"
"add %[sumi], %[sumi], %[tmp]"
: [tmp] "=&r" (tmp), [sumi] "+&r" (sumi)
: [vl128] "r" (vl128), [vl64] "r" (vl64), [vl32] "r" (vl32)
, [q4] "r" (q4), [q8] "r" (q8), [scale] "r" (scale)
: "memory"
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
, "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
, "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
, "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
);
q4 += 64; q8 += 128; scale += 4;
}
sumf += d * sumi;
} }
break; break;
default: default:
@ -1693,6 +1769,8 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
case 128: case 128:
for (int i = 0; i < nb; ++i) { for (int i = 0; i < nb; ++i) {
__builtin_prefetch(&x[i + 1].d, 0, 1);
const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d; const float d = GGML_CPU_FP16_TO_FP32(x[i].d) * y[i].d;
const uint8_t * restrict q6 = x[i].ql; const uint8_t * restrict q6 = x[i].ql;
@ -1701,23 +1779,59 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
const int8_t * restrict scale = x[i].scales; const int8_t * restrict scale = x[i].scales;
int sum_t = 0; int q6h;
int t0; float ftmp;
for (int j = 0; j < QK_K/128; ++j) { for (int j = 0; j < QK_K/128; ++j) {
__asm__ __volatile__( __asm__ __volatile__(
"addi %[q6h], %[q6], 32\n\t"
"ld t0, 0(%[scale])\n\t"
"addi %[scale], %[scale], 8\n\t"
"slli t6, t0, 1 * 8\n\t"
"lb zero, 0(%[q6])\n\t"
"slli t5, t0, 2 * 8\n\t"
"slli t4, t0, 3 * 8\n\t"
"lb zero, 0(%[q6h])\n\t"
"slli t3, t0, 4 * 8\n\t"
"slli t2, t0, 5 * 8\n\t"
"lb zero, 0(%[qh])\n\t"
"lb zero, 31(%[q6h])\n\t"
"slli t1, t0, 6 * 8\n\t"
"srai a7, t0, 56\n\t"
"vsetvli zero, %[vl32], e8, m2\n\t" "vsetvli zero, %[vl32], e8, m2\n\t"
"vle8.v v8, (%[q6])\n\t"
"srai t6, t6, 56\n\t"
"srai t5, t5, 56\n\t"
"srai t4, t4, 56\n\t"
"srai t3, t3, 56\n\t"
"vle8.v v10, (%[q6h])\n\t"
"addi %[q6], %[q6], 64\n\t"
"slli t0, t0, 7 * 8\n\t"
"srai t2, t2, 56\n\t"
"srai t1, t1, 56\n\t"
"srai t0, t0, 56\n\t"
"vle8.v v4, (%[qh])\n\t" "vle8.v v4, (%[qh])\n\t"
"vsrl.vi v12, v8, 4\n\t"
"vsrl.vi v14, v10, 4\n\t"
"lb zero, 0(%[q8])\n\t"
"vand.vi v8, v8, 0xF\n\t"
"vand.vi v10, v10, 0xF\n\t"
"lb zero, 32(%[q8])\n\t"
"vsll.vi v0, v4, 4\n\t" "vsll.vi v0, v4, 4\n\t"
"vsll.vi v2, v4, 2\n\t" "vsll.vi v2, v4, 2\n\t"
"lb zero, 64(%[q8])\n\t"
"vsrl.vi v6, v4, 2\n\t" "vsrl.vi v6, v4, 2\n\t"
"vsetvli zero, %[vl64], e8, m4\n\t"
"vle8.v v8, (%[q6])\n\t"
"vsrl.vi v12, v8, 4\n\t"
"vand.vi v8, v8, 0xF\n\t"
"vsetvli zero, %[vl128], e8, m8\n\t"
"vand.vx v0, v0, %[mask]\n\t" "vand.vx v0, v0, %[mask]\n\t"
"lb zero, 96(%[q8])\n\t"
"vand.vx v2, v2, %[mask]\n\t"
"vand.vx v4, v4, %[mask]\n\t"
"vand.vx v6, v6, %[mask]\n\t"
"vor.vv v8, v8, v0\n\t" "vor.vv v8, v8, v0\n\t"
"lb zero, 127(%[q8])\n\t"
"vor.vv v10, v10, v2\n\t"
"vor.vv v12, v12, v4\n\t"
"vor.vv v14, v14, v6\n\t"
"vsetvli zero, %[vl128], e8, m8\n\t"
"vle8.v v0, (%[q8])\n\t" "vle8.v v0, (%[q8])\n\t"
"vsub.vx v8, v8, %[vl32]\n\t" "vsub.vx v8, v8, %[vl32]\n\t"
"vsetvli zero, %[vl64], e8, m4\n\t" "vsetvli zero, %[vl64], e8, m4\n\t"
@ -1734,34 +1848,34 @@ void ggml_vec_dot_q6_K_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const voi
"vwredsum.vs v13, v28, v0\n\t" "vwredsum.vs v13, v28, v0\n\t"
"vwredsum.vs v14, v30, v0\n\t" "vwredsum.vs v14, v30, v0\n\t"
"vsetivli zero, 4, e32, m1\n\t" "vsetivli zero, 4, e32, m1\n\t"
"vslideup.vi v10, v9, 1\n\t" "vmul.vx v0, v10, t0\n\t"
"vslideup.vi v8, v7, 1\n\t" "vmul.vx v1, v9, t1\n\t"
"vslideup.vi v11, v12, 1\n\t" "vmacc.vx v0, t2, v8\n\t"
"vslideup.vi v13, v14, 1\n\t" "vmacc.vx v1, t3, v7\n\t"
"vslideup.vi v10, v8, 2\n\t" "vmacc.vx v0, t4, v11\n\t"
"vslideup.vi v11, v13, 2\n\t" "vmacc.vx v1, t5, v12\n\t"
"vsetivli zero, 8, e32, m2\n\t" "vmacc.vx v0, t6, v13\n\t"
"vle8.v v2, (%[scale])\n\t" "vmacc.vx v1, a7, v14\n\t"
"vsext.vf4 v4, v2\n\t" "vadd.vv v0, v0, v1\n\t"
"vmul.vv v2, v4, v10\n\t" "vfcvt.f.x.v v0, v0\n\t"
"vredsum.vs v0, v2, v0\n\t" "vfmv.f.s %[ftmp], v0\n\t"
"vmv.x.s %[t0], v0\n\t" "fmadd.s %[sumf], %[d], %[ftmp], %[sumf]"
"add %[sumi], %[sumi], %[t0]" : [q6] "+&r" (q6), [q6h] "=&r" (q6h)
: [sumi] "+&r" (sum_t), [t0] "=&r" (t0) , [scale] "+&r" (scale)
: [qh] "r" (qh), [q6] "r" (q6), [q8] "r" (q8), [scale] "r" (scale) , [sumf] "+&f" (sumf), [ftmp] "=&f" (ftmp)
: [qh] "r" (qh), [q8] "r" (q8)
, [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128) , [vl32] "r" (32), [vl64] "r" (64), [vl128] "r" (128)
, [mask] "r" (0x30) , [mask] "r" (0x30), [d] "f" (d)
: "memory" : "memory"
, "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7" , "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7"
, "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15" , "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15"
, "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23" , "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23"
, "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31" , "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31"
, "t0", "t1", "t2", "t3", "t4", "t5", "t6", "a7"
, "a6", "a5", "a4", "a3"
); );
q6 += 64; qh += 32; q8 += 128; scale += 8; qh += 32; q8 += 128;
} }
sumf += d * sum_t;
} }
break; break;
default: default:

View File

@ -68,12 +68,6 @@ struct ggml_compute_params {
#endif // __VXE2__ #endif // __VXE2__
#endif // __s390x__ && __VEC__ #endif // __s390x__ && __VEC__
#if defined(__s390x__) && defined(GGML_NNPA)
#ifndef __NNPA__
#define __NNPA__
#endif // __NNPA__
#endif // __s390x__ && GGML_NNPA
#if defined(__ARM_FEATURE_SVE) #if defined(__ARM_FEATURE_SVE)
#include <sys/prctl.h> #include <sys/prctl.h>
#endif #endif

View File

@ -1876,6 +1876,10 @@ static void ggml_compute_forward(struct ggml_compute_params * params, struct ggm
{ {
ggml_compute_forward_im2col_back_f32(params, tensor); ggml_compute_forward_im2col_back_f32(params, tensor);
} break; } break;
case GGML_OP_IM2COL_3D:
{
ggml_compute_forward_im2col_3d(params, tensor);
} break;
case GGML_OP_CONV_2D: case GGML_OP_CONV_2D:
{ {
ggml_compute_forward_conv_2d(params, tensor); ggml_compute_forward_conv_2d(params, tensor);
@ -2255,6 +2259,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) {
} break; } break;
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
case GGML_OP_IM2COL_BACK: case GGML_OP_IM2COL_BACK:
case GGML_OP_IM2COL_3D:
case GGML_OP_CONV_2D: case GGML_OP_CONV_2D:
case GGML_OP_CONV_3D: case GGML_OP_CONV_3D:
case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_2D_DW:
@ -3206,20 +3211,12 @@ void ggml_cpu_fp32_to_fp16(const float * x, ggml_fp16_t * y, int64_t n) {
__m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT); __m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
_mm_storel_epi64((__m128i *)(y + i), y_vec); _mm_storel_epi64((__m128i *)(y + i), y_vec);
} }
#elif defined(__NNPA__) #elif defined(__riscv_zvfh)
for (; i + 7 < n; i += 8) { for (int vl; i < n; i += vl) {
float32x4_t v_xh = vec_xl(0, (const float *)(x + i + 0)); vl = __riscv_vsetvl_e32m2(n - i);
float32x4_t v_xl = vec_xl(0, (const float *)(x + i + 4)); vfloat32m2_t vx = __riscv_vle32_v_f32m2(&x[i], vl);
uint16x8_t v_yd = vec_round_from_fp32(v_xh, v_xl, 0); vfloat16m1_t vy = __riscv_vfncvt_f_f_w_f16m1(vx, vl);
uint16x8_t v_y = vec_convert_to_fp16(v_yd, 0); __riscv_vse16_v_f16m1((_Float16 *)&y[i], vy, vl);
vec_xst(v_y, 0, (ggml_fp16_t *)(y + i));
}
for (; i + 3 < n; i += 4) {
float32x4_t v_x = vec_xl(0, (const float *)(x + i));
float32x4_t v_zero = vec_splats(0.0f);
uint16x8_t v_yd = vec_round_from_fp32(v_x, v_zero, 0);
uint16x8_t v_y = vec_convert_to_fp16(v_yd, 0);
vec_xst(v_y, 0, (ggml_fp16_t *)(y + i));
} }
#endif #endif
for (; i < n; ++i) { for (; i < n; ++i) {
@ -3247,21 +3244,6 @@ void ggml_cpu_fp16_to_fp32(const ggml_fp16_t * x, float * y, int64_t n) {
__m128 y_vec = _mm_cvtph_ps(x_vec); __m128 y_vec = _mm_cvtph_ps(x_vec);
_mm_storeu_ps(y + i, y_vec); _mm_storeu_ps(y + i, y_vec);
} }
#elif defined(__NNPA__)
for (; i + 7 < n; i += 8) {
uint16x8_t v_x = vec_xl(0, (const ggml_fp16_t *)(x + i));
uint16x8_t v_yd = vec_convert_from_fp16(v_x, 0);
float32x4_t v_yh = vec_extend_to_fp32_hi(v_yd, 0);
float32x4_t v_yl = vec_extend_to_fp32_lo(v_yd, 0);
vec_xst(v_yh, 0, (float *)(y + i + 0));
vec_xst(v_yl, 0, (float *)(y + i + 4));
}
for (; i + 3 < n; i += 4) {
uint16x8_t v_x = vec_xl(0, (const ggml_fp16_t *)(x + i));
uint16x8_t v_yd = vec_convert_from_fp16(v_x, 0);
float32x4_t v_yh = vec_extend_to_fp32_hi(v_yd, 0);
vec_xst(v_yh, 0, (float *)(y + i));
}
#endif #endif
for (; i < n; ++i) { for (; i < n; ++i) {
@ -3465,14 +3447,6 @@ int ggml_cpu_has_vxe(void) {
#endif #endif
} }
int ggml_cpu_has_nnpa(void) {
#if defined(GGML_NNPA)
return 1;
#else
return 0;
#endif
}
int ggml_cpu_has_neon(void) { int ggml_cpu_has_neon(void) {
#if defined(__ARM_ARCH) && defined(__ARM_NEON) #if defined(__ARM_ARCH) && defined(__ARM_NEON)
return 1; return 1;

View File

@ -348,8 +348,10 @@ static void ggml_backend_cpu_device_get_memory(ggml_backend_dev_t dev, size_t *
long pages = sysconf(_SC_PHYS_PAGES); long pages = sysconf(_SC_PHYS_PAGES);
long page_size = sysconf(_SC_PAGE_SIZE); long page_size = sysconf(_SC_PAGE_SIZE);
*total = pages * page_size; *total = pages * page_size;
// "free" system memory is ill-defined, for practical purposes assume that all of it is free:
*free = *total; *free = *total;
#endif #endif // _WIN32
GGML_UNUSED(dev); GGML_UNUSED(dev);
} }
@ -576,9 +578,6 @@ static ggml_backend_feature * ggml_backend_cpu_get_features(ggml_backend_reg_t r
if (ggml_cpu_has_vxe()) { if (ggml_cpu_has_vxe()) {
features.push_back({ "VXE", "1" }); features.push_back({ "VXE", "1" });
} }
if (ggml_cpu_has_nnpa()) {
features.push_back({ "NNPA", "1" });
}
if (ggml_cpu_has_wasm_simd()) { if (ggml_cpu_has_wasm_simd()) {
features.push_back({ "WASM_SIMD", "1" }); features.push_back({ "WASM_SIMD", "1" });
} }

View File

@ -14,6 +14,7 @@
#include "kai_lhs_pack_bf16p2vlx2_f32_sme.h" #include "kai_lhs_pack_bf16p2vlx2_f32_sme.h"
#include "kai_lhs_quant_pack_qsi8d32p_f32.h" #include "kai_lhs_quant_pack_qsi8d32p_f32.h"
#include "kai_lhs_quant_pack_qsi8d32p4x8sb_f32_neon.h"
#include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h" #include "kai_lhs_quant_pack_qsi8d32p_f32_neon.h"
#include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h" #include "kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme.h"
@ -127,6 +128,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa, /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa,
}, },
/* .gemm_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32_neon,
},
/* SME GEMV */ /* SME GEMV */
/* .kern_info = */ { /* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
@ -141,7 +148,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot, /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot,
}, },
/* .lhs_info = */ { /* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon, /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32_neon,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon, /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32_neon,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon, /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32_neon,
@ -173,6 +180,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
}, },
/* .gemm_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme,
/* .pack_func = */ kai_run_lhs_pack_bf16p2vlx2_f32_sme,
},
/* SME GEMV */ /* SME GEMV */
/* .kern_info = */ { /* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
@ -187,7 +200,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
/* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa, /* .run_kernel = */ kai_run_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa,
}, },
/* .lhs_info = */ { /* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme, /* .get_offset = */ kai_get_lhs_offset_lhs_pack_bf16p2vlx2_f32_sme,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme, /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_pack_bf16p2vlx2_f32_sme,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme, /* .packed_size = */ kai_get_lhs_packed_size_lhs_pack_bf16p2vlx2_f32_sme,
@ -222,6 +235,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
}, },
/* .gemm_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
},
/* DOTPROD GEMV */ /* DOTPROD GEMV */
/* .kern_info = */ { /* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
@ -236,7 +255,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
}, },
/* .lhs_info = */ { /* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
@ -270,6 +289,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
}, },
/* .gemm_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
},
/* i8mm GEMV */ /* i8mm GEMV */
/* .kern_info = */ { /* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
@ -284,7 +309,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
}, },
/* .lhs_info = */ { /* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
@ -319,6 +344,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm, /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm,
}, },
/* .gemm_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p4x8sb_f32_neon,
},
/* i8mm GEMV */ /* i8mm GEMV */
/* .kern_info = */ { /* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
@ -333,7 +364,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod, /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod,
}, },
/* .lhs_info = */ { /* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
@ -367,6 +398,12 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod, /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod,
}, },
/* .gemm_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,
/* .pack_func = */ kai_run_lhs_quant_pack_qsi8d32p_f32,
},
/* DOTPROD GEMV */ /* DOTPROD GEMV */
/* .kern_info = */ { /* .kern_info = */ {
/* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_m_step = */ kai_get_m_step_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
@ -381,7 +418,7 @@ static ggml_kleidiai_kernels gemm_gemv_kernels[] = {
/* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .get_dst_size = */ kai_get_dst_size_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
/* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod, /* .run_kernel = */ kai_run_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod,
}, },
/* .lhs_info = */ { /* .gemv_lhs_info = */ {
/* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32, /* .get_offset = */ kai_get_lhs_offset_lhs_quant_pack_qsi8d32p_f32,
/* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32, /* .get_packed_offset = */ kai_get_lhs_packed_offset_lhs_quant_pack_qsi8d32p_f32,
/* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32, /* .packed_size = */ kai_get_lhs_packed_size_lhs_quant_pack_qsi8d32p_f32,

View File

@ -84,8 +84,11 @@ struct rhs_packing_info {
struct ggml_kleidiai_kernels { struct ggml_kleidiai_kernels {
kernel_info gemm; kernel_info gemm;
lhs_packing_info gemm_lhs_info;
kernel_info gemv; kernel_info gemv;
lhs_packing_info lhs_info; lhs_packing_info gemv_lhs_info;
rhs_packing_info rhs_info; rhs_packing_info rhs_info;
cpu_feature required_cpu; cpu_feature required_cpu;

View File

@ -123,7 +123,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
} }
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op); ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, op);
GGML_ASSERT(kernels); GGML_ASSERT(kernels);
kernel_info * kernel = op->src[1]->ne[1] == 1 ? &kernels->gemv : &kernels->gemm; bool is_gemv = op->src[1]->ne[1] == 1;
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
size_t k = op->src[0]->ne[0]; size_t k = op->src[0]->ne[0];
size_t n = op->src[0]->ne[1]; size_t n = op->src[0]->ne[1];
@ -134,9 +136,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
size_t sr = kernel->get_sr(); size_t sr = kernel->get_sr();
if (kernels->rhs_type == GGML_TYPE_Q4_0) { if (kernels->rhs_type == GGML_TYPE_Q4_0) {
size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, QK4_0, mr, kr, sr); size = variant_call<size_t>(lhs_info->packed_size, m, k, QK4_0, mr, kr, sr);
} else if (kernels->rhs_type == GGML_TYPE_F16) { } else if (kernels->rhs_type == GGML_TYPE_F16) {
size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, mr, kr, sr) + size = variant_call<size_t>(lhs_info->packed_size, m, k, mr, kr, sr) +
variant_call<size_t>(kernels->rhs_info.packed_size, n, k) + variant_call<size_t>(kernels->rhs_info.packed_size, n, k) +
k * n * sizeof(float) + n * sizeof(float); k * n * sizeof(float) + n * sizeof(float);
} else { } else {
@ -173,7 +175,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst); ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
GGML_ASSERT(kernels); GGML_ASSERT(kernels);
kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm; bool is_gemv = src1->ne[1] == 1;
kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
GGML_ASSERT(kernel); GGML_ASSERT(kernel);
const int nth = params->nth; const int nth = params->nth;
@ -198,7 +202,7 @@ class tensor_traits : public ggml::cpu::tensor_traits {
const int64_t kr = static_cast<int64_t>(kernel->get_kr()); const int64_t kr = static_cast<int64_t>(kernel->get_kr());
const int64_t sr = static_cast<int64_t>(kernel->get_sr()); const int64_t sr = static_cast<int64_t>(kernel->get_sr());
const size_t lhs_packed_size = variant_call<size_t>(kernels->lhs_info.packed_size, m, k, mr, kr, sr); const size_t lhs_packed_size = variant_call<size_t>(lhs_info->packed_size, m, k, mr, kr, sr);
const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, n, k); const size_t rhs_packed_size = variant_call<size_t>(kernels->rhs_info.packed_size, n, k);
const size_t kxn_size = k * n * sizeof(float); const size_t kxn_size = k * n * sizeof(float);
const size_t bias_size = n * sizeof(float); const size_t bias_size = n * sizeof(float);
@ -229,12 +233,12 @@ class tensor_traits : public ggml::cpu::tensor_traits {
const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0; const int64_t num_m_per_thread = (ith == num_threads - 1) ? num_m_per_threadN_1 : num_m_per_thread0;
const size_t lhs_offset = variant_call<size_t>(kernels->gemm.get_lhs_offset, m_start, lhs_stride); const size_t lhs_offset = variant_call<size_t>(kernels->gemm.get_lhs_offset, m_start, lhs_stride);
const size_t lhs_packed_offset = variant_call<size_t>(kernels->lhs_info.get_packed_offset, m_start, k, mr, kr, sr); const size_t lhs_packed_offset = variant_call<size_t>(lhs_info->get_packed_offset, m_start, k, mr, kr, sr);
const void * src_ptr = static_cast<const uint8_t *>(lhs_batch) + lhs_offset; const void * src_ptr = static_cast<const uint8_t *>(lhs_batch) + lhs_offset;
void * dst_ptr = static_cast<uint8_t *>(lhs_packed) + lhs_packed_offset; void * dst_ptr = static_cast<uint8_t *>(lhs_packed) + lhs_packed_offset;
variant_call<void>(kernels->lhs_info.pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr); variant_call<void>(lhs_info->pack_func, num_m_per_thread, k, mr, kr, sr, 0, src_ptr, lhs_stride, dst_ptr);
} }
} }
@ -306,8 +310,9 @@ class tensor_traits : public ggml::cpu::tensor_traits {
ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst); ggml_kleidiai_kernels *kernels = ggml_kleidiai_select_kernels(ctx.features, dst);
GGML_ASSERT(kernels); GGML_ASSERT(kernels);
kernel_info * kernel = src1->ne[1] == 1 ? &kernels->gemv : &kernels->gemm; bool is_gemv = src1->ne[1] == 1;
lhs_packing_info * lhs_info = &kernels->lhs_info; kernel_info * kernel = is_gemv ? &kernels->gemv : &kernels->gemm;
lhs_packing_info * lhs_info = is_gemv ? &kernels->gemv_lhs_info : &kernels->gemm_lhs_info;
GGML_ASSERT(kernel); GGML_ASSERT(kernel);

View File

@ -7027,6 +7027,209 @@ void ggml_compute_forward_im2col_back_f32(
} }
} }
// ggml_compute_forward_im2col_3d_f16
// src0: kernel [OC*IC, KD, KH, KW]
// src1: image [N*IC, ID, IH, IW]
// dst: result [N*OD, OH, OW, IC * KD * KH * KW]
static void ggml_compute_forward_im2col_3d_f16(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_ASSERT(src0->type == GGML_TYPE_F16);
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16);
GGML_TENSOR_BINARY_OP_LOCALS;
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
const int32_t IC = ((const int32_t *)(dst->op_params))[9];
const int ith = params->ith;
const int nth = params->nth;
const int64_t N = ne13 / IC;
const int64_t ID = ne12;
const int64_t IH = ne11;
const int64_t IW = ne10;
const int64_t OC = ne03 / IC;
GGML_UNUSED(OC);
const int64_t KD = ne02;
const int64_t KH = ne01;
const int64_t KW = ne00;
const int64_t OD = ne3 / N;
const int64_t OH = ne2;
const int64_t OW = ne1;
const int64_t OH_OW = OH*OW;
const int64_t KD_KH_KW = KD*KH*KW;
const int64_t KH_KW = KH*KW;
const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
GGML_ASSERT(nb10 == sizeof(float));
// im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
{
ggml_fp16_t * const wdata = (ggml_fp16_t *) dst->data;
for (int64_t in = 0; in < N; in++) {
for (int64_t iod = 0; iod < OD; iod++) {
for (int64_t ioh = 0; ioh < OH; ioh++) {
for (int64_t iow = 0; iow < OW; iow++) {
for (int64_t iic = ith; iic < IC; iic += nth) {
// micro kernel
ggml_fp16_t * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
for (int64_t ikd = 0; ikd < KD; ikd++) {
for (int64_t ikh = 0; ikh < KH; ikh++) {
for (int64_t ikw = 0; ikw < KW; ikw++) {
const int64_t iiw = iow*s0 + ikw*d0 - p0;
const int64_t iih = ioh*s1 + ikh*d1 - p1;
const int64_t iid = iod*s2 + ikd*d2 - p2;
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
} else {
const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = GGML_CPU_FP32_TO_FP16(*s);
}
}
}
}
}
}
}
}
}
}
}
// ggml_compute_forward_im2col_3d_f32
// src0: kernel [OC*IC, KD, KH, KW]
// src1: image [N*IC, ID, IH, IW]
// dst: result [N*OD, OH, OW, IC * KD * KH * KW]
static void ggml_compute_forward_im2col_3d_f32(
const ggml_compute_params * params,
ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_TENSOR_BINARY_OP_LOCALS;
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
const int32_t IC = ((const int32_t *)(dst->op_params))[9];
const int ith = params->ith;
const int nth = params->nth;
const int64_t N = ne13 / IC;
const int64_t ID = ne12;
const int64_t IH = ne11;
const int64_t IW = ne10;
const int64_t OC = ne03 / IC;
GGML_UNUSED(OC);
const int64_t KD = ne02;
const int64_t KH = ne01;
const int64_t KW = ne00;
const int64_t OD = ne3 / N;
const int64_t OH = ne2;
const int64_t OW = ne1;
const int64_t OH_OW = OH*OW;
const int64_t KD_KH_KW = KD*KH*KW;
const int64_t KH_KW = KH*KW;
const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
GGML_ASSERT(nb10 == sizeof(float));
// im2col: [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
{
float * const wdata = (float *) dst->data;
for (int64_t in = 0; in < N; in++) {
for (int64_t iod = 0; iod < OD; iod++) {
for (int64_t ioh = 0; ioh < OH; ioh++) {
for (int64_t iow = 0; iow < OW; iow++) {
for (int64_t iic = ith; iic < IC; iic += nth) {
// micro kernel
float * dst_data = wdata + (in*OD*OH_OW + iod*OH_OW + ioh*OW + iow)*IC_KD_KH_KW; // [IC, KD, KH, KW]
const float * const src_data = (const float *) ((const char *)src1->data + (in*IC + iic)*nb13); // [ID, IH, IW]
for (int64_t ikd = 0; ikd < KD; ikd++) {
for (int64_t ikh = 0; ikh < KH; ikh++) {
for (int64_t ikw = 0; ikw < KW; ikw++) {
const int64_t iiw = iow*s0 + ikw*d0 - p0;
const int64_t iih = ioh*s1 + ikh*d1 - p1;
const int64_t iid = iod*s2 + ikd*d2 - p2;
if (iid < 0 || iid >= ID || iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = 0;
} else {
const float * const s = (const float *) ((const char *)src_data + iid*nb12 + iih*nb11 + iiw*nb10); // [ID, IH, IW]
dst_data[iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw] = *s;
}
}
}
}
}
}
}
}
}
}
}
void ggml_compute_forward_im2col_3d(
const ggml_compute_params * params,
ggml_tensor * dst) {
switch (dst->type) {
case GGML_TYPE_F16:
{
ggml_compute_forward_im2col_3d_f16(params, dst);
} break;
case GGML_TYPE_F32:
{
ggml_compute_forward_im2col_3d_f32(params, dst);
} break;
default:
{
GGML_ABORT("fatal error");
}
}
}
static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k, static void ggml_call_mul_mat(ggml_type type, const ggml_compute_params * params, int64_t m, int64_t n, int64_t k,
void * a, void * b, float * c) { void * a, void * b, float * c) {
const ggml_type_traits * traits = ggml_get_type_traits(type); const ggml_type_traits * traits = ggml_get_type_traits(type);
@ -8014,6 +8217,15 @@ static void ggml_compute_forward_pad_f32(
GGML_TENSOR_UNARY_OP_LOCALS GGML_TENSOR_UNARY_OP_LOCALS
float * dst_ptr = (float *) dst->data; float * dst_ptr = (float *) dst->data;
const int32_t lp0 = ggml_get_op_params_i32(dst, 0);
const int32_t rp0 = ggml_get_op_params_i32(dst, 1);
const int32_t lp1 = ggml_get_op_params_i32(dst, 2);
const int32_t rp1 = ggml_get_op_params_i32(dst, 3);
const int32_t lp2 = ggml_get_op_params_i32(dst, 4);
const int32_t rp2 = ggml_get_op_params_i32(dst, 5);
const int32_t lp3 = ggml_get_op_params_i32(dst, 6);
const int32_t rp3 = ggml_get_op_params_i32(dst, 7);
// TODO: optimize // TODO: optimize
@ -8022,10 +8234,12 @@ static void ggml_compute_forward_pad_f32(
for (int64_t i0 = 0; i0 < ne0; ++i0) { for (int64_t i0 = 0; i0 < ne0; ++i0) {
for (int64_t i3 = 0; i3 < ne3; ++i3) { for (int64_t i3 = 0; i3 < ne3; ++i3) {
const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0; const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
if ((i0 >= lp0 && i0 < ne0 - rp0) \
const float * src_ptr = (const float *)((char *) src0->data + i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00); && (i1 >= lp1 && i1 < ne1 - rp1) \
&& (i2 >= lp2 && i2 < ne2 - rp2) \
if (i0 < ne00 && i1 < ne01 && i2 < ne02 && i3 < ne03) { && (i3 >= lp3 && i3 < ne3 - rp3)) {
const int64_t src_idx = (i3 - lp3)*nb03 + (i2 - lp2)*nb02 + (i1 - lp1)*nb01 + (i0 - lp0)*nb00;
const float * src_ptr = (const float *)((char *) src0->data + src_idx);
dst_ptr[dst_idx] = *src_ptr; dst_ptr[dst_idx] = *src_ptr;
} else { } else {
dst_ptr[dst_idx] = 0; dst_ptr[dst_idx] = 0;

View File

@ -69,6 +69,7 @@ void ggml_compute_forward_clamp(const struct ggml_compute_params * params, struc
void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_transpose_1d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_im2col(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_im2col_back_f32(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_im2col_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_conv_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_3d(const struct ggml_compute_params * params, struct ggml_tensor * dst);
void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_conv_transpose_2d(const struct ggml_compute_params * params, struct ggml_tensor * dst);

View File

@ -114,26 +114,6 @@ extern "C" {
#define GGML_CPU_COMPUTE_FP32_TO_FP16(x) riscv_compute_fp32_to_fp16(x) #define GGML_CPU_COMPUTE_FP32_TO_FP16(x) riscv_compute_fp32_to_fp16(x)
#define GGML_CPU_FP16_TO_FP32(x) GGML_CPU_COMPUTE_FP16_TO_FP32(x) #define GGML_CPU_FP16_TO_FP32(x) GGML_CPU_COMPUTE_FP16_TO_FP32(x)
#define GGML_CPU_FP32_TO_FP16(x) GGML_CPU_COMPUTE_FP32_TO_FP16(x) #define GGML_CPU_FP32_TO_FP16(x) GGML_CPU_COMPUTE_FP32_TO_FP16(x)
#elif defined(__NNPA__)
#define GGML_CPU_COMPUTE_FP16_TO_FP32(x) nnpa_compute_fp16_to_fp32(x)
#define GGML_CPU_COMPUTE_FP32_TO_FP16(x) nnpa_compute_fp32_to_fp16(x)
#define GGML_CPU_FP16_TO_FP32(x) GGML_CPU_COMPUTE_FP16_TO_FP32(x)
#define GGML_CPU_FP32_TO_FP16(x) GGML_CPU_COMPUTE_FP32_TO_FP16(x)
static inline float nnpa_compute_fp16_to_fp32(ggml_fp16_t h) {
uint16x8_t v_h = vec_splats(h);
uint16x8_t v_hd = vec_convert_from_fp16(v_h, 0);
return vec_extend_to_fp32_hi(v_hd, 0)[0];
}
static inline ggml_fp16_t nnpa_compute_fp32_to_fp16(float f) {
float32x4_t v_f = vec_splats(f);
float32x4_t v_zero = vec_splats(0.0f);
uint16x8_t v_hd = vec_round_from_fp32(v_f, v_zero, 0);
uint16x8_t v_h = vec_convert_to_fp16(v_hd, 0);
return vec_extract(v_h, 0);
}
#endif #endif
// precomputed f32 table for f16 (256 KB) // precomputed f32 table for f16 (256 KB)
@ -215,6 +195,47 @@ inline static float ggml_lookup_fp16_to_fp32(ggml_fp16_t f) {
#define GGML_F32_VEC_MUL GGML_F32xt_MUL #define GGML_F32_VEC_MUL GGML_F32xt_MUL
#define GGML_F32_VEC_REDUCE GGML_F32xt_REDUCE #define GGML_F32_VEC_REDUCE GGML_F32xt_REDUCE
// F16 SVE
#define DEFAULT_PG32 svptrue_b32()
#define DEFAULT_PG16 svptrue_b16()
#define GGML_F32Cxt svfloat16_t
#define GGML_F32Cxt_ZERO svdup_n_f16(0.0f)
#define GGML_F32Cxt_SET1(x) svdup_n_f16(x)
#define GGML_F32Cxt_LOAD(p) svld1_f16(DEFAULT_PG16, (const __fp16 *)(p))
#define GGML_F32Cxt_STORE(dst_ptr, src_vec) svst1_f16(DEFAULT_PG16, (__fp16 *)(dst_ptr), (src_vec))
#define GGML_F32Cxt_FMA_IMPL(pg, a, b, c) svmad_f16_x(pg, b, c, a)
#define GGML_F32Cxt_FMA(...) GGML_F32Cxt_FMA_IMPL(DEFAULT_PG16, __VA_ARGS__)
#define GGML_F32Cxt_ADD_IMPL(pg, a, b) svadd_f16_x(pg, a, b)
#define GGML_F32Cxt_ADD(...) GGML_F32Cxt_ADD_IMPL(DEFAULT_PG16, __VA_ARGS__)
#define GGML_F32Cxt_MUL_IMPL(pg, a, b) svmul_f16_x(pg, a, b)
#define GGML_F32Cxt_MUL(...) GGML_F32Cxt_MUL_IMPL(DEFAULT_PG16, __VA_ARGS__)
#define GGML_F32Cxt_REDUCE GGML_F16xt_REDUCE_MIXED
#define GGML_F16x_VEC GGML_F32Cxt
#define GGML_F16x_VEC_ZERO GGML_F32Cxt_ZERO
#define GGML_F16x_VEC_SET1 GGML_F32Cxt_SET1
#define GGML_F16x_VEC_LOAD(p, i) GGML_F32Cxt_LOAD(p)
#define GGML_F16x_VEC_STORE(p, r, i) GGML_F32Cxt_STORE((__fp16 *)(p), r)
#define GGML_F16x_VEC_FMA GGML_F32Cxt_FMA
#define GGML_F16x_VEC_ADD GGML_F32Cxt_ADD
#define GGML_F16x_VEC_MUL GGML_F32Cxt_MUL
#define GGML_F16x_VEC_REDUCE GGML_F32Cxt_REDUCE
#define GGML_F16xt_REDUCE_ONE_IMPL(pg, a) svaddv_f16(pg, a)
#define GGML_F16xt_REDUCE_ONE(...) GGML_F16xt_REDUCE_ONE_IMPL(DEFAULT_PG16, __VA_ARGS__)
#define GGML_F16xt_REDUCE_MIXED_IMPL(pg16, res, sum1, sum2, sum3, sum4) \
{ \
sum1 = svadd_f16_x(pg16, sum1, sum2); \
sum3 = svadd_f16_x(pg16, sum3, sum4); \
sum1 = svadd_f16_x(pg16, sum1, sum3); \
__fp16 sum_f16 = svaddv_f16(pg16, sum1); \
(res) = (ggml_float) sum_f16; \
}
#define GGML_F16xt_REDUCE_MIXED(...) GGML_F16xt_REDUCE_MIXED_IMPL(DEFAULT_PG16, __VA_ARGS__)
// F16 NEON // F16 NEON
#if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC) #if defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
@ -1115,11 +1136,6 @@ static inline void __lsx_f16x4_store(ggml_fp16_t * x, __m128 y) {
#define GGML_F16_EPR GGML_F32_EPR #define GGML_F16_EPR GGML_F32_EPR
static inline float32x4_t __lzs_f16cx4_load(const ggml_fp16_t * x) { static inline float32x4_t __lzs_f16cx4_load(const ggml_fp16_t * x) {
#if defined(__NNPA__)
uint16x8_t v_x = vec_xl(0, (const ggml_fp16_t *)x);
uint16x8_t v_xd = vec_convert_from_fp16(v_x, 0);
return vec_extend_to_fp32_hi(v_xd, 0);
#else
float tmp[4]; float tmp[4];
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
@ -1129,20 +1145,9 @@ static inline float32x4_t __lzs_f16cx4_load(const ggml_fp16_t * x) {
// note: keep type-cast here to prevent compiler bugs // note: keep type-cast here to prevent compiler bugs
// see: https://github.com/ggml-org/llama.cpp/issues/12846 // see: https://github.com/ggml-org/llama.cpp/issues/12846
return vec_xl(0, (const float *)(tmp)); return vec_xl(0, (const float *)(tmp));
#endif
} }
static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) { static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) {
#if defined(__NNPA__)
float32x4_t v_zero = vec_splats(0.0f);
uint16x8_t v_xd = vec_round_from_fp32(v_y, v_zero, 0);
uint16x8_t v_x = vec_convert_to_fp16(v_xd, 0);
x[0] = vec_extract(v_x, 0);
x[1] = vec_extract(v_x, 1);
x[2] = vec_extract(v_x, 2);
x[3] = vec_extract(v_x, 3);
#else
float arr[4]; float arr[4];
// note: keep type-cast here to prevent compiler bugs // note: keep type-cast here to prevent compiler bugs
@ -1152,7 +1157,6 @@ static inline void __lzs_f16cx4_store(ggml_fp16_t * x, float32x4_t v_y) {
for (int i = 0; i < 4; i++) { for (int i = 0; i < 4; i++) {
x[i] = GGML_CPU_FP32_TO_FP16(arr[i]); x[i] = GGML_CPU_FP32_TO_FP16(arr[i]);
} }
#endif
} }
#define GGML_F16_VEC GGML_F32x4 #define GGML_F16_VEC GGML_F32x4

View File

@ -85,15 +85,21 @@ void ggml_vec_dot_f32(int n, float * GGML_RESTRICT s, size_t bs, const float * G
// reduce sum1,sum2 to sum1 // reduce sum1,sum2 to sum1
GGML_F32_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8); GGML_F32_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4, sum5, sum6, sum7, sum8);
#elif defined(__riscv_v_intrinsic) #elif defined(__riscv_v_intrinsic)
vfloat32m1_t vsum = __riscv_vfmv_v_f_f32m1(0.0f, 1); int vl = __riscv_vsetvlmax_e32m8();
for (int i = 0, avl; i < n; i += avl) { vfloat32m1_t vs = __riscv_vfmv_v_f_f32m1(0.0f, 1);
avl = __riscv_vsetvl_e32m8(n - i); vfloat32m8_t vsum;
vfloat32m8_t ax = __riscv_vle32_v_f32m8(&x[i], avl); vfloat32m8_t ax;
vfloat32m8_t ay = __riscv_vle32_v_f32m8(&y[i], avl); vfloat32m8_t ay;
vfloat32m8_t prod = __riscv_vfmul_vv_f32m8(ax, ay, avl); vsum = __riscv_vfmv_v_f_f32m8_tu(vsum, 0.0f, vl);
vsum = __riscv_vfredusum_vs_f32m8_f32m1(prod, vsum, avl); for (int i = 0; i < n; i += vl) {
vl = __riscv_vsetvl_e32m8(n - i);
ax = __riscv_vle32_v_f32m8_tu(ax, &x[i], vl);
ay = __riscv_vle32_v_f32m8_tu(ay, &y[i], vl);
vsum = __riscv_vfmacc_vv_f32m8_tu(vsum, ax, ay, vl);
} }
sumf += __riscv_vfmv_f_s_f32m1_f32(vsum); vl = __riscv_vsetvlmax_e32m8();
vs = __riscv_vfredusum_vs_f32m8_f32m1(vsum, vs, vl);
sumf += __riscv_vfmv_f_s_f32m1_f32(vs);
#else #else
const int np = (n & ~(GGML_F32_STEP - 1)); const int np = (n & ~(GGML_F32_STEP - 1));
@ -207,38 +213,125 @@ void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * G
ggml_float sumf = 0.0; ggml_float sumf = 0.0;
#if defined(GGML_SIMD) && !defined(__riscv_v_intrinsic)
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO }; #if defined(GGML_SIMD)
#if defined(__ARM_FEATURE_SVE)
const int sve_register_length = svcntb() * 8; //get vector length
const int ggml_f16_epr = sve_register_length / 16; // running when 16
const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers
GGML_F16_VEC ax[GGML_F16_ARR]; const int np= (n & ~(ggml_f16_step - 1));
GGML_F16_VEC ay[GGML_F16_ARR]; svfloat16_t sum1 = svdup_n_f16(0.0f);
svfloat16_t sum2 = svdup_n_f16(0.0f);
svfloat16_t sum3 = svdup_n_f16(0.0f);
svfloat16_t sum4 = svdup_n_f16(0.0f);
for (int i = 0; i < np; i += GGML_F16_STEP) { svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
for (int j = 0; j < GGML_F16_ARR; j++) { svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); for (int i = 0; i < np; i += ggml_f16_step) {
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
sum1 = GGML_F16x_VEC_FMA(sum1, ax1, ay1);
sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]); ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
sum2 = GGML_F16x_VEC_FMA(sum2, ax2, ay2);
ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
sum3 = GGML_F16x_VEC_FMA(sum3, ax3, ay3);
ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
sum4 = GGML_F16x_VEC_FMA(sum4, ax4, ay4);
ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
sum1 = GGML_F16x_VEC_FMA(sum1, ax5, ay5);
ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
sum2 = GGML_F16x_VEC_FMA(sum2, ax6, ay6);
ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
sum3 = GGML_F16x_VEC_FMA(sum3, ax7, ay7);
ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
sum4 = GGML_F16x_VEC_FMA(sum4, ax8, ay8);
} }
}
// reduce sum0..sum3 to sum0 const int np2 = (n & ~(ggml_f16_epr - 1)); // round down to multiple of 8
GGML_F16_VEC_REDUCE(sumf, sum); for (int k = np; k < np2; k += ggml_f16_epr) {
svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
sum1 = GGML_F16x_VEC_FMA(sum1, rx, ry);
}
// leftovers if (np2 < n) {
for (int i = np; i < n; ++i) { svbool_t pg = svwhilelt_b16(np2, n);
sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i])); svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
} svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
// if you hit this, you are likely running outside the FP range sum1 = svmad_f16_x(pg, hx, hy, sum1);
assert(!isnan(sumf) && !isinf(sumf)); }
GGML_F16x_VEC_REDUCE(sumf, sum1, sum2, sum3, sum4);
#elif defined(__riscv_v_intrinsic)
#if defined(__riscv_zvfh)
int vl = __riscv_vsetvlmax_e32m2();
vfloat32m1_t vs = __riscv_vfmv_v_f_f32m1(0.0f, 1);
vfloat32m2_t vsum;
vfloat16m1_t ax;
vfloat16m1_t ay;
vsum = __riscv_vreinterpret_v_u32m2_f32m2(__riscv_vmv_v_x_u32m2(0, vl));
for (int i = 0; i < n; i += vl) {
vl = __riscv_vsetvl_e16m1(n - i);
ax = __riscv_vle16_v_f16m1_tu(ax, (const _Float16 *)&x[i], vl);
ay = __riscv_vle16_v_f16m1_tu(ay, (const _Float16 *)&y[i], vl);
vsum = __riscv_vfwmacc_vv_f32m2_tu(vsum, ax, ay, vl);
}
vl = __riscv_vsetvlmax_e32m1();
vfloat32m1_t ac0 = __riscv_vfadd_vv_f32m1(__riscv_vget_v_f32m2_f32m1(vsum, 0), __riscv_vget_v_f32m2_f32m1(vsum, 1), vl);
vs = __riscv_vfredusum_vs_f32m1_f32m1(ac0, vs, vl);
sumf += __riscv_vfmv_f_s_f32m1_f32(vs);
#else
for (int i = 0; i < n; ++i) {
sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
}
#endif // __riscv_zvfh
#else
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC sum[GGML_F16_ARR] = { GGML_F16_VEC_ZERO };
GGML_F16_VEC ax[GGML_F16_ARR];
GGML_F16_VEC ay[GGML_F16_ARR];
for (int i = 0; i < np; i += GGML_F16_STEP) {
for (int j = 0; j < GGML_F16_ARR; j++) {
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
sum[j] = GGML_F16_VEC_FMA(sum[j], ax[j], ay[j]);
}
}
// reduce sum0..sum3 to sum0
GGML_F16_VEC_REDUCE(sumf, sum);
// leftovers
for (int i = np; i < n; ++i) {
sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
}
// if you hit this, you are likely running outside the FP range
assert(!isnan(sumf) && !isinf(sumf));
#endif
#else #else
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i])); sumf += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[i])*GGML_CPU_FP16_TO_FP32(y[i]));
} }
#endif #endif // GGML_SIMD
*s = sumf; *s = sumf;
} }
@ -257,6 +350,12 @@ void ggml_vec_silu_f32(const int n, float * y, const float * x) {
for (; i + 3 < n; i += 4) { for (; i + 3 < n; i += 4) {
_mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i))); _mm_storeu_ps(y + i, ggml_v_silu(_mm_loadu_ps(x + i)));
} }
#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
const int vlen = svcntw();
for (; i < n; i += vlen) {
const svbool_t pg = svwhilelt_b32_s32(i, n);
svst1_f32(pg, y + i, ggml_v_silu(pg, svld1_f32(pg, x + i)));
}
#elif defined(__ARM_NEON) && defined(__aarch64__) #elif defined(__ARM_NEON) && defined(__aarch64__)
for (; i + 3 < n; i += 4) { for (; i + 3 < n; i += 4) {
vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i))); vst1q_f32(y + i, ggml_v_silu(vld1q_f32(x + i)));
@ -281,10 +380,24 @@ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float *
for (; i + 3 < n; i += 4) { for (; i + 3 < n; i += 4) {
_mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(g + i))); _mm_storeu_ps(y + i, _mm_mul_ps(ggml_v_silu(_mm_loadu_ps(x + i)), _mm_loadu_ps(g + i)));
} }
#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
const int vlen = svcntw();
for (; i < n; i += vlen) {
const svbool_t pg = svwhilelt_b32_s32(i, n);
svst1_f32(pg, y + i, svmul_f32_x(pg, ggml_v_silu(pg, svld1_f32(pg, x + i)), svld1_f32(pg, g + i)));
}
#elif defined(__ARM_NEON) && defined(__aarch64__) #elif defined(__ARM_NEON) && defined(__aarch64__)
for (; i + 3 < n; i += 4) { for (; i + 3 < n; i += 4) {
vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(g + i))); vst1q_f32(y + i, vmulq_f32(ggml_v_silu(vld1q_f32(x + i)), vld1q_f32(g + i)));
} }
#elif defined(__riscv_v_intrinsic)
for (int vl; i < n; i += vl) {
vl = __riscv_vsetvl_e32m2(n - i);
vfloat32m2_t vx = __riscv_vle32_v_f32m2(&x[i], vl);
vfloat32m2_t vg = __riscv_vle32_v_f32m2(&g[i], vl);
vfloat32m2_t vy = __riscv_vfmul_vv_f32m2(ggml_v_silu_m2(vx, vl), vg, vl);
__riscv_vse32_v_f32m2(&y[i], vy, vl);
}
#endif #endif
for (; i < n; ++i) { for (; i < n; ++i) {
y[i] = ggml_silu_f32(x[i]) * g[i]; y[i] = ggml_silu_f32(x[i]) * g[i];
@ -328,6 +441,15 @@ ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float
#endif #endif
sum += (ggml_float)_mm_cvtss_f32(val); sum += (ggml_float)_mm_cvtss_f32(val);
} }
#elif defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
const int vlen = svcntw();
for (; i < n; i += vlen) {
const svbool_t pg = svwhilelt_b32_s32(i, n);
svfloat32_t val = ggml_v_expf(pg, svsub_f32_x(pg, svld1_f32(pg, x + i),
svdup_n_f32_x(pg, max)));
svst1_f32(pg, y + i, val);
sum += (ggml_float)svaddv_f32(pg, val);
}
#elif defined(__ARM_NEON) && defined(__aarch64__) #elif defined(__ARM_NEON) && defined(__aarch64__)
for (; i + 3 < n; i += 4) { for (; i + 3 < n; i += 4) {
float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i), float32x4_t val = ggml_v_expf(vsubq_f32(vld1q_f32(x + i),

View File

@ -119,45 +119,149 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG
} }
#if defined(GGML_SIMD) #if defined(GGML_SIMD)
#if defined(__riscv_v_intrinsic) #if defined(__ARM_FEATURE_SVE)
// todo: RVV impl
for (int i = 0; i < n; ++i) { const int sve_register_length = svcntb() * 8;
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { const int ggml_f16_epr = sve_register_length / 16; // running when 16
sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i])); const int ggml_f16_step = 8 * ggml_f16_epr; // choose 8 SVE registers
const int np = (n & ~(ggml_f16_step - 1));
svfloat16_t sum_00 = svdup_n_f16(0.0f);
svfloat16_t sum_01 = svdup_n_f16(0.0f);
svfloat16_t sum_02 = svdup_n_f16(0.0f);
svfloat16_t sum_03 = svdup_n_f16(0.0f);
svfloat16_t sum_10 = svdup_n_f16(0.0f);
svfloat16_t sum_11 = svdup_n_f16(0.0f);
svfloat16_t sum_12 = svdup_n_f16(0.0f);
svfloat16_t sum_13 = svdup_n_f16(0.0f);
svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
for (int i = 0; i < np; i += ggml_f16_step) {
ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); // 8 elements
ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elemnst
sum_00 = GGML_F16x_VEC_FMA(sum_00, ax1, ay1); // sum_00 = sum_00+ax1*ay1
ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 0*ggml_f16_epr, 0); // 8 elements
sum_10 = GGML_F16x_VEC_FMA(sum_10, ax1, ay1);
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); // next 8 elements
ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 ekements
sum_01 = GGML_F16x_VEC_FMA(sum_01, ax2, ay2);
ax2 = GGML_F16x_VEC_LOAD(x[1] + i + 1*ggml_f16_epr, 1);
sum_11 = GGML_F16x_VEC_FMA(sum_11, ax2, ay2);
ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
ax3 = GGML_F16x_VEC_LOAD(x[0] + i + 2*ggml_f16_epr, 2);
sum_02 = GGML_F16x_VEC_FMA(sum_02, ax3, ay3);
ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2);
sum_12 = GGML_F16x_VEC_FMA(sum_12, ax3, ay3);
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
ax4 = GGML_F16x_VEC_LOAD(x[0] + i + 3*ggml_f16_epr, 3);
sum_03 = GGML_F16x_VEC_FMA(sum_03, ax4, ay4);
ax4 = GGML_F16x_VEC_LOAD(x[1] + i + 3*ggml_f16_epr, 3);
sum_13 = GGML_F16x_VEC_FMA(sum_13, ax4, ay4);
ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
ax5 = GGML_F16x_VEC_LOAD(x[0] + i + 4*ggml_f16_epr, 4);
sum_00 = GGML_F16x_VEC_FMA(sum_00, ax5, ay5);
ax5 = GGML_F16x_VEC_LOAD(x[1] + i + 4*ggml_f16_epr, 4);
sum_10 = GGML_F16x_VEC_FMA(sum_10, ax5, ay5);
ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
ax6 = GGML_F16x_VEC_LOAD(x[0] + i + 5*ggml_f16_epr, 5);
sum_01 = GGML_F16x_VEC_FMA(sum_01, ax6, ay6);
ax6 = GGML_F16x_VEC_LOAD(x[1] + i + 5*ggml_f16_epr, 5);
sum_11 = GGML_F16x_VEC_FMA(sum_11, ax6, ay6);
ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
ax7 = GGML_F16x_VEC_LOAD(x[0] + i + 6*ggml_f16_epr, 6);
sum_02 = GGML_F16x_VEC_FMA(sum_02, ax7, ay7);
ax7 = GGML_F16x_VEC_LOAD(x[1] + i + 6*ggml_f16_epr, 6);
sum_12 = GGML_F16x_VEC_FMA(sum_12, ax7, ay7);
ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
ax8 = GGML_F16x_VEC_LOAD(x[0] + i + 7*ggml_f16_epr, 7);
sum_03 = GGML_F16x_VEC_FMA(sum_03, ax8, ay8);
ax8 = GGML_F16x_VEC_LOAD(x[1] + i + 7*ggml_f16_epr, 7);
sum_13 = GGML_F16x_VEC_FMA(sum_13, ax8, ay8);
} }
}
#else
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } }; const int np2 = (n & ~(ggml_f16_epr - 1));
for (int k = np; k < np2; k += ggml_f16_epr) {
svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
GGML_F16_VEC ax[GGML_F16_ARR]; svfloat16_t rx = GGML_F16x_VEC_LOAD(x[0] + k, 0);
GGML_F16_VEC ay[GGML_F16_ARR]; sum_00 = GGML_F16x_VEC_FMA(sum_00, rx, ry);
rx = GGML_F16x_VEC_LOAD(x[1] + k, 0);
sum_10 = GGML_F16x_VEC_FMA(sum_10, rx, ry);
}
for (int i = 0; i < np; i += GGML_F16_STEP) { if (np2 < n) {
for (int j = 0; j < GGML_F16_ARR; j++) { svbool_t pg = svwhilelt_b16(np2, n);
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); svfloat16_t hx_0 = svld1_f16(pg, (const __fp16 *)(x[0] + np2));
svfloat16_t hx_1 = svld1_f16(pg, (const __fp16 *)(x[1] + np2));
svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { sum_00 = svmad_f16_x(pg, hx_0, hy, sum_00);
ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j); sum_10 = svmad_f16_x(pg, hx_1, hy, sum_10);
}
GGML_F16x_VEC_REDUCE(sumf[0], sum_00, sum_01, sum_02, sum_03);
GGML_F16x_VEC_REDUCE(sumf[1], sum_10, sum_11, sum_12, sum_13);
#elif defined(__riscv_v_intrinsic)
// todo: RVV impl
for (int i = 0; i < n; ++i) {
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
}
}
#else
const int np = (n & ~(GGML_F16_STEP - 1));
sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]); GGML_F16_VEC sum[GGML_VEC_DOT_UNROLL][GGML_F16_ARR] = { { GGML_F16_VEC_ZERO } };
GGML_F16_VEC ax[GGML_F16_ARR];
GGML_F16_VEC ay[GGML_F16_ARR];
for (int i = 0; i < np; i += GGML_F16_STEP) {
for (int j = 0; j < GGML_F16_ARR; j++) {
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
ax[j] = GGML_F16_VEC_LOAD(x[k] + i + j*GGML_F16_EPR, j);
sum[k][j] = GGML_F16_VEC_FMA(sum[k][j], ax[j], ay[j]);
}
} }
} }
}
// reduce sum0..sum3 to sum0 // reduce sum0..sum3 to sum0
for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) { for (int k = 0; k < GGML_VEC_DOT_UNROLL; ++k) {
GGML_F16_VEC_REDUCE(sumf[k], sum[k]); GGML_F16_VEC_REDUCE(sumf[k], sum[k]);
}
// leftovers
for (int i = np; i < n; ++i) {
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
} }
}
#endif // leftovers
for (int i = np; i < n; ++i) {
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
sumf[j] += (ggml_float)(GGML_CPU_FP16_TO_FP32(x[j][i])*GGML_CPU_FP16_TO_FP32(y[i]));
}
}
#endif
#else #else
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) { for (int j = 0; j < GGML_VEC_DOT_UNROLL; ++j) {
@ -293,35 +397,112 @@ inline static void ggml_vec_mad_f32(const int n, float * GGML_RESTRICT y, const
inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) { inline static void ggml_vec_mad_f16(const int n, ggml_fp16_t * GGML_RESTRICT y, const ggml_fp16_t * GGML_RESTRICT x, const float v) {
#if defined(GGML_SIMD) #if defined(GGML_SIMD)
#if defined(__riscv_v_intrinsic) #if defined(__ARM_FEATURE_SVE)
// todo: RVV impl const int sve_register_length = svcntb() * 8;
// scalar const int ggml_f16_epr = sve_register_length / 16;
for (int i = 0; i < n; ++i) { const int ggml_f16_step = 8 * ggml_f16_epr;
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
}
#else
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
GGML_F16_VEC ax[GGML_F16_ARR]; const int np= (n & ~(ggml_f16_step - 1));
GGML_F16_VEC ay[GGML_F16_ARR];
for (int i = 0; i < np; i += GGML_F16_STEP) { svfloat16_t ax1, ax2, ax3, ax4, ax5, ax6, ax7, ax8;
for (int j = 0; j < GGML_F16_ARR; j++) { svfloat16_t ay1, ay2, ay3, ay4, ay5, ay6, ay7, ay8;
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j); for (int i = 0; i < np; i += ggml_f16_step) {
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); ax1 = GGML_F16x_VEC_LOAD(x + i + 0 * ggml_f16_epr, 0);
ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx); ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0);
ay1 = GGML_F16x_VEC_FMA(ay1, ax1, vx);
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j); GGML_F16x_VEC_STORE(y + i + 0 * ggml_f16_epr, ay1, 0);
ax2 = GGML_F16x_VEC_LOAD(x + i + 1 * ggml_f16_epr, 1);
ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1);
ay2 = GGML_F16x_VEC_FMA(ay2, ax2, vx);
GGML_F16x_VEC_STORE(y + i + 1 * ggml_f16_epr, ay2, 1);
ax3 = GGML_F16x_VEC_LOAD(x + i + 2 * ggml_f16_epr, 2);
ay3 = GGML_F16x_VEC_LOAD(y + i + 2 * ggml_f16_epr, 2);
ay3 = GGML_F16x_VEC_FMA(ay3, ax3, vx);
GGML_F16x_VEC_STORE(y + i + 2 * ggml_f16_epr, ay3, 2);
ax4 = GGML_F16x_VEC_LOAD(x + i + 3 * ggml_f16_epr, 3);
ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3);
ay4 = GGML_F16x_VEC_FMA(ay4, ax4, vx);
GGML_F16x_VEC_STORE(y + i + 3 * ggml_f16_epr, ay4, 3);
ax5 = GGML_F16x_VEC_LOAD(x + i + 4 * ggml_f16_epr, 4);
ay5 = GGML_F16x_VEC_LOAD(y + i + 4 * ggml_f16_epr, 4);
ay5 = GGML_F16x_VEC_FMA(ay5, ax5, vx);
GGML_F16x_VEC_STORE(y + i + 4 * ggml_f16_epr, ay5, 4);
ax6 = GGML_F16x_VEC_LOAD(x + i + 5 * ggml_f16_epr, 5);
ay6 = GGML_F16x_VEC_LOAD(y + i + 5 * ggml_f16_epr, 5);
ay6 = GGML_F16x_VEC_FMA(ay6, ax6, vx);
GGML_F16x_VEC_STORE(y + i + 5 * ggml_f16_epr, ay6, 5);
ax7 = GGML_F16x_VEC_LOAD(x + i + 6 * ggml_f16_epr, 6);
ay7 = GGML_F16x_VEC_LOAD(y + i + 6 * ggml_f16_epr, 6);
ay7 = GGML_F16x_VEC_FMA(ay7, ax7, vx);
GGML_F16x_VEC_STORE(y + i + 6 * ggml_f16_epr, ay7, 6);
ax8 = GGML_F16x_VEC_LOAD(x + i + 7 * ggml_f16_epr, 7);
ay8 = GGML_F16x_VEC_LOAD(y + i + 7 * ggml_f16_epr, 7);
ay8 = GGML_F16x_VEC_FMA(ay8, ax8, vx);
GGML_F16x_VEC_STORE(y + i + 7 * ggml_f16_epr, ay8, 7);
} }
} const int np2 = (n & ~(ggml_f16_epr - 1));
for (int k = np; k < np2; k += ggml_f16_epr) {
svfloat16_t rx = GGML_F16x_VEC_LOAD(x + k, 0);
svfloat16_t ry = GGML_F16x_VEC_LOAD(y + k, 0);
ry = GGML_F16x_VEC_FMA(ry, rx, vx);
// leftovers GGML_F16x_VEC_STORE(y + k, ry, 0);
for (int i = np; i < n; ++i) { }
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
} if (np2 < n) {
#endif svbool_t pg = svwhilelt_b16(np2, n);
svfloat16_t hx = svld1_f16(pg, (const __fp16 *)(x + np2));
svfloat16_t hy = svld1_f16(pg, (const __fp16 *)(y + np2));
hy = svmad_f16_x(pg, hx, vx, hy);
svst1_f16(pg, (__fp16 *)(y + np2), hy);
}
#elif defined(__riscv_v_intrinsic)
// todo: RVV impl
// scalar
for (int i = 0; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
}
#else
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
GGML_F16_VEC ax[GGML_F16_ARR];
GGML_F16_VEC ay[GGML_F16_ARR];
for (int i = 0; i < np; i += GGML_F16_STEP) {
for (int j = 0; j < GGML_F16_ARR; j++) {
ax[j] = GGML_F16_VEC_LOAD(x + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_FMA(ay[j], ax[j], vx);
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
}
}
// leftovers
for (int i = np; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i]) + GGML_CPU_FP16_TO_FP32(x[i])*v);
}
#endif
#else #else
// scalar // scalar
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
@ -517,33 +698,59 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) {
inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) { inline static void ggml_vec_scale_f16(const int n, ggml_fp16_t * y, const float v) {
#if defined(GGML_SIMD) #if defined(GGML_SIMD)
#if defined(__riscv_v_intrinsic) #if defined(__ARM_FEATURE_SVE)
// todo: RVV impl const int sve_register_length = svcntb() * 8;
// scalar const int ggml_f16_epr = sve_register_length / 16;
for (int i = 0; i < n; ++i) { const int ggml_f16_step = 2 * ggml_f16_epr;
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
}
#else
const int np = (n & ~(GGML_F16_STEP - 1));
GGML_F16_VEC vx = GGML_F16_VEC_SET1(v); GGML_F16x_VEC vx = GGML_F16x_VEC_SET1(v);
const int np = (n & ~(ggml_f16_step - 1));
svfloat16_t ay1, ay2;
GGML_F16_VEC ay[GGML_F16_ARR]; for (int i = 0; i < np; i += ggml_f16_step) {
ay1 = GGML_F16x_VEC_LOAD(y + i + 0*ggml_f16_epr, 0);
ay1 = GGML_F16x_VEC_MUL(ay1, vx);
GGML_F16x_VEC_STORE(y + i + 0*ggml_f16_epr, ay1, 0);
for (int i = 0; i < np; i += GGML_F16_STEP) { ay2 = GGML_F16x_VEC_LOAD(y + i + 1*ggml_f16_epr, 1);
for (int j = 0; j < GGML_F16_ARR; j++) { ay2 = GGML_F16x_VEC_MUL(ay2, vx);
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j); GGML_F16x_VEC_STORE(y + i + 1*ggml_f16_epr, ay2, 1);
ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
} }
} // leftovers
// maximum number of leftover elements will be less that ggmlF_16x_epr. Apply predicated svmad on available elements only
if (np < n) {
svbool_t pg = svwhilelt_b16(np, n);
svfloat16_t hy = svld1_f16(pg, (__fp16 *)(y + np));
svfloat16_t out = svmul_f16_m(pg, hy, vx);
svst1_f16(pg, (__fp16 *)(y + np), out);
}
#elif defined(__riscv_v_intrinsic)
// todo: RVV impl
// scalar
for (int i = 0; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
}
#else
const int np = (n & ~(GGML_F16_STEP - 1));
// leftovers GGML_F16_VEC vx = GGML_F16_VEC_SET1(v);
for (int i = np; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v); GGML_F16_VEC ay[GGML_F16_ARR];
}
#endif for (int i = 0; i < np; i += GGML_F16_STEP) {
for (int j = 0; j < GGML_F16_ARR; j++) {
ay[j] = GGML_F16_VEC_LOAD(y + i + j*GGML_F16_EPR, j);
ay[j] = GGML_F16_VEC_MUL(ay[j], vx);
GGML_F16_VEC_STORE(y + i + j*GGML_F16_EPR, ay, j);
}
}
// leftovers
for (int i = np; i < n; ++i) {
y[i] = GGML_CPU_FP32_TO_FP16(GGML_CPU_FP16_TO_FP32(y[i])*v);
}
#endif
#else #else
// scalar // scalar
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
@ -795,7 +1002,39 @@ https://github.com/openvinotoolkit/openvino/blob/master/src/plugins/intel_cpu/sr
} }
#endif #endif
#if defined(__ARM_NEON) && defined(__aarch64__) #if defined(__ARM_FEATURE_SVE) && defined(__aarch64__)
inline static svfloat32_t ggml_v_expf(svbool_t pg, svfloat32_t x) {
const svfloat32_t r = svdup_n_f32_x(pg, 0x1.8p23f);
const svfloat32_t z = svmla_n_f32_x(pg, r, x, 0x1.715476p+0f);
const svfloat32_t n = svsub_f32_x(pg, z, r);
const svfloat32_t b = svmls_n_f32_x(pg, svmls_n_f32_x(pg, x, n, 0x1.62e4p-1f), n, 0x1.7f7d1cp-20f);
const svuint32_t e = svlsl_n_u32_x(pg, svreinterpret_u32_f32(z), 23);
const svfloat32_t k = svreinterpret_f32_u32(svadd_u32_x(pg, e, svreinterpret_u32_f32(svdup_n_f32_x(pg, 1))));
const svbool_t c = svacgt_n_f32(pg, n, 126);
const svfloat32_t u = svmul_f32_x(pg, b, b);
const svfloat32_t j = svmla_f32_x(pg,
svmul_n_f32_x(pg, b, 0x1.ffffecp-1f),
svmla_f32_x(pg, svmla_f32_x(pg, svdup_n_f32_x(pg, 0x1.fffdb6p-2f), svdup_n_f32_x(pg, 0x1.555e66p-3f), b),
svmla_f32_x(pg, svdup_n_f32_x(pg, 0x1.573e2ep-5f), svdup_n_f32_x(pg, 0x1.0e4020p-7f), b), u), u);
const svuint32_t d = svdup_n_u32_z(svcmple_n_f32(pg, n, 0.0), 0x82000000);
const svfloat32_t s1 = svreinterpret_f32_u32(svadd_n_u32_x(pg, d, 0x7f000000));
const svfloat32_t s2 = svreinterpret_f32_u32(svsub_u32_x(pg, e, d));
return svsel_f32(svacgt_f32(pg, n, svdup_n_f32_x(pg, 192)), svmul_f32_x(pg, s1, s1),
svsel_f32(c, svmul_f32_x(pg, svmla_f32_x(pg, s2, s2, j), s1), svmla_f32_x(pg, k, k, j)));
}
// computes silu x/(1+exp(-x)) in single precision vector
inline static svfloat32_t ggml_v_silu(svbool_t pg, svfloat32_t x) {
const svfloat32_t one = svdup_n_f32_x(pg, 1.0f);
const svfloat32_t zero = svdup_n_f32_x(pg, 0.0f);
const svfloat32_t neg_x = svsub_f32_x(pg, zero, x);
const svfloat32_t exp_neg_x = ggml_v_expf(pg, neg_x);
const svfloat32_t one_plus_exp_neg_x = svadd_f32_x(pg, one, exp_neg_x);
return svdiv_f32_x(pg, x, one_plus_exp_neg_x);
}
#elif defined(__ARM_NEON) && defined(__aarch64__)
// adapted from arm limited optimized routine // adapted from arm limited optimized routine
// the maximum error is 1.45358 plus 0.5 ulps // the maximum error is 1.45358 plus 0.5 ulps
@ -1030,6 +1269,14 @@ inline static vfloat32m2_t ggml_v_expf_m2(vfloat32m2_t x, int vl) {
vl); vl);
} }
// computes silu x/(1+exp(-x)) in single precision vector
inline static vfloat32m2_t ggml_v_silu_m2(vfloat32m2_t x, int vl) {
const vfloat32m2_t neg_x = __riscv_vfneg_v_f32m2(x, vl);
const vfloat32m2_t exp_neg_x = ggml_v_expf_m2(neg_x, vl);
const vfloat32m2_t one_plus_exp_neg_x = __riscv_vfadd_vf_f32m2(exp_neg_x, 1.0f, vl);
return __riscv_vfdiv_vv_f32m2(x, one_plus_exp_neg_x, vl);
}
#endif // __ARM_NEON / __AVX2__ / __SSE2__ / __riscv_v_intrinsic #endif // __ARM_NEON / __AVX2__ / __SSE2__ / __riscv_v_intrinsic
inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { inline static void ggml_vec_silu_f16(const int n, ggml_fp16_t * y, const ggml_fp16_t * x) {

View File

@ -563,6 +563,40 @@ static __device__ __forceinline__ float ggml_cuda_e8m0_to_fp32(uint8_t x) {
#endif // CUDART_VERSION >= 12050 #endif // CUDART_VERSION >= 12050
} }
// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1.
// Precompute mp (m' in the paper) and L such that division
// can be computed using a multiply (high 32b of 64b result)
// and a shift:
//
// n/d = (mulhi(n, mp) + n) >> L;
static const uint3 init_fastdiv_values(uint32_t d) {
GGML_ASSERT(d != 0);
// compute L = ceil(log2(d));
uint32_t L = 0;
while (L < 32 && (uint32_t{ 1 } << L) < d) {
L++;
}
uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1);
// pack divisor as well to reduce error surface
return make_uint3(mp, L, d);
}
static __device__ __forceinline__ uint32_t fastdiv(uint32_t n, const uint3 fastdiv_values) {
// expects fastdiv_values to contain <mp, L, divisor> in <x, y, z>
// fastdiv_values.z is unused and optimized away by the compiler.
// Compute high 32 bits of n * mp
const uint32_t hi = __umulhi(n, fastdiv_values.x);
// add n, apply bit shift
return (hi + n) >> fastdiv_values.y;
}
static __device__ __forceinline__ uint32_t fastmodulo(uint32_t n, const uint3 fastdiv_values) {
// expects fastdiv_values to contain <mp, L, divisor> in <x, y, z> (see init_fastdiv_values)
return n - fastdiv(n, fastdiv_values) * fastdiv_values.z;
}
typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v); typedef void (*dequantize_kernel_t)(const void * vx, const int64_t ib, const int iqs, float2 & v);
static __device__ __forceinline__ float get_alibi_slope( static __device__ __forceinline__ float get_alibi_slope(

View File

@ -1,4 +1,5 @@
#include "conv2d.cuh" #include "conv2d.cuh"
#include "convert.cuh"
struct conv_params { struct conv_params {
const int64_t IW, IH; const int64_t IW, IH;
@ -82,7 +83,7 @@ static __global__ void conv2d_kernel(const float * __restrict__ input,
int64_t n, c_out, out_y, out_x; int64_t n, c_out, out_y, out_x;
Layout::unpack_indices(global_idx, P, n, c_out, out_y, out_x); Layout::unpack_indices(global_idx, P, n, c_out, out_y, out_x);
T acc = 0; float acc = 0.0f;
for (int64_t c_in = 0; c_in < P.IC; ++c_in) { for (int64_t c_in = 0; c_in < P.IC; ++c_in) {
kernel_bounds bounds = calculate_kernel_bounds(out_x, out_y, P); kernel_bounds bounds = calculate_kernel_bounds(out_x, out_y, P);
@ -93,21 +94,15 @@ static __global__ void conv2d_kernel(const float * __restrict__ input,
for (int64_t kx = bounds.x_min; kx < bounds.x_max; ++kx) { for (int64_t kx = bounds.x_min; kx < bounds.x_max; ++kx) {
const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X); const int64_t in_x = calculate_input_coord(out_x, kx, P.ST_X, P.DL_X, P.PD_X);
T input_val; const float input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)];
if (std::is_same<T, half>::value) { const T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
input_val = __float2half(input[Layout::input_index(n, c_in, in_y, in_x, P)]); acc += (input_val * ggml_cuda_cast<float>(kernel_val));
} else {
input_val = input[Layout::input_index(n, c_in, in_y, in_x, P)];
}
T kernel_val = kernel[Layout::kernel_index(c_out, c_in, ky, kx, P)];
acc += (input_val * kernel_val);
} }
} }
} }
// [N, OC, OH, OW] // [N, OC, OH, OW]
output[Layout::output_index(n, c_out, out_y, out_x, P)] = (float) acc; output[Layout::output_index(n, c_out, out_y, out_x, P)] = acc;
} }
template <typename T> template <typename T>

View File

@ -2,6 +2,8 @@
#include "dequantize.cuh" #include "dequantize.cuh"
#include "convert.cuh" #include "convert.cuh"
#define MAX_GRIDDIM_Y 65535
template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t> template<int qk, int qr, dequantize_kernel_t dequantize_kernel, typename dst_t>
static __global__ void k_get_rows( static __global__ void k_get_rows(
const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst, const void * __restrict__ src0, const int32_t * __restrict__ src1, dst_t * __restrict__ dst,
@ -11,32 +13,29 @@ static __global__ void k_get_rows(
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. for (int64_t i00 = 2*(blockIdx.y*blockDim.x + threadIdx.x); i00 < ne00; i00 += gridDim.y*blockDim.x) {
const int i00 = (blockIdx.y * blockDim.x + threadIdx.x)*2; // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
const int i10 = blockIdx.x; const int i10 = blockIdx.x;
const int i11 = blockIdx.z / ne12; const int i11 = blockIdx.z / ne12;
const int i12 = blockIdx.z % ne12; const int i12 = blockIdx.z % ne12;
if (i00 >= ne00) { const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
return;
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03;
const int ib = i00/qk; // block index
const int iqs = (i00%qk)/qr; // quant index
const int iybs = i00 - i00%qk; // dst block start index
const int y_offset = qr == 1 ? 1 : qk/2;
// dequantize
float2 v;
dequantize_kernel(src0_row, ib, iqs, v);
dst_row[iybs + iqs + 0] = ggml_cuda_cast<dst_t>(v.x);
dst_row[iybs + iqs + y_offset] = ggml_cuda_cast<dst_t>(v.y);
} }
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
const void * src0_row = (const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03;
const int ib = i00/qk; // block index
const int iqs = (i00%qk)/qr; // quant index
const int iybs = i00 - i00%qk; // dst block start index
const int y_offset = qr == 1 ? 1 : qk/2;
// dequantize
float2 v;
dequantize_kernel(src0_row, ib, iqs, v);
dst_row[iybs + iqs + 0] = ggml_cuda_cast<dst_t>(v.x);
dst_row[iybs + iqs + y_offset] = ggml_cuda_cast<dst_t>(v.y);
} }
template<typename src0_t, typename dst_t> template<typename src0_t, typename dst_t>
@ -48,22 +47,23 @@ static __global__ void k_get_rows_float(
/*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03, /*const size_t nb00,*/ const size_t nb01, const size_t nb02, const size_t nb03,
const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) { const size_t s10, const size_t s11, const size_t s12/*, const size_t s13*/) {
// The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher. for (int64_t i00 = blockIdx.y*blockDim.x + threadIdx.x; i00 < ne00; i00 += gridDim.y*blockDim.x) {
const int i00 = blockIdx.y * blockDim.x + threadIdx.x; // The x and y dimensions of the grid are swapped because the maximum allowed grid size for x is higher.
const int i10 = blockIdx.x; const int i10 = blockIdx.x;
const int i11 = blockIdx.z / ne12; const int i11 = blockIdx.z / ne12;
const int i12 = blockIdx.z % ne12; const int i12 = blockIdx.z % ne12;
if (i00 >= ne00) { if (i00 >= ne00) {
return; return;
}
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
dst_row[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
} }
const int i01 = src1[i10*s10 + i11*s11 + i12*s12];
dst_t * dst_row = dst + i10*s1 + i11*s2 + i12*s3;
const src0_t * src0_row = (const src0_t *)((const char *) src0 + i01*nb01 + i11*nb02 + i12*nb03);
dst_row[i00] = ggml_cuda_cast<dst_t>(src0_row[i00]);
} }
template<typename grad_t, typename dst_t> template<typename grad_t, typename dst_t>
@ -98,7 +98,7 @@ static void get_rows_cuda_q(
cudaStream_t stream) { cudaStream_t stream) {
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE); const int block_num_y = (ne00 + 2*CUDA_GET_ROWS_BLOCK_SIZE - 1) / (2*CUDA_GET_ROWS_BLOCK_SIZE);
const dim3 block_nums(ne10, block_num_y, ne11*ne12); const dim3 block_nums(ne10, MIN(block_num_y, MAX_GRIDDIM_Y), ne11*ne12);
// strides in elements // strides in elements
// const size_t s0 = nb0 / sizeof(dst_t); // const size_t s0 = nb0 / sizeof(dst_t);
@ -131,7 +131,7 @@ static void get_rows_cuda_float(
cudaStream_t stream) { cudaStream_t stream) {
const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1); const dim3 block_dims(CUDA_GET_ROWS_BLOCK_SIZE, 1, 1);
const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE; const int block_num_y = (ne00 + CUDA_GET_ROWS_BLOCK_SIZE - 1) / CUDA_GET_ROWS_BLOCK_SIZE;
const dim3 block_nums(ne10, block_num_y, ne11*ne12); const dim3 block_nums(ne10, MIN(block_num_y, MAX_GRIDDIM_Y), ne11*ne12);
// strides in elements // strides in elements
// const size_t s0 = nb0 / sizeof(dst_t); // const size_t s0 = nb0 / sizeof(dst_t);

View File

@ -2452,6 +2452,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
ggml_cuda_op_im2col(ctx, dst); ggml_cuda_op_im2col(ctx, dst);
break; break;
case GGML_OP_IM2COL_3D:
ggml_cuda_op_im2col_3d(ctx, dst);
break;
case GGML_OP_CONV_2D: case GGML_OP_CONV_2D:
ggml_cuda_op_conv2d(ctx, dst); ggml_cuda_op_conv2d(ctx, dst);
break; break;
@ -3559,6 +3562,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]); return op->src[0]->nb[0] == ggml_type_size(op->src[0]->type) && ggml_is_contiguous_2(op->src[0]);
} }
case GGML_OP_IM2COL: case GGML_OP_IM2COL:
case GGML_OP_IM2COL_3D:
case GGML_OP_CONV_2D: case GGML_OP_CONV_2D:
case GGML_OP_CONV_2D_DW: case GGML_OP_CONV_2D_DW:
case GGML_OP_CONV_TRANSPOSE_2D: case GGML_OP_CONV_TRANSPOSE_2D:

View File

@ -112,3 +112,132 @@ void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream); im2col_cuda_f32(src1_d, (float *) dst_d, IW, IH, OW, OH, KW, KH, IC, N, IC_IH_IW, IH_IW, s0, s1, p0, p1, d0, d1, stream);
} }
} }
// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
template <typename T>
static __global__ void im2col_3d_kernel(
const float * src, T * dst,
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
int64_t OH_OW, int64_t KD_KH_KW, int64_t ID_IH_IW, int64_t KH_KW, int64_t IH_IW, int64_t IC_ID_IH_IW,
int64_t IC_KD_KH_KW, int64_t OW_KD_KH_KW, int64_t OD_OH_OW_IC_KD_KH_KW, int64_t OH_OW_IC_KD_KH_KW,
int64_t OW_IC_KD_KH_KW, int64_t N_OD_OH, int64_t OD_OH,
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2) {
const int64_t i = threadIdx.x + blockIdx.x * blockDim.x;
if (i >= IC_KD_KH_KW) {
return;
}
const int64_t iic = i / KD_KH_KW;
const int64_t ikd = (i - iic * KD_KH_KW) / KH_KW;
const int64_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW;
const int64_t ikw = i % KW;
const int64_t iow = blockIdx.y;
for (int64_t iz = blockIdx.z; iz < N_OD_OH; iz+=MAX_GRIDDIM_Z) {
const int64_t in = iz / OD_OH;
const int64_t iod = (iz - in*OD_OH) / OH;
const int64_t ioh = iz % OH;
const int64_t iiw = iow * s0 + ikw * d0 - p0;
const int64_t iih = ioh * s1 + ikh * d1 - p1;
const int64_t iid = iod * s2 + ikd * d2 - p2;
const int64_t offset_dst = in*OD_OH_OW_IC_KD_KH_KW + iod*OH_OW_IC_KD_KH_KW + ioh*OW_IC_KD_KH_KW + iow*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
if (iih < 0 || iih >= IH || iiw < 0 || iiw >= IW || iid < 0 || iid >= ID) {
dst[offset_dst] = 0.0f;
} else {
const int64_t offset_src = in*IC_ID_IH_IW + iic*ID_IH_IW + iid*IH_IW + iih*IW + iiw;
dst[offset_dst] = src[offset_src];
}
}
}
// [N*IC, ID, IH, IW] => [N*OD, OH, OW, IC * KD * KH * KW]
template <typename T>
static void im2col_3d_cuda(const float * src, T* dst,
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
const int64_t OH_OW = OH*OW;
const int64_t KD_KH_KW = KD*KH*KW;
const int64_t ID_IH_IW = ID*IH*IW;
const int64_t KH_KW = KH*KW;
const int64_t IH_IW = IH*IW;
const int64_t IC_KD_KH_KW = IC*KD*KH*KW;
const int64_t OW_KD_KH_KW = OW*KD*KH*KW;
const int64_t N_OD_OH = N*OD*OH;
const int64_t OD_OH = OD*OH;
const int64_t IC_ID_IH_IW = IC*ID*IH*IW;
const int64_t OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW;
const int64_t OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW;
const int64_t OW_IC_KD_KH_KW = OW*IC*KD*KH*KW;
const int64_t num_blocks = (IC_KD_KH_KW + CUDA_IM2COL_BLOCK_SIZE - 1) / CUDA_IM2COL_BLOCK_SIZE;
dim3 block_nums(num_blocks, OW, MIN(N_OD_OH, MAX_GRIDDIM_Z));
im2col_3d_kernel<<<block_nums, MIN(IC_KD_KH_KW, CUDA_IM2COL_BLOCK_SIZE) , 0, stream>>>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW,
OH_OW, KD_KH_KW, ID_IH_IW, KH_KW, IH_IW, IC_ID_IH_IW,
IC_KD_KH_KW, OW_KD_KH_KW, OD_OH_OW_IC_KD_KH_KW,
OH_OW_IC_KD_KH_KW, OW_IC_KD_KH_KW, N_OD_OH, OD_OH,
s0, s1, s2, p0, p1, p2, d0, d1, d2);
}
static void im2col_3d_cuda_f16(const float * src, half * dst,
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
im2col_3d_cuda<half>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
}
static void im2col_3d_cuda_f32(const float * src, float * dst,
int64_t N, int64_t IC, int64_t ID, int64_t IH, int64_t IW, int64_t OC,
int64_t KD, int64_t KH, int64_t KW, int64_t OD, int64_t OH, int64_t OW,
int s0, int s1, int s2, int p0, int p1, int p2, int d0, int d1, int d2, cudaStream_t stream) {
im2col_3d_cuda<float>(src, dst, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
}
void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const ggml_tensor * src1 = dst->src[1];
const float * src1_d = (const float *)src1->data;
float * dst_d = (float *)dst->data;
cudaStream_t stream = ctx.stream();
GGML_ASSERT(src1->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F16 || dst->type == GGML_TYPE_F32);
GGML_TENSOR_BINARY_OP_LOCALS
const int32_t s0 = ((const int32_t *)(dst->op_params))[0];
const int32_t s1 = ((const int32_t *)(dst->op_params))[1];
const int32_t s2 = ((const int32_t *)(dst->op_params))[2];
const int32_t p0 = ((const int32_t *)(dst->op_params))[3];
const int32_t p1 = ((const int32_t *)(dst->op_params))[4];
const int32_t p2 = ((const int32_t *)(dst->op_params))[5];
const int32_t d0 = ((const int32_t *)(dst->op_params))[6];
const int32_t d1 = ((const int32_t *)(dst->op_params))[7];
const int32_t d2 = ((const int32_t *)(dst->op_params))[8];
const int32_t IC = ((const int32_t *)(dst->op_params))[9];
const int64_t N = ne13 / IC;
const int64_t ID = ne12;
const int64_t IH = ne11;
const int64_t IW = ne10;
const int64_t OC = ne03 / IC;
const int64_t KD = ne02;
const int64_t KH = ne01;
const int64_t KW = ne00;
const int64_t OD = ne3 / N;
const int64_t OH = ne2;
const int64_t OW = ne1;
if(dst->type == GGML_TYPE_F16) {
im2col_3d_cuda_f16(src1_d, (half *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
} else {
im2col_3d_cuda_f32(src1_d, (float *) dst_d, N, IC, ID, IH, IW, OC, KD, KH, KW, OD, OH, OW, s0, s1, s2, p0, p1, p2, d0, d1, d2, stream);
}
}

View File

@ -3,3 +3,4 @@
#define CUDA_IM2COL_BLOCK_SIZE 256 #define CUDA_IM2COL_BLOCK_SIZE 256
void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_im2col(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
void ggml_cuda_op_im2col_3d(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -141,9 +141,10 @@ template <ggml_type type, int ncols_dst>
__launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1) __launch_bounds__(calc_nwarps(ncols_dst, get_device_table_id())*ggml_cuda_get_physical_warp_size(), 1)
static __global__ void mul_mat_vec_q( static __global__ void mul_mat_vec_q(
const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst, const void * __restrict__ vx, const void * __restrict__ vy, const int32_t * __restrict__ ids, float * __restrict__ dst,
const int ncols_x, const int nchannels_y, const int stride_row_x, const int stride_col_y, const int stride_col_dst, const uint32_t ncols_x, const uint3 nchannels_y, const uint32_t stride_row_x, const uint32_t stride_col_y,
const int channel_ratio, const int stride_channel_x, const int stride_channel_y, const int stride_channel_dst, const uint32_t stride_col_dst, const uint3 channel_ratio, const uint32_t stride_channel_x,
const int sample_ratio, const int stride_sample_x, const int stride_sample_y, const int stride_sample_dst) { const uint32_t stride_channel_y, const uint32_t stride_channel_dst, const uint3 sample_ratio,
const uint32_t stride_sample_x, const uint32_t stride_sample_y, const uint32_t stride_sample_dst) {
constexpr int qk = ggml_cuda_type_traits<type>::qk; constexpr int qk = ggml_cuda_type_traits<type>::qk;
constexpr int qi = ggml_cuda_type_traits<type>::qi; constexpr int qi = ggml_cuda_type_traits<type>::qi;
@ -161,12 +162,12 @@ static __global__ void mul_mat_vec_q(
constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi; constexpr int blocks_per_iter = vdr * nwarps*warp_size / qi;
// The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1. // The MUL_MAT_ID code path with ids != nullptr is only implemented for ncols_dst == 1.
const int channel_dst = blockIdx.y; const uint32_t channel_dst = blockIdx.y;
const int channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : channel_dst / channel_ratio; const uint32_t channel_x = ncols_dst == 1 && ids ? ids[channel_dst] : fastdiv(channel_dst, channel_ratio);
const int channel_y = ncols_dst == 1 && ids ? channel_dst % nchannels_y : channel_dst; const uint32_t channel_y = ncols_dst == 1 && ids ? fastmodulo(channel_dst, nchannels_y) : channel_dst;
const int sample_dst = blockIdx.z; const uint32_t sample_dst = blockIdx.z;
const int sample_x = sample_dst / sample_ratio; const uint32_t sample_x = fastdiv(sample_dst, sample_ratio);
const int sample_y = sample_dst; const uint32_t sample_y = sample_dst;
// partial sum for each thread // partial sum for each thread
float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}}; float tmp[ncols_dst][rows_per_cuda_block] = {{0.0f}};
@ -247,8 +248,9 @@ static void mul_mat_vec_q_switch_ncols_dst(
GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0); GGML_ASSERT(ncols_x % ggml_blck_size(type) == 0);
GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE); GGML_ASSERT(ncols_dst <= MMVQ_MAX_BATCH_SIZE);
const int channel_ratio = nchannels_dst / nchannels_x; const uint3 nchannels_y_fd = ids ? init_fastdiv_values(nchannels_y) : make_uint3(0, 0, 0);
const int sample_ratio = nsamples_dst / nsamples_x; const uint3 channel_ratio_fd = ids ? make_uint3(0, 0, 0) : init_fastdiv_values(nchannels_dst / nchannels_x);
const uint3 sample_ratio_fd = init_fastdiv_values(nsamples_dst / nsamples_x);
const int device = ggml_cuda_get_device(); const int device = ggml_cuda_get_device();
const int warp_size = ggml_cuda_info().devices[device].warp_size; const int warp_size = ggml_cuda_info().devices[device].warp_size;
@ -256,86 +258,70 @@ static void mul_mat_vec_q_switch_ncols_dst(
GGML_ASSERT(!ids || ncols_dst == 1); GGML_ASSERT(!ids || ncols_dst == 1);
switch (ncols_dst) { switch (ncols_dst) {
case 1: case 1: {
{
constexpr int c_ncols_dst = 1; constexpr int c_ncols_dst = 1;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
break; } break;
} case 2: {
case 2:
{
constexpr int c_ncols_dst = 2; constexpr int c_ncols_dst = 2;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
break; } break;
} case 3: {
case 3:
{
constexpr int c_ncols_dst = 3; constexpr int c_ncols_dst = 3;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
break; } break;
} case 4: {
case 4:
{
constexpr int c_ncols_dst = 4; constexpr int c_ncols_dst = 4;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
break; } break;
} case 5: {
case 5:
{
constexpr int c_ncols_dst = 5; constexpr int c_ncols_dst = 5;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
break; } break;
} case 6: {
case 6:
{
constexpr int c_ncols_dst = 6; constexpr int c_ncols_dst = 6;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
break; } break;
} case 7: {
case 7:
{
constexpr int c_ncols_dst = 7; constexpr int c_ncols_dst = 7;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
break; } break;
} case 8: {
case 8:
{
constexpr int c_ncols_dst = 8; constexpr int c_ncols_dst = 8;
std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id); std::pair<dim3, dim3> dims = calc_launch_params(c_ncols_dst, nrows_x, nchannels_dst, nsamples_dst, warp_size, table_id);
mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>> mul_mat_vec_q<type, c_ncols_dst><<<dims.first, dims.second, 0, stream>>>
(vx, vy, ids, dst, ncols_x, nchannels_y, stride_row_x, stride_col_y, stride_col_dst, (vx, vy, ids, dst, ncols_x, nchannels_y_fd, stride_row_x, stride_col_y, stride_col_dst,
channel_ratio, stride_channel_x, stride_channel_y, stride_channel_dst, channel_ratio_fd, stride_channel_x, stride_channel_y, stride_channel_dst,
sample_ratio, stride_sample_x, stride_sample_y, stride_sample_dst); sample_ratio_fd, stride_sample_x, stride_sample_y, stride_sample_dst);
break; } break;
}
default: default:
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
break; break;

View File

@ -105,29 +105,29 @@ static __global__ void group_norm_f32(const float * x, float * dst, const int gr
} }
template <int block_size, bool do_multiply = false, bool do_add = false> template <int block_size, bool do_multiply = false, bool do_add = false>
static __global__ void rms_norm_f32(const float * x, float * dst, static __global__ void rms_norm_f32(const float * x,
float * dst,
const int ncols, const int ncols,
const int64_t stride_row, const int64_t stride_row,
const int64_t stride_channel, const int64_t stride_channel,
const int64_t stride_sample, const int64_t stride_sample,
const float eps, const float eps,
const float * mul = nullptr, const float * mul = nullptr,
const int64_t mul_stride_row = 0, const int64_t mul_stride_row = 0,
const int64_t mul_stride_channel = 0, const int64_t mul_stride_channel = 0,
const int64_t mul_stride_sample = 0, const int64_t mul_stride_sample = 0,
const int mul_ncols = 0, const uint3 mul_ncols_packed = make_uint3(0, 0, 0),
const int mul_nrows = 0, const uint3 mul_nrows_packed = make_uint3(0, 0, 0),
const int mul_nchannels = 0, const uint3 mul_nchannels_packed = make_uint3(0, 0, 0),
const int mul_nsamples = 0, const uint3 mul_nsamples_packed = make_uint3(0, 0, 0),
const float * add = nullptr, const float * add = nullptr,
const int64_t add_stride_row = 0, const int64_t add_stride_row = 0,
const int64_t add_stride_channel = 0, const int64_t add_stride_channel = 0,
const int64_t add_stride_sample = 0, const int64_t add_stride_sample = 0,
const int add_ncols = 0, const uint3 add_ncols_packed = make_uint3(0, 0, 0),
const int add_nrows = 0, const uint3 add_nrows_packed = make_uint3(0, 0, 0),
const int add_nchannels = 0, const uint3 add_nchannels_packed = make_uint3(0, 0, 0),
const int add_nsamples = 0) { const uint3 add_nsamples_packed = make_uint3(0, 0, 0)) {
const int nrows = gridDim.x; const int nrows = gridDim.x;
const int nchannels = gridDim.y; const int nchannels = gridDim.y;
@ -142,16 +142,16 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
dst += ((sample*nchannels + channel)*nrows + row)*ncols; dst += ((sample*nchannels + channel)*nrows + row)*ncols;
if constexpr (do_multiply) { if constexpr (do_multiply) {
const int mul_row = row % mul_nrows; const uint32_t mul_row = fastmodulo(row, mul_nrows_packed);
const int mul_channel = channel % mul_nchannels; const uint32_t mul_channel = fastmodulo(channel, mul_nchannels_packed);
const int mul_sample = sample % mul_nsamples; const uint32_t mul_sample = fastmodulo(sample, mul_nsamples_packed);
mul += mul_sample*mul_stride_sample + mul_channel*mul_stride_channel + mul_row*mul_stride_row; mul += mul_sample * mul_stride_sample + mul_channel * mul_stride_channel + mul_row * mul_stride_row;
} }
if constexpr (do_add) { if constexpr (do_add) {
const int add_row = row % add_nrows; const int add_row = fastmodulo(row, add_nrows_packed);
const int add_channel = channel % add_nchannels; const int add_channel = fastmodulo(channel, add_nchannels_packed);
const int add_sample = sample % add_nsamples; const int add_sample = fastmodulo(sample, add_nsamples_packed);
add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row; add += add_sample * add_stride_sample + add_channel * add_stride_channel + add_row * add_stride_row;
} }
@ -165,15 +165,18 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
// sum up partial sums // sum up partial sums
tmp = warp_reduce_sum(tmp); tmp = warp_reduce_sum(tmp);
if constexpr (block_size > WARP_SIZE) { if constexpr (block_size > WARP_SIZE) {
static_assert(block_size == 1024, "unexpected block_size"); static_assert((block_size <= 1024) && (block_size % 32 == 0), "unexpected block_size");
__shared__ float s_sum[32]; __shared__ float s_sum[32];
const int warp_id = threadIdx.x / WARP_SIZE; const int warp_id = tid / WARP_SIZE;
const int lane_id = threadIdx.x % WARP_SIZE; const int lane_id = tid % WARP_SIZE;
if (lane_id == 0) { if (lane_id == 0) {
s_sum[warp_id] = tmp; s_sum[warp_id] = tmp;
} }
__syncthreads(); __syncthreads();
tmp = s_sum[lane_id]; tmp = 0.0f;
if (lane_id < (block_size / WARP_SIZE)) {
tmp = s_sum[lane_id];
}
tmp = warp_reduce_sum(tmp); tmp = warp_reduce_sum(tmp);
} }
@ -182,12 +185,12 @@ static __global__ void rms_norm_f32(const float * x, float * dst,
for (int col = tid; col < ncols; col += block_size) { for (int col = tid; col < ncols; col += block_size) {
if constexpr (do_multiply && do_add) { if constexpr (do_multiply && do_add) {
const int mul_col = col % mul_ncols; const int mul_col = fastmodulo(col, mul_ncols_packed);
const int add_col = col % add_ncols; const int add_col = fastmodulo(col, add_ncols_packed);
dst[col] = scale * x[col] * mul[mul_col] + add[add_col]; dst[col] = scale * x[col] * mul[mul_col] + add[add_col];
} else if constexpr (do_multiply) { } else if constexpr (do_multiply) {
const int mul_col = col % mul_ncols; const int mul_col = fastmodulo(col, mul_ncols_packed);
dst[col] = scale * x[col] * mul[mul_col]; dst[col] = scale * x[col] * mul[mul_col];
} else { } else {
dst[col] = scale * x[col]; dst[col] = scale * x[col];
} }
@ -354,77 +357,86 @@ static void rms_norm_f32_cuda(
const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) { const int64_t stride_row, const int64_t stride_channel, const int64_t stride_sample, const float eps, cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples); const dim3 blocks_num(nrows, nchannels, nsamples);
if (ncols < 1024) { if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1); const dim3 block_dims(256, 1, 1);
rms_norm_f32<WARP_SIZE, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); rms_norm_f32<256, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} else { } else {
const dim3 block_dims(1024, 1, 1); const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps); rms_norm_f32<1024, false><<<blocks_num, block_dims, 0, stream>>>(x, dst, ncols, stride_row, stride_channel, stride_sample, eps);
} }
} }
static void rms_norm_mul_f32_cuda(const float * x, static void rms_norm_mul_f32_cuda(const float * x,
const float * mul, const float * mul,
const float * add, const float * add,
float * dst, float * dst,
const int ncols, const int ncols,
const int nrows, const int nrows,
const int nchannels, const int nchannels,
const int nsamples, const int nsamples,
const int64_t stride_row, const int64_t stride_row,
const int64_t stride_channel, const int64_t stride_channel,
const int64_t stride_sample, const int64_t stride_sample,
const int64_t mul_stride_row, const int64_t mul_stride_row,
const int64_t mul_stride_channel, const int64_t mul_stride_channel,
const int64_t mul_stride_sample, const int64_t mul_stride_sample,
const int mul_ncols, const uint32_t mul_ncols,
const int mul_nrows, const uint32_t mul_nrows,
const int mul_nchannels, const uint32_t mul_nchannels,
const int mul_nsamples, const uint32_t mul_nsamples,
const int64_t add_stride_row, const int64_t add_stride_row,
const int64_t add_stride_channel, const int64_t add_stride_channel,
const int64_t add_stride_sample, const int64_t add_stride_sample,
const int add_ncols, const uint32_t add_ncols,
const int add_nrows, const uint32_t add_nrows,
const int add_nchannels, const uint32_t add_nchannels,
const int add_nsamples, const uint32_t add_nsamples,
const float eps, const float eps,
cudaStream_t stream) { cudaStream_t stream) {
const dim3 blocks_num(nrows, nchannels, nsamples); const dim3 blocks_num(nrows, nchannels, nsamples);
if (mul == nullptr) { if (mul == nullptr) {
rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream); rms_norm_f32_cuda(x, dst, ncols, nrows, nchannels, nsamples, stride_row, stride_channel, stride_sample, eps, stream);
return; return;
} }
if (add == nullptr) { if (add == nullptr) {
const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols);
const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows);
const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels);
const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples);
if (ncols < 1024) { if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1); const dim3 block_dims(256, 1, 1);
rms_norm_f32<WARP_SIZE, true><<<blocks_num, block_dims, 0, stream>>>(x, dst, rms_norm_f32<256, true><<<blocks_num, block_dims, 0, stream>>>(
ncols, stride_row, stride_channel, stride_sample, eps, x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
} else { } else {
const dim3 block_dims(1024, 1, 1); const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(x, dst, rms_norm_f32<1024, true><<<blocks_num, block_dims, 0, stream>>>(
ncols, stride_row, stride_channel, stride_sample, eps, x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed);
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples);
} }
} else { } else {
const uint3 mul_ncols_packed = init_fastdiv_values(mul_ncols);
const uint3 mul_nrows_packed = init_fastdiv_values(mul_nrows);
const uint3 mul_nchannels_packed = init_fastdiv_values(mul_nchannels);
const uint3 mul_nsamples_packed = init_fastdiv_values(mul_nsamples);
const uint3 add_ncols_packed = init_fastdiv_values(add_ncols);
const uint3 add_nrows_packed = init_fastdiv_values(add_nrows);
const uint3 add_nchannels_packed = init_fastdiv_values(add_nchannels);
const uint3 add_nsamples_packed = init_fastdiv_values(add_nsamples);
if (ncols < 1024) { if (ncols < 1024) {
const dim3 block_dims(WARP_SIZE, 1, 1); const dim3 block_dims(256, 1, 1);
rms_norm_f32<WARP_SIZE, true, true><<<blocks_num, block_dims, 0, stream>>>(x, dst, rms_norm_f32<256, true, true><<<blocks_num, block_dims, 0, stream>>>(
ncols, stride_row, stride_channel, stride_sample, eps, x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
add, add_stride_row, add_stride_channel, add_stride_sample, add_nchannels_packed, add_nsamples_packed);
add_ncols, add_nrows, add_nchannels, add_nsamples);
} else { } else {
const dim3 block_dims(1024, 1, 1); const dim3 block_dims(1024, 1, 1);
rms_norm_f32<1024, true, true><<<blocks_num, block_dims, 0, stream>>>(x, dst, rms_norm_f32<1024, true, true><<<blocks_num, block_dims, 0, stream>>>(
ncols, stride_row, stride_channel, stride_sample, eps, x, dst, ncols, stride_row, stride_channel, stride_sample, eps, mul, mul_stride_row, mul_stride_channel,
mul, mul_stride_row, mul_stride_channel, mul_stride_sample, mul_stride_sample, mul_ncols_packed, mul_nrows_packed, mul_nchannels_packed, mul_nsamples_packed, add,
mul_ncols, mul_nrows, mul_nchannels, mul_nsamples, add_stride_row, add_stride_channel, add_stride_sample, add_ncols_packed, add_nrows_packed,
add, add_stride_row, add_stride_channel, add_stride_sample, add_nchannels_packed, add_nsamples_packed);
add_ncols, add_nrows, add_nchannels, add_nsamples);
} }
} }
} }

View File

@ -1,36 +1,50 @@
#include "pad.cuh" #include "pad.cuh"
static __global__ void pad_f32(const float * x, float * dst, const int ne0, const int ne00, const int ne01, const int ne02, const int ne03) { static __global__ void pad_f32(const float * src, float * dst,
// blockIdx.z: idx of ne2*ne3, aka ne02*ne03 const int lp0, const int rp0, const int lp1, const int rp1,
// blockIdx.y: idx of ne1 const int lp2, const int rp2, const int lp3, const int rp3,
// blockIDx.x: idx of ne0 / BLOCK_SIZE const int ne0, const int ne1, const int ne2, const int ne3) {
int nidx = threadIdx.x + blockIdx.x * blockDim.x; // blockIdx.z: i3*ne2+i2
if (nidx >= ne0) { // blockIdx.y: i1
// blockIDx.x: i0 / CUDA_PAD_BLOCK_SIZE
// gridDim.y: ne1
int i0 = threadIdx.x + blockIdx.x * blockDim.x;
int i1 = blockIdx.y;
int i2 = blockIdx.z % ne2;
int i3 = blockIdx.z / ne2;
if (i0 >= ne0 || i1 >= ne1 || i2 >= ne2 || i3 >= ne3) {
return; return;
} }
// operation // operation
int offset_dst = const int64_t dst_idx = i3*(ne0*ne1*ne2) + i2*(ne0*ne1) + i1*ne0 + i0;
nidx + if ((i0 >= lp0 && i0 < ne0 - rp0) &&
blockIdx.y * ne0 + (i1 >= lp1 && i1 < ne1 - rp1) &&
blockIdx.z * ne0 * gridDim.y; (i2 >= lp2 && i2 < ne2 - rp2) &&
if (nidx < ne00 && blockIdx.y < (unsigned)ne01 && blockIdx.z < (unsigned)(ne02*ne03)) { (i3 >= lp3 && i3 < ne3 - rp3)) {
int offset_src = const int64_t i00 = i0 - lp0;
nidx + const int64_t i01 = i1 - lp1;
blockIdx.y * ne00 + const int64_t i02 = i2 - lp2;
blockIdx.z * ne00 * ne01; const int64_t i03 = i3 - lp3;
dst[offset_dst] = x[offset_src]; const int64_t ne02 = ne2 - lp2 - rp2;
const int64_t ne01 = ne1 - lp1 - rp1;
const int64_t ne00 = ne0 - lp0 - rp0;
const int64_t src_idx = i03*(ne00*ne01*ne02) + i02*(ne00*ne01) + i01*ne00 + i00;
dst[dst_idx] = src[src_idx];
} else { } else {
dst[offset_dst] = 0.0f; dst[dst_idx] = 0.0f;
} }
} }
static void pad_f32_cuda(const float * x, float * dst, static void pad_f32_cuda(const float * src, float * dst,
const int ne00, const int ne01, const int ne02, const int ne03, const int lp0, const int rp0, const int lp1, const int rp1,
const int lp2, const int rp2, const int lp3, const int rp3,
const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) { const int ne0, const int ne1, const int ne2, const int ne3, cudaStream_t stream) {
int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE; int num_blocks = (ne0 + CUDA_PAD_BLOCK_SIZE - 1) / CUDA_PAD_BLOCK_SIZE;
dim3 gridDim(num_blocks, ne1, ne2*ne3); dim3 gridDim(num_blocks, ne1, ne2*ne3);
pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(x, dst, ne0, ne00, ne01, ne02, ne03); pad_f32<<<gridDim, CUDA_PAD_BLOCK_SIZE, 0, stream>>>(src, dst, lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3, ne0, ne1, ne2, ne3);
} }
void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
@ -41,9 +55,18 @@ void ggml_cuda_op_pad(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
GGML_ASSERT(src0->type == GGML_TYPE_F32); GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32); GGML_ASSERT(dst->type == GGML_TYPE_F32);
GGML_ASSERT(src0->ne[3] == 1 && dst->ne[3] == 1); // just 3D tensors GGML_ASSERT(ggml_is_contiguous(src0));
const int32_t lp0 = ((const int32_t*)(dst->op_params))[0];
const int32_t rp0 = ((const int32_t*)(dst->op_params))[1];
const int32_t lp1 = ((const int32_t*)(dst->op_params))[2];
const int32_t rp1 = ((const int32_t*)(dst->op_params))[3];
const int32_t lp2 = ((const int32_t*)(dst->op_params))[4];
const int32_t rp2 = ((const int32_t*)(dst->op_params))[5];
const int32_t lp3 = ((const int32_t*)(dst->op_params))[6];
const int32_t rp3 = ((const int32_t*)(dst->op_params))[7];
pad_f32_cuda(src0_d, dst_d, pad_f32_cuda(src0_d, dst_d,
src0->ne[0], src0->ne[1], src0->ne[2], src0->ne[3], lp0, rp0, lp1, rp1, lp2, rp2, lp3, rp3,
dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream); dst->ne[0], dst->ne[1], dst->ne[2], dst->ne[3], stream);
} }

View File

@ -1,26 +1,27 @@
#include "quantize.cuh" #include "quantize.cuh"
#include <cstdint> #include <cstdint>
__launch_bounds__(CUDA_QUANTIZE_BLOCK_SIZE, 1)
static __global__ void quantize_q8_1( static __global__ void quantize_q8_1(
const float * __restrict__ x, void * __restrict__ vy, const float * __restrict__ x, void * __restrict__ vy,
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03, const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t ne0, const int ne1, const int ne2) { const int64_t ne0, const uint32_t ne1, const uint3 ne2) {
const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
if (i0 >= ne0) { if (i0 >= ne0) {
return; return;
} }
const int64_t i3 = fastdiv(blockIdx.z, ne2);
const int64_t i2 = blockIdx.z - i3*ne2.z;
const int64_t i1 = blockIdx.y; const int64_t i1 = blockIdx.y;
const int64_t i2 = blockIdx.z % ne2;
const int64_t i3 = blockIdx.z / ne2;
const int64_t & i00 = i0; const int64_t & i00 = i0;
const int64_t & i01 = i1; const int64_t & i01 = i1;
const int64_t & i02 = i2; const int64_t & i02 = i2;
const int64_t & i03 = i3; const int64_t & i03 = i3;
const int64_t i_cont = ((i3*ne2 + i2) * ne1 + i1) * ne0 + i0; const int64_t i_cont = ((i3*ne2.z + i2) * ne1 + i1) * ne0 + i0;
block_q8_1 * y = (block_q8_1 *) vy; block_q8_1 * y = (block_q8_1 *) vy;
@ -31,10 +32,10 @@ static __global__ void quantize_q8_1(
float amax = fabsf(xi); float amax = fabsf(xi);
float sum = xi; float sum = xi;
amax = warp_reduce_max(amax); amax = warp_reduce_max<QK8_1>(amax);
sum = warp_reduce_sum(sum); sum = warp_reduce_sum<QK8_1>(sum);
const float d = amax / 127; const float d = amax / 127.0f;
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
y[ib].qs[iqs] = q; y[ib].qs[iqs] = q;
@ -43,8 +44,7 @@ static __global__ void quantize_q8_1(
return; return;
} }
reinterpret_cast<half&>(y[ib].ds.x) = d; y[ib].ds = make_half2(d, sum);
reinterpret_cast<half&>(y[ib].ds.y) = sum;
} }
template <mmq_q8_1_ds_layout ds_layout> template <mmq_q8_1_ds_layout ds_layout>
@ -152,10 +152,12 @@ void quantize_row_q8_1_cuda(
GGML_ASSERT(!ids); GGML_ASSERT(!ids);
GGML_ASSERT(ne0 % QK8_1 == 0); GGML_ASSERT(ne0 % QK8_1 == 0);
const uint3 ne2_fastdiv = init_fastdiv_values(ne2);
const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE; const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
const dim3 num_blocks(block_num_x, ne1, ne2*ne3); const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2); quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2_fastdiv);
GGML_UNUSED(type_src0); GGML_UNUSED(type_src0);
} }

View File

@ -1,18 +1,19 @@
#include "scale.cuh" #include "scale.cuh"
static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int k) { #define MAX_GRIDDIM_X 0x7FFFFFFF
const int i = blockDim.x*blockIdx.x + threadIdx.x;
if (i >= k) { static __global__ void scale_f32(const float * x, float * dst, const float scale, const float bias, const int64_t nelements) {
return; int64_t tid = (int64_t)blockIdx.x * (int64_t)blockDim.x + (int64_t)threadIdx.x;
int64_t stride = (int64_t)blockDim.x * (int64_t)gridDim.x;
for (int64_t i = tid; i < nelements; i += stride) {
dst[i] = scale * x[i] + bias;
} }
dst[i] = scale * x[i] + bias;
} }
static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int k, cudaStream_t stream) { static void scale_f32_cuda(const float * x, float * dst, const float scale, const float bias, const int64_t nelements, cudaStream_t stream) {
const int num_blocks = (k + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE; const int64_t num_blocks = (nelements + CUDA_SCALE_BLOCK_SIZE - 1) / CUDA_SCALE_BLOCK_SIZE;
scale_f32<<<num_blocks, CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, k); scale_f32<<<MIN(MAX_GRIDDIM_X, num_blocks), CUDA_SCALE_BLOCK_SIZE, 0, stream>>>(x, dst, scale, bias, nelements);
} }
void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { void ggml_cuda_op_scale(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {

View File

@ -407,6 +407,7 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_10,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
@ -523,13 +524,6 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H40,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64,
GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64,
@ -1446,6 +1440,7 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4, mul_mm_id_map0_f16_ne20_4, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4, mul_mm_id_map0_f16_ne20_4, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6, mul_mm_id_map0_f16_ne20_6, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6, mul_mm_id_map0_f16_ne20_6, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8, mul_mm_id_map0_f16_ne20_8, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8, mul_mm_id_map0_f16_ne20_8, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_10, mul_mm_id_map0_f16_ne20_10, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16, mul_mm_id_map0_f16_ne20_16, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16, mul_mm_id_map0_f16_ne20_16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm);
@ -1562,13 +1557,6 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK192_HV128, flash_attn_ext_q8_0_hk192_hv128, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_H256, flash_attn_ext_q8_0_h256, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_Q8_0_HK576_HV512, flash_attn_ext_q8_0_hk576_hv512, has_simdgroup_mm);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H40, flash_attn_ext_vec_f16_h40, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H40, flash_attn_ext_vec_bf16_h40, has_simdgroup_reduction && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H40, flash_attn_ext_vec_q4_0_h40, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H40, flash_attn_ext_vec_q4_1_h40, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H40, flash_attn_ext_vec_q5_0_h40, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H40, flash_attn_ext_vec_q5_1_h40, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H40, flash_attn_ext_vec_q8_0_h40, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H64, flash_attn_ext_vec_f16_h64, has_simdgroup_reduction);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, flash_attn_ext_vec_bf16_h64, has_simdgroup_reduction && use_bfloat); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H64, flash_attn_ext_vec_bf16_h64, has_simdgroup_reduction && use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, flash_attn_ext_vec_q4_0_h64, has_simdgroup_reduction); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H64, flash_attn_ext_vec_q4_0_h64, has_simdgroup_reduction);
@ -1900,7 +1888,10 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_UPSCALE: case GGML_OP_UPSCALE:
return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST; return op->src[0]->type == GGML_TYPE_F32 && op->op_params[0] == GGML_SCALE_MODE_NEAREST;
case GGML_OP_POOL_2D: case GGML_OP_POOL_2D:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_PAD: case GGML_OP_PAD:
return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
case GGML_OP_PAD_REFLECT_1D: case GGML_OP_PAD_REFLECT_1D:
case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_ARGSORT: case GGML_OP_ARGSORT:
@ -1909,9 +1900,15 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_OP_ARANGE: case GGML_OP_ARANGE:
return true; return true;
case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_EXT:
if (op->src[0]->ne[0] == 32) { // for new head sizes, add checks here
// head size == 32 (e.g. bert-bge-small) if (op->src[0]->ne[0] != 40 &&
// TODO: not sure if it is worth adding kernels for this size op->src[0]->ne[0] != 64 &&
op->src[0]->ne[0] != 80 &&
op->src[0]->ne[0] != 96 &&
op->src[0]->ne[0] != 112 &&
op->src[0]->ne[0] != 128 &&
op->src[0]->ne[0] != 192 &&
op->src[0]->ne[0] != 256) {
return false; return false;
} }
if (op->src[0]->ne[0] == 576) { if (op->src[0]->ne[0] == 576) {
@ -3984,6 +3981,7 @@ static int ggml_metal_encode_node(
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4 ].pipeline; break; case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4 ].pipeline; break;
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6 ].pipeline; break; case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6 ].pipeline; break;
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8 ].pipeline; break; case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8 ].pipeline; break;
case 10: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_10].pipeline; break;
case 16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16].pipeline; break; case 16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16].pipeline; break;
default: GGML_ABORT("missing specialization for ne20 = %d", (int) ne20); default: GGML_ABORT("missing specialization for ne20 = %d", (int) ne20);
} }
@ -5138,10 +5136,8 @@ static int ggml_metal_encode_node(
bool use_vec_kernel = false; bool use_vec_kernel = false;
// TODO: add vec kernels for (ne00%64 == 0) and maybe also for (ne00%32 == 0) // use non-vec kernel if the batch size is large or if the vec-kernel is not supported for this head size
// for now avoiding mainly to keep the number of templates/kernels a bit lower if (ne01 >= 20 || (ne00 == 40 || ne00 == 80 || ne00 == 112)) {
// these are now trivial to add after: https://github.com/ggml-org/llama.cpp/pull/12612
if (ne01 >= 20 || (ne00%128 != 0 && ne00 != 64 && ne00 != 96 && ne00 != 192 && ne00 != 576)) {
switch (src1->type) { switch (src1->type) {
case GGML_TYPE_F16: case GGML_TYPE_F16:
{ {
@ -5329,24 +5325,6 @@ static int ggml_metal_encode_node(
use_vec_kernel = true; use_vec_kernel = true;
switch (ne00) { switch (ne00) {
case 40:
{
switch (src1->type) {
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_F16_H40].pipeline; break;
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_BF16_H40].pipeline; break;
case GGML_TYPE_Q4_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_0_H40].pipeline; break;
case GGML_TYPE_Q4_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q4_1_H40].pipeline; break;
case GGML_TYPE_Q5_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_0_H40].pipeline; break;
case GGML_TYPE_Q5_1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q5_1_H40].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_FLASH_ATTN_EXT_VEC_Q8_0_H40].pipeline; break;
default:
{
GGML_LOG_ERROR("unsupported type: %d\n", src1->type);
GGML_LOG_ERROR("add template specialization for this type\n");
GGML_ABORT("add template specialization for this type");
}
}
} break;
case 64: case 64:
{ {
switch (src1->type) { switch (src1->type) {

View File

@ -4803,6 +4803,9 @@ kernel void kernel_flash_attn_ext_vec(
ushort3 ntg[[threads_per_threadgroup]], ushort3 ntg[[threads_per_threadgroup]],
ushort tiisg[[thread_index_in_simdgroup]], ushort tiisg[[thread_index_in_simdgroup]],
ushort sgitg[[simdgroup_index_in_threadgroup]]) { ushort sgitg[[simdgroup_index_in_threadgroup]]) {
static_assert(DK % 32 == 0, "DK must be divisible by 32");
static_assert(DV % 32 == 0, "DV must be divisible by 32");
const short nsg = ntg.y; // number of simdgroups const short nsg = ntg.y; // number of simdgroups
const short iwg = tgpig[2]%nwg; const short iwg = tgpig[2]%nwg;
@ -5160,16 +5163,6 @@ kernel void kernel_flash_attn_ext_vec(
typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t; typedef decltype(kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 128, 128, 4>) flash_attn_ext_vec_t;
template [[host_name("kernel_flash_attn_ext_vec_f16_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 40, 40, 8>;
#if defined(GGML_METAL_USE_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 40, 40, 8>;
#endif
template [[host_name("kernel_flash_attn_ext_vec_q4_0_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_0, 8, dequantize_q4_0_t4, block_q4_0, 8, dequantize_q4_0_t4, 40, 40, 8>;
template [[host_name("kernel_flash_attn_ext_vec_q4_1_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q4_1, 8, dequantize_q4_1_t4, block_q4_1, 8, dequantize_q4_1_t4, 40, 40, 8>;
template [[host_name("kernel_flash_attn_ext_vec_q5_0_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_0, 8, dequantize_q5_0_t4, block_q5_0, 8, dequantize_q5_0_t4, 40, 40, 8>;
template [[host_name("kernel_flash_attn_ext_vec_q5_1_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q5_1, 8, dequantize_q5_1_t4, block_q5_1, 8, dequantize_q5_1_t4, 40, 40, 8>;
template [[host_name("kernel_flash_attn_ext_vec_q8_0_h40")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, block_q8_0, 8, dequantize_q8_0_t4, block_q8_0, 8, dequantize_q8_0_t4, 40, 40, 8>;
template [[host_name("kernel_flash_attn_ext_vec_f16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 8>; template [[host_name("kernel_flash_attn_ext_vec_f16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, half4, 1, dequantize_f16_t4, half4, 1, dequantize_f16_t4, 64, 64, 8>;
#if defined(GGML_METAL_USE_BF16) #if defined(GGML_METAL_USE_BF16)
template [[host_name("kernel_flash_attn_ext_vec_bf16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 8>; template [[host_name("kernel_flash_attn_ext_vec_bf16_h64")]] kernel flash_attn_ext_vec_t kernel_flash_attn_ext_vec<FA_TYPES, bfloat4, 1, dequantize_bf16_t4, bfloat4, 1, dequantize_bf16_t4, 64, 64, 8>;
@ -7625,6 +7618,7 @@ template [[host_name("kernel_mul_mm_id_map0_f16_ne20_2" )]] kernel kernel_mul_mm
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>; template [[host_name("kernel_mul_mm_id_map0_f16_ne20_4" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<4>;
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>; template [[host_name("kernel_mul_mm_id_map0_f16_ne20_6" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<6>;
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>; template [[host_name("kernel_mul_mm_id_map0_f16_ne20_8" )]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<8>;
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_10")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<10>;
template [[host_name("kernel_mul_mm_id_map0_f16_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>; template [[host_name("kernel_mul_mm_id_map0_f16_ne20_16")]] kernel kernel_mul_mm_id_map0_t kernel_mul_mm_id_map0<16>;
template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)> template<typename T, typename T4x4, typename simdgroup_T8x8, typename block_q, short nl, void (*dequantize_func)(device const block_q *, short, thread T4x4 &)>

View File

@ -1339,7 +1339,7 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
if (!kernel_src_f16.empty() && !kernel_src_f32.empty() && !kernel_src_f32_f16.empty()) { if (!kernel_src_f16.empty() && !kernel_src_f32.empty() && !kernel_src_f32_f16.empty()) {
const struct { int dk; int dv; int bm; int bn; } fa_dims[] = { const struct { int dk; int dv; int bm; int bn; } fa_dims[] = {
{ 64, 64, 64, 64}, { 80, 80, 64, 32}, { 96, 96, 64, 32}, { 40, 40, 32, 32}, { 64, 64, 64, 64}, { 80, 80, 64, 32}, { 96, 96, 64, 32},
{112, 112, 32, 32}, {128, 128, 32, 32}, {192, 128, 16, 16}, {112, 112, 32, 32}, {128, 128, 32, 32}, {192, 128, 16, 16},
{192, 192, 16, 16}, {256, 256, 16, 16}, {192, 192, 16, 16}, {256, 256, 16, 16},
}; };
@ -2701,7 +2701,9 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; // Assuming F32 for now, can be expanded
case GGML_OP_PAD: case GGML_OP_PAD:
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 && return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32 &&
op->src[0]->ne[3] == 1 && op->ne[3] == 1; op->src[0]->ne[3] == 1 && op->ne[3] == 1 &&
(ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
case GGML_OP_UPSCALE: case GGML_OP_UPSCALE:
return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32; return op->src[0]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
case GGML_OP_CONV_2D: case GGML_OP_CONV_2D:
@ -2776,10 +2778,6 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]); return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
case GGML_OP_FLASH_ATTN_EXT: case GGML_OP_FLASH_ATTN_EXT:
{ {
if (op->src[4]) {
return false;
}
const ggml_tensor * q = op->src[0]; const ggml_tensor * q = op->src[0];
const ggml_tensor * k = op->src[1]; const ggml_tensor * k = op->src[1];
const ggml_tensor * v = op->src[2]; const ggml_tensor * v = op->src[2];
@ -2788,7 +2786,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
const int dv = v->ne[0]; const int dv = v->ne[0];
const struct { int dk; int dv; } supported_dims[] = { const struct { int dk; int dv; } supported_dims[] = {
{ 64, 64}, { 80, 80}, { 96, 96}, { 40, 40}, { 64, 64}, { 80, 80}, { 96, 96},
{112, 112}, {128, 128}, {192, 128}, {112, 112}, {128, 128}, {192, 128},
{192, 192}, {256, 256}, {192, 192}, {256, 256},
}; };
@ -5765,6 +5763,7 @@ static void ggml_cl_timestep_embedding(ggml_backend_t backend, const ggml_tensor
static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, const ggml_tensor * k, ggml_tensor * dst) { static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, const ggml_tensor * k, ggml_tensor * dst) {
const ggml_tensor * v = dst->src[2]; const ggml_tensor * v = dst->src[2];
const ggml_tensor * mask = dst->src[3]; const ggml_tensor * mask = dst->src[3];
const ggml_tensor * sinks = dst->src[4];
GGML_ASSERT(q->extra); GGML_ASSERT(q->extra);
GGML_ASSERT(k->extra); GGML_ASSERT(k->extra);
GGML_ASSERT(v->extra); GGML_ASSERT(v->extra);
@ -5772,6 +5771,9 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co
if (mask) { if (mask) {
GGML_ASSERT(mask->extra); GGML_ASSERT(mask->extra);
} }
if (sinks) {
GGML_ASSERT(sinks->extra);
}
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context; ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
@ -5813,6 +5815,7 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co
ggml_tensor_extra_cl * extra_v = (ggml_tensor_extra_cl *)v->extra; ggml_tensor_extra_cl * extra_v = (ggml_tensor_extra_cl *)v->extra;
ggml_tensor_extra_cl * extra_o = (ggml_tensor_extra_cl *)dst->extra; ggml_tensor_extra_cl * extra_o = (ggml_tensor_extra_cl *)dst->extra;
ggml_tensor_extra_cl * extra_mask = mask ? (ggml_tensor_extra_cl *)mask->extra : NULL; ggml_tensor_extra_cl * extra_mask = mask ? (ggml_tensor_extra_cl *)mask->extra : NULL;
ggml_tensor_extra_cl * extra_sinks = sinks ? (ggml_tensor_extra_cl *)sinks->extra : NULL;
cl_ulong offset_q = extra_q->offset + q->view_offs; cl_ulong offset_q = extra_q->offset + q->view_offs;
cl_ulong offset_k = extra_k->offset + k->view_offs; cl_ulong offset_k = extra_k->offset + k->view_offs;
@ -5820,6 +5823,8 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co
cl_ulong offset_o = extra_o->offset + dst->view_offs; cl_ulong offset_o = extra_o->offset + dst->view_offs;
cl_mem mask_buffer = extra_mask ? extra_mask->data_device : NULL; cl_mem mask_buffer = extra_mask ? extra_mask->data_device : NULL;
cl_ulong offset_mask = extra_mask ? extra_mask->offset + mask->view_offs : 0; cl_ulong offset_mask = extra_mask ? extra_mask->offset + mask->view_offs : 0;
cl_mem sinks_buffer = extra_sinks ? extra_sinks->data_device : NULL;
cl_ulong offset_sinks = extra_sinks ? extra_sinks->offset + sinks->view_offs : 0;
const cl_ulong q_nb1 = q->nb[1], q_nb2 = q->nb[2], q_nb3 = q->nb[3]; const cl_ulong q_nb1 = q->nb[1], q_nb2 = q->nb[2], q_nb3 = q->nb[3];
const cl_ulong k_nb1 = k->nb[1], k_nb2 = k->nb[2], k_nb3 = k->nb[3]; const cl_ulong k_nb1 = k->nb[1], k_nb2 = k->nb[2], k_nb3 = k->nb[3];
@ -5874,6 +5879,8 @@ static void ggml_cl_flash_attn(ggml_backend_t backend, const ggml_tensor * q, co
CL_CHECK(clSetKernelArg(kernel, 35, sizeof(cl_ulong), &mask_nb3)); CL_CHECK(clSetKernelArg(kernel, 35, sizeof(cl_ulong), &mask_nb3));
CL_CHECK(clSetKernelArg(kernel, 36, sizeof(int), &mask_ne2)); CL_CHECK(clSetKernelArg(kernel, 36, sizeof(int), &mask_ne2));
CL_CHECK(clSetKernelArg(kernel, 37, sizeof(int), &mask_ne3)); CL_CHECK(clSetKernelArg(kernel, 37, sizeof(int), &mask_ne3));
CL_CHECK(clSetKernelArg(kernel, 38, sizeof(cl_mem), &sinks_buffer));
CL_CHECK(clSetKernelArg(kernel, 39, sizeof(cl_ulong), &offset_sinks));
if (n_q == 1) { if (n_q == 1) {
const size_t wg_size = 64; const size_t wg_size = 64;

View File

@ -49,7 +49,9 @@ __kernel void flash_attn_f16(
const ulong mask_nb2, const ulong mask_nb2,
const ulong mask_nb3, const ulong mask_nb3,
const int mask_ne2, const int mask_ne2,
const int mask_ne3 const int mask_ne3,
const global void* sinks_void,
const ulong sinks_offset
) { ) {
const int tid = get_local_id(0); const int tid = get_local_id(0);
const int block_q_idx = get_group_id(0); const int block_q_idx = get_group_id(0);
@ -171,6 +173,20 @@ __kernel void flash_attn_f16(
} }
if (my_query_row < n_q) { if (my_query_row < n_q) {
if (sinks_void != NULL) {
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
const ACC_TYPE m_sink = sinks_ptr[head_idx];
const ACC_TYPE m_final = max(m_i, m_sink);
const ACC_TYPE scale_o = exp(m_i - m_final);
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] *= scale_o;
}
l_i = l_i * exp(m_i - m_final) + exp(m_sink - m_final);
}
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1; const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset); global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
if (l_i > 0.0f) { if (l_i > 0.0f) {
@ -214,7 +230,9 @@ __kernel void flash_attn_f16_q1(
const ulong mask_nb2, const ulong mask_nb2,
const ulong mask_nb3, const ulong mask_nb3,
const int mask_ne2, const int mask_ne2,
const int mask_ne3 const int mask_ne3,
const global void* sinks_void,
const ulong sinks_offset
) { ) {
const int tid = get_local_id(0); const int tid = get_local_id(0);
const int head_batch_idx = get_global_id(1); const int head_batch_idx = get_global_id(1);
@ -247,7 +265,12 @@ __kernel void flash_attn_f16_q1(
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1); float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
ACC_TYPE m_i = -INFINITY; const global ACC_TYPE* sinks_ptr = NULL;
if (sinks_void != NULL) {
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
}
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) { for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1; const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset); const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
@ -320,7 +343,11 @@ __kernel void flash_attn_f16_q1(
const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1; const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset); global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
const ACC_TYPE l_final = local_l[0]; ACC_TYPE l_final = local_l[0];
if (sinks_ptr != NULL) {
l_final += exp(sinks_ptr[head_idx] - m_final);
}
if (l_final > 0.0f) { if (l_final > 0.0f) {
const ACC_TYPE l_inv = 1.0f / l_final; const ACC_TYPE l_inv = 1.0f / l_final;

View File

@ -49,7 +49,9 @@ __kernel void flash_attn_f32(
const ulong mask_nb2, const ulong mask_nb2,
const ulong mask_nb3, const ulong mask_nb3,
const int mask_ne2, const int mask_ne2,
const int mask_ne3 const int mask_ne3,
const global void* sinks_void,
const ulong sinks_offset
) { ) {
const int tid = get_local_id(0); const int tid = get_local_id(0);
const int block_q_idx = get_group_id(0); const int block_q_idx = get_group_id(0);
@ -171,6 +173,20 @@ __kernel void flash_attn_f32(
} }
if (my_query_row < n_q) { if (my_query_row < n_q) {
if (sinks_void != NULL) {
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
const ACC_TYPE m_sink = sinks_ptr[head_idx];
const ACC_TYPE m_final = max(m_i, m_sink);
const ACC_TYPE scale_o = exp(m_i - m_final);
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] *= scale_o;
}
l_i = l_i * exp(m_i - m_final) + exp(m_sink - m_final);
}
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1; const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset); global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
if (l_i > 0.0f) { if (l_i > 0.0f) {
@ -214,7 +230,9 @@ __kernel void flash_attn_f32_q1(
const ulong mask_nb2, const ulong mask_nb2,
const ulong mask_nb3, const ulong mask_nb3,
const int mask_ne2, const int mask_ne2,
const int mask_ne3 const int mask_ne3,
const global void* sinks_void,
const ulong sinks_offset
) { ) {
const int tid = get_local_id(0); const int tid = get_local_id(0);
const int head_batch_idx = get_global_id(1); const int head_batch_idx = get_global_id(1);
@ -247,7 +265,12 @@ __kernel void flash_attn_f32_q1(
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1); float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
ACC_TYPE m_i = -INFINITY; const global ACC_TYPE* sinks_ptr = NULL;
if (sinks_void != NULL) {
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
}
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) { for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1; const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset); const global DATA_TYPE4* k_ptr = (const global DATA_TYPE4*)(k_base + k_row_offset);
@ -320,7 +343,11 @@ __kernel void flash_attn_f32_q1(
const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1; const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset); global DATA_TYPE4 *o_row = (global DATA_TYPE4 *)(o_base + o_row_offset);
const ACC_TYPE l_final = local_l[0]; ACC_TYPE l_final = local_l[0];
if (sinks_ptr != NULL) {
l_final += exp(sinks_ptr[head_idx] - m_final);
}
if (l_final > 0.0f) { if (l_final > 0.0f) {
const ACC_TYPE l_inv = 1.0f / l_final; const ACC_TYPE l_inv = 1.0f / l_final;

View File

@ -52,7 +52,9 @@ __kernel void flash_attn_f32_f16(
const ulong mask_nb2, const ulong mask_nb2,
const ulong mask_nb3, const ulong mask_nb3,
const int mask_ne2, const int mask_ne2,
const int mask_ne3 const int mask_ne3,
const global void* sinks_void,
const ulong sinks_offset
) { ) {
const int tid = get_local_id(0); const int tid = get_local_id(0);
const int block_q_idx = get_group_id(0); const int block_q_idx = get_group_id(0);
@ -174,6 +176,20 @@ __kernel void flash_attn_f32_f16(
} }
if (my_query_row < n_q) { if (my_query_row < n_q) {
if (sinks_void != NULL) {
const global ACC_TYPE* sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
const ACC_TYPE m_sink = sinks_ptr[head_idx];
const ACC_TYPE m_final = max(m_i, m_sink);
const ACC_TYPE scale_o = exp(m_i - m_final);
#pragma unroll
for (int i = 0; i < DV_VEC; ++i) {
o_acc[i] *= scale_o;
}
l_i = l_i * exp(m_i - m_final) + exp(m_sink - m_final);
}
const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1; const ulong o_row_offset = batch_idx * o_nb3 + my_query_row * o_nb2 + head_idx * o_nb1;
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset); global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
if (l_i > 0.0f) { if (l_i > 0.0f) {
@ -217,7 +233,9 @@ __kernel void flash_attn_f32_f16_q1(
const ulong mask_nb2, const ulong mask_nb2,
const ulong mask_nb3, const ulong mask_nb3,
const int mask_ne2, const int mask_ne2,
const int mask_ne3 const int mask_ne3,
const global void* sinks_void,
const ulong sinks_offset
) { ) {
const int tid = get_local_id(0); const int tid = get_local_id(0);
const int head_batch_idx = get_global_id(1); const int head_batch_idx = get_global_id(1);
@ -250,7 +268,12 @@ __kernel void flash_attn_f32_f16_q1(
float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1); float slope = get_alibi_slope(max_bias, head_idx, n_head_log2, m0, m1);
ACC_TYPE m_i = -INFINITY; const global ACC_TYPE* sinks_ptr = NULL;
if (sinks_void != NULL) {
sinks_ptr = (const global ACC_TYPE*)((const global char*)sinks_void + sinks_offset);
}
ACC_TYPE m_i = (sinks_ptr != NULL) ? sinks_ptr[head_idx] : -INFINITY;
for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) { for (int k_idx = tid; k_idx < n_kv; k_idx += Q1_WG_SIZE) {
const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1; const ulong k_row_offset = batch_idx * k_nb3 + head_kv_idx * k_nb2 + k_idx * k_nb1;
const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset); const global KV_DATA_TYPE4* k_ptr = (const global KV_DATA_TYPE4*)(k_base + k_row_offset);
@ -323,7 +346,11 @@ __kernel void flash_attn_f32_f16_q1(
const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1; const ulong o_row_offset = batch_idx * o_nb3 + head_idx * o_nb1;
global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset); global O_DATA_TYPE4 *o_row = (global O_DATA_TYPE4 *)(o_base + o_row_offset);
const ACC_TYPE l_final = local_l[0]; ACC_TYPE l_final = local_l[0];
if (sinks_ptr != NULL) {
l_final += exp(sinks_ptr[head_idx] - m_final);
}
if (l_final > 0.0f) { if (l_final > 0.0f) {
const ACC_TYPE l_inv = 1.0f / l_final; const ACC_TYPE l_inv = 1.0f / l_final;

View File

@ -4398,7 +4398,10 @@ static bool ggml_backend_sycl_device_supports_op(ggml_backend_dev_t dev, const g
return ggml_is_contiguous(op->src[0]); return ggml_is_contiguous(op->src[0]);
case GGML_OP_POOL_2D: case GGML_OP_POOL_2D:
case GGML_OP_ACC: case GGML_OP_ACC:
return true;
case GGML_OP_PAD: case GGML_OP_PAD:
return (ggml_get_op_params_i32(op, 0) == 0) && (ggml_get_op_params_i32(op, 2) == 0) &&
(ggml_get_op_params_i32(op, 4) == 0) && (ggml_get_op_params_i32(op, 6) == 0);
case GGML_OP_LEAKY_RELU: case GGML_OP_LEAKY_RELU:
case GGML_OP_TIMESTEP_EMBEDDING: case GGML_OP_TIMESTEP_EMBEDDING:
case GGML_OP_RWKV_WKV6: case GGML_OP_RWKV_WKV6:

File diff suppressed because it is too large Load Diff

View File

@ -334,6 +334,9 @@ void main() {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < Br; ++r) { [[unroll]] for (uint32_t r = 0; r < Br; ++r) {
Of[r][d] *= Lfrcp[r]; Of[r][d] *= Lfrcp[r];
#if defined(ACC_TYPE_MAX)
Of[r][d] = clamp(Of[r][d], -vec4(ACC_TYPE_MAX), vec4(ACC_TYPE_MAX));
#endif
} }
} }

View File

@ -373,6 +373,9 @@ void main() {
[[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) {
[[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) {
Of[r][d] *= ACC_TYPE(Lfrcp[r]); Of[r][d] *= ACC_TYPE(Lfrcp[r]);
#if defined(ACC_TYPE_MAX)
Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX);
#endif
} }
} }

View File

@ -283,6 +283,10 @@ void main() {
O = Ldiag*O; O = Ldiag*O;
#if defined(ACC_TYPE_MAX)
[[unroll]] for (uint i = 0; i < O.length(); ++i) { O[i] = clamp(O[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
#endif
uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; uint32_t o_offset = iq3*p.ne2*p.ne1*HSV;
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O); coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, HSV_pad, gl_MatrixUseAccumulator>(O);

View File

@ -111,6 +111,10 @@ void main() {
} }
} }
O *= L; O *= L;
const float FLT_MAX = uintBitsToFloat(0x7F7FFFFF);
O = clamp(O, -FLT_MAX, FLT_MAX);
data_d[iq3 * D * N + D * n + d] = O; data_d[iq3 * D * N + D * n + d] = O;
} }
} }

View File

@ -7,27 +7,36 @@ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() { void main() {
const uint i00 = gl_GlobalInvocationID.x; const uint i00 = gl_GlobalInvocationID.x;
const uint i10 = gl_GlobalInvocationID.y;
const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
if (i00 >= p.ne00) { if (i00 >= p.ne00) {
return; return;
} }
const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12]; uint gid_z = gl_GlobalInvocationID.z;
while (gid_z < p.ne11 * p.ne12) {
uint gid_y = gl_GlobalInvocationID.y;
while (gid_y < p.ne10) {
const uint i10 = gid_y;
const uint i11 = gid_z / p.ne12;
const uint i12 = gid_z % p.ne12;
const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03; const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
#if defined(DATA_A_BF16) #if defined(DATA_A_BF16)
FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00])); FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00]));
#else #else
FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]); FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]);
#endif #endif
#ifndef OPTIMIZATION_ERROR_WORKAROUND #ifndef OPTIMIZATION_ERROR_WORKAROUND
data_d[d_offset + i00] = D_TYPE(v); data_d[d_offset + i00] = D_TYPE(v);
#else #else
data_d[d_offset + i00] = D_TYPE(v); data_d[d_offset + i00] = D_TYPE(v);
#endif #endif
gid_y += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
}
gid_z += gl_WorkGroupSize.z * gl_NumWorkGroups.z;
}
} }

View File

@ -10,9 +10,6 @@ layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() { void main() {
const uint i00 = (gl_GlobalInvocationID.x)*2; const uint i00 = (gl_GlobalInvocationID.x)*2;
const uint i10 = gl_GlobalInvocationID.y;
const uint i11 = (gl_GlobalInvocationID.z)/p.ne12;
const uint i12 = (gl_GlobalInvocationID.z)%p.ne12;
#ifdef NEEDS_INIT_IQ_SHMEM #ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize); init_iq_shmem(gl_WorkGroupSize);
@ -22,20 +19,33 @@ void main() {
return; return;
} }
const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12]; uint gid_z = gl_GlobalInvocationID.z;
while (gid_z < p.ne11 * p.ne12) {
uint gid_y = gl_GlobalInvocationID.y;
while (gid_y < p.ne10) {
const uint i10 = gid_y;
const uint i11 = gid_z / p.ne12;
const uint i12 = gid_z % p.ne12;
const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03; const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12];
const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
const uint ib = a_offset + i00/QUANT_K; // block index const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03;
const uint iqs = (i00%QUANT_K)/QUANT_R; // quant index const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23;
const uint iybs = i00 - i00%QUANT_K; // dst block start index
const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
vec2 v = dequantize(ib, iqs, 0); const uint ib = a_offset + i00/QUANT_K; // block index
const vec2 dm = get_dm(ib, 0); const uint iqs = (i00%QUANT_K)/QUANT_R; // quant index
v = v * dm.x + dm.y; const uint iybs = i00 - i00%QUANT_K; // dst block start index
const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
data_d[d_offset + iybs + iqs ] = D_TYPE(v.x); vec2 v = dequantize(ib, iqs, 0);
data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y); const vec2 dm = get_dm(ib, 0);
v = v * dm.x + dm.y;
data_d[d_offset + iybs + iqs ] = D_TYPE(v.x);
data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y);
gid_y += gl_WorkGroupSize.y * gl_NumWorkGroups.y;
}
gid_z += gl_WorkGroupSize.z * gl_NumWorkGroups.z;
}
} }

View File

@ -0,0 +1,22 @@
#version 450
#include "generic_head.comp"
#include "types.comp"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
const float x = float(data_a[i]);
data_d[i] = D_TYPE(min(1.0f, max(0.0f, (x + 3.0f) / 6.0f)));
}

View File

@ -0,0 +1,22 @@
#version 450
#include "generic_head.comp"
#include "types.comp"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
const float x = float(data_a[i]);
data_d[i] = D_TYPE(x * min(1.0f, max(0.0f, (x + 3.0f) / 6.0f)));
}

View File

@ -1,7 +1,8 @@
#extension GL_EXT_control_flow_attributes : enable #extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_8bit_storage : require #extension GL_EXT_shader_8bit_storage : require
#if USE_SUBGROUP_ADD
#if USE_SUBGROUP_ADD || USE_SUBGROUP_ADD_NO_SHMEM
#extension GL_KHR_shader_subgroup_basic : require #extension GL_KHR_shader_subgroup_basic : require
#extension GL_KHR_shader_subgroup_arithmetic : require #extension GL_KHR_shader_subgroup_arithmetic : require
#endif #endif
@ -12,10 +13,19 @@
#include "types.comp" #include "types.comp"
#ifndef MMQ
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#else
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
#endif
layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
#ifdef B_TYPE_VEC2
layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];}; layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
#endif
#ifdef B_TYPE_VEC4
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
#endif
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
@ -92,6 +102,23 @@ layout (constant_id = 0) const uint BLOCK_SIZE = 32;
layout (constant_id = 1) const uint NUM_ROWS = 1; layout (constant_id = 1) const uint NUM_ROWS = 1;
layout (constant_id = 2) const uint NUM_COLS = 1; layout (constant_id = 2) const uint NUM_COLS = 1;
#ifdef USE_SUBGROUP_ADD_NO_SHMEM
void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
temp[j][n] = subgroupAdd(temp[j][n]);
}
}
if (tid == 0) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
}
}
}
}
#else
shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE]; shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE];
void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) { void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
@ -152,3 +179,4 @@ void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offs
} }
#endif #endif
} }
#endif

View File

@ -0,0 +1,140 @@
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_EXT_integer_dot_product : require
#define MMQ
#define B_TYPE block_q8_1_x4
#include "mul_mat_vec_base.comp"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
#define K_PER_ITER 8
#include "mul_mmq_funcs.comp"
uint a_offset, b_offset, d_offset;
int32_t cache_b_qs[2];
vec2 cache_b_ds;
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
const uint col = i*BLOCK_SIZE + tid*K_PER_ITER;
// Preload data_b block
const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset;
const uint b_qs_idx = tid % 4;
const uint b_block_idx_outer = b_block_idx / 4;
const uint b_block_idx_inner = b_block_idx % 4;
cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]);
#if QUANT_R == 2
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx];
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx + 4];
#else
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2];
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2 + 1];
#endif
uint ibi = first_row*p.ncols;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint a_block_idx = (ibi + col)/QUANT_K + a_offset;
ibi += p.ncols;
int32_t q_sum = 0;
#if QUANT_R == 2
const i32vec2 data_a_qs = repack(a_block_idx, b_qs_idx);
q_sum += dotPacked4x8EXT(data_a_qs.x,
cache_b_qs[0]);
q_sum += dotPacked4x8EXT(data_a_qs.y,
cache_b_qs[1]);
#else
int32_t data_a_qs = repack(a_block_idx, b_qs_idx * 2);
q_sum += dotPacked4x8EXT(data_a_qs,
cache_b_qs[0]);
data_a_qs = repack(a_block_idx, b_qs_idx * 2 + 1);
q_sum += dotPacked4x8EXT(data_a_qs,
cache_b_qs[1]);
#endif
#if QUANT_AUXF == 1
temp[j][n] += mul_q8_1(q_sum, get_d(a_block_idx), cache_b_ds, 4);
#else
temp[j][n] += mul_q8_1(q_sum, get_dm(a_block_idx), cache_b_ds, 4);
#endif
}
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint tid = gl_LocalInvocationID.x;
get_offsets(a_offset, b_offset, d_offset);
a_offset /= QUANT_K;
b_offset /= QUANT_K_Q8_1;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
temp[j][n] = FLOAT_TYPE(0.0f);
}
}
uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {
num_iters++;
}
int unroll_count = 4;
uint unrolled_iters = num_iters & ~(unroll_count - 1);
uint i = 0;
while (i < unrolled_iters) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
i++;
}
}
unroll_count = 2;
unrolled_iters = num_iters & ~(unroll_count - 1);
#if K_PER_ITER == 2
if ((p.ncols & 1) != 0 &&
unrolled_iters == num_iters &&
unrolled_iters > 0) {
unrolled_iters -= unroll_count;
}
#endif
while (i < unrolled_iters) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
i++;
}
}
while (i < num_iters) {
iter(temp, first_row, num_rows, tid, i*K_PER_ITER);
i++;
}
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}

View File

@ -891,6 +891,20 @@ void main() {
barrier(); barrier();
} }
#if defined(ACC_TYPE_MAX)
#ifdef COOPMAT
[[unroll]] for (uint j = 0; j < cms_per_row * cms_per_col; j++) {
[[unroll]] for (uint i = 0; i < sums[j].length(); ++i) {
sums[j][i] = clamp(sums[j][i], -ACC_TYPE_MAX, ACC_TYPE_MAX);
}
}
#else
[[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) {
sums[i] = clamp(sums[i], -ACC_TYPE_MAX, ACC_TYPE_MAX);
}
#endif
#endif
const uint dr = ir * BM + warp_r * WM; const uint dr = ir * BM + warp_r * WM;
const uint dc = ic * BN + warp_c * WN; const uint dc = ic * BN + warp_c * WN;

View File

@ -349,6 +349,10 @@ void main() {
sum = coopMatMulAdd(mat_a, mat_b, sum); sum = coopMatMulAdd(mat_a, mat_b, sum);
block_k += BK; block_k += BK;
} }
#if defined(ACC_TYPE_MAX)
[[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
#endif
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum); coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose); coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose);
@ -388,6 +392,10 @@ void main() {
sum = coopMatMulAdd(mat_a, mat_b, sum); sum = coopMatMulAdd(mat_a, mat_b, sum);
block_k += BK; block_k += BK;
} }
#if defined(ACC_TYPE_MAX)
[[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
#endif
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum); coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose); coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose);
@ -428,6 +436,10 @@ void main() {
sum = coopMatMulAdd(mat_a, mat_b, sum); sum = coopMatMulAdd(mat_a, mat_b, sum);
block_k += BK; block_k += BK;
} }
#if defined(ACC_TYPE_MAX)
[[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
#endif
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum); coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(sum);
coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose); coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose);
@ -444,18 +456,111 @@ void main() {
tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1); tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1);
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
uint k_iters = (end_k - start_k + BK - 1) / BK; uint k_iters = (end_k - start_k + BK - 1) / BK;
fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, false); fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, false);
store_scales(tid);
#ifdef MUL_MAT_ID
if (enable_smaller_matrices && ic * BN + BNover4 >= _ne1) {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> sum;
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(0.0);
[[dont_unroll]]
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
if ((block_k % QUANT_K) == 0) {
store_scales(tid);
}
if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
}
if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
sum = coopMatMulAdd(mat_a, mat_b, sum);
} else {
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover4, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB);
sum = coopMatMulAdd(mat_a, mat_b, sum);
}
}
#if defined(ACC_TYPE_MAX)
[[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
#endif
// Convert from ACC_TYPE to D_TYPE
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator> mat_d;
mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover4, gl_MatrixUseAccumulator>(sum);
// Call callback to store each element, remapping row through shared memory
coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
return;
}
if (enable_smaller_matrices && ic * BN + BNover2 >= _ne1) {
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> sum;
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(0.0);
[[dont_unroll]]
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
if ((block_k % QUANT_K) == 0) {
store_scales(tid);
}
if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
}
if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) {
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
sum = coopMatMulAdd(mat_a, mat_b, sum);
} else {
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BM, BK, gl_MatrixUseA> mat_a;
coopmat<MAT_TYPE, gl_ScopeWorkgroup, BK, BNover2, gl_MatrixUseB> mat_b;
coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA);
coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB);
sum = coopMatMulAdd(mat_a, mat_b, sum);
}
}
#if defined(ACC_TYPE_MAX)
[[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
#endif
// Convert from ACC_TYPE to D_TYPE
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator> mat_d;
mat_d = coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BNover2, gl_MatrixUseAccumulator>(sum);
// Call callback to store each element, remapping row through shared memory
coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic);
return;
}
#endif
coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> sum;
sum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator>(0.0);
[[dont_unroll]] [[dont_unroll]]
for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) {
store_scales(tid); if ((block_k % QUANT_K) == 0) {
if (block_k + BK < end_k) { store_scales(tid);
}
if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) {
fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false); fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false);
} }
@ -485,6 +590,9 @@ void main() {
sum = coopMatMulAdd(mat_a, mat_b, sum); sum = coopMatMulAdd(mat_a, mat_b, sum);
} }
} }
#if defined(ACC_TYPE_MAX)
[[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); }
#endif
// Convert from ACC_TYPE to D_TYPE // Convert from ACC_TYPE to D_TYPE
coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d; coopmat<D_TYPE, gl_ScopeWorkgroup, BM, BN, gl_MatrixUseAccumulator> mat_d;

View File

@ -28,7 +28,7 @@ layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
#if defined(A_TYPE_PACKED32) #if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif #endif
layout (binding = 1) readonly buffer B {block_q8_1_packed32 data_b[];}; layout (binding = 1) readonly buffer B {block_q8_1_x4_packed128 data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
@ -98,7 +98,7 @@ shared FLOAT_TYPE_VEC2 buf_b_ds[BN];
#endif #endif
#define LOAD_VEC_A (4 * QUANT_R) #define LOAD_VEC_A (4 * QUANT_R)
#define LOAD_VEC_B 4 #define LOAD_VEC_B 16
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
shared u16vec2 row_ids[4096]; shared u16vec2 row_ids[4096];
@ -270,15 +270,22 @@ void main() {
const uint iqs = idx & 0x7; const uint iqs = idx & 0x7;
#else #else
const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK; const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK;
const uint ib_outer = ib / 4;
const uint ib_inner = ib % 4;
const uint iqs = loadr_b; const uint iqs = loadr_b;
#endif #endif
const uint buf_ib = loadc_b + l; const uint buf_ib = loadc_b + l;
if (iqs == 0) { if (iqs == 0) {
buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib].ds); buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
} }
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs] = data_b[ib].qs[iqs]; const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x;
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y;
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z;
buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w;
} }
barrier(); barrier();
@ -349,7 +356,7 @@ void main() {
cache_b_qs[cc * (BK / 4) + idx_k]); cache_b_qs[cc * (BK / 4) + idx_k]);
} }
sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc]); sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1);
} }
} }
} }

View File

@ -16,8 +16,8 @@ i32vec2 repack(uint ib, uint iqs) {
(vui >> 4) & 0x0F0F0F0F); (vui >> 4) & 0x0F0F0F0F);
} }
ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) { ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * (float(q_sum) * dsb.x - 8.0f * dsb.y)); return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
} }
#endif #endif
@ -29,8 +29,8 @@ i32vec2 repack(uint ib, uint iqs) {
(vui >> 4) & 0x0F0F0F0F); (vui >> 4) & 0x0F0F0F0F);
} }
ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) { ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y); return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
} }
#endif #endif
@ -50,8 +50,8 @@ i32vec2 repack(uint ib, uint iqs) {
return i32vec2(v0, v1); return i32vec2(v0, v1);
} }
ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) { ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * (float(q_sum) * dsb.x - 16.0f * dsb.y)); return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
} }
#endif #endif
@ -69,8 +69,8 @@ i32vec2 repack(uint ib, uint iqs) {
return i32vec2(v0, v1); return i32vec2(v0, v1);
} }
ACC_TYPE mul_q8_1(int32_t q_sum, vec2 dma, vec2 dsb) { ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y); return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
} }
#endif #endif
@ -81,7 +81,7 @@ int32_t repack(uint ib, uint iqs) {
data_a[ib].qs[iqs * 2 + 1])); data_a[ib].qs[iqs * 2 + 1]));
} }
ACC_TYPE mul_q8_1(int32_t q_sum, float da, vec2 dsb) { ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(float(q_sum) * da * dsb.x); return ACC_TYPE(float(q_sum) * da * dsb.x);
} }
#endif #endif

View File

@ -3,6 +3,15 @@
#extension GL_EXT_control_flow_attributes : require #extension GL_EXT_control_flow_attributes : require
#extension GL_EXT_shader_16bit_storage : require #extension GL_EXT_shader_16bit_storage : require
#ifdef USE_SUBGROUPS
#extension GL_KHR_shader_subgroup_basic : require
#extension GL_KHR_shader_subgroup_clustered : require
#define INVOCATION_ID gl_SubgroupInvocationID.x
#else
#define INVOCATION_ID gl_LocalInvocationID.x
#endif
layout (push_constant) uniform parameter layout (push_constant) uniform parameter
{ {
uint ne; uint ne;
@ -14,13 +23,19 @@ layout(constant_id = 0) const uint GROUP_SIZE = 32;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {vec4 data_a[];}; layout (binding = 0) readonly buffer A {vec4 data_a[];};
#ifndef QBLOCK_X4
layout (binding = 1) writeonly buffer D {block_q8_1_packed32 data_b[];}; layout (binding = 1) writeonly buffer D {block_q8_1_packed32 data_b[];};
#else
layout (binding = 1) writeonly buffer D {block_q8_1_x4 data_b[];};
#endif
#ifndef USE_SUBGROUPS
shared float shmem[GROUP_SIZE]; shared float shmem[GROUP_SIZE];
#endif
void quantize() { void quantize() {
const uint wgid = gl_WorkGroupID.x; const uint wgid = gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x; const uint tid = INVOCATION_ID;
// Each thread handles a vec4, so 8 threads handle a block // Each thread handles a vec4, so 8 threads handle a block
const uint blocks_per_group = GROUP_SIZE / 8; const uint blocks_per_group = GROUP_SIZE / 8;
@ -30,9 +45,19 @@ void quantize() {
const uint ib = wgid * blocks_per_group + block_in_wg; const uint ib = wgid * blocks_per_group + block_in_wg;
const uint iqs = tid % 8; const uint iqs = tid % 8;
#ifndef QBLOCK_X4
if (ib >= gl_NumWorkGroups.x * blocks_per_group) { if (ib >= gl_NumWorkGroups.x * blocks_per_group) {
return; return;
} }
#else
const uint ibx4_outer = ib / 4;
const uint ibx4_inner = ib % 4;
const uint required_x4_blocks = (p.ne + 127) / 128;
if (ibx4_outer >= required_x4_blocks) {
return;
}
#endif
const uint a_idx = ib * 8 + iqs; const uint a_idx = ib * 8 + iqs;
@ -40,7 +65,9 @@ void quantize() {
const vec4 abs_vals = abs(vals); const vec4 abs_vals = abs(vals);
// Find absolute max for each block // Find absolute max for each block
shmem[tid] = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w)); const float thread_max = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w));
#ifndef USE_SUBGROUPS
shmem[tid] = thread_max;
barrier(); barrier();
[[unroll]] for (uint s = 4; s > 0; s >>= 1) { [[unroll]] for (uint s = 4; s > 0; s >>= 1) {
if (iqs < s) { if (iqs < s) {
@ -50,14 +77,28 @@ void quantize() {
} }
const float amax = shmem[block_in_wg * 8]; const float amax = shmem[block_in_wg * 8];
#else
const float amax = subgroupClusteredMax(thread_max, 8);
#endif
const float d = amax / 127.0; const float d = amax / 127.0;
const float d_inv = d != 0.0 ? 1.0 / d : 0.0; const float d_inv = d != 0.0 ? 1.0 / d : 0.0;
vals = round(vals * d_inv); vals = round(vals * d_inv);
#ifndef QBLOCK_X4
data_b[ib].qs[iqs] = pack32(i8vec4(round(vals))); data_b[ib].qs[iqs] = pack32(i8vec4(round(vals)));
#else
data_b[ibx4_outer].qs[ibx4_inner * 8 + iqs] = pack32(i8vec4(round(vals)));
#endif
#ifndef USE_SUBGROUPS
barrier(); barrier();
#endif
// Calculate the sum for each block // Calculate the sum for each block
shmem[tid] = vals.x + vals.y + vals.z + vals.w; const float thread_sum = vals.x + vals.y + vals.z + vals.w;
#ifndef USE_SUBGROUPS
shmem[tid] = thread_sum;
barrier(); barrier();
[[unroll]] for (uint s = 4; s > 0; s >>= 1) { [[unroll]] for (uint s = 4; s > 0; s >>= 1) {
if (iqs < s) { if (iqs < s) {
@ -65,10 +106,19 @@ void quantize() {
} }
barrier(); barrier();
} }
#else
const float sum = subgroupClusteredAdd(thread_sum, 8);
#endif
if (iqs == 0) { if (iqs == 0) {
#ifndef USE_SUBGROUPS
const float sum = shmem[tid]; const float sum = shmem[tid];
#endif
#ifndef QBLOCK_X4
data_b[ib].ds = f16vec2(vec2(d, sum * d)); data_b[ib].ds = f16vec2(vec2(d, sum * d));
#else
data_b[ibx4_outer].ds[ibx4_inner] = f16vec2(vec2(d, sum * d));
#endif
} }
} }

View File

@ -207,6 +207,18 @@ struct block_q8_1_packed32
int32_t qs[8]; int32_t qs[8];
}; };
// 4 blocks in one to allow 16-byte/128-bit alignment and loads
struct block_q8_1_x4
{
f16vec2 ds[4];
int32_t qs[32];
};
struct block_q8_1_x4_packed128
{
f16vec2 ds[4];
ivec4 qs[8];
};
// K-quants // K-quants
#define QUANT_K_Q2_K 256 #define QUANT_K_Q2_K 256

View File

@ -206,6 +206,22 @@ bool string_ends_with(const std::string& str, const std::string& suffix) {
return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin()); return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin());
} }
bool is_quantized_type(const std::string& type_name) {
return type_name != "f32" && type_name != "f16" && type_name != "bf16";
}
bool is_legacy_quant(const std::string& type_name) {
return type_name == "q4_0" || type_name == "q4_1" || type_name == "q5_0" || type_name == "q5_1" || type_name == "q8_0";
}
bool is_k_quant(const std::string& type_name) {
return string_ends_with(type_name, "_k");
}
bool is_iq_quant(const std::string& type_name) {
return string_starts_with(type_name, "iq");
}
static const char path_separator = '/'; static const char path_separator = '/';
std::string join_paths(const std::string& path1, const std::string& path2) { std::string join_paths(const std::string& path1, const std::string& path2) {
@ -323,6 +339,9 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
} }
base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
if (f16acc) {
base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\"";
}
if (coopmat) { if (coopmat) {
base_dict["COOPMAT"] = "1"; base_dict["COOPMAT"] = "1";
@ -399,7 +418,7 @@ void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool c
} }
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (!coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && (tname == "q4_0" || tname == "q4_1" || tname == "q5_0" || tname == "q5_1" || tname == "q8_0")) { if (!coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && is_legacy_quant(tname)) {
string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc); string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(base_dict, {{"FLOAT_TYPE", FLOAT_TYPE(tname)}, {data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc);
} }
#endif #endif
@ -437,8 +456,12 @@ void process_shaders() {
// flash attention // flash attention
for (const auto& f16acc : {false, true}) { for (const auto& f16acc : {false, true}) {
std::string acctype = f16acc ? "float16_t" : "float"; std::map<std::string, std::string> fa_base_dict = base_dict;
std::string acctypev4 = f16acc ? "f16vec4" : "vec4"; fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float";
fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4";
if (f16acc) {
fa_base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\"";
}
for (const auto& tname : type_names) { for (const auto& tname : type_names) {
if (tname == "f32") { if (tname == "f32") {
@ -449,30 +472,30 @@ void process_shaders() {
#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) #if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (tname == "f16") { if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, true, f16acc); merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, true, f16acc);
} else { } else {
std::string data_a_key = "DATA_A_" + to_uppercase(tname); std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp",
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc);
} }
#endif #endif
#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) #if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (tname == "f16") { if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"COOPMAT", "1"}}), true, true, false, f16acc); merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc);
} else if (tname == "q4_0" || tname == "q8_0") { } else if (tname == "q4_0" || tname == "q8_0") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname); std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp",
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"ACC_TYPEV4", acctypev4}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc); merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc);
} }
#endif #endif
if (tname == "f16") { if (tname == "f16") {
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}}), true, false, false, f16acc); merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc);
} else if (tname == "q4_0" || tname == "q8_0") { } else if (tname == "q4_0" || tname == "q8_0") {
std::string data_a_key = "DATA_A_" + to_uppercase(tname); std::string data_a_key = "DATA_A_" + to_uppercase(tname);
string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp",
merge_maps(base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"ACC_TYPE", acctype}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc); merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc);
} }
} }
} }
@ -488,8 +511,20 @@ void process_shaders() {
string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
// mul mat vec with integer dot product
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (is_legacy_quant(tname)) {
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
}
#endif
// Dequant shaders // Dequant shaders
if (tname != "f16" && tname != "bf16") { if (tname != "f16" && tname != "bf16") {
string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}})); string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}}));
@ -572,7 +607,12 @@ void process_shaders() {
string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {}); string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {});
string_to_spv("fa_split_k_reduce", "flash_attn_split_k_reduce.comp", {}); string_to_spv("fa_split_k_reduce", "flash_attn_split_k_reduce.comp", {});
string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {}); string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {});
string_to_spv("quantize_q8_1_subgroup", "quantize_q8_1.comp", {{"USE_SUBGROUPS", "1"}});
string_to_spv("quantize_q8_1_x4", "quantize_q8_1.comp", {{"QBLOCK_X4", "1"}});
string_to_spv("quantize_q8_1_x4_subgroup", "quantize_q8_1.comp", {{"QBLOCK_X4", "1"}, {"USE_SUBGROUPS", "1"}});
string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
@ -617,6 +657,10 @@ void process_shaders() {
string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("hardsigmoid_f16","hardsigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("hardsigmoid_f32","hardsigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("hardswish_f16", "hardswish.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("hardswish_f32", "hardswish.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
for (auto rte : {false, true}) { for (auto rte : {false, true}) {
std::string suffix = rte ? "_rte" : ""; std::string suffix = rte ? "_rte" : "";
@ -814,12 +858,21 @@ void write_output_files() {
fputs(len.c_str(), src); fputs(len.c_str(), src);
} }
for (const std::string& btype : {"f16", "f32"}) { std::vector<std::string> btypes = {"f16", "f32"};
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
btypes.push_back("q8_1");
#endif
for (const std::string& btype : btypes) {
for (const auto& tname : type_names) { for (const auto& tname : type_names) {
fprintf(hdr, "extern unsigned char *arr_dmmv_%s_%s_f32_data[2];\n", tname.c_str(), btype.c_str()); if (btype == "q8_1" && !is_legacy_quant(tname)) {
fprintf(hdr, "extern uint64_t arr_dmmv_%s_%s_f32_len[2];\n", tname.c_str(), btype.c_str()); continue;
std::string data = "unsigned char *arr_dmmv_" + tname + "_" + btype + "_f32_data[2] = {mul_mat_vec_" + tname + "_" + btype + "_f32_data, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_data};\n"; }
std::string len = "uint64_t arr_dmmv_" + tname + "_" + btype + "_f32_len[2] = {mul_mat_vec_" + tname + "_" + btype + "_f32_len, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_len};\n"; fprintf(hdr, "extern unsigned char *arr_dmmv_%s_%s_f32_data[3];\n", tname.c_str(), btype.c_str());
fprintf(hdr, "extern uint64_t arr_dmmv_%s_%s_f32_len[3];\n", tname.c_str(), btype.c_str());
std::string data = "unsigned char *arr_dmmv_" + tname + "_" + btype + "_f32_data[3] = {mul_mat_vec_" + tname + "_" + btype + "_f32_data, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_data, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_no_shmem_data};\n";
std::string len = "uint64_t arr_dmmv_" + tname + "_" + btype + "_f32_len[3] = {mul_mat_vec_" + tname + "_" + btype + "_f32_len, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_len, mul_mat_vec_" + tname + "_" + btype + "_f32_subgroup_no_shmem_len};\n";
fputs(data.c_str(), src); fputs(data.c_str(), src);
fputs(len.c_str(), src); fputs(len.c_str(), src);
} }

View File

@ -611,6 +611,8 @@ static bool ggml_webgpu_encode_node(webgpu_context ctx, ggml_tensor * node) {
case GGML_OP_NONE: case GGML_OP_NONE:
case GGML_OP_VIEW: case GGML_OP_VIEW:
case GGML_OP_PERMUTE: case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
case GGML_OP_RESHAPE:
return false; return false;
case GGML_OP_CPY: case GGML_OP_CPY:
{ {
@ -1062,6 +1064,8 @@ static bool ggml_backend_webgpu_device_supports_op(ggml_backend_dev_t dev, const
case GGML_OP_NONE: case GGML_OP_NONE:
case GGML_OP_VIEW: case GGML_OP_VIEW:
case GGML_OP_PERMUTE: case GGML_OP_PERMUTE:
case GGML_OP_TRANSPOSE:
case GGML_OP_RESHAPE:
return true; return true;
case GGML_OP_CPY: case GGML_OP_CPY:
case GGML_OP_SET_ROWS: case GGML_OP_SET_ROWS:

View File

@ -974,6 +974,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CONV_TRANSPOSE_1D", "CONV_TRANSPOSE_1D",
"IM2COL", "IM2COL",
"IM2COL_BACK", "IM2COL_BACK",
"IM2COL_3D",
"CONV_2D", "CONV_2D",
"CONV_3D", "CONV_3D",
"CONV_2D_DW", "CONV_2D_DW",
@ -1018,7 +1019,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"GLU", "GLU",
}; };
static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89"); static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none", "none",
@ -1077,6 +1078,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"conv_transpose_1d(x)", "conv_transpose_1d(x)",
"im2col(x)", "im2col(x)",
"im2col_back(x)", "im2col_back(x)",
"im2col_3d(x)",
"conv_2d(x)", "conv_2d(x)",
"conv_3d(x)", "conv_3d(x)",
"conv_2d_dw(x)", "conv_2d_dw(x)",
@ -1121,7 +1123,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"glu(x)", "glu(x)",
}; };
static_assert(GGML_OP_COUNT == 89, "GGML_OP_COUNT != 89"); static_assert(GGML_OP_COUNT == 90, "GGML_OP_COUNT != 90");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
@ -4361,6 +4363,91 @@ struct ggml_tensor * ggml_conv_2d(
return result; return result;
} }
// a: [OC*IC, KD, KH, KW]
// b: [N*IC, ID, IH, IW]
// result: [N*OD, OH, OW, IC * KD * KH * KW]
struct ggml_tensor * ggml_im2col_3d(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int64_t IC,
int s0, // stride width
int s1, // stride height
int s2, // stride depth
int p0, // padding width
int p1, // padding height
int p2, // padding depth
int d0, // dilation width
int d1, // dilation height
int d2, // dilation depth
enum ggml_type dst_type) {
const int64_t N = b->ne[3] / IC;
const int64_t ID = b->ne[2];
const int64_t IH = b->ne[1];
const int64_t IW = b->ne[0];
const int64_t OC = a->ne[3] / IC;
UNUSED(OC);
const int64_t KD = a->ne[2];
const int64_t KH = a->ne[1];
const int64_t KW = a->ne[0];
const int64_t OD = ggml_calc_conv_output_size(ID, KD, s2, p2, d2);
const int64_t OH = ggml_calc_conv_output_size(IH, KH, s1, p1, d1);
const int64_t OW = ggml_calc_conv_output_size(IW, KW, s0, p0, d0);
GGML_ASSERT((OD > 0) && "b too small compared to a");
GGML_ASSERT((OH > 0) && "b too small compared to a");
GGML_ASSERT((OW > 0) && "b too small compared to a");
const int64_t ne[4] = {KW*KH*KD*IC, OW, OH, OD*N};
struct ggml_tensor * result = ggml_new_tensor(ctx, dst_type, 4, ne);
int32_t params[] = { s0, s1, s2, p0, p1, p2, d0, d1, d2, (int32_t)IC};
ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_IM2COL_3D;
result->src[0] = a;
result->src[1] = b;
return result;
}
// a: [OC*IC, KD, KH, KW]
// b: [N*IC, ID, IH, IW]
// result: [N*OC, OD, OH, OW]
struct ggml_tensor * ggml_conv_3d(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int64_t IC,
int s0, // stride width
int s1, // stride height
int s2, // stride depth
int p0, // padding width
int p1, // padding height
int p2, // padding depth
int d0, // dilation width
int d1, // dilation height
int d2 // dilation depth
) {
struct ggml_tensor * im2col = ggml_im2col_3d(ctx, a, b, IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, a->type); // [N*OD, OH, OW, IC * KD * KH * KW]
int64_t OC = a->ne[3] / IC;
int64_t N = b->ne[3] / IC;
struct ggml_tensor * result =
ggml_mul_mat(ctx,
ggml_reshape_2d(ctx, im2col, im2col->ne[0], im2col->ne[3] * im2col->ne[2] * im2col->ne[1]), // [N*OD, OH, OW, IC * KD * KH * KW] => [N*OD*OH*OW, IC * KD * KH * KW]
ggml_reshape_2d(ctx, a, (a->ne[0] * a->ne[1] * a->ne[2] * IC), OC)); // [OC*IC, KD, KH, KW] => [OC, IC * KD * KH * KW]
int64_t OD = im2col->ne[3] / N;
result = ggml_reshape_4d(ctx, result, im2col->ne[1]*im2col->ne[2], OD, N, OC); // [OC, N*OD*OH*OW] => [OC, N, OD, OH*OW]
result = ggml_cont(ctx, ggml_permute(ctx, result, 0, 1, 3, 2)); // [N, OC, OD, OH*OW]
result = ggml_reshape_4d(ctx, result, im2col->ne[1], im2col->ne[2], OD, OC * N); // [N*OC, OD, OH, OW]
return result;
}
// ggml_conv_2d_sk_p0 // ggml_conv_2d_sk_p0
struct ggml_tensor * ggml_conv_2d_sk_p0( struct ggml_tensor * ggml_conv_2d_sk_p0(
@ -4482,9 +4569,9 @@ struct ggml_tensor * ggml_conv_2d_direct(
return result; return result;
} }
// ggml_conv_3d // ggml_conv_3d_direct
struct ggml_tensor * ggml_conv_3d( struct ggml_tensor * ggml_conv_3d_direct(
struct ggml_context * ctx, struct ggml_context * ctx,
struct ggml_tensor * a, struct ggml_tensor * a,
struct ggml_tensor * b, struct ggml_tensor * b,
@ -4710,11 +4797,36 @@ struct ggml_tensor * ggml_pad(
int p1, int p1,
int p2, int p2,
int p3) { int p3) {
return ggml_pad_ext(ctx, a, 0, p0, 0, p1, 0, p2, 0, p3);
}
struct ggml_tensor * ggml_pad_ext(
struct ggml_context * ctx,
struct ggml_tensor * a,
int lp0,
int rp0,
int lp1,
int rp1,
int lp2,
int rp2,
int lp3,
int rp3
) {
struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type, struct ggml_tensor * result = ggml_new_tensor_4d(ctx, a->type,
a->ne[0] + p0, a->ne[0] + lp0 + rp0,
a->ne[1] + p1, a->ne[1] + lp1 + rp1,
a->ne[2] + p2, a->ne[2] + lp2 + rp2,
a->ne[3] + p3); a->ne[3] + lp3 + rp3);
ggml_set_op_params_i32(result, 0, lp0);
ggml_set_op_params_i32(result, 1, rp0);
ggml_set_op_params_i32(result, 2, lp1);
ggml_set_op_params_i32(result, 3, rp1);
ggml_set_op_params_i32(result, 4, lp2);
ggml_set_op_params_i32(result, 5, rp2);
ggml_set_op_params_i32(result, 6, lp3);
ggml_set_op_params_i32(result, 7, rp3);
result->op = GGML_OP_PAD; result->op = GGML_OP_PAD;
result->src[0] = a; result->src[0] = a;

Some files were not shown because too many files have changed in this diff Show More