Merge branch 'master' into HEAD

This commit is contained in:
Georgi Gerganov 2025-11-29 22:38:44 +02:00
commit d8d98bb4bb
No known key found for this signature in database
GPG Key ID: 449E073F9DC10735
98 changed files with 6675 additions and 1609 deletions

View File

@ -50,6 +50,7 @@ WORKDIR /app
RUN apt-get update \
&& apt-get install -y \
build-essential \
git \
python3 \
python3-pip \

View File

@ -45,7 +45,7 @@ sd=`dirname $0`
cd $sd/../
SRC=`pwd`
CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON"
CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON -DGGML_SCHED_NO_REALLOC=ON"
if [ ! -z ${GG_BUILD_METAL} ]; then
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=ON"
@ -428,10 +428,10 @@ function gg_run_qwen3_0_6b {
(time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -ngl 99 -c 1024 -b 512 --chunks 2 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa off --no-op-offload) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa on --no-op-offload) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
function check_ppl {
qnt="$1"
@ -523,8 +523,8 @@ function gg_run_embd_bge_small {
./bin/llama-quantize ${model_f16} ${model_q8_0} q8_0
(time ./bin/llama-embedding --model ${model_f16} -p "I believe the meaning of life is" -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
(time ./bin/llama-embedding --model ${model_q8_0} -p "I believe the meaning of life is" -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log
(time ./bin/llama-embedding --model ${model_f16} -p "I believe the meaning of life is" -ngl 99 -c 0 --no-op-offload) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
(time ./bin/llama-embedding --model ${model_q8_0} -p "I believe the meaning of life is" -ngl 99 -c 0 --no-op-offload) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log
set +e
}
@ -564,7 +564,7 @@ function gg_run_rerank_tiny {
model_f16="${path_models}/ggml-model-f16.gguf"
# for this model, the SEP token is "</s>"
(time ./bin/llama-embedding --model ${model_f16} -p "what is panda?\thi\nwhat is panda?\tit's a bear\nwhat is panda?\tThe giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log
(time ./bin/llama-embedding --model ${model_f16} -p "what is panda?\thi\nwhat is panda?\tit's a bear\nwhat is panda?\tThe giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --no-op-offload --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log
# sample output
# rerank score 0: 0.029

View File

@ -694,6 +694,12 @@ static bool is_autoy(const std::string & value) {
}
common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **)) {
// default values specific to example
// note: we place it here instead of inside server.cpp to allow llama-gen-docs to pick it up
if (ex == LLAMA_EXAMPLE_SERVER) {
params.use_jinja = true;
}
// load dynamic backends
ggml_backend_load_all();
@ -2495,11 +2501,18 @@ common_params_context common_params_parser_init(common_params & params, llama_ex
).set_examples({LLAMA_EXAMPLE_SERVER}));
add_opt(common_arg(
{"--jinja"},
"use jinja template for chat (default: disabled)",
string_format("use jinja template for chat (default: %s)\n", params.use_jinja ? "enabled" : "disabled"),
[](common_params & params) {
params.use_jinja = true;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_JINJA"));
add_opt(common_arg(
{"--no-jinja"},
string_format("disable jinja template for chat (default: %s)\n", params.use_jinja ? "enabled" : "disabled"),
[](common_params & params) {
params.use_jinja = false;
}
).set_examples({LLAMA_EXAMPLE_SERVER, LLAMA_EXAMPLE_MAIN, LLAMA_EXAMPLE_MTMD}).set_env("LLAMA_ARG_NO_JINJA"));
add_opt(common_arg(
{"--reasoning-format"}, "FORMAT",
"controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:\n"

View File

@ -13,6 +13,120 @@
using json = nlohmann::ordered_json;
static void parse_prefixed_json_tool_call_array(common_chat_msg_parser & builder,
const common_regex & prefix,
size_t rstrip_prefix = 0) {
static const std::vector<std::vector<std::string>> args_paths = { { "arguments" } };
if (auto res = builder.try_find_regex(prefix)) {
builder.move_back(rstrip_prefix);
auto tool_calls = builder.consume_json_with_dumped_args(args_paths);
if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call array");
}
} else {
builder.add_content(builder.consume_rest());
}
}
static std::string wrap_code_as_arguments(common_chat_msg_parser & builder, const std::string & code) {
std::string arguments;
if (builder.is_partial()) {
arguments = (json{
{ "code", code + builder.healing_marker() }
})
.dump();
auto idx = arguments.find(builder.healing_marker());
if (idx != std::string::npos) {
arguments.resize(idx);
}
} else {
arguments = (json{
{ "code", code }
})
.dump();
}
return arguments;
}
/**
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
* Aggregates the prefix, suffix and in-between text into the content.
*/
static void parse_json_tool_calls(
common_chat_msg_parser & builder,
const std::optional<common_regex> & block_open,
const std::optional<common_regex> & function_regex_start_only,
const std::optional<common_regex> & function_regex,
const common_regex & close_regex,
const std::optional<common_regex> & block_close,
bool allow_raw_python = false,
const std::function<std::string(const common_chat_msg_parser::find_regex_result & fres)> & get_function_name =
nullptr) {
auto parse_tool_calls = [&]() {
size_t from = std::string::npos;
auto first = true;
while (true) {
auto start_pos = builder.pos();
auto res = function_regex_start_only && first ? builder.try_consume_regex(*function_regex_start_only) :
function_regex ? builder.try_find_regex(*function_regex, from) :
std::nullopt;
if (res) {
std::string name;
if (get_function_name) {
name = get_function_name(*res);
} else {
GGML_ASSERT(res->groups.size() == 2);
name = builder.str(res->groups[1]);
}
first = false;
if (name.empty()) {
// get_function_name signalled us that we should skip this match and treat it as content.
from = res->groups[0].begin + 1;
continue;
}
from = std::string::npos;
auto maybe_raw_python = name == "python" && allow_raw_python;
if (builder.input()[builder.pos()] == '{' || !maybe_raw_python) {
if (auto arguments = builder.try_consume_json_with_dumped_args({ {} })) {
if (!builder.add_tool_call(name, "", arguments->value) || arguments->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
builder.consume_regex(close_regex);
}
continue;
}
if (maybe_raw_python) {
auto arguments = wrap_code_as_arguments(builder, builder.consume_rest());
if (!builder.add_tool_call(name, "", arguments)) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
return;
}
throw common_chat_msg_partial_exception("incomplete tool call");
} else {
builder.move_to(start_pos);
}
break;
}
if (block_close) {
builder.consume_regex(*block_close);
}
builder.consume_spaces();
builder.add_content(builder.consume_rest());
};
if (block_open) {
if (auto res = builder.try_find_regex(*block_open)) {
parse_tool_calls();
} else {
builder.add_content(builder.consume_rest());
}
} else {
parse_tool_calls();
}
}
common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax)
: input_(input), is_partial_(is_partial), syntax_(syntax)
{
@ -532,3 +646,857 @@ std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parse
void common_chat_msg_parser::clear_tools() {
result_.tool_calls.clear();
}
/**
* All common_chat_parse_* moved from chat.cpp to chat-parser.cpp below
* to reduce incremental compile time for parser changes.
*/
static void common_chat_parse_generic(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const std::vector<std::vector<std::string>> content_paths = {
{"response"},
};
static const std::vector<std::vector<std::string>> args_paths = {
{"tool_call", "arguments"},
{"tool_calls", "arguments"},
};
auto data = builder.consume_json_with_dumped_args(args_paths, content_paths);
if (data.value.contains("tool_calls")) {
if (!builder.add_tool_calls(data.value.at("tool_calls")) || data.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool calls");
}
} else if (data.value.contains("tool_call")) {
if (!builder.add_tool_call(data.value.at("tool_call")) || data.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
} else if (data.value.contains("response")) {
const auto & response = data.value.at("response");
builder.add_content(response.is_string() ? response.template get<std::string>() : response.dump(2));
if (data.is_partial) {
throw common_chat_msg_partial_exception("incomplete response");
}
} else {
throw common_chat_msg_partial_exception("Expected 'tool_call', 'tool_calls' or 'response' in JSON");
}
}
static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex prefix(regex_escape("[TOOL_CALLS]"));
parse_prefixed_json_tool_call_array(builder, prefix);
}
static void common_chat_parse_magistral(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("[THINK]", "[/THINK]");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex prefix(regex_escape("[TOOL_CALLS]"));
parse_prefixed_json_tool_call_array(builder, prefix);
}
static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>");
static const common_regex start_action_regex("<\\|START_ACTION\\|>");
static const common_regex end_action_regex("<\\|END_ACTION\\|>");
static const common_regex start_response_regex("<\\|START_RESPONSE\\|>");
static const common_regex end_response_regex("<\\|END_RESPONSE\\|>");
if (auto res = builder.try_find_regex(start_action_regex)) {
// If we didn't extract thoughts, prelude includes them.
auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}});
for (const auto & tool_call : tool_calls.value) {
std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : "";
std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : "";
std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : "";
if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
}
if (tool_calls.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
builder.consume_regex(end_action_regex);
} else if (auto res = builder.try_find_regex(start_response_regex)) {
if (!builder.try_find_regex(end_response_regex)) {
builder.add_content(builder.consume_rest());
throw common_chat_msg_partial_exception(end_response_regex.str());
}
} else {
builder.add_content(builder.consume_rest());
}
}
static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) {
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex function_regex(
"\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: ");
static const common_regex close_regex("\\}\\s*");
static const common_regex function_name_regex("\\s*(\\w+)\\s*\\.\\s*call\\(");
static const common_regex arg_name_regex("\\s*(\\w+)\\s*=\\s*");
if (with_builtin_tools) {
static const common_regex builtin_call_regex("<\\|python_tag\\|>");
if (auto res = builder.try_find_regex(builtin_call_regex)) {
auto fun_res = builder.consume_regex(function_name_regex);
auto function_name = builder.str(fun_res.groups[1]);
common_healing_marker healing_marker;
json args = json::object();
while (true) {
if (auto arg_res = builder.try_consume_regex(arg_name_regex)) {
auto arg_name = builder.str(arg_res->groups[1]);
auto partial = builder.consume_json();
args[arg_name] = partial.json;
healing_marker.marker = partial.healing_marker.marker;
healing_marker.json_dump_marker = partial.healing_marker.json_dump_marker;
builder.consume_spaces();
if (!builder.try_consume_literal(",")) {
break;
}
} else {
break;
}
}
builder.consume_literal(")");
builder.consume_spaces();
auto arguments = args.dump();
if (!builder.add_tool_call(function_name, "", arguments)) {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
return;
}
}
parse_json_tool_calls(
builder,
/* block_open= */ std::nullopt,
/* function_regex_start_only= */ function_regex,
/* function_regex= */ std::nullopt,
close_regex,
std::nullopt);
}
static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex tool_calls_begin("(?:<tool▁calls▁begin>|<tool_calls_begin>|<tool calls begin>|<tool\\\\_calls\\\\_begin>|<tool▁calls>)");
static const common_regex tool_calls_end("<tool▁calls▁end>");
static const common_regex function_regex("(?:<tool▁call▁begin>)?function<toolsep>([^\n]+)\n```json\n");
static const common_regex close_regex("```[\\s\\r\\n]*<tool▁call▁end>");
parse_json_tool_calls(
builder,
/* block_open= */ tool_calls_begin,
/* function_regex_start_only= */ std::nullopt,
function_regex,
close_regex,
tool_calls_end);
}
static void common_chat_parse_deepseek_v3_1_content(common_chat_msg_parser & builder) {
static const common_regex function_regex("(?:<tool▁call▁begin>)?([^\\n<]+)(?:<toolsep>)");
static const common_regex close_regex("(?:[\\s]*)?<toolcallend>");
static const common_regex tool_calls_begin("(?:<tool▁calls▁begin>|<tool_calls_begin>|<tool calls begin>|<tool\\\\_calls\\\\_begin>|<tool▁calls>)");
static const common_regex tool_calls_end("<tool▁calls▁end>");
if (!builder.syntax().parse_tool_calls) {
LOG_DBG("%s: not parse_tool_calls\n", __func__);
builder.add_content(builder.consume_rest());
return;
}
LOG_DBG("%s: parse_tool_calls\n", __func__);
parse_json_tool_calls(
builder,
/* block_open= */ tool_calls_begin,
/* function_regex_start_only= */ std::nullopt,
function_regex,
close_regex,
tool_calls_end);
}
static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) {
// DeepSeek V3.1 outputs reasoning content between "<think>" and "</think>" tags, followed by regular content
// First try to parse using the standard reasoning parsing method
LOG_DBG("%s: thinking_forced_open: %s\n", __func__, std::to_string(builder.syntax().thinking_forced_open).c_str());
auto start_pos = builder.pos();
auto found_end_think = builder.try_find_literal("</think>");
builder.move_to(start_pos);
if (builder.syntax().thinking_forced_open && !builder.is_partial() && !found_end_think) {
LOG_DBG("%s: no end_think, not partial, adding content\n", __func__);
common_chat_parse_deepseek_v3_1_content(builder);
} else if (builder.try_parse_reasoning("<think>", "</think>")) {
// If reasoning was parsed successfully, the remaining content is regular content
LOG_DBG("%s: parsed reasoning, adding content\n", __func__);
// </think><tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>NAME\n```json\nJSON\n```<tool▁call▁end><tool▁calls▁end>
common_chat_parse_deepseek_v3_1_content(builder);
} else {
if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE) {
LOG_DBG("%s: reasoning_format none, adding content\n", __func__);
common_chat_parse_deepseek_v3_1_content(builder);
return;
}
// If no reasoning tags found, check if we should treat everything as reasoning
if (builder.syntax().thinking_forced_open) {
// If thinking is forced open but no tags found, treat everything as reasoning
LOG_DBG("%s: thinking_forced_open, adding reasoning content\n", __func__);
builder.add_reasoning_content(builder.consume_rest());
} else {
LOG_DBG("%s: no thinking_forced_open, adding content\n", __func__);
// <tool▁call▁begin>NAME<tool▁sep>JSON<tool▁call▁end>
common_chat_parse_deepseek_v3_1_content(builder);
}
}
}
static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) {
static const xml_tool_call_format form {
/* form.scope_start = */ "<minimax:tool_call>",
/* form.tool_start = */ "<invoke name=\"",
/* form.tool_sep = */ "\">",
/* form.key_start = */ "<parameter name=\"",
/* form.key_val_sep = */ "\">",
/* form.val_end = */ "</parameter>",
/* form.tool_end = */ "</invoke>",
/* form.scope_end = */ "</minimax:tool_call>",
};
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
}
static void common_chat_parse_qwen3_coder_xml(common_chat_msg_parser & builder) {
static const xml_tool_call_format form = ([]() {
xml_tool_call_format form {};
form.scope_start = "<tool_call>";
form.tool_start = "<function=";
form.tool_sep = ">";
form.key_start = "<parameter=";
form.key_val_sep = ">";
form.val_end = "</parameter>";
form.tool_end = "</function>";
form.scope_end = "</tool_call>";
form.trim_raw_argval = true;
return form;
})();
builder.consume_reasoning_with_xml_tool_calls(form);
}
static void common_chat_parse_kimi_k2(common_chat_msg_parser & builder) {
static const xml_tool_call_format form = ([]() {
xml_tool_call_format form {};
form.scope_start = "<|tool_calls_section_begin|>";
form.tool_start = "<|tool_call_begin|>";
form.tool_sep = "<|tool_call_argument_begin|>{";
form.key_start = "\"";
form.key_val_sep = "\": ";
form.val_end = ", ";
form.tool_end = "}<|tool_call_end|>";
form.scope_end = "<|tool_calls_section_end|>";
form.raw_argval = false;
form.last_val_end = "";
return form;
})();
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
}
static void common_chat_parse_apriel_1_5(common_chat_msg_parser & builder) {
static const xml_tool_call_format form = ([]() {
xml_tool_call_format form {};
form.scope_start = "<tool_calls>[";
form.tool_start = "{\"name\": \"";
form.tool_sep = "\", \"arguments\": {";
form.key_start = "\"";
form.key_val_sep = "\": ";
form.val_end = ", ";
form.tool_end = "}, ";
form.scope_end = "]</tool_calls>";
form.raw_argval = false;
form.last_val_end = "";
form.last_tool_end = "}";
return form;
})();
builder.consume_reasoning_with_xml_tool_calls(form, "<thinking>", "</thinking>");
}
static void common_chat_parse_xiaomi_mimo(common_chat_msg_parser & builder) {
static const xml_tool_call_format form = ([]() {
xml_tool_call_format form {};
form.scope_start = "";
form.tool_start = "<tool_call>\n{\"name\": \"";
form.tool_sep = "\", \"arguments\": {";
form.key_start = "\"";
form.key_val_sep = "\": ";
form.val_end = ", ";
form.tool_end = "}\n</tool_call>";
form.scope_end = "";
form.raw_argval = false;
form.last_val_end = "";
return form;
})();
builder.consume_reasoning_with_xml_tool_calls(form);
}
static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
static const std::string constraint = "(?: (<\\|constrain\\|>)?([a-zA-Z0-9_-]+))";
static const std::string recipient("(?: to=functions\\.([^<\\s]+))");
static const common_regex start_regex("<\\|start\\|>assistant");
static const common_regex analysis_regex("<\\|channel\\|>analysis");
static const common_regex final_regex("<\\|channel\\|>final" + constraint + "?");
static const common_regex preamble_regex("<\\|channel\\|>commentary");
static const common_regex tool_call1_regex(recipient + "<\\|channel\\|>(analysis|commentary)" + constraint + "?");
static const common_regex tool_call2_regex("<\\|channel\\|>(analysis|commentary)" + recipient + constraint + "?");
auto consume_end = [&](bool include_end = false) {
if (auto res = builder.try_find_literal("<|end|>")) {
return res->prelude + (include_end ? builder.str(res->groups[0]) : "");
}
return builder.consume_rest();
};
auto handle_tool_call = [&](const std::string & name) {
if (auto args = builder.try_consume_json_with_dumped_args({{}})) {
if (builder.syntax().parse_tool_calls) {
if (!builder.add_tool_call(name, "", args->value) || args->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
} else if (args->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
}
};
auto regex_match = [](const common_regex & regex, const std::string & input) -> std::optional<common_regex_match> {
auto match = regex.search(input, 0, true);
if (match.type == COMMON_REGEX_MATCH_TYPE_FULL) {
return match;
}
return std::nullopt;
};
do {
auto header_start_pos = builder.pos();
auto content_start = builder.try_find_literal("<|message|>");
if (!content_start) {
throw common_chat_msg_partial_exception("incomplete header");
}
auto header = content_start->prelude;
if (auto match = regex_match(tool_call1_regex, header)) {
auto group = match->groups[1];
auto name = header.substr(group.begin, group.end - group.begin);
handle_tool_call(name);
continue;
}
if (auto match = regex_match(tool_call2_regex, header)) {
auto group = match->groups[2];
auto name = header.substr(group.begin, group.end - group.begin);
handle_tool_call(name);
continue;
}
if (regex_match(analysis_regex, header)) {
builder.move_to(header_start_pos);
if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) {
builder.add_content(consume_end(true));
} else {
builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|end|>");
}
continue;
}
if(regex_match(final_regex, header) || regex_match(preamble_regex, header)) {
builder.add_content(consume_end());
continue;
}
// Possibly a malformed message, attempt to recover by rolling
// back to pick up the next <|start|>
LOG_DBG("%s: unknown header from message: %s\n", __func__, header.c_str());
builder.move_to(header_start_pos);
} while (builder.try_find_regex(start_regex, std::string::npos, false));
auto remaining = builder.consume_rest();
if (!remaining.empty()) {
LOG_DBG("%s: content after last message: %s\n", __func__, remaining.c_str());
}
}
static void common_chat_parse_glm_4_5(common_chat_msg_parser & builder) {
static const xml_tool_call_format form {
/* form.scope_start = */ "",
/* form.tool_start = */ "<tool_call>",
/* form.tool_sep = */ "",
/* form.key_start = */ "<arg_key>",
/* form.key_val_sep = */ "</arg_key>",
/* form.val_end = */ "</arg_value>",
/* form.tool_end = */ "</tool_call>",
/* form.scope_end = */ "",
/* form.key_val_sep2 = */ "<arg_value>",
};
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
}
static void common_chat_parse_firefunction_v2(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex prefix(regex_escape(" functools["));
parse_prefixed_json_tool_call_array(builder, prefix, /* rstrip_prefix= */ 1);
}
static void common_chat_parse_functionary_v3_2(common_chat_msg_parser & builder) {
static const common_regex function_regex_start_only(R"((\w+\n\{|python\n|all\n))");
static const common_regex function_regex(R"(>>>(\w+\n\{|python\n|all\n))");
static const common_regex close_regex(R"(\s*)");
parse_json_tool_calls(
builder,
std::nullopt,
function_regex_start_only,
function_regex,
close_regex,
std::nullopt,
/* allow_raw_python= */ true,
/* get_function_name= */ [&](const auto & res) -> std::string {
auto at_start = res.groups[0].begin == 0;
auto name = builder.str(res.groups[1]);
if (!name.empty() && name.back() == '{') {
// Unconsume the opening brace '{' to ensure the JSON parsing goes well.
builder.move_back(1);
}
auto idx = name.find_last_not_of("\n{");
name = name.substr(0, idx + 1);
if (at_start && name == "all") {
return "";
}
return name;
});
}
static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
static const common_regex python_tag_regex(regex_escape("<|python_tag|>"));
static const common_regex function_regex(R"(<function=(\w+)>)");
static const common_regex close_regex(R"(</function>)");
parse_json_tool_calls(
builder,
/* block_open= */ std::nullopt,
/* function_regex_start_only= */ std::nullopt,
function_regex,
close_regex,
std::nullopt);
if (auto res = builder.try_find_regex(python_tag_regex)) {
auto arguments = wrap_code_as_arguments(builder, builder.consume_rest());
builder.add_tool_call("python", "", arguments);
return;
}
}
static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex open_regex(
"(?:"
"(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
"(" // match 2 (open_tag)
"<tool_call>"
"|<function_call>"
"|<tool>"
"|<tools>"
"|<response>"
"|<json>"
"|<xml>"
"|<JSON>"
")?"
"(\\s*\\{\\s*\"name\")" // match 3 (named tool call)
")"
"|<function=([^>]+)>" // match 4 (function name)
"|<function name=\"([^\"]+)\">" // match 5 (function name again)
);
while (auto res = builder.try_find_regex(open_regex)) {
const auto & block_start = res->groups[1];
std::string block_end = block_start.empty() ? "" : "```";
const auto & open_tag = res->groups[2];
std::string close_tag;
if (!res->groups[3].empty()) {
builder.move_to(res->groups[3].begin);
close_tag = open_tag.empty() ? "" : "</" + builder.str(open_tag).substr(1);
if (auto tool_call = builder.try_consume_json_with_dumped_args({{"arguments"}})) {
if (!builder.add_tool_call(tool_call->value) || tool_call->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
builder.consume_spaces();
builder.consume_literal(close_tag);
builder.consume_spaces();
if (!block_end.empty()) {
builder.consume_literal(block_end);
builder.consume_spaces();
}
} else {
throw common_chat_msg_partial_exception("failed to parse tool call");
}
} else {
auto function_name = builder.str(res->groups[4]);
if (function_name.empty()) {
function_name = builder.str(res->groups[5]);
}
GGML_ASSERT(!function_name.empty());
close_tag = "</function>";
if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) {
if (!builder.add_tool_call(function_name, "", arguments->value) || arguments->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
builder.consume_spaces();
builder.consume_literal(close_tag);
builder.consume_spaces();
if (!block_end.empty()) {
builder.consume_literal(block_end);
builder.consume_spaces();
}
}
}
}
builder.add_content(builder.consume_rest());
}
static void common_chat_parse_granite(common_chat_msg_parser & builder) {
// Parse thinking tags
static const common_regex start_think_regex(regex_escape("<think>"));
static const common_regex end_think_regex(regex_escape("</think>"));
// Granite models output partial tokens such as "<" and "<think".
// By leveraging try_consume_regex()/try_find_regex() throwing
// common_chat_msg_partial_exception for these partial tokens,
// processing is interrupted and the tokens are not passed to add_content().
if (auto res = builder.try_consume_regex(start_think_regex)) {
// Restore position for try_parse_reasoning()
builder.move_to(res->groups[0].begin);
builder.try_find_regex(end_think_regex, std::string::npos, false);
// Restore position for try_parse_reasoning()
builder.move_to(res->groups[0].begin);
}
builder.try_parse_reasoning("<think>", "</think>");
// Parse response tags
static const common_regex start_response_regex(regex_escape("<response>"));
static const common_regex end_response_regex(regex_escape("</response>"));
// Granite models output partial tokens such as "<" and "<response".
// Same hack as reasoning parsing.
if (builder.try_consume_regex(start_response_regex)) {
builder.try_find_regex(end_response_regex);
}
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// Look for tool calls
static const common_regex tool_call_regex(regex_escape("<|tool_call|>"));
if (auto res = builder.try_find_regex(tool_call_regex)) {
builder.move_to(res->groups[0].end);
// Expect JSON array of tool calls
if (auto tool_call = builder.try_consume_json_with_dumped_args({{{"arguments"}}})) {
if (!builder.add_tool_calls(tool_call->value) || tool_call->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
}
} else {
builder.add_content(builder.consume_rest());
}
}
static void common_chat_parse_nemotron_v2(common_chat_msg_parser & builder) {
// Parse thinking tags
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// Look for tool calls
static const common_regex tool_call_regex(regex_escape("<TOOLCALL>"));
if (auto res = builder.try_find_regex(tool_call_regex)) {
builder.move_to(res->groups[0].end);
// Expect JSON array of tool calls
auto tool_calls_data = builder.consume_json();
if (tool_calls_data.json.is_array()) {
if (!builder.try_consume_literal("</TOOLCALL>")) {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
builder.add_tool_calls(tool_calls_data.json);
} else {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
}
builder.add_content(builder.consume_rest());
}
static void common_chat_parse_apertus(common_chat_msg_parser & builder) {
// Parse thinking tags
builder.try_parse_reasoning("<|inner_prefix|>", "<|inner_suffix|>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// Look for tool calls
static const common_regex tool_call_regex(regex_escape("<|tools_prefix|>"));
if (auto res = builder.try_find_regex(tool_call_regex)) {
builder.move_to(res->groups[0].end);
auto tool_calls_data = builder.consume_json();
if (tool_calls_data.json.is_array()) {
builder.consume_spaces();
if (!builder.try_consume_literal("<|tools_suffix|>")) {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
for (const auto & value : tool_calls_data.json) {
if (value.is_object()) {
builder.add_tool_call_short_form(value);
}
}
} else {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
}
builder.add_content(builder.consume_rest());
}
static void common_chat_parse_lfm2(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// LFM2 format: <|tool_call_start|>[{"name": "get_current_time", "arguments": {"location": "Paris"}}]<|tool_call_end|>
static const common_regex tool_call_start_regex(regex_escape("<|tool_call_start|>"));
static const common_regex tool_call_end_regex(regex_escape("<|tool_call_end|>"));
// Loop through all tool calls
while (auto res = builder.try_find_regex(tool_call_start_regex, std::string::npos, /* add_prelude_to_content= */ true)) {
builder.move_to(res->groups[0].end);
// Parse JSON array format: [{"name": "...", "arguments": {...}}]
auto tool_calls_data = builder.consume_json();
// Consume end marker
builder.consume_spaces();
if (!builder.try_consume_regex(tool_call_end_regex)) {
throw common_chat_msg_partial_exception("Expected <|tool_call_end|>");
}
// Process each tool call in the array
if (tool_calls_data.json.is_array()) {
for (const auto & tool_call : tool_calls_data.json) {
if (!tool_call.is_object()) {
throw common_chat_msg_partial_exception("Tool call must be an object");
}
if (!tool_call.contains("name")) {
throw common_chat_msg_partial_exception("Tool call missing 'name' field");
}
std::string function_name = tool_call.at("name");
std::string arguments = "{}";
if (tool_call.contains("arguments")) {
if (tool_call.at("arguments").is_object()) {
arguments = tool_call.at("arguments").dump();
} else if (tool_call.at("arguments").is_string()) {
arguments = tool_call.at("arguments");
}
}
if (!builder.add_tool_call(function_name, "", arguments)) {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
}
} else {
throw common_chat_msg_partial_exception("Expected JSON array for tool calls");
}
// Consume any trailing whitespace after this tool call
builder.consume_spaces();
}
// Consume any remaining content after all tool calls
auto remaining = builder.consume_rest();
if (!string_strip(remaining).empty()) {
builder.add_content(remaining);
}
}
static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
static const xml_tool_call_format form {
/* form.scope_start = */ "<seed:tool_call>",
/* form.tool_start = */ "<function=",
/* form.tool_sep = */ ">",
/* form.key_start = */ "<parameter=",
/* form.key_val_sep = */ ">",
/* form.val_end = */ "</parameter>",
/* form.tool_end = */ "</function>",
/* form.scope_end = */ "</seed:tool_call>",
};
builder.consume_reasoning_with_xml_tool_calls(form, "<seed:think>", "</seed:think>");
}
static void common_chat_parse_content_only(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<think>", "</think>");
builder.add_content(builder.consume_rest());
}
static void common_chat_parse(common_chat_msg_parser & builder) {
LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(builder.syntax().format), builder.input().c_str());
switch (builder.syntax().format) {
case COMMON_CHAT_FORMAT_CONTENT_ONLY:
common_chat_parse_content_only(builder);
break;
case COMMON_CHAT_FORMAT_GENERIC:
common_chat_parse_generic(builder);
break;
case COMMON_CHAT_FORMAT_MISTRAL_NEMO:
common_chat_parse_mistral_nemo(builder);
break;
case COMMON_CHAT_FORMAT_MAGISTRAL:
common_chat_parse_magistral(builder);
break;
case COMMON_CHAT_FORMAT_LLAMA_3_X:
common_chat_parse_llama_3_1(builder);
break;
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS:
common_chat_parse_llama_3_1(builder, /* with_builtin_tools= */ true);
break;
case COMMON_CHAT_FORMAT_DEEPSEEK_R1:
common_chat_parse_deepseek_r1(builder);
break;
case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1:
common_chat_parse_deepseek_v3_1(builder);
break;
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2:
common_chat_parse_functionary_v3_2(builder);
break;
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1:
common_chat_parse_functionary_v3_1_llama_3_1(builder);
break;
case COMMON_CHAT_FORMAT_HERMES_2_PRO:
common_chat_parse_hermes_2_pro(builder);
break;
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
common_chat_parse_firefunction_v2(builder);
break;
case COMMON_CHAT_FORMAT_COMMAND_R7B:
common_chat_parse_command_r7b(builder);
break;
case COMMON_CHAT_FORMAT_GRANITE:
common_chat_parse_granite(builder);
break;
case COMMON_CHAT_FORMAT_GPT_OSS:
common_chat_parse_gpt_oss(builder);
break;
case COMMON_CHAT_FORMAT_SEED_OSS:
common_chat_parse_seed_oss(builder);
break;
case COMMON_CHAT_FORMAT_NEMOTRON_V2:
common_chat_parse_nemotron_v2(builder);
break;
case COMMON_CHAT_FORMAT_APERTUS:
common_chat_parse_apertus(builder);
break;
case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS:
common_chat_parse_lfm2(builder);
break;
case COMMON_CHAT_FORMAT_MINIMAX_M2:
common_chat_parse_minimax_m2(builder);
break;
case COMMON_CHAT_FORMAT_GLM_4_5:
common_chat_parse_glm_4_5(builder);
break;
case COMMON_CHAT_FORMAT_KIMI_K2:
common_chat_parse_kimi_k2(builder);
break;
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML:
common_chat_parse_qwen3_coder_xml(builder);
break;
case COMMON_CHAT_FORMAT_APRIEL_1_5:
common_chat_parse_apriel_1_5(builder);
break;
case COMMON_CHAT_FORMAT_XIAOMI_MIMO:
common_chat_parse_xiaomi_mimo(builder);
break;
default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
}
builder.finish();
}
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
common_chat_msg_parser builder(input, is_partial, syntax);
try {
common_chat_parse(builder);
} catch (const common_chat_msg_partial_exception & ex) {
LOG_DBG("Partial parse: %s\n", ex.what());
if (!is_partial) {
builder.clear_tools();
builder.move_to(0);
common_chat_parse_content_only(builder);
}
}
auto msg = builder.result();
if (!is_partial) {
LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat<json>({msg}).at(0).dump().c_str());
}
return msg;
}

File diff suppressed because it is too large Load Diff

View File

@ -268,10 +268,10 @@ static bool is_reserved_name(const std::string & name) {
}
std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+");
std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"]");
std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"\\\\]");
std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]");
std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
{'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}
{'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}, {'\\', "\\\\"}
};
std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};

View File

@ -4183,6 +4183,36 @@ class Qwen3MoeModel(Qwen2MoeModel):
super().set_vocab()
@ModelBase.register("Qwen3NextForCausalLM")
class Qwen3NextModel(Qwen2MoeModel):
model_arch = gguf.MODEL_ARCH.QWEN3NEXT
def set_gguf_parameters(self):
super().set_gguf_parameters()
self.gguf_writer.add_ssm_conv_kernel(self.hparams["linear_conv_kernel_dim"])
self.gguf_writer.add_ssm_state_size(self.hparams["linear_key_head_dim"])
self.gguf_writer.add_ssm_group_count(self.hparams["linear_num_key_heads"])
self.gguf_writer.add_ssm_time_step_rank(self.hparams["linear_num_value_heads"])
self.gguf_writer.add_ssm_inner_size(self.hparams["linear_value_head_dim"] * self.hparams["linear_num_value_heads"])
if (rope_dim := self.hparams.get("head_dim")) is None:
rope_dim = self.hparams["hidden_size"] // self.hparams["num_attention_heads"]
self.gguf_writer.add_rope_dimension_count(int(rope_dim * self.hparams.get("partial_rotary_factor", 0.25)))
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
if name.startswith("mtp"):
return [] # ignore MTP layers for now
if name.endswith(".A_log"):
data_torch = -torch.exp(data_torch)
elif name.endswith(".dt_bias"):
name = name.rpartition(".dt_bias")[0] + ".dt_proj.bias"
elif "conv1d" in name:
data_torch = data_torch.squeeze()
elif name.endswith("norm.weight") and not name.endswith("linear_attn.norm.weight"):
data_torch = data_torch + 1
yield from super().modify_tensors(data_torch, name, bid)
@ModelBase.register("RND1")
class RND1Model(Qwen2MoeModel):
model_arch = gguf.MODEL_ARCH.RND1

View File

@ -42,6 +42,9 @@ The following releases are verified and recommended:
## News
- 2025.11
- Support malloc memory on device more than 4GB.
- 2025.2
- Optimize MUL_MAT Q4_0 on Intel GPU for all dGPUs and built-in GPUs since MTL. Increase the performance of LLM (llama-2-7b.Q4_0.gguf) 21%-87% on Intel GPUs (MTL, ARL-H, Arc, Flex, PVC).
|GPU|Base tokens/s|Increased tokens/s|Percent|
@ -789,6 +792,8 @@ use 1 SYCL GPUs: [0] with Max compute units:512
| GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because graph performance isn't yet better than non-graph performance. |
| GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. |
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |
| UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS | 0 (default) or 1 | Support malloc device memory more than 4GB.|
## Known Issues
@ -835,6 +840,14 @@ use 1 SYCL GPUs: [0] with Max compute units:512
| The default context is too big. It leads to excessive memory usage.|Set `-c 8192` or a smaller value.|
| The model is too big and requires more memory than what is available.|Choose a smaller model or change to a smaller quantization, like Q5 -> Q4;<br>Alternatively, use more than one device to load model.|
- `ggml_backend_sycl_buffer_type_alloc_buffer: can't allocate 5000000000 Bytes of memory on device`
You need to enable to support 4GB memory malloc by:
```
export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
```
### **GitHub contribution**:
Please add the `SYCL :` prefix/tag in issues/PRs titles to help the SYCL contributors to check/address them without delay.

View File

@ -104,12 +104,16 @@ int main(int argc, char ** argv) {
params.embedding = true;
// get max number of sequences per batch
const int n_seq_max = llama_max_parallel_sequences();
// if the number of prompts that would be encoded is known in advance, it's more efficient to specify the
// --parallel argument accordingly. for convenience, if not specified, we fallback to unified KV cache
// in order to support any number of prompts
if (params.n_parallel == 1) {
LOG_INF("%s: n_parallel == 1 -> unified KV cache is enabled\n", __func__);
params.kv_unified = true;
params.n_parallel = n_seq_max;
}
// utilize the full context
@ -123,9 +127,6 @@ int main(int argc, char ** argv) {
params.n_ubatch = params.n_batch;
}
// get max number of sequences per batch
const int n_seq_max = llama_max_parallel_sequences();
llama_backend_init();
llama_numa_init(params.numa);

View File

@ -231,9 +231,9 @@ DOT = '[^\\x0A\\x0D]'
RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()])
INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+')
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]')
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\\]')
GRAMMAR_RANGE_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\]\-\\]')
GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]'}
GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]', '\\': '\\\\'}
NON_LITERAL_SET = set('|.()[]{}*+?')
ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('^$.[]()|{}*+?')

View File

@ -4,6 +4,11 @@ set -e
# First try command line argument, then environment variable, then file
CONVERTED_MODEL="${1:-"$CONVERTED_MODEL"}"
MODEL_TESTING_PROMPT="${2:-"$MODEL_TESTING_PROMPT"}"
if [ -z "$MODEL_TESTING_PROMPT"]; then
MODEL_TESTING_PROMPT="Hello, my name is"
fi
# Final check if we have a model path
if [ -z "$CONVERTED_MODEL" ]; then
@ -14,7 +19,8 @@ if [ -z "$CONVERTED_MODEL" ]; then
fi
echo $CONVERTED_MODEL
echo $MODEL_TESTING_PROMPT
cmake --build ../../build --target llama-logits -j8
../../build/bin/llama-logits -m "$CONVERTED_MODEL" "Hello, my name is"
../../build/bin/llama-logits -m "$CONVERTED_MODEL" "$MODEL_TESTING_PROMPT"

View File

@ -184,8 +184,12 @@ model_name = os.path.basename(model_path)
# of using AutoModelForCausalLM.
print(f"Model class: {model.__class__.__name__}")
prompt = "Hello, my name is"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
device = next(model.parameters()).device
if os.getenv("MODEL_TESTING_PROMPT"):
prompt = os.getenv("MODEL_TESTING_PROMPT")
else:
prompt = "Hello, my name is"
input_ids = tokenizer(prompt, return_tensors="pt").input_ids.to(device)
print(f"Input tokens: {input_ids}")
print(f"Input text: {repr(prompt)}")

View File

@ -15,6 +15,9 @@ MODEL_FILE=models/llama-2-7b.Q4_0.gguf
NGL=99
CONTEXT=4096
#support malloc device memory more than 4GB.
export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
if [ $# -gt 0 ]; then
GGML_SYCL_DEVICE=$1
echo "use $GGML_SYCL_DEVICE as main GPU"

View File

@ -6,7 +6,7 @@
# If you want more control, DPC++ Allows selecting a specific device through the
# following environment variable
#export ONEAPI_DEVICE_SELECTOR="level_zero:0"
export ONEAPI_DEVICE_SELECTOR="level_zero:0"
source /opt/intel/oneapi/setvars.sh
#export GGML_SYCL_DEBUG=1
@ -18,11 +18,14 @@ MODEL_FILE=models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf
NGL=99 # Layers offloaded to the GPU. If the device runs out of memory, reduce this value according to the model you are using.
CONTEXT=4096
#support malloc device memory more than 4GB.
export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
if [ $# -gt 0 ]; then
GGML_SYCL_DEVICE=$1
echo "Using $GGML_SYCL_DEVICE as the main GPU"
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none
else
#use multiple GPUs with same max compute units
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -c ${CONTEXT}
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT}
fi

View File

@ -5,5 +5,7 @@
set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:"
@call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force
:: support malloc device memory more than 4GB.
set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
.\build\bin\llama-cli.exe -m models\llama-2-7b.Q4_0.gguf -p %INPUT2% -n 400 -e -ngl 99 -s 0

View File

@ -5,5 +5,7 @@
set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:"
@call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force
:: support malloc device memory more than 4GB.
set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
.\build\bin\llama-cli.exe -m models\Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -p %INPUT2% -n 400 -e -ngl 99
.\build\bin\llama-cli.exe -m models\Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -p %INPUT2% -n 400 -s 0 -e -ngl 99

View File

@ -183,6 +183,7 @@ endif()
# ggml core
set(GGML_SCHED_MAX_COPIES "4" CACHE STRING "ggml: max input copies for pipeline parallelism")
option(GGML_CPU "ggml: enable CPU backend" ON)
option(GGML_SCHED_NO_REALLOC "ggml: disallow reallocations in ggml-alloc (for debugging)" OFF)
# 3rd party libs / backends
option(GGML_ACCELERATE "ggml: enable Accelerate framework" ON)

View File

@ -8,7 +8,7 @@ extern "C" {
#endif
#define RPC_PROTO_MAJOR_VERSION 3
#define RPC_PROTO_MINOR_VERSION 0
#define RPC_PROTO_MINOR_VERSION 5
#define RPC_PROTO_PATCH_VERSION 0
#define GGML_RPC_MAX_SERVERS 16

View File

@ -221,6 +221,10 @@ if (GGML_BACKEND_DL)
target_compile_definitions(ggml-base PUBLIC GGML_BACKEND_DL)
endif()
if (GGML_SCHED_NO_REALLOC)
target_compile_definitions(ggml-base PUBLIC GGML_SCHED_NO_REALLOC)
endif()
add_library(ggml
ggml-backend-reg.cpp)
add_library(ggml::ggml ALIAS ggml)

View File

@ -921,10 +921,15 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
}
if (realloc) {
#ifndef NDEBUG
size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0;
GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
{
size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0;
if (cur_size > 0) {
GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n",
__func__, ggml_backend_buft_name(galloc->bufts[i]),
cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
}
}
#endif
ggml_vbuffer_free(galloc->buffers[i]);
galloc->buffers[i] = ggml_vbuffer_alloc(galloc->bufts[i], galloc->buf_tallocs[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE);
if (galloc->buffers[i] == NULL) {

View File

@ -1395,14 +1395,20 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
// allocate graph
if (backend_ids_changed || !ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) {
#ifdef GGML_SCHED_NO_REALLOC
GGML_ABORT("%s: failed to allocate graph, but graph re-allocation is disabled by GGML_SCHED_NO_REALLOC\n", __func__);
#endif
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed);
#endif
// the re-allocation may cause the split inputs to be moved to a different address
// synchronize without ggml_backend_sched_synchronize to avoid changing cur_copy
for (int i = 0; i < sched->n_backends; i++) {
ggml_backend_synchronize(sched->backends[i]);
}
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed);
#endif
ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids);
if (!ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) {
GGML_LOG_ERROR("%s: failed to allocate graph\n", __func__);

View File

@ -33,10 +33,12 @@
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
@ -44,12 +46,14 @@
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#define ggml_gemm_iq4_nl_8x8_q8_0_generic ggml_gemm_iq4_nl_8x8_q8_0
#elif defined(__aarch64__) || defined(__arm__) || defined(_M_ARM) || defined(_M_ARM64)
// repack.cpp
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
@ -58,11 +62,14 @@
#elif defined(__x86_64__) || defined(__i386__) || defined(_M_IX86) || defined(_M_X64)
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
#elif defined(__POWERPC__) || defined(__powerpc__)
// ref: https://github.com/ggml-org/llama.cpp/pull/14146#issuecomment-2972561679
@ -74,10 +81,12 @@
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
@ -85,6 +94,7 @@
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
@ -99,10 +109,12 @@
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
@ -110,6 +122,7 @@
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
@ -132,15 +145,18 @@
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
#define ggml_gemv_iq4_nl_8x8_q8_0_generic ggml_gemv_iq4_nl_8x8_q8_0
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
@ -161,10 +177,12 @@
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
@ -172,6 +190,7 @@
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0
@ -194,10 +213,12 @@
// repack.cpp
#define ggml_quantize_mat_q8_0_4x4_generic ggml_quantize_mat_q8_0_4x4
#define ggml_quantize_mat_q8_0_4x8_generic ggml_quantize_mat_q8_0_4x8
#define ggml_quantize_mat_q8_K_4x4_generic ggml_quantize_mat_q8_K_4x4
#define ggml_quantize_mat_q8_K_4x8_generic ggml_quantize_mat_q8_K_4x8
#define ggml_gemv_q4_0_4x4_q8_0_generic ggml_gemv_q4_0_4x4_q8_0
#define ggml_gemv_q4_0_4x8_q8_0_generic ggml_gemv_q4_0_4x8_q8_0
#define ggml_gemv_q4_0_8x8_q8_0_generic ggml_gemv_q4_0_8x8_q8_0
#define ggml_gemv_q4_K_8x4_q8_K_generic ggml_gemv_q4_K_8x4_q8_K
#define ggml_gemv_q4_K_8x8_q8_K_generic ggml_gemv_q4_K_8x8_q8_K
#define ggml_gemv_q2_K_8x8_q8_K_generic ggml_gemv_q2_K_8x8_q8_K
#define ggml_gemv_iq4_nl_4x4_q8_0_generic ggml_gemv_iq4_nl_4x4_q8_0
@ -205,6 +226,7 @@
#define ggml_gemm_q4_0_4x4_q8_0_generic ggml_gemm_q4_0_4x4_q8_0
#define ggml_gemm_q4_0_4x8_q8_0_generic ggml_gemm_q4_0_4x8_q8_0
#define ggml_gemm_q4_0_8x8_q8_0_generic ggml_gemm_q4_0_8x8_q8_0
#define ggml_gemm_q4_K_8x4_q8_K_generic ggml_gemm_q4_K_8x4_q8_K
#define ggml_gemm_q4_K_8x8_q8_K_generic ggml_gemm_q4_K_8x8_q8_K
#define ggml_gemm_q2_K_8x8_q8_K_generic ggml_gemm_q2_K_8x8_q8_K
#define ggml_gemm_iq4_nl_4x4_q8_0_generic ggml_gemm_iq4_nl_4x4_q8_0

View File

@ -497,6 +497,140 @@ void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
ggml_gemv_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
constexpr int qk = QK_K;
const int nb = n / qk;
constexpr int ncols_interleaved = 8;
constexpr int blocklen = 8;
assert(n % qk == 0);
assert(nr % 4 == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
constexpr int col_groups = ncols_interleaved / 4; // 0123 and 4567
const uint8x16_t m4b = vdupq_n_u8(0x0f);
// 1x8 tile = 2 x 4
float32x4_t acc_f32[col_groups];
const block_q8_K * GGML_RESTRICT q8_ptr = (const block_q8_K *) vy;
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
for (int i = 0; i < col_groups; i++) {
acc_f32[i] = vdupq_n_f32(0);
}
for (int b = 0; b < nb; b++) {
float32x4_t q4_d_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d)); // d0 d1 d2 d3
float32x4_t q4_d_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4)); // d4 d5 d6 d7
float32x4_t q8_d = vdupq_n_f32(q8_ptr[b].d);
float32x4_t sb_scale_0123 = vmulq_f32(q4_d_0, q8_d);
float32x4_t sb_scale_4567 = vmulq_f32(q4_d_1, q8_d);
float32x4_t q4_dmin_0 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin)); // dmin 0..3
float32x4_t q4_dmin_1 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4)); // dmin 4..7
float32x4_t sb_min_0123 = vmulq_f32(q4_dmin_0, q8_d);
float32x4_t sb_min_4567 = vmulq_f32(q4_dmin_1, q8_d);
// interleaved bias_acc: [0]->r0 0123, [1]->r0 4567
int32x4_t bias_acc[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
int32x4_t acc_lo[col_groups];
int32x4_t acc_hi[col_groups];
// Each bsum is 16 elements, pairwise add leaves us with the 8 bsums of the entire block
const int16x8_t bsums = vpaddq_s16(vld1q_s16(q8_ptr[b].bsums), vld1q_s16(q8_ptr[b].bsums + 8));
int16_t bsums_arr[8];
vst1q_s16(bsums_arr, bsums);
for (int sb = 0; sb < QK_K / 64; sb++) {
for (int i = 0; i < col_groups; i++) {
acc_lo[i] = vdupq_n_s32(0);
acc_hi[i] = vdupq_n_s32(0);
}
// Need scales for the low and high nibbles
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
int16x8_t q4sb_mins[2];
int16x8_t q4sb_scales[2];
for (int i = 0; i < 2; i++) {
int8_t aux_q4sb[8];
const int offset = sb * 24 + i * 12;
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
}
int8x16_t q8_qs[64 / 16];
for (int i = 0; i < 64 / 16; i++) {
q8_qs[i] = vld1q_s8(q8_ptr[b].qs + sb * 64 + i * 16);
}
for (int c = 0; c < col_groups; c++) {
uint8x16_t q4_cols[8];
for (int i = 0; i < 8; i++) {
q4_cols[i] = vld1q_u8(q4_ptr[b].qs + sb * QK_K + i * 32 + 16 * c);
}
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[0], m4b)), q8_qs[0], 0);
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[1], m4b)), q8_qs[0], 1);
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[2], m4b)), q8_qs[0], 2);
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[3], m4b)), q8_qs[0], 3);
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[4], m4b)), q8_qs[1], 0);
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[5], m4b)), q8_qs[1], 1);
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[6], m4b)), q8_qs[1], 2);
acc_lo[c] = vdotq_laneq_s32(acc_lo[c], vreinterpretq_s8_u8(vandq_u8(q4_cols[7], m4b)), q8_qs[1], 3);
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[0], 4)), q8_qs[2], 0);
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[1], 4)), q8_qs[2], 1);
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[2], 4)), q8_qs[2], 2);
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[3], 4)), q8_qs[2], 3);
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[4], 4)), q8_qs[3], 0);
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[5], 4)), q8_qs[3], 1);
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[6], 4)), q8_qs[3], 2);
acc_hi[c] = vdotq_laneq_s32(acc_hi[c], vreinterpretq_s8_u8(vshrq_n_u8(q4_cols[7], 4)), q8_qs[3], 3);
}
// Scales
// row c0123 blk0 and blk1
const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]);
const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]);
const float32x4_t sumf_0123 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[0]),
vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[0])));
acc_f32[0] = vfmaq_f32(acc_f32[0], sb_scale_0123, sumf_0123);
// row c4567 blk0 and blk1
const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]);
const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]);
const float32x4_t sumf_4567 = vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[1]),
vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[1])));
acc_f32[1] = vfmaq_f32(acc_f32[1], sb_scale_4567, sumf_4567);
// Bias Correction
const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[2 * sb + 0]);
const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[2 * sb + 1]);
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
bias_acc[0] = vmlal_s16(bias_acc[0], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
bias_acc[1] = vmlal_s16(bias_acc[1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
} // for sb
acc_f32[0] = vmlsq_f32(acc_f32[0], vcvtq_f32_s32(bias_acc[0]), sb_min_0123);
acc_f32[1] = vmlsq_f32(acc_f32[1], vcvtq_f32_s32(bias_acc[1]), sb_min_4567);
} // for b
int base = x * ncols_interleaved;
vst1q_f32(s + base, acc_f32[0]);
vst1q_f32(s + base + 4, acc_f32[1]);
} // for x
return;
#endif // #if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
ggml_gemv_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemv_q4_K_8x8_q8_K(int n,
float * GGML_RESTRICT s,
size_t bs,
@ -518,7 +652,7 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
UNUSED(ncols_interleaved);
UNUSED(blocklen);
#if defined(__aarch64__) && defined(__ARM_NEON)
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
constexpr int col_pairs = ncols_interleaved / 2;
const uint8x16_t m4b = vdupq_n_u8(0x0f);
@ -615,7 +749,6 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
float32x4_t sb_scale = p == 0 ? sb_scale_0 : sb_scale_1;
// 0123 or 4567
// TODO: Single superblock mul at the end of the superblock
float32x4_t sumf_0 =
vcvtq_f32_s32(vmulq_s32(vmovl_s16(group_scales_lo), vpaddq_s32(acc_lo[p], acc_lo[p + 1])));
acc_f32[i] = vfmaq_f32(acc_f32[i], sb_scale, sumf_0);
@ -649,7 +782,7 @@ void ggml_gemv_q4_K_8x8_q8_K(int n,
vst1q_f32(s + base + 4, acc_f32[1]);
} // for x
return;
#endif // defined(__aarch64__) && defined(__ARM_NEON)
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
ggml_gemv_q4_K_8x8_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
@ -2069,6 +2202,206 @@ void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
ggml_gemm_iq4_nl_4x4_q8_0_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
constexpr int qk = QK_K;
const int nb = n / qk;
constexpr int ncols_interleaved = 8;
constexpr int blocklen = 4;
assert(n % qk == 0);
assert(nr % 4 == 0);
assert(nc % ncols_interleaved == 0);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
#if defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
constexpr int q8_k_blocklen = 4;
constexpr int acc_size = 2 * 4; // 2 row pairs × 4 col pairs
const uint8x16_t m4b = vdupq_n_u8(0x0f);
// 8 accumulators: 2 row pairs × 4 col pairs
float32x4_t acc_f32[acc_size];
for (int y = 0; y < nr / q8_k_blocklen; y++) {
const block_q8_Kx4 * GGML_RESTRICT q8_ptr = (const block_q8_Kx4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q4_Kx8 * GGML_RESTRICT q4_ptr = (const block_q4_Kx8 *) vx + (x * nb);
for (int i = 0; i < acc_size; i++) {
acc_f32[i] = vdupq_n_f32(0);
}
for (int b = 0; b < nb; b++) {
// d4 0 1 2 3, 4 5 6 7
float32x4_t q4_d_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d));
float32x4_t q4_d_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].d + 4));
// d8 0 1 2 3
float32x4_t q8_d_0123 = vld1q_f32(q8_ptr[b].d);
// mins
float32x4_t q4_dmin_0123 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin));
float32x4_t q4_dmin_4567 = vcvt_f32_f16(vld1_f16((const __fp16 *) q4_ptr[b].dmin + 4));
// Precomputation of scales and mins
float32x4_t sbd_scale_0123[q8_k_blocklen];
float32x4_t sbd_scale_4567[q8_k_blocklen];
float32x4_t sbd_min_0123[q8_k_blocklen];
float32x4_t sbd_min_4567[q8_k_blocklen];
sbd_scale_0123[0] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 0);
sbd_scale_4567[0] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 0);
sbd_min_0123[0] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 0);
sbd_min_4567[0] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 0);
sbd_scale_0123[1] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 1);
sbd_scale_4567[1] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 1);
sbd_min_0123[1] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 1);
sbd_min_4567[1] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 1);
sbd_scale_0123[2] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 2);
sbd_scale_4567[2] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 2);
sbd_min_0123[2] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 2);
sbd_min_4567[2] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 2);
sbd_scale_0123[3] = vmulq_laneq_f32(q4_d_0123, q8_d_0123, 3);
sbd_scale_4567[3] = vmulq_laneq_f32(q4_d_4567, q8_d_0123, 3);
sbd_min_0123[3] = vmulq_laneq_f32(q4_dmin_0123, q8_d_0123, 3);
sbd_min_4567[3] = vmulq_laneq_f32(q4_dmin_4567, q8_d_0123, 3);
// Precomputation of bsums, each vpaddq calcs all the bsums for each row
const int16x8_t bsums[q8_k_blocklen] = {
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 0), vld1q_s16(q8_ptr[b].bsums + 16 * 0 + 8)),
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 1), vld1q_s16(q8_ptr[b].bsums + 16 * 1 + 8)),
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 2), vld1q_s16(q8_ptr[b].bsums + 16 * 2 + 8)),
vpaddq_s16(vld1q_s16(q8_ptr[b].bsums + 16 * 3), vld1q_s16(q8_ptr[b].bsums + 16 * 3 + 8)),
};
int16_t bsums_arr[QK_K / 64][8];
for (int q8_row = 0; q8_row < 4; q8_row++) {
vst1q_s16(bsums_arr[q8_row], bsums[q8_row]);
}
// interleaved bias_acc: [0]->r0 0123, [1]->r1 0123, .., [4]->r0 4567, [5]->r1 4567 ..
int32x4_t bias_acc[acc_size];
for (int i = 0; i < acc_size; i++) {
bias_acc[i] = vdupq_n_s32(0);
}
for (int sb = 0; sb < QK_K / 64; sb++) {
// Int accumulators for qs vecdot (4 row x 2 col quartets)
int32x4_t acc_lo[acc_size];
int32x4_t acc_hi[acc_size];
for (int i = 0; i < acc_size; i++) {
acc_lo[i] = vdupq_n_s32(0);
acc_hi[i] = vdupq_n_s32(0);
}
// Need scales for the low and high nibbles
// 2 * 12 = 24 bytes per subblock, 4 sbs -> 4 * 24 = 96 bytes total
int16x8_t q4sb_scales[2];
int16x8_t q4sb_mins[2];
for (int i = 0; i < 2; i++) {
int8_t aux_q4sb[8];
const int offset = sb * 24 + i * 12;
decode_q4_Kx8_scales_mins(&q4_ptr[b].scales[offset], &q4sb_mins[i], aux_q4sb);
q4sb_scales[i] = vmovl_s8(vld1_s8(aux_q4sb));
}
constexpr int reads_per_sb = 8; // 8 * 16 bytes each => 32 qs * 4 rows
for (int k = 0; k < reads_per_sb; k++) {
const int8x16_t q8_blk0 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k);
const int8x16_t q8_blk1 = vld1q_s8(q8_ptr[b].qs + sb * 256 + 16 * k + 128);
// 0..3 & 32..35
const uint8x16_t q4_0123 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k);
const uint8x16_t q4_4567 = vld1q_u8(q4_ptr[b].qs + sb * QK_K + 32 * k + 16);
const int8x16_t q4_0123_lo = vreinterpretq_s8_u8(vandq_u8(q4_0123, m4b));
const int8x16_t q4_0123_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_0123, 4));
acc_lo[0] = vdotq_laneq_s32(acc_lo[0], q4_0123_lo, q8_blk0, 0); // 0..3 r0 c0123
acc_lo[1] = vdotq_laneq_s32(acc_lo[1], q4_0123_lo, q8_blk0, 1); // 0..3 r1 c0123
acc_lo[2] = vdotq_laneq_s32(acc_lo[2], q4_0123_lo, q8_blk0, 2); // 0..3 r2 c0123
acc_lo[3] = vdotq_laneq_s32(acc_lo[3], q4_0123_lo, q8_blk0, 3); // 0..3 r3 c0123
acc_hi[0] = vdotq_laneq_s32(acc_hi[0], q4_0123_hi, q8_blk1, 0); // 32..35 r0 c0123
acc_hi[1] = vdotq_laneq_s32(acc_hi[1], q4_0123_hi, q8_blk1, 1); // 32..35 r1 c0123
acc_hi[2] = vdotq_laneq_s32(acc_hi[2], q4_0123_hi, q8_blk1, 2); // 32..35 r2 c0123
acc_hi[3] = vdotq_laneq_s32(acc_hi[3], q4_0123_hi, q8_blk1, 3); // 32..35 r3 c0123
const int8x16_t q4_4567_lo = vreinterpretq_s8_u8(vandq_u8(q4_4567, m4b));
const int8x16_t q4_4567_hi = vreinterpretq_s8_u8(vshrq_n_u8(q4_4567, 4));
acc_lo[4] = vdotq_laneq_s32(acc_lo[4], q4_4567_lo, q8_blk0, 0); // 0..3 r0 c4567
acc_lo[5] = vdotq_laneq_s32(acc_lo[5], q4_4567_lo, q8_blk0, 1); // 0..3 r1 c4567
acc_lo[6] = vdotq_laneq_s32(acc_lo[6], q4_4567_lo, q8_blk0, 2); // 0..3 r2 c4567
acc_lo[7] = vdotq_laneq_s32(acc_lo[7], q4_4567_lo, q8_blk0, 3); // 0..3 r3 c4567
acc_hi[4] = vdotq_laneq_s32(acc_hi[4], q4_4567_hi, q8_blk1, 0); // 32..35 r0 c4567
acc_hi[5] = vdotq_laneq_s32(acc_hi[5], q4_4567_hi, q8_blk1, 1); // 32..35 r1 c4567
acc_hi[6] = vdotq_laneq_s32(acc_hi[6], q4_4567_hi, q8_blk1, 2); // 32..35 r2 c4567
acc_hi[7] = vdotq_laneq_s32(acc_hi[7], q4_4567_hi, q8_blk1, 3); // 32..35 r3 c4567
}
// Scale and bias application
// acc is stored interleaved to match output layout
const int16x4_t sc_0123_lo = vget_low_s16(q4sb_scales[0]);
const int16x4_t sc_4567_lo = vget_high_s16(q4sb_scales[0]);
const int16x4_t sc_0123_hi = vget_low_s16(q4sb_scales[1]);
const int16x4_t sc_4567_hi = vget_high_s16(q4sb_scales[1]);
for (int row = 0; row < q8_k_blocklen; row++) {
// Bias correction
// row c0123 blk0 and blk1
const float32x4_t sumf_0123 =
vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_0123_lo), acc_lo[row]),
vmulq_s32(vmovl_s16(sc_0123_hi), acc_hi[row])));
acc_f32[2 * row] = vfmaq_f32(acc_f32[2 * row], sbd_scale_0123[row], sumf_0123);
// row c4567 blk0 and blk1
const float32x4_t sumf_4567 =
vcvtq_f32_s32(vaddq_s32(vmulq_s32(vmovl_s16(sc_4567_lo), acc_lo[row + 4]),
vmulq_s32(vmovl_s16(sc_4567_hi), acc_hi[row + 4])));
acc_f32[2 * row + 1] = vfmaq_f32(acc_f32[2 * row + 1], sbd_scale_4567[row], sumf_4567);
// Bias
const int16x4_t bsums_vec_lo = vdup_n_s16(bsums_arr[sb][row * 2]);
const int16x4_t bsums_vec_hi = vdup_n_s16(bsums_arr[sb][row * 2 + 1]);
// row c0123 blk0 and blk1
bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_lo, vget_low_s16(q4sb_mins[0]));
bias_acc[2 * row] = vmlal_s16(bias_acc[2 * row], bsums_vec_hi, vget_low_s16(q4sb_mins[1]));
// row c4567 blk0 and blk1
bias_acc[2 * row + 1] =
vmlal_s16(bias_acc[2 * row + 1], bsums_vec_lo, vget_high_s16(q4sb_mins[0]));
bias_acc[2 * row + 1] =
vmlal_s16(bias_acc[2 * row + 1], bsums_vec_hi, vget_high_s16(q4sb_mins[1]));
}
} // for sb
for (int row = 0; row < q8_k_blocklen; row++) {
acc_f32[2 * row] = vmlsq_f32(acc_f32[2 * row], vcvtq_f32_s32(bias_acc[2 * row]), sbd_min_0123[row]);
acc_f32[2 * row + 1] =
vmlsq_f32(acc_f32[2 * row + 1], vcvtq_f32_s32(bias_acc[2 * row + 1]), sbd_min_4567[row]);
}
} // for b
for (int i = 0; i < q8_k_blocklen; i++) {
int row = y * q8_k_blocklen + i;
for (int j = 0; j < 2; j++) {
int col = x * ncols_interleaved + j * 4;
int offset = row * bs + col;
vst1q_f32(s + offset, acc_f32[2 * i + j]);
}
}
} // for x
} // for y
return;
#endif // defined(__aarch64__) && defined(__ARM_NEON) && defined(__ARM_FEATURE_DOTPROD)
ggml_gemm_q4_K_8x4_q8_K_generic(n, s, bs, vx, vy, nr, nc);
}
void ggml_gemm_q4_K_8x8_q8_K(int n,
float * GGML_RESTRICT s,
size_t bs,

View File

@ -1,20 +1,23 @@
#include "ggml-backend-impl.h"
#if defined(__riscv) && __riscv_xlen == 64
#include <sys/auxv.h>
//https://github.com/torvalds/linux/blob/master/arch/riscv/include/uapi/asm/hwcap.h#L24
#ifndef COMPAT_HWCAP_ISA_V
#define COMPAT_HWCAP_ISA_V (1 << ('V' - 'A'))
#endif
#include <asm/hwprobe.h>
#include <asm/unistd.h>
#include <unistd.h>
struct riscv64_features {
bool has_rvv = false;
riscv64_features() {
uint32_t hwcap = getauxval(AT_HWCAP);
struct riscv_hwprobe probe;
probe.key = RISCV_HWPROBE_KEY_IMA_EXT_0;
probe.value = 0;
has_rvv = !!(hwcap & COMPAT_HWCAP_ISA_V);
int ret = syscall(__NR_riscv_hwprobe, &probe, 1, 0, NULL, 0);
if (0 == ret) {
has_rvv = !!(probe.value & RISCV_HWPROBE_IMA_V);
}
}
};

View File

@ -9766,7 +9766,8 @@ static void ggml_compute_forward_solve_tri_f32(const struct ggml_compute_params
}
const float diag = A_batch[i00 * n + i00];
GGML_ASSERT(diag != 0.0f && "Zero diagonal in triangular matrix");
assert(diag != 0.0f && "Zero diagonal in triangular matrix");
X_batch[i00 * k + i01] = (B_batch[i00 * k + i01] - sum) / diag;
}
}

View File

@ -124,6 +124,58 @@ void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GG
}
}
void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
assert(QK_K == 256);
assert(k % QK_K == 0);
const int nb = k / QK_K;
block_q8_Kx4 * GGML_RESTRICT y = (block_q8_Kx4 *) vy;
// scalar
const int blck_size_interleave = 4;
float srcv[4][QK_K];
float iscale[4];
for (int i = 0; i < nb; i++) {
for (int row_iter = 0; row_iter < 4; row_iter++) {
float amax = 0.0f; // absolute max
float max = 0;
for (int j = 0; j < QK_K; j++) {
srcv[row_iter][j] = x[row_iter * k + i * QK_K + j];
// Update the maximum value of the corresponding super block
if(amax < fabsf(srcv[row_iter][j])) {
amax = fabsf(srcv[row_iter][j]);
max = srcv[row_iter][j];
}
}
iscale[row_iter] = amax ? -127.f/max : 0;
y[i].d[row_iter] = amax ? 1/iscale[row_iter] : 0;
}
for (int j = 0; j < QK_K / 4; j++) {
y[i].bsums[j] = 0;
}
// Quants values are interleaved in sequence of four bytes from corresponding super blocks
// Bsums values are interleaved in sequence of four bsums from each super block taken for interleaving
// i.e first four bsums from the first super block, followed by first four bsums from second super block and so on
for (int j = 0; j < QK_K * 4; j++) {
int src_offset = (j / (4 * blck_size_interleave)) * blck_size_interleave;
int src_id = (j % (4 * blck_size_interleave)) / blck_size_interleave;
src_offset += (j % blck_size_interleave);
int index = (((j & 15) >> 2) << 2) + ((j >> 8) << 4) + ((j >> 6) & 3);
float x0 = srcv[src_id][src_offset] * iscale[src_id];
y[i].qs[j] = nearest_int(x0);
y[i].bsums[index] += y[i].qs[j];
}
}
}
void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k) {
assert(QK_K == 256);
assert(k % QK_K == 0);
@ -192,6 +244,12 @@ template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_0>(const float * GGML_RESTR
ggml_quantize_mat_q8_0_4x8(x, vy, n_per_row);
}
template <> void ggml_quantize_mat_t<4, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
assert(nrow == 4);
UNUSED(nrow);
ggml_quantize_mat_q8_K_4x4(x, vy, n_per_row);
}
template <> void ggml_quantize_mat_t<8, GGML_TYPE_Q8_K>(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t nrow, int64_t n_per_row) {
assert(nrow == 4);
UNUSED(nrow);
@ -333,6 +391,77 @@ void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
}
}
void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK_K;
const int nb = n / qk;
const int ncols_interleaved = 8;
const int blocklen = 4;
static const uint32_t kmask1 = 0x3f3f3f3f;
static const uint32_t kmask2 = 0x0f0f0f0f;
static const uint32_t kmask3 = 0x03030303;
assert (n % qk == 0);
assert (nc % ncols_interleaved == 0);
UNUSED(bs);
UNUSED(nr);
float sumf[8];
float sum_minf[8];
uint32_t utmp[32];
int sumi1;
int sumi2;
int sumi;
const block_q8_K * a_ptr = (const block_q8_K *) vy;
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
for (int j = 0; j < ncols_interleaved; j++) {
sumf[j] = 0.0;
sum_minf[j] = 0.0;
}
for (int l = 0; l < nb; l++) {
for (int sb = 0; sb < 8; sb++) {
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
utmp[sb * 4 + 2] = uaux_0;
utmp[sb * 4 + 0] &= kmask1;
}
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
uint8_t * scales_0 = (uint8_t *) utmp + (k / 8) * 32;
uint8_t * scales_1 = (uint8_t *) utmp + (k / 8) * 32 + 16;
for (int j = 0; j < ncols_interleaved; j++) {
sumi1 = 0;
sumi2 = 0;
sumi = 0;
for (int i = 0; i < blocklen; ++i) {
const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
sumi1 = (v0 * a_ptr[l].qs[(k / 8) * 64 + (k % 8) * blocklen + i]);
sumi2 = (v1 * a_ptr[l].qs[(k / 8) * 64 + (k % 8) * blocklen + i + 32]);
sumi1 = sumi1 * scales_0[j];
sumi2 = sumi2 * scales_1[j];
sumi += sumi1 + sumi2;
}
sumf[j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d;
}
}
for (int sb = 0; sb < 8; sb++) {
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
for (int j = 0; j < ncols_interleaved; j++) {
sum_minf[j] += mins[j] * (a_ptr[l].bsums[sb * 2] + a_ptr[l].bsums[sb * 2 + 1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d;
}
}
}
for (int j = 0; j < ncols_interleaved; j++) {
s[x * ncols_interleaved + j] = sumf[j] - sum_minf[j];
}
}
}
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK_K;
const int nb = n / qk;
@ -727,6 +856,89 @@ void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs,
}
}
void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK_K;
const int nb = n / qk;
const int ncols_interleaved = 8;
const int blocklen = 4;
static const uint32_t kmask1 = 0x3f3f3f3f;
static const uint32_t kmask2 = 0x0f0f0f0f;
static const uint32_t kmask3 = 0x03030303;
assert (n % qk == 0);
assert (nr % 4 == 0);
assert (nc % ncols_interleaved == 0);
UNUSED(nb);
UNUSED(ncols_interleaved);
UNUSED(blocklen);
float sumf[4][8];
float sum_minf[4][8];
uint32_t utmp[32];
int sumi1;
int sumi2;
int sumi;
for (int y = 0; y < nr / 4; y++) {
const block_q8_Kx4 * a_ptr = (const block_q8_Kx4 *) vy + (y * nb);
for (int x = 0; x < nc / ncols_interleaved; x++) {
const block_q4_Kx8 * b_ptr = (const block_q4_Kx8 *) vx + (x * nb);
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumf[m][j] = 0.0;
sum_minf[m][j] = 0.0;
}
}
for (int l = 0; l < nb; l++) {
for (int sb = 0; sb < 8; sb++) {
memcpy(utmp + sb * 4, b_ptr[l].scales + sb * 12, 12);
utmp[sb * 4 + 3] = ((utmp[sb * 4 + 2] >> 4) & kmask2) | (((utmp[sb * 4 + 1] >> 6) & kmask3) << 4);
const uint32_t uaux_0 = utmp[sb * 4 + 1] & kmask1;
utmp[sb * 4 + 1] = (utmp[sb * 4 + 2] & kmask2) | (((utmp[sb * 4 + 0] >> 6) & kmask3) << 4);
utmp[sb * 4 + 2] = uaux_0;
utmp[sb * 4 + 0] &= kmask1;
}
for (int k = 0; k < (qk / (2 * blocklen)); k++) {
uint8_t * scales_0 = (uint8_t *) utmp + (k / 8) * 32;
uint8_t * scales_1 = (uint8_t *) utmp + (k / 8) * 32 + 16;
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
sumi1 = 0;
sumi2 = 0;
sumi = 0;
for (int i = 0; i < blocklen; ++i) {
const int v0 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] & 0xF);
const int v1 = (int8_t) (b_ptr[l].qs[k * ncols_interleaved * blocklen + j * blocklen + i] >> 4);
sumi1 = (v0 * a_ptr[l].qs[(k / 8) * 256 + (k % 8) * 4 * blocklen + m * blocklen + i]);
sumi2 = (v1 * a_ptr[l].qs[(k / 8) * 256 + (k % 8) * 4 * blocklen + m * blocklen + i + 128]);
sumi1 = sumi1 * scales_0[j];
sumi2 = sumi2 * scales_1[j];
sumi += sumi1 + sumi2;
}
sumf[m][j] += sumi * GGML_CPU_FP16_TO_FP32(b_ptr[l].d[j]) * a_ptr[l].d[m];
}
}
}
for (int sb = 0; sb < 8; sb++) {
uint8_t * mins = (uint8_t *) utmp + 8 + sb * 16;
for(int m = 0; m < 4; m++) {
const int16_t * bsums = a_ptr[l].bsums + (sb * 8) + (m * 4) - ((sb % 2) * 6);
for(int j = 0; j < ncols_interleaved; j++) {
sum_minf[m][j] += mins[j] * (bsums[0] + bsums[1]) * GGML_CPU_FP16_TO_FP32(b_ptr[l].dmin[j]) * a_ptr[l].d[m];
}
}
}
}
for (int m = 0; m < 4; m++) {
for (int j = 0; j < ncols_interleaved; j++) {
s[(y * 4 + m) * bs + x * ncols_interleaved + j] = sumf[m][j] - sum_minf[m][j];
}
}
}
}
}
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc) {
const int qk = QK_K;
const int nb = n / qk;
@ -1228,9 +1440,10 @@ static int repack_q4_0_to_q4_0_4_bl(struct ggml_tensor * t, int interleave_block
GGML_UNUSED(data_size);
}
static int repack_q4_K_to_q4_K_8_bl(struct ggml_tensor * t, int interleave_block, const void * GGML_RESTRICT data, size_t data_size) {
GGML_ASSERT(t->type == GGML_TYPE_Q4_K);
GGML_ASSERT(interleave_block == 8);
GGML_ASSERT(interleave_block == 8 || interleave_block == 4);
constexpr int nrows_interleaved = 8;
block_q4_Kx8 * dst = (block_q4_Kx8*)t->data;
@ -1468,6 +1681,10 @@ template <> int repack<block_q4_K, 8, 8>(struct ggml_tensor * t, const void * da
return repack_q4_K_to_q4_K_8_bl(t, 8, data, data_size);
}
template <> int repack<block_q4_K, 4, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
return repack_q4_K_to_q4_K_8_bl(t, 4, data, data_size);
}
template <> int repack<block_q2_K, 8, 8>(struct ggml_tensor * t, const void * data, size_t data_size) {
return repack_q2_K_to_q2_K_8_bl(t, 8, data, data_size);
}
@ -1501,6 +1718,10 @@ template <> void gemv<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t
ggml_gemv_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
template <> void gemv<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
}
template <> void gemv<block_q4_K, 8, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemv_q4_K_8x8_q8_K(n, s, bs, vx, vy, nr, nc);
}
@ -1529,6 +1750,10 @@ template <> void gemm<block_q4_0, 8, 4, GGML_TYPE_Q8_0>(int n, float * s, size_t
ggml_gemm_q4_0_4x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
template <> void gemm<block_q4_K, 4, 8, GGML_TYPE_Q8_K>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q4_K_8x4_q8_K(n, s, bs, vx, vy, nr, nc);
}
template <> void gemm<block_q4_0, 8, 8, GGML_TYPE_Q8_0>(int n, float * s, size_t bs, const void * vx, const void * vy, int nr, int nc) {
ggml_gemm_q4_0_8x8_q8_0(n, s, bs, vx, vy, nr, nc);
}
@ -1731,12 +1956,13 @@ template <typename BLOC_TYPE, int64_t INTER_SIZE, int64_t NB_COLS, ggml_type PAR
nchunk0 = (nr0 + min_chunk_size - 1) / min_chunk_size;
}
if (nth == 1 || nchunk0 < nth || disable_chunking) {
int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
// Only increase nchunk0 to nth if it won't make chunks too small
if (nth == 1 || ((nchunk0 < nth || disable_chunking) && (nr0 + nth - 1) / nth >= min_chunk_size)) {
nchunk0 = nth;
dr0 = (nr0 + nchunk0 - 1) / nchunk0;
}
const int64_t dr0 = (nr0 + nchunk0 - 1) / nchunk0;
// Ensure nchunk doesn't exceed the number of rows divided by minimum chunk size
// This prevents creating too many tiny chunks that could overlap after alignment
const int64_t max_nchunk = (nr0 + min_chunk_size - 1) / min_chunk_size;
@ -1930,6 +2156,9 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
static const ggml::cpu::repack::tensor_traits<block_q4_0, 4, 4, GGML_TYPE_Q8_0> q4_0_4x4_q8_0;
static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 4, GGML_TYPE_Q8_0> q4_0_4x8_q8_0;
static const ggml::cpu::repack::tensor_traits<block_q4_0, 8, 8, GGML_TYPE_Q8_0> q4_0_8x8_q8_0;
// instance for Q4_K
static const ggml::cpu::repack::tensor_traits<block_q4_K, 4, 8, GGML_TYPE_Q8_K> q4_K_8x4_q8_K;
static const ggml::cpu::repack::tensor_traits<block_q4_K, 8, 8, GGML_TYPE_Q8_K> q4_K_8x8_q8_K;
// instance for Q2
@ -1966,6 +2195,11 @@ static const ggml::cpu::tensor_traits * ggml_repack_get_optimal_repack_type(cons
return &q4_K_8x8_q8_K;
}
}
if (ggml_cpu_has_neon() && ggml_cpu_has_dotprod()) {
if (cur->ne[1] % 8 == 0) {
return &q4_K_8x4_q8_K;
}
}
} else if (cur->type == GGML_TYPE_Q2_K) {
if (ggml_cpu_has_avx512()) {
if (cur->ne[1] % 8 == 0) {

View File

@ -80,10 +80,12 @@ extern "C" {
void ggml_quantize_mat_q8_0_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_quantize_mat_q8_0_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_quantize_mat_q8_K_4x4(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_quantize_mat_q8_K_4x8(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_gemv_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@ -91,6 +93,7 @@ void ggml_gemv_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
void ggml_gemm_q4_0_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_4x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_8x4_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q2_K_8x8_q8_K(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_4x4_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@ -99,10 +102,12 @@ void ggml_gemm_iq4_nl_8x8_q8_0(int n, float * GGML_RESTRICT s, size_t bs, const
// Native implementations
void ggml_quantize_mat_q8_0_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_quantize_mat_q8_0_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_quantize_mat_q8_K_4x4_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_quantize_mat_q8_K_4x8_generic(const float * GGML_RESTRICT x, void * GGML_RESTRICT vy, int64_t k);
void ggml_gemv_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemv_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
@ -110,6 +115,7 @@ void ggml_gemv_iq4_nl_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs
void ggml_gemm_q4_0_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_4x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_0_8x8_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_8x4_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q4_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_q2_K_8x8_q8_K_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);
void ggml_gemm_iq4_nl_4x4_q8_0_generic(int n, float * GGML_RESTRICT s, size_t bs, const void * GGML_RESTRICT vx, const void * GGML_RESTRICT vy, int nr, int nc);

View File

@ -84,12 +84,12 @@
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD
#define GGML_CUDA_CC_PH1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // MTT S5000
#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)
#define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)
#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG)
#define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG)
#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_PH1)
#define GGML_CUDA_CC_IS_PH1(cc) (cc >= GGML_CUDA_CC_PH1)
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
# define GGML_CUDA_USE_CUB
@ -212,9 +212,9 @@ static const char * cu_get_error_str(CUresult err) {
#define GGML_USE_VMM
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
#if defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
#if defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
#define FP16_AVAILABLE
#endif // defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
#endif // defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
#define FAST_FP16_AVAILABLE
@ -250,12 +250,14 @@ static const char * cu_get_error_str(CUresult err) {
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
static bool fp16_available(const int cc) {
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL;
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL ||
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);
}
static bool fast_fp16_available(const int cc) {
return GGML_CUDA_CC_IS_AMD(cc) ||
(GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610);
(GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610) ||
(GGML_CUDA_CC_IS_MTHREADS(cc) && fp16_available(cc));
}
// To be used for feature selection of external libraries, e.g. cuBLAS.
@ -272,7 +274,9 @@ static bool fp16_mma_hardware_available(const int cc) {
}
static bool bf16_mma_hardware_available(const int cc) {
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) || GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3;
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) ||
GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3 ||
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);
}
static bool fp32_mma_hardware_available(const int cc) {
@ -558,8 +562,12 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const float2 v
acc += v.y*u.y;
}
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
#if defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))
#define V_DOT2_F32_F16_AVAILABLE
#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(__gfx906__) || defined(CDNA))
static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v, const half2 u) {
#ifdef V_DOT2_F32_F16_AVAILABLE
asm volatile("v_dot2_f32_f16 %0, %1, %2, %0" : "+v"(acc) : "v"(v), "v"(u));
#else
#ifdef FAST_FP16_AVAILABLE
@ -571,7 +579,7 @@ static __device__ __forceinline__ void ggml_cuda_mad(float & acc, const half2 v,
acc += tmpv.x * tmpu.x;
acc += tmpv.y * tmpu.y;
#endif // FAST_FP16_AVAILABLE
#endif // defined(GGML_USE_HIP) && (defined(RDNA2) || defined(RDNA3) || defined(RDNA4) || defined(GCN5) || defined(CDNA))
#endif // V_DOT2_F32_F16_AVAILABLE
}
static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v, const half2 u) {

View File

@ -86,6 +86,9 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const
}
}
}
GGML_UNUSED_VARS(ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11,
nb12, nb13);
}
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
@ -202,7 +205,7 @@ static void ggml_cpy_scalar_cuda(
ne00n = ne00;
ne01n = ne01;
ne02n = ne02;
} else if (nb00 > nb02) {
} else {
ne00n = ne00;
ne01n = ne01*ne02;
ne02n = 1;

View File

@ -55,11 +55,11 @@ static __device__ __forceinline__ float vec_dot_fattn_vec_KQ_f16(
ggml_cuda_memcpy_1<sizeof(tmp)>(tmp, K_h2 + k_KQ_0 + (threadIdx.x % nthreads)*cpy_ne);
#pragma unroll
for (int k_KQ_1 = 0; k_KQ_1 < cpy_ne; ++k_KQ_1) {
#ifdef FAST_FP16_AVAILABLE
#ifdef V_DOT2_F32_F16_AVAILABLE
ggml_cuda_mad(sum, tmp[k_KQ_1] , ((const half2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
#else
ggml_cuda_mad(sum, __half22float2(tmp[k_KQ_1]), ((const float2 *) Q_v)[k_KQ_0/nthreads + k_KQ_1]);
#endif // FP16_AVAILABLE
#endif // V_DOT2_F32_F16_AVAILABLE
}
}

View File

@ -609,7 +609,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
float KQ_sum_add = 0.0f;
#pragma unroll
for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup ?
const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < static_cast<uint32_t>(k_VKQ_sup) ?
expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f;
KQ_sum_add += val;
tmp[i0/(np*warp_size)][jc1] = val;

View File

@ -86,11 +86,11 @@ static __global__ void flash_attn_ext_vec(
constexpr vec_dot_KQ_t vec_dot_KQ = get_vec_dot_KQ<type_K, D, nthreads_KQ>();
constexpr bool Q_q8_1 = type_K != GGML_TYPE_F16;
#ifdef FAST_FP16_AVAILABLE
#ifdef V_DOT2_F32_F16_AVAILABLE
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, half, V_rows_per_thread>();
#else
constexpr dequantize_V_t dequantize_V = get_dequantize_V<type_V, float, V_rows_per_thread>();
#endif // FAST_FP16_AVAILABLE
#endif // V_DOT2_F32_F16_AVAILABLE
const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on.
@ -112,13 +112,13 @@ static __global__ void flash_attn_ext_vec(
constexpr int ne_KQ = ncols*D;
constexpr int ne_combine = nwarps*V_cols_per_iter*D;
#ifdef FAST_FP16_AVAILABLE
#ifdef V_DOT2_F32_F16_AVAILABLE
half2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
__shared__ half KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
#else
float2 VKQ[ncols][(D/2)/nthreads_V] = {{{0.0f, 0.0f}}};
__shared__ float KQ[ne_KQ > ne_combine ? ne_KQ : ne_combine];
#endif // FAST_FP16_AVAILABLE
#endif // V_DOT2_F32_F16_AVAILABLE
float KQ_max[ncols];
float KQ_sum[ncols];
@ -129,11 +129,11 @@ static __global__ void flash_attn_ext_vec(
}
// Convert Q to float2 (f16 K) or q8_1 (quantized K) and store in registers:
#ifdef FAST_FP16_AVAILABLE
#ifdef V_DOT2_F32_F16_AVAILABLE
half2 Q_reg[ncols][(D/2)/nthreads_KQ]; // Will be initialized completely.
#else
float2 Q_reg[ncols][(D/2)/nthreads_KQ] = {{{0.0f, 0.0f}}}; // May be only partially initialized.
#endif // FAST_FP16_AVAILABLE
#endif // V_DOT2_F32_F16_AVAILABLE
int Q_i32[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
float2 Q_ds[ncols][1 > D/(sizeof(int)*nthreads_KQ) ? 1 : D/(sizeof(int)*nthreads_KQ)];
if constexpr (Q_q8_1) {
@ -155,7 +155,7 @@ static __global__ void flash_attn_ext_vec(
for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE <= D/sizeof(int) || i < D/sizeof(int)) {
if (i0 + WARP_SIZE <= int(D/sizeof(int)) || i < int(D/sizeof(int))) {
tmp_q_i32[i] = 0;
}
}
@ -191,7 +191,7 @@ static __global__ void flash_attn_ext_vec(
__syncthreads();
} else {
#ifdef FAST_FP16_AVAILABLE
#ifdef V_DOT2_F32_F16_AVAILABLE
const half2 scale_h2 = make_half2(scale, scale);
#pragma unroll
for (int j = 0; j < ncols; ++j) {
@ -233,7 +233,7 @@ static __global__ void flash_attn_ext_vec(
Q_reg[j][k].y *= scale;
}
}
#endif // FAST_FP16_AVAILABLE
#endif // V_DOT2_F32_F16_AVAILABLE
}
const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11;
@ -272,7 +272,7 @@ static __global__ void flash_attn_ext_vec(
KQ_max_new[j] = fmaxf(KQ_max_new[j], sum);
if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == i_KQ_0) {
if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == uint32_t(i_KQ_0)) {
KQ_reg[j] = sum;
}
}
@ -291,7 +291,7 @@ static __global__ void flash_attn_ext_vec(
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + KQ_reg[j];
KQ[j*nthreads + tid] = KQ_reg[j];
#ifdef FAST_FP16_AVAILABLE
#ifdef V_DOT2_F32_F16_AVAILABLE
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
@ -303,7 +303,7 @@ static __global__ void flash_attn_ext_vec(
VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
}
#endif // FAST_FP16_AVAILABLE
#endif // V_DOT2_F32_F16_AVAILABLE
}
#ifndef GGML_USE_HIP
@ -314,7 +314,7 @@ static __global__ void flash_attn_ext_vec(
for (int k0 = 0; k0 < WARP_SIZE; k0 += V_cols_per_iter) {
const int k = threadIdx.y*WARP_SIZE + k0 + (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V);
#ifdef FAST_FP16_AVAILABLE
#ifdef V_DOT2_F32_F16_AVAILABLE
half2 KQ_k[ncols];
#pragma unroll
for (int j = 0; j < ncols; ++j) {
@ -353,7 +353,7 @@ static __global__ void flash_attn_ext_vec(
}
}
}
#endif // FAST_FP16_AVAILABLE
#endif // V_DOT2_F32_F16_AVAILABLE
}
}
@ -374,7 +374,7 @@ static __global__ void flash_attn_ext_vec(
KQ_sum[j] = KQ_sum[j]*KQ_max_scale + (threadIdx.x == 0 ? expf(sink - KQ_max[j]) : 0.0f);
#ifdef FAST_FP16_AVAILABLE
#ifdef V_DOT2_F32_F16_AVAILABLE
const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale);
#pragma unroll
for (int i_VKQ_0 = 0; i_VKQ_0 < D/2; i_VKQ_0 += nthreads_V) {
@ -386,7 +386,7 @@ static __global__ void flash_attn_ext_vec(
VKQ[j][i_VKQ_0/nthreads_V].x *= KQ_max_scale;
VKQ[j][i_VKQ_0/nthreads_V].y *= KQ_max_scale;
}
#endif // FAST_FP16_AVAILABLE
#endif // V_DOT2_F32_F16_AVAILABLE
}
}
@ -421,7 +421,7 @@ static __global__ void flash_attn_ext_vec(
const float kqmax_scale = expf(KQ_max[j_VKQ] - kqmax_new);
KQ_max[j_VKQ] = kqmax_new;
#ifdef FAST_FP16_AVAILABLE
#ifdef V_DOT2_F32_F16_AVAILABLE
half2 * VKQ_tmp = (half2 *) KQ + threadIdx.y*(V_cols_per_iter*D/2)
+ (nthreads_V == WARP_SIZE ? 0 : threadIdx.x / nthreads_V)*(D/2);
@ -452,7 +452,7 @@ static __global__ void flash_attn_ext_vec(
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ, &VKQ[j_VKQ][i_VKQ_0/nthreads_V]);
ggml_cuda_memcpy_1<V_rows_per_thread/2*sizeof(float)>(VKQ_tmp + i_VKQ + V_rows_per_thread/4, &VKQ[j_VKQ][i_VKQ_0/nthreads_V + V_rows_per_thread/4]);
}
#endif // FAST_FP16_AVAILABLE
#endif // V_DOT2_F32_F16_AVAILABLE
KQ_sum[j_VKQ] *= kqmax_scale;
KQ_sum[j_VKQ] = warp_reduce_sum(KQ_sum[j_VKQ]);

View File

@ -55,6 +55,7 @@
#include "ggml-cuda/set.cuh"
#include "ggml-cuda/set-rows.cuh"
#include "ggml-cuda/pad_reflect_1d.cuh"
#include "ggml-cuda/solve_tri.cuh"
#include "ggml.h"
#include <algorithm>
@ -2725,6 +2726,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_OPT_STEP_SGD:
ggml_cuda_opt_step_sgd(ctx, dst);
break;
case GGML_OP_SOLVE_TRI:
ggml_cuda_op_solve_tri(ctx, dst);
break;
default:
return false;
}
@ -3054,7 +3058,12 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax =
ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
if (ops.size() == topk_moe_ops_with_norm.size() &&
const auto is_equal = [](const std::initializer_list<enum ggml_op> & list1,
const std::initializer_list<enum ggml_op> & list2) {
return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end());
};
if (is_equal(topk_moe_ops_with_norm, ops) &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx + 9];
@ -3064,8 +3073,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
}
}
if (ops.size() == topk_moe_ops.size() &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx + 4];
if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
@ -3073,7 +3081,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
}
}
if (ops.size() == topk_moe_ops_delayed_softmax.size() &&
if (is_equal(topk_moe_ops_delayed_softmax, ops) &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
ggml_tensor * weights = cgraph->nodes[node_idx + 5];
@ -3089,9 +3097,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
std::initializer_list<enum ggml_op> mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU };
std::initializer_list<enum ggml_op> mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU };
if (ops.size() == 5 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}) ||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}))) {
if ((is_equal(mul_mat_bias_glu_ops, ops) || is_equal(mul_mat_id_bias_glu_ops, ops)) &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 4 })) {
const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
const ggml_tensor * ffn_gate_bias = cgraph->nodes[node_idx + 1];
const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 2];
@ -3103,9 +3110,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
}
}
if (ops.size() == 3 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}) ||
ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}))) {
if ((is_equal(mul_mat_id_glu_ops, ops) || is_equal(mul_mat_glu_ops, ops)) &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 1];
const ggml_tensor * glu = cgraph->nodes[node_idx + 2];
@ -3115,7 +3121,9 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
}
}
if (ops.size() == 3 && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
std::initializer_list<enum ggml_op> rope_set_rows_ops = { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS };
if (is_equal(rope_set_rows_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
const ggml_tensor * rope = cgraph->nodes[node_idx];
const ggml_tensor * view = cgraph->nodes[node_idx + 1];
const ggml_tensor * set_rows = cgraph->nodes[node_idx + 2];
@ -3845,7 +3853,7 @@ static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t *
// Check if UMA is explicitly enabled via environment variable
bool uma_env = getenv("GGML_CUDA_ENABLE_UNIFIED_MEMORY") != nullptr;
bool is_uma = prop.unifiedAddressing > 0 || uma_env;
bool is_uma = prop.integrated > 0 || uma_env;
if (is_uma) {
// For UMA systems (like DGX Spark), use system memory info
@ -4265,6 +4273,8 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
return true;
case GGML_OP_SOLVE_TRI:
return op->src[0]->ne[0] <= 64 && op->src[1]->ne[0] <= 32;
default:
return false;
}

View File

@ -889,8 +889,8 @@ namespace ggml_cuda_mma {
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
: "r"(Axi[6]), "r"(Axi[7]), "r"(Bxi[6]), "r"(Bxi[7]));
#else
tile<16, 8, float> * D16 = (tile<16, 8, float> *) &D;
tile<16, 8, half2> * A16 = (tile<16, 8, half2> *) &A;
tile <16, 8, float> * D16 = reinterpret_cast<tile <16, 8, float> *>(&D);
const tile<16, 8, half2> * A16 = reinterpret_cast<const tile<16, 8, half2> *>(&A);
mma(D16[0], A16[0], B);
mma(D16[1], A16[1], B);
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE

View File

@ -151,7 +151,7 @@ bool ggml_cuda_should_use_mmf(enum ggml_type type, int cc, int warp_size, const
return false;
}
} else {
if (src1_ncols > 16 || GGML_CUDA_CC_IS_RDNA4(cc)) {
if (src1_ncols > 16) {
return false;
}
}

View File

@ -0,0 +1,203 @@
#include "common.cuh"
#include "ggml.h"
#include "solve_tri.cuh"
#define MAX_N_FAST 64
#define MAX_K_FAST 32
// ======================
// Fast Kernel (n <= 64, k <= 32) - Warp-based parallel reduction
// ======================
// When ncols_template == 0 the bounds for the loops in this function are not
// known and can't be unrolled. As we want to keep pragma unroll for all other
// cases we supress the clang transformation warning here.
#ifdef __clang__
# pragma clang diagnostic push
# pragma clang diagnostic ignored "-Wpass-failed"
#endif // __clang__
template <int n_template, int k_template>
static __global__ void solve_tri_f32_fast(const float * __restrict__ A,
const float * __restrict__ B,
float * __restrict__ X,
const uint3 ne02,
const size_t nb02,
const size_t nb03,
const size_t nb12,
const size_t nb13,
const size_t nb2,
const size_t nb3,
const int n_arg,
const int k_arg) {
const int n = n_template == 0 ? n_arg : n_template;
const int k = k_template == 0 ? k_arg : k_template;
const int batch_idx = blockIdx.x;
const int lane = threadIdx.x;
const int col_idx = threadIdx.y;
if (col_idx >= k) {
return;
}
const uint2 i02_i03 = fast_div_modulo(batch_idx, ne02);
const int64_t i02 = i02_i03.y;
const int64_t i03 = i02_i03.x;
const float * const A_batch = (const float *) (A + i02 * nb02 + i03 * nb03);
const float * const B_batch = (const float *) (B + i02 * nb12 + i03 * nb13);
float * X_batch = (float *) (X + i02 * nb2 + i03 * nb3);
__shared__ float sA[MAX_N_FAST * MAX_N_FAST];
__shared__ float sXt[MAX_N_FAST * (MAX_K_FAST + 1)];
const int offset = threadIdx.x + threadIdx.y * blockDim.x;
#pragma unroll
for (int i = 0; i < n * n; i += k * WARP_SIZE) {
int i0 = i + offset;
if (i0 < n * n) {
sA[i0] = A_batch[i0];
}
}
const int rows_per_warp = (n + WARP_SIZE - 1) / WARP_SIZE;
#pragma unroll
for (int i = 0; i < rows_per_warp; i++) {
const int i0 = lane + i * WARP_SIZE;
if (i0 < n) {
sXt[col_idx * n + i0] = B_batch[i0 * k + col_idx];
}
}
__syncthreads();
#pragma unroll
for (int row = 0; row < n; ++row) {
float sum = 0.0f;
{
int j = lane;
if (j < row) {
sum += sA[row * n + j] * sXt[col_idx * n + j];
}
}
if (row >= WARP_SIZE) {
int j = WARP_SIZE + lane;
if (j < row) {
sum += sA[row * n + j] * sXt[col_idx * n + j];
}
}
sum = warp_reduce_sum(sum);
if (lane == 0) {
const float b_val = sXt[col_idx * n + row];
const float a_diag = sA[row * n + row];
// no safeguards for division by zero because that indicates corrupt
// data anyway
sXt[col_idx * n + row] = (b_val - sum) / a_diag;
}
}
__syncthreads();
#pragma unroll
for (int i = 0; i < rows_per_warp; i++) {
const int i0 = lane + i * WARP_SIZE;
if (i0 < n) {
X_batch[i0 * k + col_idx] = sXt[col_idx * n + i0];
}
}
}
#ifdef __clang__
# pragma clang diagnostic pop
#endif // __clang__
static void solve_tri_f32_cuda(const float * A,
const float * B,
float * X,
int n,
int k,
int64_t ne02,
int64_t ne03,
size_t nb02,
size_t nb03,
size_t nb12,
size_t nb13,
size_t nb2,
size_t nb3,
cudaStream_t stream) {
const uint3 ne02_fd = init_fastdiv_values((uint32_t) ne02);
dim3 threads(WARP_SIZE, k);
dim3 grid(ne02 * ne03);
if (n == 64) {
switch (k) {
case 32:
solve_tri_f32_fast<64, 32>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 16:
solve_tri_f32_fast<64, 16>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 14:
solve_tri_f32_fast<64, 14>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 12:
solve_tri_f32_fast<64, 12>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 10:
solve_tri_f32_fast<64, 10>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 8:
solve_tri_f32_fast<64, 8>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 6:
solve_tri_f32_fast<64, 6>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 4:
solve_tri_f32_fast<64, 4>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 2:
solve_tri_f32_fast<64, 2>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
case 1:
solve_tri_f32_fast<64, 1>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, 0, 0);
break;
default:
solve_tri_f32_fast<0, 0>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
}
} else { // run general case
solve_tri_f32_fast<0, 0>
<<<grid, threads, 0, stream>>>(A, B, X, ne02_fd, nb02, nb03, nb12, nb13, nb2, nb3, n, k);
}
}
void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0]; // A (triangular n x x matrix)
const ggml_tensor * src1 = dst->src[1]; // B (right hand side of n x k equation columns)
ggml_is_contiguous(src0);
ggml_is_contiguous(src1);
const int64_t n = src0->ne[0];
const int64_t k = src1->ne[0];
GGML_ASSERT(n <= 64);
GGML_ASSERT(k <= 32);
solve_tri_f32_cuda((const float *) src0->data, (const float *) src1->data, (float *) dst->data, n, k, src0->ne[2],
src0->ne[3], src0->nb[2] / sizeof(float), src0->nb[3] / sizeof(float),
src1->nb[2] / sizeof(float), src1->nb[3] / sizeof(float), dst->nb[2] / sizeof(float),
dst->nb[3] / sizeof(float), ctx.stream());
}

View File

@ -0,0 +1,3 @@
#include "common.cuh"
void ggml_cuda_op_solve_tri(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

View File

@ -70,6 +70,7 @@ set(GGML_OPENCL_KERNELS
group_norm
im2col_f32
im2col_f16
mean
mul_mat_Ab_Bi_8x4
mul_mv_f16_f16
mul_mv_f16_f32_1row
@ -109,6 +110,9 @@ set(GGML_OPENCL_KERNELS
softmax_4_f16
softmax_f32
softmax_f16
sqr
sqrt
ssm_conv
sub
sum_rows
transpose

View File

@ -449,6 +449,9 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_sub, kernel_sub_row, kernel_sub_f16, kernel_sub_row_f16;
cl_kernel kernel_add_id;
cl_kernel kernel_scale;
cl_kernel kernel_sqr_cont_f32, kernel_sqr_cont_f32_4, kernel_sqr_cont_f16, kernel_sqr_cont_f16_4;
cl_kernel kernel_sqrt_cont_f32, kernel_sqrt_cont_f32_4, kernel_sqrt_cont_f16, kernel_sqrt_cont_f16_4;
cl_kernel kernel_mean_f32;
cl_kernel kernel_silu, kernel_silu_4;
cl_kernel kernel_gelu, kernel_gelu_4;
cl_kernel kernel_gelu_erf, kernel_gelu_erf_4;
@ -509,6 +512,7 @@ struct ggml_backend_opencl_context {
cl_kernel kernel_conv_2d_f16;
cl_kernel kernel_conv_2d_f32;
cl_kernel kernel_conv_2d_f16_f32;
cl_kernel kernel_ssm_conv_f32_f32, kernel_ssm_conv_f32_f32_4;
cl_kernel kernel_timestep_embedding;
cl_kernel kernel_gemv_moe_mxfp4_f32, kernel_gemm_moe_mxfp4_f32;
cl_kernel kernel_mul_mv_id_q4_0_f32_8x_flat;
@ -1552,6 +1556,66 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
GGML_LOG_CONT(".");
}
// sqr
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "sqr.cl.h"
};
#else
const std::string kernel_src = read_file("sqr.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_sqr_cont_f32 = clCreateKernel(prog, "kernel_sqr_cont_f32", &err), err));
CL_CHECK((backend_ctx->kernel_sqr_cont_f32_4 = clCreateKernel(prog, "kernel_sqr_cont_f32_4", &err), err));
CL_CHECK((backend_ctx->kernel_sqr_cont_f16 = clCreateKernel(prog, "kernel_sqr_cont_f16", &err), err));
CL_CHECK((backend_ctx->kernel_sqr_cont_f16_4 = clCreateKernel(prog, "kernel_sqr_cont_f16_4", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// sqrt
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "sqrt.cl.h"
};
#else
const std::string kernel_src = read_file("sqrt.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_sqrt_cont_f32 = clCreateKernel(prog, "kernel_sqrt_cont_f32", &err), err));
CL_CHECK((backend_ctx->kernel_sqrt_cont_f32_4 = clCreateKernel(prog, "kernel_sqrt_cont_f32_4", &err), err));
CL_CHECK((backend_ctx->kernel_sqrt_cont_f16 = clCreateKernel(prog, "kernel_sqrt_cont_f16", &err), err));
CL_CHECK((backend_ctx->kernel_sqrt_cont_f16_4 = clCreateKernel(prog, "kernel_sqrt_cont_f16_4", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// mean
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "mean.cl.h"
};
#else
const std::string kernel_src = read_file("mean.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_mean_f32 = clCreateKernel(prog, "kernel_mean_f32", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// sub
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@ -1825,6 +1889,24 @@ static void load_cl_kernels(ggml_backend_opencl_context *backend_ctx, ggml_cl_ve
}
}
// ssm_conv
{
#ifdef GGML_OPENCL_EMBED_KERNELS
const std::string kernel_src {
#include "ssm_conv.cl.h"
};
#else
const std::string kernel_src = read_file("ssm_conv.cl");
#endif
cl_program prog =
build_program_from_source(backend_ctx->context, backend_ctx->device, kernel_src.c_str(), compile_opts);
CL_CHECK((backend_ctx->kernel_ssm_conv_f32_f32 = clCreateKernel(prog, "kernel_ssm_conv_f32_f32", &err), err));
CL_CHECK((backend_ctx->kernel_ssm_conv_f32_f32_4 = clCreateKernel(prog, "kernel_ssm_conv_f32_f32_4", &err), err));
CL_CHECK(clReleaseProgram(prog));
GGML_LOG_CONT(".");
}
// mul_mv_id_q4_0_f32_8x_flat
{
#ifdef GGML_OPENCL_EMBED_KERNELS
@ -2959,6 +3041,10 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
(op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16);
case GGML_OP_ADD_ID:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SQR:
case GGML_OP_SQRT:
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
ggml_is_contiguous(op->src[0]);
case GGML_OP_UNARY:
switch (ggml_get_unary_op(op)) {
case GGML_UNARY_OP_GELU:
@ -3007,6 +3093,8 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
return (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16 && op->type == GGML_TYPE_F16) ||
(op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32) ||
(op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32);
case GGML_OP_SSM_CONV:
return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32);
case GGML_OP_CONCAT:
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->type == GGML_TYPE_F32;
case GGML_OP_TIMESTEP_EMBEDDING:
@ -3075,6 +3163,7 @@ static bool ggml_opencl_supports_op(ggml_backend_dev_t dev, const struct ggml_te
return cols <= max_workgroup_size && op->src[0]->type == GGML_TYPE_F32;
}
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous(op->src[0]);
case GGML_OP_FLASH_ATTN_EXT:
{
@ -5193,6 +5282,224 @@ static void ggml_cl_sub(ggml_backend_t backend, const ggml_tensor * src0, const
}
}
static void ggml_cl_sqr(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);
UNUSED(src1);
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
cl_kernel kernel;
// Currently assumes src0 is contiguous
int n = ggml_nelements(dst);
if (n % 4 == 0) {
if (src0->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_sqr_cont_f32_4;
} else {
kernel = backend_ctx->kernel_sqr_cont_f16_4;
}
n /= 4;
} else {
if (src0->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_sqr_cont_f32;
} else {
kernel = backend_ctx->kernel_sqr_cont_f16;
}
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
size_t global_work_size[] = {(size_t)n, 1, 1};
size_t local_work_size[] = {64, 1, 1};
size_t * local_work_size_ptr = local_work_size;
if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
local_work_size_ptr = nullptr;
}
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
}
static void ggml_cl_sqrt(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);
UNUSED(src1);
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
cl_kernel kernel;
// Currently assumes src0 is contiguous
int n = ggml_nelements(dst);
if (n % 4 == 0) {
if (src0->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_sqrt_cont_f32_4;
} else {
kernel = backend_ctx->kernel_sqrt_cont_f16_4;
}
n /= 4;
} else {
if (src0->type == GGML_TYPE_F32) {
kernel = backend_ctx->kernel_sqrt_cont_f32;
} else {
kernel = backend_ctx->kernel_sqrt_cont_f16;
}
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
size_t global_work_size[] = {(size_t)n, 1, 1};
size_t local_work_size[] = {64, 1, 1};
size_t * local_work_size_ptr = local_work_size;
if (n % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
local_work_size_ptr = nullptr;
}
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
}
static void ggml_cl_mean(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);
GGML_UNUSED(src1);
GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));
GGML_ASSERT(ggml_is_contiguous(src0));
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
const int ne00 = src0->ne[0];
const int ne01 = src0->ne[1];
const int ne02 = src0->ne[2];
const int ne03 = src0->ne[3];
const cl_ulong nb01 = src0->nb[1];
const cl_ulong nb02 = src0->nb[2];
const cl_ulong nb03 = src0->nb[3];
const cl_ulong nb1 = dst->nb[1];
const cl_ulong nb2 = dst->nb[2];
const cl_ulong nb3 = dst->nb[3];
cl_kernel kernel = backend_ctx->kernel_mean_f32;
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(int), &ne00));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(int), &ne01));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(int), &ne02));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(int), &ne03));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb03));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb2));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb3));
size_t global_work_size[] = {(size_t)ne01, (size_t)ne02, (size_t)ne03};
size_t local_work_size[] = {(size_t)64, 1, 1};
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size, dst);
}
static void ggml_cl_ssm_conv(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
GGML_ASSERT(src1);
GGML_ASSERT(src1->extra);
GGML_ASSERT(dst);
GGML_ASSERT(dst->extra);
ggml_backend_opencl_context *backend_ctx = (ggml_backend_opencl_context *)backend->context;
ggml_tensor_extra_cl * extra0 = (ggml_tensor_extra_cl *)src0->extra;
ggml_tensor_extra_cl * extra1 = (ggml_tensor_extra_cl *)src1->extra;
ggml_tensor_extra_cl * extrad = (ggml_tensor_extra_cl *)dst->extra;
cl_ulong offset0 = extra0->offset + src0->view_offs;
cl_ulong offset1 = extra1->offset + src1->view_offs;
cl_ulong offsetd = extrad->offset + dst->view_offs;
int ne01 = src0->ne[1];
cl_ulong nb00 = src0->nb[0];
cl_ulong nb01 = src0->nb[1];
cl_ulong nb02 = src0->nb[2];
int ne10 = src1->ne[0];
cl_ulong nb11 = src1->nb[1];
int ne1 = dst->ne[1];
int ne2 = dst->ne[2];
cl_ulong nb0 = dst->nb[0];
cl_ulong nb1 = dst->nb[1];
cl_ulong nb2 = dst->nb[2];
cl_kernel kernel = backend_ctx->kernel_ssm_conv_f32_f32;
if (ne10 % 4 == 0) {
kernel = backend_ctx->kernel_ssm_conv_f32_f32_4;
}
CL_CHECK(clSetKernelArg(kernel, 0, sizeof(cl_mem), &extra0->data_device));
CL_CHECK(clSetKernelArg(kernel, 1, sizeof(cl_ulong), &offset0));
CL_CHECK(clSetKernelArg(kernel, 2, sizeof(cl_mem), &extra1->data_device));
CL_CHECK(clSetKernelArg(kernel, 3, sizeof(cl_ulong), &offset1));
CL_CHECK(clSetKernelArg(kernel, 4, sizeof(cl_mem), &extrad->data_device));
CL_CHECK(clSetKernelArg(kernel, 5, sizeof(cl_ulong), &offsetd));
CL_CHECK(clSetKernelArg(kernel, 6, sizeof(cl_ulong), &nb00));
CL_CHECK(clSetKernelArg(kernel, 7, sizeof(cl_ulong), &nb01));
CL_CHECK(clSetKernelArg(kernel, 8, sizeof(cl_ulong), &nb02));
CL_CHECK(clSetKernelArg(kernel, 9, sizeof(int), &ne10));
CL_CHECK(clSetKernelArg(kernel, 10, sizeof(cl_ulong), &nb11));
CL_CHECK(clSetKernelArg(kernel, 11, sizeof(cl_ulong), &nb0));
CL_CHECK(clSetKernelArg(kernel, 12, sizeof(cl_ulong), &nb1));
CL_CHECK(clSetKernelArg(kernel, 13, sizeof(cl_ulong), &nb2));
size_t global_work_size[] = {(size_t)ne01, (size_t)ne1, (size_t)ne2};
size_t local_work_size[] = {64, 1, 1};
size_t * local_work_size_ptr = local_work_size;
if (ne01 % 64 != 0 && !backend_ctx->non_uniform_workgroups) {
local_work_size_ptr = nullptr;
}
backend_ctx->enqueue_ndrange_kernel(kernel, 3, global_work_size, local_work_size_ptr, dst);
}
static void ggml_cl_gelu(ggml_backend_t backend, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(src0);
GGML_ASSERT(src0->extra);
@ -9091,6 +9398,24 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
}
func = ggml_cl_sub;
break;
case GGML_OP_SQR:
if (!any_on_device) {
return false;
}
func = ggml_cl_sqr;
break;
case GGML_OP_SQRT:
if (!any_on_device) {
return false;
}
func = ggml_cl_sqrt;
break;
case GGML_OP_MEAN:
if (!any_on_device) {
return false;
}
func = ggml_cl_mean;
break;
case GGML_OP_UNARY:
switch (ggml_get_unary_op(tensor)) {
case GGML_UNARY_OP_GELU:
@ -9192,6 +9517,12 @@ bool ggml_cl_compute_forward(ggml_backend_t backend, struct ggml_tensor * tensor
}
func = ggml_cl_conv_2d;
break;
case GGML_OP_SSM_CONV:
if (!any_on_device) {
return false;
}
func = ggml_cl_ssm_conv;
break;
case GGML_OP_CONCAT:
if (!any_on_device) {
return false;

View File

@ -0,0 +1,39 @@
kernel void kernel_mean_f32(
global float * src0,
ulong offset0,
global float * dst,
ulong offsetd,
int ne00,
int ne01,
int ne02,
int ne03,
ulong nb01,
ulong nb02,
ulong nb03,
ulong nb1,
ulong nb2,
ulong nb3
) {
src0 = (global float *)((global char *)src0 + offset0);
dst = (global float *)((global char *)dst + offsetd);
int i3 = get_global_id(2);
int i2 = get_global_id(1);
int i1 = get_global_id(0);
if (i3 >= ne03 || i2 >= ne02 || i1 >= ne01) {
return;
}
global float * src_row = (global float *) ((global char *) src0 + i1*nb01 + i2*nb02 + i3*nb03);
global float * dst_row = (global float *) ((global char *) dst + i1*nb1 + i2*nb2 + i3*nb3);
float row_sum = 0;
for (int i0 = 0; i0 < ne00; i0++) {
row_sum += src_row[i0];
}
dst_row[0] = row_sum / ne00;
}

View File

@ -0,0 +1,53 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
kernel void kernel_sqr_cont_f32(
global float * src0,
ulong offset0,
global float * dst,
ulong offsetd
) {
src0 = (global float*)((global char*)src0 + offset0);
dst = (global float*)((global char*)dst + offsetd);
uint gid = get_global_id(0);
dst[gid] = src0[gid] * src0[gid];
}
kernel void kernel_sqr_cont_f32_4(
global float4 * src0,
ulong offset0,
global float4 * dst,
ulong offsetd
) {
src0 = (global float4*)((global char*)src0 + offset0);
dst = (global float4*)((global char*)dst + offsetd);
uint gid = get_global_id(0);
dst[gid] = src0[gid] * src0[gid];
}
kernel void kernel_sqr_cont_f16(
global half * src0,
ulong offset0,
global half * dst,
ulong offsetd
) {
src0 = (global half*)((global char*)src0 + offset0);
dst = (global half*)((global char*)dst + offsetd);
uint gid = get_global_id(0);
dst[gid] = src0[gid] * src0[gid];
}
kernel void kernel_sqr_cont_f16_4(
global half4 * src0,
ulong offset0,
global half4 * dst,
ulong offsetd
) {
src0 = (global half4*)((global char*)src0 + offset0);
dst = (global half4*)((global char*)dst + offsetd);
uint gid = get_global_id(0);
dst[gid] = src0[gid] * src0[gid];
}

View File

@ -0,0 +1,53 @@
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
kernel void kernel_sqrt_cont_f32(
global float * src0,
ulong offset0,
global float * dst,
ulong offsetd
) {
src0 = (global float*)((global char*)src0 + offset0);
dst = (global float*)((global char*)dst + offsetd);
uint gid = get_global_id(0);
dst[gid] = sqrt(src0[gid]);
}
kernel void kernel_sqrt_cont_f32_4(
global float4 * src0,
ulong offset0,
global float4 * dst,
ulong offsetd
) {
src0 = (global float4*)((global char*)src0 + offset0);
dst = (global float4*)((global char*)dst + offsetd);
uint gid = get_global_id(0);
dst[gid] = sqrt(src0[gid]);
}
kernel void kernel_sqrt_cont_f16(
global half * src0,
ulong offset0,
global half * dst,
ulong offsetd
) {
src0 = (global half*)((global char*)src0 + offset0);
dst = (global half*)((global char*)dst + offsetd);
uint gid = get_global_id(0);
dst[gid] = convert_half(sqrt(convert_float(src0[gid])));
}
kernel void kernel_sqrt_cont_f16_4(
global half4 * src0,
ulong offset0,
global half4 * dst,
ulong offsetd
) {
src0 = (global half4*)((global char*)src0 + offset0);
dst = (global half4*)((global char*)dst + offsetd);
uint gid = get_global_id(0);
dst[gid] = convert_half4(sqrt(convert_float4(src0[gid])));
}

View File

@ -0,0 +1,77 @@
kernel void kernel_ssm_conv_f32_f32(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
ulong nb00,
ulong nb01,
ulong nb02,
int ne10,
ulong nb11,
ulong nb0,
ulong nb1,
ulong nb2
){
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int ir = get_global_id(0);
int i2 = get_global_id(1);
int i3 = get_global_id(2);
int nc = ne10;
global float * s = (global float *) (src0 + ir*nb01 + i2*nb00 + i3*nb02);
global float * c = (global float *) (src1 + ir*nb11);
global float * d = (global float *) (dst + ir*nb0 + i2*nb1 + i3*nb2);
float sumf = 0.0f;
for (int i0 = 0; i0 < nc; ++i0) {
sumf += s[i0] * c[i0];
}
d[0] = sumf;
}
kernel void kernel_ssm_conv_f32_f32_4(
global char * src0,
ulong offset0,
global char * src1,
ulong offset1,
global char * dst,
ulong offsetd,
ulong nb00,
ulong nb01,
ulong nb02,
int ne10,
ulong nb11,
ulong nb0,
ulong nb1,
ulong nb2
) {
src0 = src0 + offset0;
src1 = src1 + offset1;
dst = dst + offsetd;
int ir = get_global_id(0);
int i2 = get_global_id(1);
int i3 = get_global_id(2);
int nc = ne10;
global float4 * s = (global float4 *) (src0 + ir*nb01 + i2*nb00 + i3*nb02);
global float4 * c = (global float4 *) (src1 + ir*nb11);
global float * d = (global float *) (dst + ir*nb0 + i2*nb1 + i3*nb2);
float sumf = 0.0f;
for (int i0 = 0; i0 < nc/4; ++i0) {
sumf += dot(s[i0], c[i0]);
}
d[0] = sumf;
}

View File

@ -106,6 +106,7 @@ enum rpc_cmd {
RPC_CMD_GET_ALLOC_SIZE,
RPC_CMD_HELLO,
RPC_CMD_DEVICE_COUNT,
RPC_CMD_GRAPH_RECOMPUTE,
RPC_CMD_COUNT,
};
@ -205,10 +206,6 @@ struct rpc_msg_copy_tensor_rsp {
uint8_t result;
};
struct rpc_msg_graph_compute_rsp {
uint8_t result;
};
struct rpc_msg_get_device_memory_req {
uint32_t device;
};
@ -217,6 +214,11 @@ struct rpc_msg_get_device_memory_rsp {
uint64_t free_mem;
uint64_t total_mem;
};
struct rpc_msg_graph_recompute_req {
uint32_t device;
};
#pragma pack(pop)
// RPC data structures
@ -234,10 +236,35 @@ struct ggml_backend_rpc_buffer_type_context {
size_t max_size;
};
struct graph_cache {
bool is_cached(const ggml_cgraph * cgraph) {
if ((int)last_graph.size() != cgraph->n_nodes) {
return false;
}
for (int i = 0; i < cgraph->n_nodes; i++) {
if (memcmp(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor)) != 0) {
return false;
}
}
return true;
}
void add(const ggml_cgraph * cgraph) {
last_graph.resize(cgraph->n_nodes);
for (int i = 0; i < cgraph->n_nodes; i++) {
memcpy(&last_graph[i], cgraph->nodes[i], sizeof(ggml_tensor));
}
}
std::vector<ggml_tensor> last_graph;
};
struct ggml_backend_rpc_context {
std::string endpoint;
uint32_t device;
std::string name;
graph_cache gc;
};
struct ggml_backend_rpc_buffer_context {
@ -815,13 +842,24 @@ static void serialize_graph(uint32_t device, const ggml_cgraph * cgraph, std::ve
static enum ggml_status ggml_backend_rpc_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) {
ggml_backend_rpc_context * rpc_ctx = (ggml_backend_rpc_context *)backend->context;
std::vector<uint8_t> input;
serialize_graph(rpc_ctx->device, cgraph, input);
rpc_msg_graph_compute_rsp response;
auto sock = get_socket(rpc_ctx->endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size(), &response, sizeof(response));
RPC_STATUS_ASSERT(status);
return (enum ggml_status)response.result;
GGML_ASSERT(cgraph->n_nodes > 0);
bool reuse = rpc_ctx->gc.is_cached(cgraph);
if (reuse) {
rpc_msg_graph_recompute_req request;
request.device = rpc_ctx->device;
auto sock = get_socket(rpc_ctx->endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_RECOMPUTE, &request, sizeof(request));
RPC_STATUS_ASSERT(status);
} else {
rpc_ctx->gc.add(cgraph);
std::vector<uint8_t> input;
serialize_graph(rpc_ctx->device, cgraph, input);
auto sock = get_socket(rpc_ctx->endpoint);
bool status = send_rpc_cmd(sock, RPC_CMD_GRAPH_COMPUTE, input.data(), input.size());
RPC_STATUS_ASSERT(status);
}
return GGML_STATUS_SUCCESS;
}
static ggml_backend_i ggml_backend_rpc_interface = {
@ -880,7 +918,8 @@ ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device) {
ggml_backend_rpc_context * ctx = new ggml_backend_rpc_context {
/* .endpoint = */ endpoint,
/* .device = */ device,
/* .name = */ dev_name
/* .name = */ dev_name,
/* .gc = */ {},
};
auto reg = ggml_backend_rpc_add_server(endpoint);
ggml_backend_t backend = new ggml_backend {
@ -920,8 +959,9 @@ void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device,
class rpc_server {
public:
rpc_server(std::vector<ggml_backend_t> backends, const char * cache_dir)
: backends(std::move(backends)), cache_dir(cache_dir) {
rpc_server(std::vector<ggml_backend_t> all_backends, const char * cache_dir)
: backends(std::move(all_backends)), cache_dir(cache_dir) {
stored_graphs.resize(backends.size());
}
~rpc_server();
@ -936,11 +976,17 @@ public:
bool set_tensor_hash(const rpc_msg_set_tensor_hash_req & request, rpc_msg_set_tensor_hash_rsp & response);
bool get_tensor(const rpc_msg_get_tensor_req & request, std::vector<uint8_t> & response);
bool copy_tensor(const rpc_msg_copy_tensor_req & request, rpc_msg_copy_tensor_rsp & response);
bool graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response);
bool graph_compute(const std::vector<uint8_t> & input);
bool graph_recompute(const rpc_msg_graph_recompute_req & request);
bool init_tensor(const rpc_msg_init_tensor_req & request);
bool get_alloc_size(const rpc_msg_get_alloc_size_req & request, rpc_msg_get_alloc_size_rsp & response);
bool get_device_memory(const rpc_msg_get_device_memory_req & request, rpc_msg_get_device_memory_rsp & response);
struct stored_graph {
ggml_context_ptr ctx_ptr;
ggml_cgraph * graph;
};
private:
bool get_cached_file(uint64_t hash, std::vector<uint8_t> & data);
ggml_tensor * deserialize_tensor(struct ggml_context * ctx, const rpc_tensor * tensor);
@ -953,6 +999,8 @@ private:
std::vector<ggml_backend_t> backends;
const char * cache_dir;
std::unordered_set<ggml_backend_buffer_t> buffers;
// store the last computed graph for each backend
std::vector<stored_graph> stored_graphs;
};
void rpc_server::hello(rpc_msg_hello_rsp & response) {
@ -1394,7 +1442,7 @@ ggml_tensor * rpc_server::create_node(uint64_t id,
return result;
}
bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph_compute_rsp & response) {
bool rpc_server::graph_compute(const std::vector<uint8_t> & input) {
// serialization format:
// | device (4 bytes) | n_nodes (4 bytes) | nodes (n_nodes * sizeof(uint64_t) | n_tensors (4 bytes) | tensors (n_tensors * sizeof(rpc_tensor)) |
if (input.size() < 2*sizeof(uint32_t)) {
@ -1455,7 +1503,24 @@ bool rpc_server::graph_compute(const std::vector<uint8_t> & input, rpc_msg_graph
}
}
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
response.result = status;
GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
stored_graphs[device].ctx_ptr.swap(ctx_ptr);
stored_graphs[device].graph = graph;
return true;
}
bool rpc_server::graph_recompute(const rpc_msg_graph_recompute_req & request) {
uint32_t device = request.device;
if (device >= backends.size()) {
return false;
}
if (stored_graphs[device].graph == nullptr) {
return false;
}
ggml_cgraph * graph = stored_graphs[device].graph;
LOG_DBG("[%s] device: %u\n", __func__, device);
ggml_status status = ggml_backend_graph_compute(backends[device], graph);
GGML_ASSERT(status == GGML_STATUS_SUCCESS && "Unsuccessful graph computations are not supported with RPC");
return true;
}
@ -1690,11 +1755,17 @@ static void rpc_serve_client(const std::vector<ggml_backend_t> & backends, const
if (!recv_msg(sockfd, input)) {
return;
}
rpc_msg_graph_compute_rsp response;
if (!server.graph_compute(input, response)) {
if (!server.graph_compute(input)) {
return;
}
if (!send_msg(sockfd, &response, sizeof(response))) {
break;
}
case RPC_CMD_GRAPH_RECOMPUTE: {
rpc_msg_graph_recompute_req request;
if (!recv_msg(sockfd, &request, sizeof(request))) {
return;
}
if (!server.graph_recompute(request)) {
return;
}
break;

View File

@ -91,7 +91,10 @@ if (GGML_SYCL_F16)
add_compile_definitions(GGML_SYCL_F16)
endif()
if (GGML_SYCL_TARGET STREQUAL "NVIDIA")
if (GGML_SYCL_TARGET STREQUAL "INTEL")
add_compile_definitions(GGML_SYCL_WARP_SIZE=16)
target_link_options(ggml-sycl PRIVATE -Xs -ze-intel-greater-than-4GB-buffer-required)
elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA")
add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
elseif (GGML_SYCL_TARGET STREQUAL "AMD")
# INFO: Allowed Sub_group_sizes are not consistent through all
@ -100,7 +103,8 @@ elseif (GGML_SYCL_TARGET STREQUAL "AMD")
# Target archs tested working: gfx1030, gfx1031, (Only tested sub_group_size = 32)
add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
else()
add_compile_definitions(GGML_SYCL_WARP_SIZE=16)
# default for other target
add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
endif()
if (GGML_SYCL_GRAPH)

View File

@ -617,4 +617,30 @@ static __dpct_inline__ float get_alibi_slope(const float max_bias,
return dpct::pow(base, exph);
}
static const sycl::uint3 init_fastdiv_values(uint32_t d) {
GGML_ASSERT(d != 0);
uint32_t L = 0;
while (L < 32 && (uint32_t{ 1 } << L) < d) {
L++;
}
uint32_t mp = (uint32_t) ((uint64_t{ 1 } << 32) * ((uint64_t{ 1 } << L) - d) / d + 1);
return sycl::uint3(mp, L, d);
}
static __dpct_inline__ uint32_t fastdiv(uint32_t n, const sycl::uint3 fastdiv_values) {
const uint32_t hi = sycl::mul_hi<unsigned>(n, fastdiv_values.x());
return (hi + n) >> fastdiv_values.y();
}
static __dpct_inline__ sycl::uint2 fast_div_modulo(uint32_t n, const sycl::uint3 fastdiv_values) {
const uint32_t div_val = fastdiv(n, fastdiv_values);
const uint32_t mod_val = n - div_val * fastdiv_values.z();
return sycl::uint2(div_val, mod_val);
}
#endif // GGML_SYCL_COMMON_HPP

View File

@ -515,9 +515,6 @@ void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, co
const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne == ggml_nelements(src1));
GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
GGML_TENSOR_BINARY_OP_LOCALS01;
SYCL_CHECK(ggml_sycl_set_device(ctx.device));

View File

@ -1,72 +1,100 @@
#include "pad_reflect_1d.hpp"
void pad_reflect_1d_f32(const float* src,float* dst,
const int64_t ne0, const int64_t ne02, const int p0, const int p1,
const int64_t nb0, const int64_t nb1, const int64_t nb2, const int64_t nb3,
const int64_t nb00, const int64_t nb01, const int64_t nb02, const int64_t nb03,
const sycl::nd_item<3> &item_ct1){
static void pad_reflect_1d_kernel_f32(
const void *__restrict__ src0, void *__restrict__ dst, const int64_t ne0,
const int64_t ne00, const sycl::uint3 ne01, const int64_t ne02,
const int64_t ne03, const int64_t nb00, const int64_t nb01,
const int64_t nb02, const int64_t nb03, const int64_t nb0,
const int64_t nb1, const int64_t nb2, const int64_t nb3, const int p0,
const int p1, sycl::nd_item<3> item_ct1) {
const int i0 = item_ct1.get_group(0) * SYCL_CONCAT_BLOCK_SIZE + item_ct1.get_local_id(0);
const int i1 = item_ct1.get_group(1);
const int g2 = item_ct1.get_group(2);
const int i2 = g2 % ne02;
const int i3 = g2 / ne02;
const int64_t i3 = item_ct1.get_group(0);
const int64_t i2 = item_ct1.get_group(1);
if (i0 >= p0 + ne0 + p1) return;
const sycl::uint2 div_mod_packed =
fast_div_modulo(item_ct1.get_group(2), ne01);
const int64_t tile1 = div_mod_packed.y();
const int64_t tile0 = div_mod_packed.x();
const int64_t i1 = tile1;
const int64_t i0 =
item_ct1.get_local_id(2) + tile0 * item_ct1.get_local_range(2);
int t = i0 - p0;
int period = 2 * ne0 -2;
int m = t % period;
m += (m < 0) * period;
int center = ne0 -1;
int srci0 = center - abs(center - m);
if (i0 >= ne0 || i1 >= ne01.z() || i2 >= ne02 || i3 >= ne03) {
return;
}
int offest_src = i3*nb3 + i2*nb2 + i1*nb1 + srci0*nb0;
int offest_dst = i3*nb03 + i2*nb02 + i1*nb01 + i0*nb00;
dst[offest_dst] = src[offest_src];
const char *src0_ptr =
(const char *)src0 + i3 * nb03 + i2 * nb02 + i1 * nb01;
char *dst_ptr = (char *)dst + i3 * nb3 + i2 * nb2 + i1 * nb1;
const int64_t rel_i0 = i0 - p0; // relative i0 in src0
int64_t src_idx;
if (rel_i0 < 0) {
// Left padding - reflect
src_idx = -rel_i0;
} else if (rel_i0 < ne00) {
// Middle - copy
src_idx = rel_i0;
} else {
// Right padding - reflect
src_idx = 2 * ne00 - 2 - rel_i0;
}
const float value = *(const float *)(src0_ptr + src_idx * nb00);
*(float *)(dst_ptr + i0 * nb0) = value;
GGML_UNUSED(p1);
}
void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context& ctx, ggml_tensor* dst){
void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context &ctx,
ggml_tensor *dst) {
const ggml_tensor * src0 = dst->src[0];
queue_ptr stream = ctx.stream();
const ggml_tensor *src0 = dst->src[0];
dpct::queue_ptr stream = ctx.stream();
GGML_ASSERT(src0->type == GGML_TYPE_F32);
GGML_ASSERT( dst->type == GGML_TYPE_F32);
GGML_ASSERT(dst->type == GGML_TYPE_F32);
const int32_t * opts = (const int32_t *) dst->op_params;
const int32_t *opts = (const int32_t *)dst->op_params;
const int p0 = opts[0];
const int p1 = opts[1];
const int64_t ne0 = src0->ne[0];
const int64_t ne00 = src0->ne[0];
const int64_t ne01 = src0->ne[1];
const sycl::uint3 ne01_packed = init_fastdiv_values(ne01);
const int64_t ne02 = src0->ne[2];
const int64_t ne03 = src0->ne[3];
const int64_t ne00 = dst->ne[0];
const int64_t ne01 = dst->ne[1];
const int64_t ne02 = dst->ne[2];
const int64_t ne03 = dst->ne[3];
const int64_t ne0 = dst->ne[0];
const int64_t nb00 = dst->nb[0];
const int64_t nb01 = dst->nb[1];
const int64_t nb02 = dst->nb[2];
const int64_t nb03 = dst->nb[3];
const int64_t nb0 = src0->nb[0];
const int64_t nb1 = src0->nb[1];
const int64_t nb2 = src0->nb[2];
const int64_t nb3 = src0->nb[3];
GGML_ASSERT(ne0 == ne00 + p0 + p1);
int num_blocks = (ne00 + SYCL_CONCAT_BLOCK_SIZE - 1) / SYCL_CONCAT_BLOCK_SIZE;
sycl::range<3> global(num_blocks * SYCL_CONCAT_BLOCK_SIZE, ne01, ne02*ne03);
sycl::range<3> local(SYCL_CONCAT_BLOCK_SIZE, 1, 1);
constexpr int64_t bx = SYCL_PAD_REFLECT_1D_BLOCK_SIZE;
const int64_t tiles0 = (ne0 + bx - 1) / bx;
const dpct::dim3 grid_dims((unsigned)(ne01 * tiles0), (unsigned)ne02,
(unsigned)ne03);
const dpct::dim3 block_dims((unsigned)bx, 1, 1);
stream->parallel_for(
sycl::nd_range<3>(global,
local),
[=](sycl::nd_item<3> item_ct1) { pad_reflect_1d_f32(
(const float *) src0->data, (float *) dst->data,
ne0, ne02, p0, p1,
nb0, nb1, nb2, nb3,
nb00, nb01, nb02, nb03
, item_ct1);
});
stream->submit([&](sycl::handler &cgh) {
auto src0_data_ct0 = src0->data;
auto dst_data_ct1 = dst->data;
auto src0_nb_ct7 = src0->nb[0];
auto src0_nb_ct8 = src0->nb[1];
auto src0_nb_ct9 = src0->nb[2];
auto src0_nb_ct10 = src0->nb[3];
auto dst_nb_ct11 = dst->nb[0];
auto dst_nb_ct12 = dst->nb[1];
auto dst_nb_ct13 = dst->nb[2];
auto dst_nb_ct14 = dst->nb[3];
cgh.parallel_for(sycl::nd_range<3>(grid_dims * block_dims, block_dims),
[=](sycl::nd_item<3> item_ct1) {
pad_reflect_1d_kernel_f32(
src0_data_ct0, dst_data_ct1, ne0, ne00,
ne01_packed, ne02, ne03, src0_nb_ct7,
src0_nb_ct8, src0_nb_ct9, src0_nb_ct10,
dst_nb_ct11, dst_nb_ct12, dst_nb_ct13,
dst_nb_ct14, p0, p1, item_ct1);
});
});
}

View File

@ -3,6 +3,8 @@
#include "common.hpp"
#define SYCL_PAD_REFLECT_1D_BLOCK_SIZE 256
void ggml_sycl_op_pad_reflect_1d(ggml_backend_sycl_context& ctx, ggml_tensor* dst);
#endif // GGML_SYCL_PAD_REFLECT_1D_HPP

View File

@ -399,6 +399,18 @@ struct vk_conv2d_pipeline_state {
}
};
struct vk_solve_tri_pipeline_state {
vk_solve_tri_pipeline_state(uint32_t N, uint32_t K)
: N(N), K(K) {}
uint32_t N, K;
bool operator<(const vk_solve_tri_pipeline_state &b) const {
return std::tie(N, K) <
std::tie(b.N, b.K);
}
};
enum shader_reduction_mode {
SHADER_REDUCTION_MODE_SHMEM,
SHADER_REDUCTION_MODE_HYBRID,
@ -601,9 +613,10 @@ struct vk_device_struct {
vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT];
vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT];
vk_pipeline pipeline_dequant_mul_mat_vec_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
vk_pipeline pipeline_dequant_mul_mat_vec_id_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT];
vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio];
vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
@ -637,6 +650,7 @@ struct vk_device_struct {
vk_pipeline pipeline_sin_f32;
vk_pipeline pipeline_cos_f32;
vk_pipeline pipeline_log[2];
vk_pipeline pipeline_tri[2];
vk_pipeline pipeline_clamp_f32;
vk_pipeline pipeline_pad_f32;
vk_pipeline pipeline_roll_f32;
@ -711,6 +725,7 @@ struct vk_device_struct {
vk_pipeline pipeline_cumsum_f32;
vk_pipeline pipeline_argmax_f32;
vk_pipeline pipeline_count_equal_i32;
std::map<vk_solve_tri_pipeline_state, vk_pipeline> pipeline_solve_tri_f32;
vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16;
vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16;
vk_pipeline pipeline_timestep_embedding_f32;
@ -1597,7 +1612,7 @@ class vk_perf_logger {
}
if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
const uint64_t m = node->src[0]->ne[1];
const uint64_t n = node->ne[1];
const uint64_t n = (node->op == GGML_OP_MUL_MAT) ? node->ne[1] : node->ne[2];
const uint64_t k = node->src[1]->ne[0];
const uint64_t batch = node->src[1]->ne[2] * node->src[1]->ne[3];
std::string name = ggml_op_name(node->op);
@ -3511,13 +3526,18 @@ static void ggml_vk_load_shaders(vk_device& device) {
// the number of rows computed per shader depends on GPU model and quant
uint32_t rm_stdq = 1;
uint32_t rm_kq = 2;
uint32_t rm_stdq_int = 1;
uint32_t rm_kq_int = 1;
if (device->vendor_id == VK_VENDOR_ID_AMD) {
if (device->architecture == AMD_GCN) {
rm_stdq = 2;
rm_kq = 4;
rm_stdq_int = 4;
}
} else if (device->vendor_id == VK_VENDOR_ID_INTEL)
} else if (device->vendor_id == VK_VENDOR_ID_INTEL) {
rm_stdq = 2;
rm_stdq_int = 2;
}
uint32_t rm_iq = 2 * rm_kq;
const bool use_subgroups = device->subgroup_arithmetic && device->architecture != vk_device_architecture::AMD_GCN;
@ -3598,39 +3618,73 @@ static void ggml_vk_load_shaders(vk_device& device) {
const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;
const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32", arr_dmmv_q4_0_q8_1_f32_len[reduc], arr_dmmv_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32", arr_dmmv_q4_1_q8_1_f32_len[reduc], arr_dmmv_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_q8_1_f32", arr_dmmv_q5_0_q8_1_f32_len[reduc], arr_dmmv_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_q8_1_f32", arr_dmmv_q5_1_q8_1_f32_len[reduc], arr_dmmv_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32", arr_dmmv_q8_0_q8_1_f32_len[reduc], arr_dmmv_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32", arr_dmmv_q4_0_q8_1_f32_len[reduc], arr_dmmv_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32", arr_dmmv_q4_1_q8_1_f32_len[reduc], arr_dmmv_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_q8_1_f32", arr_dmmv_q5_0_q8_1_f32_len[reduc], arr_dmmv_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_q8_1_f32", arr_dmmv_q5_1_q8_1_f32_len[reduc], arr_dmmv_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32", arr_dmmv_q8_0_q8_1_f32_len[reduc], arr_dmmv_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_q8_1_f32", arr_dmmv_mxfp4_q8_1_f32_len[reduc], arr_dmmv_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_q8_1_f32", arr_dmmv_q2_k_q8_1_f32_len[reduc], arr_dmmv_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_q8_1_f32", arr_dmmv_q3_k_q8_1_f32_len[reduc], arr_dmmv_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_q8_1_f32", arr_dmmv_q4_k_q8_1_f32_len[reduc], arr_dmmv_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_q8_1_f32", arr_dmmv_q5_k_q8_1_f32_len[reduc], arr_dmmv_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_q8_1_f32", arr_dmmv_q6_k_q8_1_f32_len[reduc], arr_dmmv_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
}
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
}
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", arr_dmmv_id_f32_f32_f32_len[reduc], arr_dmmv_id_f32_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", arr_dmmv_id_f16_f32_f32_len[reduc], arr_dmmv_id_f16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", arr_dmmv_id_bf16_f32_f32_len[reduc], arr_dmmv_id_bf16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", arr_dmmv_id_q4_0_f32_f32_len[reduc], arr_dmmv_id_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", arr_dmmv_id_q4_1_f32_f32_len[reduc], arr_dmmv_id_q4_1_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", arr_dmmv_id_q5_0_f32_f32_len[reduc], arr_dmmv_id_q5_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", arr_dmmv_id_q5_1_f32_f32_len[reduc], arr_dmmv_id_q5_1_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", arr_dmmv_id_q8_0_f32_f32_len[reduc], arr_dmmv_id_q8_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", arr_dmmv_id_q2_k_f32_f32_len[reduc16], arr_dmmv_id_q2_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", arr_dmmv_id_q3_k_f32_f32_len[reduc16], arr_dmmv_id_q3_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", arr_dmmv_id_q4_k_f32_f32_len[reduc16], arr_dmmv_id_q4_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", arr_dmmv_id_q5_k_f32_f32_len[reduc16], arr_dmmv_id_q5_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", arr_dmmv_id_q6_k_f32_f32_len[reduc16], arr_dmmv_id_q6_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", arr_dmmv_id_iq1_s_f32_f32_len[reduc16], arr_dmmv_id_iq1_s_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", arr_dmmv_id_iq1_m_f32_f32_len[reduc16], arr_dmmv_id_iq1_m_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", arr_dmmv_id_iq2_xxs_f32_f32_len[reduc16], arr_dmmv_id_iq2_xxs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", arr_dmmv_id_iq2_xs_f32_f32_len[reduc16], arr_dmmv_id_iq2_xs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", arr_dmmv_id_iq2_s_f32_f32_len[reduc16], arr_dmmv_id_iq2_s_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", arr_dmmv_id_iq3_xxs_f32_f32_len[reduc16], arr_dmmv_id_iq3_xxs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", arr_dmmv_id_iq3_s_f32_f32_len[reduc16], arr_dmmv_id_iq3_s_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", arr_dmmv_id_iq4_xs_f32_f32_len[reduc16], arr_dmmv_id_iq4_xs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", arr_dmmv_id_iq4_nl_f32_f32_len[reduc16], arr_dmmv_id_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", arr_dmmv_id_mxfp4_f32_f32_len[reduc16], arr_dmmv_id_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;
const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_q8_1_f32", arr_dmmv_id_q4_0_q8_1_f32_len[reduc], arr_dmmv_id_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_q8_1_f32", arr_dmmv_id_q4_1_q8_1_f32_len[reduc], arr_dmmv_id_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_q8_1_f32", arr_dmmv_id_q5_0_q8_1_f32_len[reduc], arr_dmmv_id_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_q8_1_f32", arr_dmmv_id_q5_1_q8_1_f32_len[reduc], arr_dmmv_id_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_q8_1_f32", arr_dmmv_id_q8_0_q8_1_f32_len[reduc], arr_dmmv_id_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_q8_1_f32", arr_dmmv_id_mxfp4_q8_1_f32_len[reduc], arr_dmmv_id_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_q8_1_f32", arr_dmmv_id_q2_k_q8_1_f32_len[reduc], arr_dmmv_id_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_q8_1_f32", arr_dmmv_id_q3_k_q8_1_f32_len[reduc], arr_dmmv_id_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", arr_dmmv_id_q4_k_q8_1_f32_len[reduc], arr_dmmv_id_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", arr_dmmv_id_q5_k_q8_1_f32_len[reduc], arr_dmmv_id_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", arr_dmmv_id_q6_k_q8_1_f32_len[reduc], arr_dmmv_id_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
}
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
}
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", mul_mat_vec_id_bf16_f32_len, mul_mat_vec_id_bf16_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", mul_mat_vec_id_iq1_s_f32_len, mul_mat_vec_id_iq1_s_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", mul_mat_vec_id_iq1_m_f32_len, mul_mat_vec_id_iq1_m_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", mul_mat_vec_id_mxfp4_f32_len, mul_mat_vec_id_mxfp4_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
#if !defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
GGML_UNUSED(rm_stdq_int);
GGML_UNUSED(rm_kq_int);
#endif
// dequant shaders
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
@ -3863,6 +3917,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
}
ggml_vk_create_pipeline(device, device->pipeline_tri[0], "tri_f32", tri_f32_len, tri_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_tri[1], "tri_f16", tri_f16_len, tri_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1);
@ -4002,6 +4059,14 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1);
for (auto &s : device->pipeline_solve_tri_f32) {
const vk_solve_tri_pipeline_state &state = s.first;
ggml_vk_create_pipeline(
device, s.second, "solve_tri_f32",
solve_tri_f32_len, solve_tri_f32_data, "main", 3,
sizeof(vk_op_binary_push_constants), {1, 1, 1}, { 0, state.N, state.K }, 1, true);
}
#define IM2COL(bda) \
ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \
ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \
@ -5289,7 +5354,8 @@ static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) {
ctx->prealloc_size_x = 0;
ctx->prealloc_size_y = 0;
ctx->prealloc_size_split_k = 0;
ctx->prealloc_size_add_rms_partials = 0;
// Fixed size of 1KB, for deterministic behavior
ctx->prealloc_size_add_rms_partials = 1024;
ctx->fence = ctx->device->device.createFence({});
ctx->almost_ready_fence = ctx->device->device.createFence({});
@ -5427,6 +5493,12 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
break;
default:
return nullptr;
@ -5566,9 +5638,28 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
}
}
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) {
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t m, uint32_t k) {
VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec_id()");
GGML_ASSERT(b_type == GGML_TYPE_F32);
GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_Q8_1);
if (b_type == GGML_TYPE_Q8_1) {
switch (a_type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
break;
default:
return nullptr;
}
}
switch (a_type) {
case GGML_TYPE_F32:
@ -5599,7 +5690,31 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
return nullptr;
}
return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type];
// heuristic to choose workgroup size
uint32_t dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
if ((ctx->device->vendor_id == VK_VENDOR_ID_NVIDIA && ctx->device->architecture != vk_device_architecture::NVIDIA_PRE_TURING) || ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {
// Prefer larger workgroups when M is small, to spread the work out more
// and keep more SMs busy.
// q6_k seems to prefer small workgroup size even for "medium" values of M.
if (a_type == GGML_TYPE_Q6_K) {
if (m < 4096 && k >= 1024) {
dmmv_wg = DMMV_WG_SIZE_LARGE;
}
} else {
if (m <= 8192 && k >= 1024) {
dmmv_wg = DMMV_WG_SIZE_LARGE;
}
}
}
if (b_type == GGML_TYPE_Q8_1) {
if (ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {
dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
}
return ctx->device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[dmmv_wg][a_type];
}
return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[dmmv_wg][a_type];
}
static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
@ -6791,20 +6906,35 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
return false;
}
// General performance issue with q3_k and q6_k due to 2-byte alignment
if (src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q6_K) {
return false;
}
// MMVQ is generally good for batches
if (n > 1) {
return true;
}
// Quantization overhead is not worth it for small k
switch (device->vendor_id) {
case VK_VENDOR_ID_NVIDIA:
if (k <= 4096) {
return false;
}
switch (src0_type) {
case GGML_TYPE_MXFP4:
case GGML_TYPE_Q8_0:
return device->architecture == vk_device_architecture::NVIDIA_PRE_TURING;
default:
return true;
}
case VK_VENDOR_ID_AMD:
if (k < 2048) {
return false;
}
switch (src0_type) {
case GGML_TYPE_Q8_0:
return device->architecture == vk_device_architecture::AMD_GCN;
@ -6812,6 +6942,10 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
return true;
}
case VK_VENDOR_ID_INTEL:
if (k < 2048) {
return false;
}
switch (src0_type) {
// From tests on A770 Linux, may need more tuning
case GGML_TYPE_Q4_0:
@ -6825,7 +6959,6 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
}
GGML_UNUSED(m);
GGML_UNUSED(k);
}
static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
@ -7548,7 +7681,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (x_non_contig || qx_needs_dequant) {
ctx->prealloc_x_need_sync = true;
}
if (y_non_contig) {
if (y_non_contig || quantize_y) {
ctx->prealloc_y_need_sync = true;
}
}
@ -7574,7 +7707,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
const uint64_t ne10 = src1->ne[0];
const uint64_t ne11 = src1->ne[1];
// const uint64_t ne12 = src1->ne[2];
const uint64_t ne12 = src1->ne[2];
// const uint64_t ne13 = src1->ne[3];
const uint64_t nei0 = ids->ne[0];
@ -7591,19 +7724,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
const bool qx_needs_dequant = x_non_contig;
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig;
// Not implemented
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
const uint64_t x_ne = ggml_nelements(src0);
const uint64_t y_ne = ggml_nelements(src1);
const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && !y_non_contig && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne12, ne10, src0->type);
vk_pipeline to_fp16_vk_0 = nullptr;
vk_pipeline to_fp16_vk_1 = nullptr;
@ -7615,11 +7736,38 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
} else {
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
}
vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type);
// Check for mmq first
vk_pipeline dmmv = quantize_y ? ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, GGML_TYPE_Q8_1, ne20, ne00) : nullptr;
vk_pipeline to_q8_1 = nullptr;
if (dmmv == nullptr) {
// Fall back to f16 dequant mul mat
dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type, ne20, ne00);
quantize_y = false;
}
if (quantize_y) {
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
}
const bool qx_needs_dequant = x_non_contig;
const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig);
// Not implemented
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
GGML_ASSERT(dmmv != nullptr);
const uint64_t x_ne = ggml_nelements(src0);
const uint64_t y_ne = ggml_nelements(src1);
const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) :
(f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
{
if (
(qx_needs_dequant && x_sz > ctx->device->properties.limits.maxStorageBufferRange) ||
@ -7630,7 +7778,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
ctx->prealloc_size_x = x_sz;
ggml_vk_preallocate_buffers(ctx, subctx);
}
if (qy_needs_dequant && ctx->prealloc_size_y < y_sz) {
if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz) {
ctx->prealloc_size_y = y_sz;
ggml_vk_preallocate_buffers(ctx, subctx);
}
@ -7642,6 +7790,9 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
if (qy_needs_dequant) {
ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
}
if (quantize_y) {
ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
}
ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
}
@ -7657,7 +7808,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
} else {
d_X = d_Qx;
}
if (qy_needs_dequant) {
if (qy_needs_dequant || quantize_y) {
d_Y = { ctx->prealloc_y, 0, ctx->prealloc_y->size };
} else {
d_Y = d_Qy;
@ -7685,6 +7836,17 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
ctx->prealloc_y_last_tensor_used = src1;
}
}
if (quantize_y) {
if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_quantize_q8_1(ctx, subctx, d_Qy, d_Y, y_ne);
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
ctx->prealloc_y_last_tensor_used = src1;
}
}
uint32_t stride_batch_y = ne10*ne11;
@ -7746,7 +7908,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
if (x_non_contig) {
ctx->prealloc_x_need_sync = true;
}
if (y_non_contig) {
if (y_non_contig || quantize_y) {
ctx->prealloc_y_need_sync = true;
}
}
@ -8268,6 +8430,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_log[dst->type == GGML_TYPE_F16];
}
return nullptr;
case GGML_OP_TRI:
if (src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
return ctx->device->pipeline_tri[dst->type == GGML_TYPE_F16];
}
return nullptr;
case GGML_OP_CLAMP:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_clamp_f32;
@ -8495,6 +8663,26 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_cumsum_f32;
}
return nullptr;
case GGML_OP_SOLVE_TRI:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
vk_solve_tri_pipeline_state solve_tri_pipeline_state(src0->ne[0], src1->ne[0]);
vk_pipeline pipeline = nullptr;
{
std::lock_guard<std::recursive_mutex> guard(ctx->device->mutex);
auto it = ctx->device->pipeline_solve_tri_f32.find(solve_tri_pipeline_state);
if (it != ctx->device->pipeline_solve_tri_f32.end()) {
pipeline = it->second;
} else {
ctx->device->pipeline_solve_tri_f32[solve_tri_pipeline_state] = pipeline = std::make_shared<vk_pipeline_struct>();
}
}
return pipeline;
}
return nullptr;
case GGML_OP_ARGMAX:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) {
return ctx->device->pipeline_argmax_f32;
@ -8686,41 +8874,6 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
GGML_UNUSED(src2);
}
static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
switch (op) {
case GGML_OP_CPY:
case GGML_OP_GET_ROWS:
case GGML_OP_ADD:
case GGML_OP_SUB:
case GGML_OP_MUL:
case GGML_OP_DIV:
case GGML_OP_ADD_ID:
case GGML_OP_CONCAT:
case GGML_OP_UPSCALE:
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_LOG:
case GGML_OP_CLAMP:
case GGML_OP_PAD:
case GGML_OP_REPEAT:
case GGML_OP_REPEAT_BACK:
case GGML_OP_ROPE:
case GGML_OP_RMS_NORM:
case GGML_OP_CONV_2D_DW:
case GGML_OP_IM2COL:
case GGML_OP_IM2COL_3D:
case GGML_OP_SET_ROWS:
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
return true;
default:
return false;
}
}
template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * src3, ggml_tensor * dst) {
const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type);
const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type);
@ -8805,7 +8958,6 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3];
std::cerr << "), " << ggml_op_name(op) << ")");
GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT
GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT
GGML_ASSERT(dst->buffer != nullptr);
const uint64_t ne00 = src0->ne[0];
const uint64_t ne01 = src0->ne[1];
@ -8836,22 +8988,17 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1);
const bool op_supports_incontiguous = ggml_vk_op_supports_incontiguous(op);
vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0, op_supports_incontiguous);
vk_subbuffer src1_buf = use_src1 ? ggml_vk_tensor_subbuffer(ctx, src1, op_supports_incontiguous) : vk_subbuffer{};
vk_subbuffer src2_buf = use_src2 ? ggml_vk_tensor_subbuffer(ctx, src2, op_supports_incontiguous) : vk_subbuffer{};
vk_subbuffer src3_buf = use_src3 ? ggml_vk_tensor_subbuffer(ctx, src3, op_supports_incontiguous) : vk_subbuffer{};
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, op_supports_incontiguous);
vk_subbuffer src0_buf = ggml_vk_tensor_subbuffer(ctx, src0, true);
vk_subbuffer src1_buf = use_src1 ? ggml_vk_tensor_subbuffer(ctx, src1, true) : vk_subbuffer{};
vk_subbuffer src2_buf = use_src2 ? ggml_vk_tensor_subbuffer(ctx, src2, true) : vk_subbuffer{};
vk_subbuffer src3_buf = use_src3 ? ggml_vk_tensor_subbuffer(ctx, src3, true) : vk_subbuffer{};
vk_subbuffer dst_buf = ggml_vk_tensor_subbuffer(ctx, dst, true);
// Compute misalignment offset for descriptors and store it in in push constants.
init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, src3, dst);
std::array<uint32_t, 3> elements;
// Single call if dimension 2 is contiguous
GGML_ASSERT(op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1))));
switch (op) {
case GGML_OP_NORM:
case GGML_OP_RMS_NORM_BACK:
@ -8872,6 +9019,18 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
elements = { nr, 1, 1 };
}
} break;
case GGML_OP_SOLVE_TRI:
{
uint32_t nr = (uint32_t)(ne02 * ne03);
if (nr > 262144) {
elements = { 512, 512, CEIL_DIV(nr, 262144) };
} else if (nr > 512) {
elements = { 512, CEIL_DIV(nr, 512), 1 };
} else {
elements = { nr, 1, 1 };
}
}
break;
case GGML_OP_RMS_NORM:
if (ctx->do_add_rms_partials) {
// Run one element per thread, 128 threads per workgroup
@ -8978,6 +9137,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_LOG:
case GGML_OP_TRI:
case GGML_OP_CLAMP:
case GGML_OP_PAD:
case GGML_OP_ROLL:
@ -9658,6 +9818,13 @@ static void ggml_vk_log(ggml_backend_vk_context * ctx, vk_context& subctx, const
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LOG, vk_op_unary_push_constants_init(src0, dst));
}
static void ggml_vk_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
p.param1 = ggml_get_op_params_f32(dst, 0);
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_TRI, std::move(p));
}
static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
p.param1 = ggml_get_op_params_f32(dst, 0);
@ -10208,7 +10375,9 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
// Prefer going as small as num_topk_pipelines - 3 for perf reasons.
// But if K is larger, then we need a larger workgroup
uint32_t max_pipeline = num_topk_pipelines - 3;
uint32_t max_pipeline = num_topk_pipelines - 1;
uint32_t preferred_pipeline = std::max(num_topk_pipelines - 3, (uint32_t)log2f(float(k)) + 2);
max_pipeline = std::min(preferred_pipeline, max_pipeline);
uint32_t min_pipeline = (uint32_t)log2f(float(k)) + 1;
// require full subgroup
min_pipeline = std::max(min_pipeline, ctx->device->subgroup_size_log2);
@ -10300,6 +10469,21 @@ static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subct
ggml_vk_op_f32<vk_op_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f });
}
static void ggml_vk_solve_tri(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const uint32_t src0_type_size = ggml_type_size(src0->type);
const uint32_t src1_type_size = ggml_type_size(src1->type);
const uint32_t dst_type_size = ggml_type_size(dst->type);
ggml_vk_op_f32<vk_op_binary_push_constants>(ctx, subctx, src0, src1, nullptr, nullptr, dst, GGML_OP_SOLVE_TRI, {
(uint32_t)ggml_nelements(src0),
(uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size,
(uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size,
(uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size,
0,
0.0f, 0.0f, 0,
});
}
static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
const int32_t s0 = dst->op_params[0];
const int32_t s1 = dst->op_params[1];
@ -11766,6 +11950,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_LOG:
ggml_vk_log(ctx, compute_ctx, src0, node);
break;
case GGML_OP_TRI:
ggml_vk_tri(ctx, compute_ctx, src0, node);
break;
case GGML_OP_CLAMP:
ggml_vk_clamp(ctx, compute_ctx, src0, node);
@ -11911,6 +12099,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_COUNT_EQUAL:
ggml_vk_count_equal(ctx, compute_ctx, src0, src1, node);
break;
case GGML_OP_SOLVE_TRI:
ggml_vk_solve_tri(ctx, compute_ctx, src0, src1, node);
break;
case GGML_OP_IM2COL:
ggml_vk_im2col(ctx, compute_ctx, src0, src1, node);
@ -13095,7 +13287,6 @@ static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cg
ctx->fused_ops_write_mask = 0;
}
ctx->prealloc_size_add_rms_partials = std::max(ctx->prealloc_size_add_rms_partials, ctx->prealloc_size_add_rms_partials_offset);
ctx->last_total_mul_mat_bytes = total_mul_mat_bytes;
if (vk_perf_logger_enabled) {
@ -13876,17 +14067,21 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
op->type == GGML_TYPE_F32;
case GGML_OP_SILU_BACK:
case GGML_OP_RMS_NORM_BACK:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SQR:
case GGML_OP_SQRT:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_CLAMP:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_LEAKY_RELU:
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
return op->src[0]->type == GGML_TYPE_F32;
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_LOG:
return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
case GGML_OP_TRI:
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) &&
op->type == op->src[0]->type;
case GGML_OP_ARGSORT:
{
if (!ggml_is_contiguous(op) || !ggml_is_contiguous(op->src[0])) {
@ -13919,17 +14114,29 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return true;
case GGML_OP_UPSCALE:
case GGML_OP_ACC:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_CONCAT:
return ggml_type_size(op->src[0]->type) == ggml_type_size(GGML_TYPE_F32);
case GGML_OP_ADD1:
return (op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32)
|| (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F32)
|| (op->src[0]->type == GGML_TYPE_F16 && op->src[1]->type == GGML_TYPE_F16);
case GGML_OP_ARANGE:
case GGML_OP_FILL:
return op->type == GGML_TYPE_F32;
case GGML_OP_SCALE:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_PAD:
case GGML_OP_ROLL:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_DIAG_MASK_INF:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_SOFT_MAX:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32
&& (!op->src[1] || (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16));
case GGML_OP_SOFT_MAX_BACK:
return true;
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32
&& ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_F32;
case GGML_OP_SUM:
case GGML_OP_SUM_ROWS:
case GGML_OP_MEAN:
@ -13943,16 +14150,47 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
}
return false;
}
case GGML_OP_SOLVE_TRI:
{
ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context;
const vk_device& device = ggml_vk_get_device(ctx->device);
if (op->type != GGML_TYPE_F32 || op->src[0]->type != GGML_TYPE_F32) {
return false;
}
const uint32_t N = op->src[0]->ne[0];
const uint32_t K = op->src[1]->ne[0];
// K dimension limited to workgroup size
if (K > 128) {
return false;
}
if (N * N * sizeof(float) + N * K * sizeof(float) > device->properties.limits.maxComputeSharedMemorySize) {
return false;
}
return true;
}
case GGML_OP_ARGMAX:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_COUNT_EQUAL:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_I32
&& ggml_is_contiguous(op->src[1]) && op->src[1]->type == GGML_TYPE_I32;
case GGML_OP_IM2COL:
return ggml_is_contiguous(op->src[1])
&& op->src[1]->type == GGML_TYPE_F32
&& (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
case GGML_OP_IM2COL_3D:
return op->src[1]->type == GGML_TYPE_F32
&& (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16);
case GGML_OP_TIMESTEP_EMBEDDING:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_CONV_2D_DW:
return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16)
&& op->src[1]->type == GGML_TYPE_F32;
case GGML_OP_POOL_2D:
return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_RWKV_WKV6:
case GGML_OP_RWKV_WKV7:
return true;
return true; // all inputs are contiguous, see ggml.c
case GGML_OP_SSM_SCAN:
{
for (int i = 0; i < 6; i++) {
@ -13993,7 +14231,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
return true;
}
case GGML_OP_SSM_CONV:
return true;
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_CONV_TRANSPOSE_1D:
return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32;
case GGML_OP_CONV_2D:
@ -14434,6 +14672,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_LOG) {
tensor_clone = ggml_log(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_TRI) {
tensor_clone = ggml_tri(ggml_ctx, src_clone[0], ggml_get_op_params_i32(tensor, 0));
} else if (tensor->op == GGML_OP_CLAMP) {
const float * params = (const float *)tensor->op_params;
tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
@ -14603,6 +14843,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_COUNT_EQUAL) {
tensor_clone = ggml_count_equal(ggml_ctx, src_clone[0], src_clone[1]);
} else if (tensor->op == GGML_OP_SOLVE_TRI) {
tensor_clone = ggml_solve_tri(ggml_ctx, src_clone[0], src_clone[1], true, true, false);
} else if (tensor->op == GGML_OP_IM2COL) {
const int32_t s0 = tensor->op_params[0];
const int32_t s1 = tensor->op_params[1];

View File

@ -4,13 +4,6 @@
#include "types.glsl"
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
#if defined(DATA_A_F32)
vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);

View File

@ -22,6 +22,13 @@ layout (push_constant) uniform parameter
#if !RMS_NORM_ROPE_FUSION
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#endif

View File

@ -18,6 +18,13 @@ layout (push_constant) uniform parameter
} p;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
uint get_idx() {

View File

@ -3,6 +3,7 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.glsl"
#include "dequant_funcs.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

View File

@ -13,8 +13,6 @@
#include "mul_mat_vec_iface.glsl"
#include "dequant_funcs.glsl"
layout (push_constant) uniform parameter
{
uint ncols;

View File

@ -5,13 +5,15 @@
#define MAT_VEC_FUSION_FLAGS_SCALE0 0x4
#define MAT_VEC_FUSION_FLAGS_SCALE1 0x8
#ifndef MMQ
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(A_TYPE_VEC4)
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
#endif
#else
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};

View File

@ -10,60 +10,56 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
#if defined(DATA_A_QUANT_LEGACY) || defined(DATA_A_MXFP4)
#define K_PER_ITER 8
#include "mul_mmq_funcs.glsl"
#elif defined(DATA_A_QUANT_K)
#define K_PER_ITER 16
#else
#error unimplemented
#endif
uint a_offset, b_offset, d_offset;
int32_t cache_b_qs[2];
int32_t cache_b_qs[K_PER_ITER / 4];
vec2 cache_b_ds;
#include "mul_mat_vecq_funcs.glsl"
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
const uint col = i*BLOCK_SIZE + tid*K_PER_ITER;
// Preload data_b block
const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset;
const uint b_qs_idx = tid % 4;
const uint b_qs_idx = tid % (32 / K_PER_ITER);
const uint b_block_idx_outer = b_block_idx / 4;
const uint b_block_idx_inner = b_block_idx % 4;
cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]);
#if QUANT_R == 2
// Assumes K_PER_ITER == 8
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx];
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx + 4];
#else
#if K_PER_ITER == 8
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2];
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2 + 1];
#elif K_PER_ITER == 16
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 ];
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 1];
cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 2];
cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 3];
#else
#error unimplemented
#endif
#endif
uint ibi = first_row*p.ncols;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint a_block_idx = (ibi + col)/QUANT_K + a_offset;
const uint a_block_idx = (ibi + col)/QUANT_K_Q8_1 + a_offset;
ibi += p.ncols;
int32_t q_sum = 0;
#if QUANT_R == 2
const i32vec2 data_a_qs = repack(a_block_idx, b_qs_idx);
q_sum += dotPacked4x8EXT(data_a_qs.x,
cache_b_qs[0]);
q_sum += dotPacked4x8EXT(data_a_qs.y,
cache_b_qs[1]);
#else
int32_t data_a_qs = repack(a_block_idx, b_qs_idx * 2);
q_sum += dotPacked4x8EXT(data_a_qs,
cache_b_qs[0]);
data_a_qs = repack(a_block_idx, b_qs_idx * 2 + 1);
q_sum += dotPacked4x8EXT(data_a_qs,
cache_b_qs[1]);
#endif
#if QUANT_AUXF == 1
temp[j][n] += mul_q8_1(q_sum, get_d(a_block_idx), cache_b_ds, 4);
#else
temp[j][n] += mul_q8_1(q_sum, get_dm(a_block_idx), cache_b_ds, 4);
#endif
temp[j][n] += mmvq_dot_product(a_block_idx, b_qs_idx);
}
}
}
@ -72,7 +68,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint tid = gl_LocalInvocationID.x;
get_offsets(a_offset, b_offset, d_offset);
a_offset /= QUANT_K;
a_offset /= QUANT_K_Q8_1;
b_offset /= QUANT_K_Q8_1;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
@ -102,14 +98,6 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
unroll_count = 2;
unrolled_iters = num_iters & ~(unroll_count - 1);
#if K_PER_ITER == 2
if ((p.ncols & 1) != 0 &&
unrolled_iters == num_iters &&
unrolled_iters > 0) {
unrolled_iters -= unroll_count;
}
#endif
while (i < unrolled_iters) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
@ -128,6 +116,10 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);

View File

@ -0,0 +1,379 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#include "types.glsl"
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
FLOAT_TYPE get_dm(uint ib) {
return FLOAT_TYPE(data_a[ib].d);
}
#endif
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
FLOAT_TYPE_VEC2 get_dm(uint ib) {
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
}
#endif
#if defined(DATA_A_MXFP4)
FLOAT_TYPE get_dm(uint ib) {
return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e));
}
#endif
#if defined(DATA_A_Q2_K)
FLOAT_TYPE_VEC2 get_dm(uint ib) {
const uint ib_k = ib / 8;
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
}
#endif
// Each iqs value maps to a 32-bit integer
#if defined(DATA_A_Q4_0)
// 2-byte loads for Q4_0 blocks (18 bytes)
i32vec2 repack(uint ib, uint iqs) {
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
}
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return FLOAT_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
}
#endif
#if defined(DATA_A_Q4_1)
// 4-byte loads for Q4_1 blocks (20 bytes)
i32vec2 repack(uint ib, uint iqs) {
const uint32_t vui = data_a_packed32[ib].qs[iqs];
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
}
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return FLOAT_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
}
#endif
#if defined(DATA_A_Q5_0)
// 2-byte loads for Q5_0 blocks (22 bytes)
i32vec2 repack(uint ib, uint iqs) {
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs));
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
return i32vec2(v0, v1);
}
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return FLOAT_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
}
#endif
#if defined(DATA_A_Q5_1)
// 4-byte loads for Q5_1 blocks (24 bytes)
i32vec2 repack(uint ib, uint iqs) {
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
return i32vec2(v0, v1);
}
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return FLOAT_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
}
#endif
#if defined(DATA_A_Q8_0)
// 2-byte loads for Q8_0 blocks (34 bytes)
int32_t repack(uint ib, uint iqs) {
return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]));
}
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return FLOAT_TYPE(float(q_sum) * da * dsb.x);
}
#endif
#if defined(DATA_A_MXFP4)
// 1-byte loads for mxfp4 blocks (17 bytes)
i32vec2 repack(uint ib, uint iqs) {
const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
data_a[ib].qs[iqs * 4 + 1],
data_a[ib].qs[iqs * 4 + 2],
data_a[ib].qs[iqs * 4 + 3]));
const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
return i32vec2(pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w])),
pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w])));
}
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return FLOAT_TYPE(da * dsb.x * float(q_sum) * 0.5);
}
#endif
#if defined(DATA_A_QUANT_LEGACY) || defined(DATA_A_MXFP4)
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
int32_t q_sum = 0;
#if QUANT_R == 2
const i32vec2 data_a_qs = repack(ib_a, iqs);
q_sum += dotPacked4x8EXT(data_a_qs.x,
cache_b_qs[0]);
q_sum += dotPacked4x8EXT(data_a_qs.y,
cache_b_qs[1]);
#else
int32_t data_a_qs = repack(ib_a, iqs * 2);
q_sum += dotPacked4x8EXT(data_a_qs,
cache_b_qs[0]);
data_a_qs = repack(ib_a, iqs * 2 + 1);
q_sum += dotPacked4x8EXT(data_a_qs,
cache_b_qs[1]);
#endif
// 2 quants per call => divide sums by 8/2 = 4
return mul_q8_1(q_sum, get_dm(ib_a), cache_b_ds, 4);
}
#endif
#if defined(DATA_A_Q2_K)
// 4-byte loads for Q2_K blocks (84 bytes)
i32vec4 repack4(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
return i32vec4((data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303,
(data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303,
(data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303,
(data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303);
}
uint8_t get_scale(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
return data_a[ib_k].scales[iqs_k / 4];
}
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
int32_t sum_d = 0;
int32_t sum_m = 0;
const i32vec4 qs_a = repack4(ib_a, iqs * 4);
const uint8_t scale = get_scale(ib_a, iqs * 4);
const vec2 dm = vec2(get_dm(ib_a));
const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits.
sum_d += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]) * (scale & 0xF);
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[0]);
sum_d += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]) * (scale & 0xF);
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[1]);
sum_d += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]) * (scale & 0xF);
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[2]);
sum_d += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]) * (scale & 0xF);
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[3]);
return FLOAT_TYPE(float(cache_b_ds.x) * (float(dm.x) * float(sum_d) - float(dm.y) * float(sum_m)));
}
#endif
#if defined(DATA_A_Q3_K)
// 2-byte loads for Q3_K blocks (110 bytes)
i32vec4 repack4(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
const uint hm_shift = iqs_k / 8;
// bitwise OR to add 4 if hmask is set, subtract later
const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals20 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 4] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 4] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals21 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 5] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 5] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals30 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 6] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 6] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals31 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 7] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 7] >> hm_shift) & uint16_t(0x0101)) << 2));
return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y) - int8_t(4)),
pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y) - int8_t(4)),
pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y) - int8_t(4)),
pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y) - int8_t(4)));
}
float get_d_scale(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint is = iqs_k / 4;
const int8_t scale = int8_t(((data_a[ib_k].scales[is % 8 ] >> (4 * (is / 8))) & 0x0F0F) |
(((data_a[ib_k].scales[8 + (is % 4)] >> (2 * (is / 4))) & 0x0303) << 4));
return float(data_a[ib_k].d) * float(scale - 32);
}
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
int32_t q_sum = 0;
const i32vec4 qs_a = repack4(ib_a, iqs * 4);
const float d_scale = get_d_scale(ib_a, iqs * 4);
q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]);
q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]);
q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]);
q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]);
return FLOAT_TYPE(float(cache_b_ds.x) * d_scale * float(q_sum));
}
#endif
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
i32vec4 repack4(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint qs_idx = (iqs_k / 16) * 8 + (iqs_k % 8);
const uint qs_shift = ((iqs_k % 16) / 8) * 4;
#if defined(DATA_A_Q4_K)
const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F;
const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F;
const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x0F0F0F0F;
const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x0F0F0F0F;
return i32vec4(vals0, vals1, vals2, vals3);
#else // defined(DATA_A_Q5_K)
const uint qh_idx = iqs;
const uint qh_shift = iqs_k / 8;
return i32vec4(((data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F) |
(((data_a_packed32[ib_k].qh[qh_idx ] >> qh_shift) & 0x01010101) << 4),
((data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F) |
(((data_a_packed32[ib_k].qh[qh_idx + 1] >> qh_shift) & 0x01010101) << 4),
((data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x0F0F0F0F) |
(((data_a_packed32[ib_k].qh[qh_idx + 2] >> qh_shift) & 0x01010101) << 4),
((data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x0F0F0F0F) |
(((data_a_packed32[ib_k].qh[qh_idx + 3] >> qh_shift) & 0x01010101) << 4));
#endif
}
vec2 get_dm_scale(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint is = iqs_k / 8;
u8vec2 scale_dm;
if (is < 4) {
scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F);
} else {
scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2),
(data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2));
}
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm);
}
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
int32_t q_sum = 0;
const i32vec4 qs_a = repack4(ib_a, iqs * 4);
const vec2 dm_scale = get_dm_scale(ib_a, iqs * 4);
q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]);
q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]);
q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]);
q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]);
return FLOAT_TYPE(float(cache_b_ds.x) * float(dm_scale.x) * float(q_sum) - float(dm_scale.y) * float(cache_b_ds.y / 2));
}
#endif
#if defined(DATA_A_Q6_K)
// 2-byte loads for Q6_K blocks (210 bytes)
i32vec4 repack4(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint ql_idx = (iqs_k / 32) * 16 + iqs_k % 16;
const uint ql_shift = ((iqs_k % 32) / 16) * 4;
const uint qh_idx = (iqs_k / 32) * 8 + iqs;
const uint qh_shift = ((iqs_k % 32) / 8) * 2;
const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
const i8vec2 vals10 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 2] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 2] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
const i8vec2 vals11 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 3] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 3] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
const i8vec2 vals20 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 4] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 4] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
const i8vec2 vals21 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 5] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 5] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
const i8vec2 vals30 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 6] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 6] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
const i8vec2 vals31 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 7] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 7] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y)),
pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y)),
pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y)),
pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y)));
}
float get_d_scale(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
return float(data_a[ib_k].d) * float(data_a[ib_k].scales[iqs_k / 4]);
}
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
int32_t q_sum = 0;
const i32vec4 qs_a = repack4(ib_a, iqs * 4);
const float d_scale = get_d_scale(ib_a, iqs * 4);
q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]);
q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]);
q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]);
q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]);
return FLOAT_TYPE(float(cache_b_ds.x) * float(d_scale) * float(q_sum));
}
#endif

View File

@ -78,8 +78,6 @@ layout (constant_id = 10) const uint WARP = 32;
#define BK 32
#define MMQ_SHMEM
#include "mul_mmq_shmem_types.glsl"
#ifdef MUL_MAT_ID

View File

@ -9,31 +9,6 @@
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
// 2-byte loads for Q4_0 blocks (18 bytes)
// 4-byte loads for Q4_1 blocks (20 bytes)
i32vec2 repack(uint ib, uint iqs) {
#ifdef DATA_A_Q4_0
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
#else // DATA_A_Q4_1
const uint32_t vui = data_a_packed32[ib].qs[iqs];
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
#endif
}
#ifdef DATA_A_Q4_0
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
}
#else // DATA_A_Q4_1
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
}
#endif
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
#ifdef DATA_A_Q4_0
buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
@ -73,42 +48,17 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
q_sum += dotPacked4x8EXT(qs_a.y, qs_b1);
}
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
#ifdef DATA_A_Q4_0
return ACC_TYPE(float(cache_a[ib_a].dm) * (float(q_sum) * float(cache_b.ds.x) - 8.0 * float(cache_b.ds.y)));
#else // DATA_A_Q4_1
return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm.x) * float(cache_b.ds.x) + float(cache_a[ib_a].dm.y) * float(cache_b.ds.y));
#endif
}
#endif // MMQ_SHMEM
#endif
#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
#if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
// 2-byte loads for Q5_0 blocks (22 bytes)
// 4-byte loads for Q5_1 blocks (24 bytes)
i32vec2 repack(uint ib, uint iqs) {
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
#ifdef DATA_A_Q5_0
const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs));
#else // DATA_A_Q5_1
const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
#endif
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
return i32vec2(v0, v1);
}
#ifdef DATA_A_Q5_0
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
}
#else // DATA_A_Q5_1
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
}
#endif
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
#ifdef DATA_A_Q5_0
buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
@ -154,23 +104,16 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
q_sum += dotPacked4x8EXT(qs_a1, qs_b1);
}
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
#ifdef DATA_A_Q5_0
return ACC_TYPE(float(cache_a[ib_a].dm) * (float(q_sum) * float(cache_b.ds.x) - 16.0 * float(cache_b.ds.y)));
#else // DATA_A_Q5_1
return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm.x) * float(cache_b.ds.x) + float(cache_a[ib_a].dm.y) * float(cache_b.ds.y));
#endif
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q8_0)
// 2-byte loads for Q8_0 blocks (34 bytes)
int32_t repack(uint ib, uint iqs) {
return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]));
}
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(float(q_sum) * da * dsb.x);
}
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
buf_a[buf_ib].qs[iqs] = pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2],
data_a_packed16[ib].qs[iqs * 2 + 1]));
@ -197,28 +140,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
q_sum += dotPacked4x8EXT(qs_a, qs_b);
}
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm) * float(cache_b.ds.x));
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_MXFP4)
// 1-byte loads for mxfp4 blocks (17 bytes)
i32vec2 repack(uint ib, uint iqs) {
const uint32_t quants = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
data_a[ib].qs[iqs * 4 + 1],
data_a[ib].qs[iqs * 4 + 2],
data_a[ib].qs[iqs * 4 + 3]));
return i32vec2( quants & 0x0F0F0F0F,
(quants >> 4) & 0x0F0F0F0F);
}
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * dsb.x * float(q_sum));
}
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
data_a[ib].qs[iqs * 4 + 1],
@ -252,37 +179,14 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
}
return mul_q8_1(q_sum, cache_a[ib_a].d, cache_b.ds, 1);
return ACC_TYPE(float(cache_a[ib_a].d) * float(cache_b.ds.x) * float(q_sum));
}
#endif // MMQ_SHMEM
#endif
// For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide
// iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants
#if defined(DATA_A_Q2_K)
// 4-byte loads for Q2_K blocks (84 bytes)
int32_t repack(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
return int32_t((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x03030303);
}
uint8_t get_scale(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
return data_a[ib_k].scales[iqs_k / 4];
}
ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(dsb.x * (dma.x * float(sum_d) - dma.y * float(sum_m)));
}
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
@ -326,14 +230,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);
}
return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1);
return ACC_TYPE(float(cache_b.ds.x) * (float(cache_a[ib_a].dm.x) * float(sum_d) - float(cache_a[ib_a].dm.y) * float(sum_m)));
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q3_K)
// 2-byte loads for Q3_K blocks (110 bytes)
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_k = ib / 8;
const uint hm_idx = iqs * QUANT_R_MMQ;
@ -394,18 +296,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
}
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
return ACC_TYPE(cache_b.ds.x * result);
return ACC_TYPE(float(cache_b.ds.x) * result);
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(dsb.x * dma.x * float(q_sum) - dma.y * dsb.y);
}
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
@ -427,7 +323,6 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
(((data_a_packed32[ib_k].qh[qh_idx] >> qh_shift) & 0x01010101) << 4));
#endif
if (iqs == 0) {
// Scale index
const uint is = iqs_k / 8;
@ -464,49 +359,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
}
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1);
}
#endif // MMQ_SHMEM
#endif
#ifdef MMQ_SHMEM
void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bool is_in_bounds) {
if (is_in_bounds) {
const uint ib_outer = ib / 4;
const uint ib_inner = ib % 4;
if (iqs == 0) {
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
}
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
buf_b[buf_ib].qs[iqs * 4 ] = values.x;
buf_b[buf_ib].qs[iqs * 4 + 1] = values.y;
buf_b[buf_ib].qs[iqs * 4 + 2] = values.z;
buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
} else {
if (iqs == 0) {
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(0.0f);
}
buf_b[buf_ib].qs[iqs * 4 ] = 0;
buf_b[buf_ib].qs[iqs * 4 + 1] = 0;
buf_b[buf_ib].qs[iqs * 4 + 2] = 0;
buf_b[buf_ib].qs[iqs * 4 + 3] = 0;
}
}
void block_b_to_registers(const uint ib) {
cache_b.ds = buf_b[ib].ds;
[[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {
cache_b.qs[iqs] = buf_b[ib].qs[iqs];
}
return ACC_TYPE(float(cache_b.ds.x) * float(cache_a[ib_a].dm.x) * float(q_sum) - float(cache_a[ib_a].dm.y) * float(cache_b.ds.y));
}
#endif
#if defined(DATA_A_Q6_K)
// 2-byte loads for Q6_K blocks (210 bytes)
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
@ -558,32 +416,39 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
}
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
return ACC_TYPE(cache_b.ds.x * result);
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
FLOAT_TYPE get_d(uint ib) {
return FLOAT_TYPE(data_a[ib].d);
return ACC_TYPE(float(cache_b.ds.x) * result);
}
#endif
#if defined(DATA_A_MXFP4)
FLOAT_TYPE get_d(uint ib) {
return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e));
}
#endif
void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bool is_in_bounds) {
if (is_in_bounds) {
const uint ib_outer = ib / 4;
const uint ib_inner = ib % 4;
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
FLOAT_TYPE_VEC2 get_dm(uint ib) {
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
}
#endif
if (iqs == 0) {
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
}
#if defined(DATA_A_Q2_K)
FLOAT_TYPE_VEC2 get_dm(uint ib) {
const uint ib_k = ib / 8;
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
buf_b[buf_ib].qs[iqs * 4 ] = values.x;
buf_b[buf_ib].qs[iqs * 4 + 1] = values.y;
buf_b[buf_ib].qs[iqs * 4 + 2] = values.z;
buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
} else {
if (iqs == 0) {
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(0.0f);
}
buf_b[buf_ib].qs[iqs * 4 ] = 0;
buf_b[buf_ib].qs[iqs * 4 + 1] = 0;
buf_b[buf_ib].qs[iqs * 4 + 2] = 0;
buf_b[buf_ib].qs[iqs * 4 + 3] = 0;
}
}
void block_b_to_registers(const uint ib) {
cache_b.ds = buf_b[ib].ds;
[[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {
cache_b.qs[iqs] = buf_b[ib].qs[iqs];
}
}
#endif

View File

@ -0,0 +1,72 @@
#version 450
#include "types.glsl"
#include "generic_binary_head.glsl"
layout (constant_id = 1) const uint N = 64;
layout (constant_id = 2) const uint K = 32;
layout(local_size_x = 128, local_size_y = 1, local_size_z = 1) in;
uint a_base, b_base, x_base;
FLOAT_TYPE get_a(uint r, uint c) {
return FLOAT_TYPE(data_a[a_base + r * p.nb01 + c * p.nb00]);
}
FLOAT_TYPE get_b(uint r, uint c) {
return FLOAT_TYPE(data_b[b_base + r * p.nb11 + c * p.nb10]);
}
void store_x(uint r, uint c, FLOAT_TYPE v) {
data_d[x_base + r * p.nb21 + c * p.nb20] = D_TYPE(v);
}
shared FLOAT_TYPE shA[N * N];
shared FLOAT_TYPE shB[N * K];
void main() {
const uint batch = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
if (batch >= p.ne02 * p.ne03) {
return;
}
const uint i3 = batch / p.ne22;
const uint i2 = batch % p.ne22;
a_base = get_aoffset() + i2 * p.nb02 + i3 * p.nb03;
b_base = get_boffset() + i2 * p.nb12 + i3 * p.nb13;
x_base = get_doffset() + i2 * p.nb22 + i3 * p.nb23;
// Load the A matrix into shA
[[unroll]] for (uint i = 0; i < N * N; i += gl_WorkGroupSize.x) {
uint idx = i + tid;
if (((N * N) % gl_WorkGroupSize.x == 0) || idx < N * N) {
shA[idx] = get_a(idx / N, idx % N);
}
}
// Load the B matrix into shB
[[unroll]] for (uint i = 0; i < N * K; i += gl_WorkGroupSize.x) {
uint idx = i + tid;
if (((N * K) % gl_WorkGroupSize.x == 0) || idx < N * K) {
shB[idx] = get_b(idx / K, idx % K);
}
}
barrier();
FLOAT_TYPE X[N];
// Each thread solves one column
if (tid < K) {
[[unroll]] for (int r = 0; r < N; ++r) {
FLOAT_TYPE b = shB[r * K + tid];
// Compute x[r,c] = (b[r,c] - sum(a[r,c]*x[c])) / a[r,r]
[[unroll]] for (int c = 0; c < r; ++c) {
b -= shA[r * N + c] * X[c];
}
FLOAT_TYPE x = b / shA[r * N + r];
X[r] = x;
store_x(r, tid, x);
}
}
}

View File

@ -0,0 +1,43 @@
#version 450
#include "rte.glsl"
#include "types.glsl"
#include "generic_unary_head.glsl"
#define GGML_TRI_TYPE_UPPER_DIAG 0
#define GGML_TRI_TYPE_UPPER 1
#define GGML_TRI_TYPE_LOWER_DIAG 2
#define GGML_TRI_TYPE_LOWER 3
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
void main() {
const uint idx = get_idx();
if (idx >= p.ne) {
return;
}
const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L);
const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00;
const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L);
const uint i02_offset = i02*p.ne01*p.ne00;
const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L);
const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00;
int param = floatBitsToInt(p.param1);
bool pass = false;
switch (param) {
case GGML_TRI_TYPE_UPPER_DIAG: pass = i00 >= i01; break;
case GGML_TRI_TYPE_UPPER: pass = i00 > i01; break;
case GGML_TRI_TYPE_LOWER_DIAG: pass = i00 <= i01; break;
case GGML_TRI_TYPE_LOWER: pass = i00 < i01; break;
}
if (pass) {
const float val = float(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val);
} else {
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(0);
}
}

View File

@ -679,14 +679,20 @@ void process_shaders() {
string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
// mul mat vec with integer dot product
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (is_legacy_quant(tname)) {
if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname)) {
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
}
#endif
@ -846,6 +852,9 @@ void process_shaders() {
string_to_spv("abs_f16", "abs.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("abs_f32", "abs.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("tri_f16", "tri.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("tri_f32", "tri.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
string_to_spv("softplus_f16", "softplus.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}});
string_to_spv("softplus_f32", "softplus.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
@ -944,6 +953,8 @@ void process_shaders() {
string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}}));
string_to_spv("solve_tri_f32", "solve_tri.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}));
for (auto transpose : {false, true}) {
for (auto unroll : {false, true}) {
for (auto a_f16 : {false, true}) {
@ -1095,7 +1106,7 @@ void write_output_files() {
for (const std::string& btype : btypes) {
for (const auto& tname : type_names) {
if (btype == "q8_1" && !is_legacy_quant(tname)) {
if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && !is_k_quant(tname)) {
continue;
}
hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n";
@ -1104,6 +1115,16 @@ void write_output_files() {
src << "const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n";
src << "const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n";
}
if (btype == "f16") {
continue;
}
hdr << "extern const void * arr_dmmv_id_" << tname << "_" << btype << "_f32_data[3];\n";
hdr << "extern const uint64_t arr_dmmv_id_" << tname << "_" << btype << "_f32_len[3];\n";
if (basename(input_filepath) == "mul_mat_vec.comp") {
src << "const void * arr_dmmv_id_" << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_id_" << tname << "_" << btype << "_f32_data, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n";
src << "const uint64_t arr_dmmv_id_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_id_" << tname << "_" << btype << "_f32_len, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n";
}
}
}

View File

@ -366,6 +366,7 @@ class MODEL_ARCH(IntEnum):
QWEN2VL = auto()
QWEN3 = auto()
QWEN3MOE = auto()
QWEN3NEXT = auto()
QWEN3VL = auto()
QWEN3VLMOE = auto()
PHI2 = auto()
@ -531,6 +532,7 @@ class MODEL_TENSOR(IntEnum):
SSM_D = auto()
SSM_NORM = auto()
SSM_OUT = auto()
SSM_BETA_ALPHA = auto() # qwen3next
TIME_MIX_W0 = auto()
TIME_MIX_W1 = auto()
TIME_MIX_W2 = auto()
@ -736,6 +738,7 @@ MODEL_ARCH_NAMES: dict[MODEL_ARCH, str] = {
MODEL_ARCH.QWEN2VL: "qwen2vl",
MODEL_ARCH.QWEN3: "qwen3",
MODEL_ARCH.QWEN3MOE: "qwen3moe",
MODEL_ARCH.QWEN3NEXT: "qwen3next",
MODEL_ARCH.QWEN3VL: "qwen3vl",
MODEL_ARCH.QWEN3VLMOE: "qwen3vlmoe",
MODEL_ARCH.PHI2: "phi2",
@ -900,6 +903,7 @@ TENSOR_NAMES: dict[MODEL_TENSOR, str] = {
MODEL_TENSOR.SSM_D: "blk.{bid}.ssm_d",
MODEL_TENSOR.SSM_NORM: "blk.{bid}.ssm_norm",
MODEL_TENSOR.SSM_OUT: "blk.{bid}.ssm_out",
MODEL_TENSOR.SSM_BETA_ALPHA: "blk.{bid}.ssm_ba",
MODEL_TENSOR.TIME_MIX_W0: "blk.{bid}.time_mix_w0",
MODEL_TENSOR.TIME_MIX_W1: "blk.{bid}.time_mix_w1",
MODEL_TENSOR.TIME_MIX_W2: "blk.{bid}.time_mix_w2",
@ -1569,6 +1573,35 @@ MODEL_TENSORS: dict[MODEL_ARCH, list[MODEL_TENSOR]] = {
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
],
MODEL_ARCH.QWEN3NEXT: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,
MODEL_TENSOR.OUTPUT,
MODEL_TENSOR.ATTN_NORM,
MODEL_TENSOR.ATTN_Q,
MODEL_TENSOR.ATTN_Q_NORM,
MODEL_TENSOR.ATTN_K,
MODEL_TENSOR.ATTN_K_NORM,
MODEL_TENSOR.ATTN_V,
MODEL_TENSOR.ATTN_OUT,
MODEL_TENSOR.ATTN_POST_NORM,
MODEL_TENSOR.ATTN_GATE,
MODEL_TENSOR.FFN_GATE_INP,
MODEL_TENSOR.FFN_GATE_INP_SHEXP,
MODEL_TENSOR.FFN_UP_SHEXP,
MODEL_TENSOR.FFN_DOWN_SHEXP,
MODEL_TENSOR.FFN_GATE_SHEXP,
MODEL_TENSOR.FFN_DOWN_EXP,
MODEL_TENSOR.FFN_UP_EXP,
MODEL_TENSOR.FFN_GATE_EXP,
MODEL_TENSOR.SSM_A,
MODEL_TENSOR.SSM_CONV1D,
MODEL_TENSOR.SSM_DT,
MODEL_TENSOR.SSM_NORM,
MODEL_TENSOR.SSM_IN,
MODEL_TENSOR.SSM_BETA_ALPHA,
MODEL_TENSOR.SSM_OUT
],
MODEL_ARCH.QWEN3VL: [
MODEL_TENSOR.TOKEN_EMBD,
MODEL_TENSOR.OUTPUT_NORM,

View File

@ -371,10 +371,13 @@ class GGUFWriter:
def add_tensor(
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
raw_dtype: GGMLQuantizationType | None = None,
raw_dtype: GGMLQuantizationType | None = None, tensor_endianess: GGUFEndian | None = None
) -> None:
if (self.endianess == GGUFEndian.BIG and sys.byteorder != 'big') or \
(self.endianess == GGUFEndian.LITTLE and sys.byteorder != 'little'):
# if tensor endianness is not passed, assume it's native to system
if tensor_endianess is None:
tensor_endianess = GGUFEndian.BIG if sys.byteorder == 'big' else GGUFEndian.LITTLE
if tensor_endianess != self.endianess:
# Don't byteswap inplace since lazy copies cannot handle it
tensor = tensor.byteswap(inplace=False)
if self.use_temp_file and self.temp_file is None:
@ -397,13 +400,16 @@ class GGUFWriter:
if pad != 0:
fp.write(bytes([0] * pad))
def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None:
def write_tensor_data(self, tensor: np.ndarray[Any, Any], tensor_endianess: GGUFEndian | None = None) -> None:
if self.state is not WriterState.TI_DATA and self.state is not WriterState.WEIGHTS:
raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}')
assert self.fout is not None
if (self.endianess == GGUFEndian.BIG and sys.byteorder != 'big') or \
(self.endianess == GGUFEndian.LITTLE and sys.byteorder != 'little'):
# if tensor endianness is not passed, assume it's native to system
if tensor_endianess is None:
tensor_endianess = GGUFEndian.BIG if sys.byteorder == 'big' else GGUFEndian.LITTLE
if tensor_endianess != self.endianess:
# Don't byteswap inplace since lazy copies cannot handle it
tensor = tensor.byteswap(inplace=False)

View File

@ -19,6 +19,11 @@ import gguf
logger = logging.getLogger("gguf-convert-endian")
def byteswap_noop(tensor, block_offs):
# this function is used when byteswapping is not needed
pass
def byteswap_q4_0(tensor, block_offs):
# Each block_q4_0 consists of an f16 delta (scaling factor) followed by 16 int8 quantizations.
@ -55,22 +60,11 @@ def byteswap_q6_k(tensor, block_offs):
byteswap_tensors = {
gguf.GGMLQuantizationType.Q4_0: {
"block_size": 18, # 18 bytes = <f16 delta scaling factor> + 16 * <int8 quant>
"byteswap_func": byteswap_q4_0,
},
gguf.GGMLQuantizationType.Q8_0: {
"block_size": 34, # 34 bytes = <f16 delta scaling factor> + 32 * <int8 quant>
"byteswap_func": byteswap_q8_0,
},
gguf.GGMLQuantizationType.Q4_K: {
"block_size": 144, # 144 bytes = 2 * <f16 delta scaling factor> + 140 * <int8 quant>
"byteswap_func": byteswap_q4_k,
},
gguf.GGMLQuantizationType.Q6_K: {
"block_size": 210, # 210 bytes = <f16 delta scaling factor> + 208 * <int8 quant>
"byteswap_func": byteswap_q6_k,
},
gguf.GGMLQuantizationType.Q4_0: byteswap_q4_0,
gguf.GGMLQuantizationType.Q8_0: byteswap_q8_0,
gguf.GGMLQuantizationType.Q4_K: byteswap_q4_k,
gguf.GGMLQuantizationType.Q6_K: byteswap_q6_k,
gguf.GGMLQuantizationType.MXFP4: byteswap_noop,
}
@ -135,8 +129,8 @@ def convert_byteorder(reader: gguf.GGUFReader, args: argparse.Namespace) -> None
tensor.data.resize(newshape)
block_size = byteswap_tensors[tensor.tensor_type]["block_size"]
byteswap_func = byteswap_tensors[tensor.tensor_type]["byteswap_func"]
block_size = gguf.constants.GGML_QUANT_SIZES[tensor.tensor_type][1]
byteswap_func = byteswap_tensors[tensor.tensor_type]
n_blocks = len(tensor.data) // block_size
for block_num in (inner_pbar := tqdm(range(n_blocks), desc="Byte-swapping Blocks", leave=False)):

View File

@ -1552,7 +1552,7 @@ class GGUFEditorWindow(QMainWindow):
# Add tensors (including data)
for tensor in self.reader.tensors:
writer.add_tensor(tensor.name, tensor.data, raw_shape=tensor.data.shape, raw_dtype=tensor.tensor_type)
writer.add_tensor(tensor.name, tensor.data, raw_shape=tensor.data.shape, raw_dtype=tensor.tensor_type, tensor_endianess=self.reader.endianess)
# Write header and metadata
writer.open_output_file(Path(file_path))

View File

@ -94,7 +94,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new
writer.write_ti_data_to_file()
for tensor in reader.tensors:
writer.write_tensor_data(tensor.data)
writer.write_tensor_data(tensor.data, tensor_endianess=reader.endianess)
bar.update(tensor.n_bytes)
writer.close()

View File

@ -672,10 +672,11 @@ class TensorNameMap:
),
MODEL_TENSOR.SSM_IN: (
"model.layers.{bid}.in_proj", # mamba-hf
"backbone.layers.{bid}.mixer.in_proj", # mamba
"model.layers.{bid}.mamba.in_proj", # jamba falcon-h1 granite-hybrid
"model.layers.layers.{bid}.mixer.in_proj", # plamo2
"model.layers.{bid}.in_proj", # mamba-hf
"backbone.layers.{bid}.mixer.in_proj", # mamba
"model.layers.{bid}.mamba.in_proj", # jamba falcon-h1 granite-hybrid
"model.layers.layers.{bid}.mixer.in_proj", # plamo2
"model.layers.{bid}.linear_attn.in_proj_qkvz", # qwen3next
),
MODEL_TENSOR.SSM_CONV1D: (
@ -683,6 +684,7 @@ class TensorNameMap:
"backbone.layers.{bid}.mixer.conv1d", # mamba
"model.layers.{bid}.mamba.conv1d", # jamba falcon-h1 granite-hybrid
"model.layers.layers.{bid}.mixer.conv1d", # plamo2
"model.layers.{bid}.linear_attn.conv1d", # qwen3next
),
MODEL_TENSOR.SSM_X: (
@ -697,6 +699,7 @@ class TensorNameMap:
"backbone.layers.{bid}.mixer.dt_proj", # mamba
"model.layers.{bid}.mamba.dt_proj", # jamba falcon-h1 granite-hybrid
"model.layers.layers.{bid}.mixer.dt_proj", # plamo2
"model.layers.{bid}.linear_attn.dt_proj", # qwen3next
),
MODEL_TENSOR.SSM_DT_NORM: (
@ -709,6 +712,7 @@ class TensorNameMap:
"backbone.layers.{bid}.mixer.A_log", # mamba
"model.layers.{bid}.mamba.A_log", # jamba falcon-h1 granite-hybrid
"model.layers.layers.{bid}.mixer.A_log", # plamo2
"model.layers.{bid}.linear_attn.A_log", # qwen3next
),
MODEL_TENSOR.SSM_B_NORM: (
@ -731,17 +735,23 @@ class TensorNameMap:
),
MODEL_TENSOR.SSM_NORM: (
"model.layers.{bid}.mamba.norm", # falcon-h1 granite-hybrid
"backbone.layers.{bid}.mixer.norm", # mamba2
"model.layers.{bid}.mamba.norm", # falcon-h1 granite-hybrid
"model.layers.{bid}.linear_attn.norm", # qwen3next
"backbone.layers.{bid}.mixer.norm", # mamba2
),
MODEL_TENSOR.SSM_OUT: (
"model.layers.{bid}.out_proj", # mamba-hf
"backbone.layers.{bid}.mixer.out_proj", # mamba
"model.layers.{bid}.mamba.out_proj", # jamba falcon-h1 granite-hybrid
"model.layers.{bid}.linear_attn.out_proj", # qwen3next
"model.layers.layers.{bid}.mixer.out_proj", # plamo2
),
MODEL_TENSOR.SSM_BETA_ALPHA: (
"model.layers.{bid}.linear_attn.in_proj_ba", # qwen3next
),
MODEL_TENSOR.TIME_MIX_W0: (
"model.layers.{bid}.attention.w0", # rwkv7
),

View File

@ -114,6 +114,7 @@ add_library(llama
models/qwen3vl.cpp
models/qwen3vl-moe.cpp
models/qwen3moe.cpp
models/qwen3next.cpp
models/refact.cpp
models/rnd1.cpp
models/rwkv6-base.cpp

View File

@ -32,6 +32,7 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
{ LLM_ARCH_QWEN2VL, "qwen2vl" },
{ LLM_ARCH_QWEN3, "qwen3" },
{ LLM_ARCH_QWEN3MOE, "qwen3moe" },
{ LLM_ARCH_QWEN3NEXT, "qwen3next" },
{ LLM_ARCH_QWEN3VL, "qwen3vl" },
{ LLM_ARCH_QWEN3VLMOE, "qwen3vlmoe" },
{ LLM_ARCH_PHI2, "phi2" },
@ -829,6 +830,38 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
},
},
{
LLM_ARCH_QWEN3NEXT,
{
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_OUTPUT_NORM, "output_norm" },
{ LLM_TENSOR_OUTPUT, "output" },
{ LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" },
{ LLM_TENSOR_ATTN_POST_NORM, "blk.%d.post_attention_norm" },
{ LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" },
{ LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" },
{ LLM_TENSOR_ATTN_K, "blk.%d.attn_k" },
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
{ LLM_TENSOR_ATTN_V, "blk.%d.attn_v" },
{ LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" },
{ LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" },
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
{ LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" },
{ LLM_TENSOR_FFN_GATE_INP_SHEXP, "blk.%d.ffn_gate_inp_shexp" },
{ LLM_TENSOR_FFN_GATE_SHEXP, "blk.%d.ffn_gate_shexp" },
{ LLM_TENSOR_FFN_DOWN_SHEXP, "blk.%d.ffn_down_shexp" },
{ LLM_TENSOR_FFN_UP_SHEXP, "blk.%d.ffn_up_shexp" },
{ LLM_TENSOR_SSM_A, "blk.%d.ssm_a" },
{ LLM_TENSOR_SSM_CONV1D, "blk.%d.ssm_conv1d" },
{ LLM_TENSOR_SSM_DT, "blk.%d.ssm_dt" },
{ LLM_TENSOR_SSM_BETA_ALPHA, "blk.%d.ssm_ba" },
{ LLM_TENSOR_SSM_IN, "blk.%d.ssm_in" },
{ LLM_TENSOR_SSM_NORM, "blk.%d.ssm_norm" },
{ LLM_TENSOR_SSM_OUT, "blk.%d.ssm_out" },
},
},
{
LLM_ARCH_QWEN3VL,
{
@ -2237,7 +2270,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" },
{ LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" },
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_OUTPUT_NORM, "token_embd_norm" }, // note: wrong tensor name
{ LLM_TENSOR_OUTPUT, "output" },
}
},
@ -2259,7 +2292,7 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
{ LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" },
{ LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" },
{ LLM_TENSOR_TOKEN_EMBD, "token_embd" },
{ LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" },
{ LLM_TENSOR_OUTPUT_NORM, "token_embd_norm" }, // note: wrong tensor name
{ LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" },
{ LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" },
{ LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" },
@ -2487,11 +2520,21 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
},
};
// declare information about the model weight tensors:
// - the layer in which the tensor is going to be used. this is needed in order to assign the correct buffer type for the weight
// - the operator which is going to use the weight. this is needed to determine if the respective backend supports the operator
//
// for example, input layers are usually assigned to CPU/host buffer types
//
// a mismatch between the declared information and the actual layer/op in which the tensor is used can lead to sub-optimal
// assignment of the buffer types and extra overhead during computation
// example: https://github.com/ggml-org/llama.cpp/pull/17548
//
static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_TOKEN_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_POS_EMBD, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_TOKEN_TYPES, {LLM_TENSOR_LAYER_INPUT, GGML_OP_GET_ROWS}},
{LLM_TENSOR_TOKEN_EMBD_NORM, {LLM_TENSOR_LAYER_INPUT, GGML_OP_MUL}},
{LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
{LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}},
@ -2546,6 +2589,7 @@ static const std::map<llm_tensor, llm_tensor_info> LLM_TENSOR_INFOS = {
{LLM_TENSOR_SSM_X, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_SSM_DT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_SSM_OUT, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_SSM_BETA_ALPHA, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_W1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_W2, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
{LLM_TENSOR_TIME_MIX_A1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL_MAT}},
@ -2744,6 +2788,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) {
case LLM_ARCH_LFM2:
case LLM_ARCH_LFM2MOE:
case LLM_ARCH_NEMOTRON_H:
case LLM_ARCH_QWEN3NEXT:
return true;
default:
return false;

View File

@ -36,6 +36,7 @@ enum llm_arch {
LLM_ARCH_QWEN2VL,
LLM_ARCH_QWEN3,
LLM_ARCH_QWEN3MOE,
LLM_ARCH_QWEN3NEXT,
LLM_ARCH_QWEN3VL,
LLM_ARCH_QWEN3VLMOE,
LLM_ARCH_PHI2,
@ -381,6 +382,7 @@ enum llm_tensor {
LLM_TENSOR_SSM_D,
LLM_TENSOR_SSM_NORM,
LLM_TENSOR_SSM_OUT,
LLM_TENSOR_SSM_BETA_ALPHA, // qwen3next
LLM_TENSOR_TIME_MIX_W0,
LLM_TENSOR_TIME_MIX_W1,
LLM_TENSOR_TIME_MIX_W2,

View File

@ -1,5 +1,6 @@
#include "llama-context.h"
#include "llama-arch.h"
#include "llama-impl.h"
#include "llama-batch.h"
#include "llama-io.h"
@ -322,7 +323,7 @@ llama_context::llama_context(
cross.v_embd.clear();
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
// avoid reserving graphs with zero outputs - assume one output per sequence
@ -575,7 +576,7 @@ bool llama_context::memory_update(bool optimize) {
throw std::runtime_error("failed to initialize memory context");
}
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max;
const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());
@ -1849,6 +1850,9 @@ void llama_context::output_reorder() {
//
uint32_t llama_context::graph_max_nodes() const {
if (model.arch == LLM_ARCH_QWEN3NEXT) {
return std::max<uint32_t>(8192u, 32u*model.n_tensors());
}
return std::max<uint32_t>(1024u, 8u*model.n_tensors());
}

View File

@ -6,7 +6,7 @@
// bump if necessary
#define LLAMA_MAX_LAYERS 512
#define LLAMA_MAX_EXPERTS 384 // Kimi-K2
#define LLAMA_MAX_EXPERTS 512 // Qwen3 Next
enum llama_expert_gating_func_type {
LLAMA_EXPERT_GATING_FUNC_TYPE_NONE = 0,

View File

@ -2,7 +2,6 @@
#include "llama-impl.h"
#include "llama-mmap.h"
#include "llama-batch.h"
#include "llama-cparams.h"
#include "llama-model-loader.h"
@ -2225,6 +2224,29 @@ void llama_model::load_hparams(llama_model_loader & ml) {
default: type = LLM_TYPE_UNKNOWN;
}
} break;
case LLM_ARCH_QWEN3NEXT:
{
ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp, false);
ml.get_key(LLM_KV_EXPERT_SHARED_FEED_FORWARD_LENGTH, hparams.n_ff_shexp, false);
ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps);
// Load linear attention (gated delta net) parameters
ml.get_key(LLM_KV_SSM_CONV_KERNEL, hparams.ssm_d_conv);
ml.get_key(LLM_KV_SSM_INNER_SIZE, hparams.ssm_d_inner);
ml.get_key(LLM_KV_SSM_STATE_SIZE, hparams.ssm_d_state);
ml.get_key(LLM_KV_SSM_TIME_STEP_RANK, hparams.ssm_dt_rank);
ml.get_key(LLM_KV_SSM_GROUP_COUNT, hparams.ssm_n_group);
// Mark recurrent layers (linear attention layers)
for (uint32_t i = 0; i < hparams.n_layer; ++i) {
hparams.recurrent_layer_arr[i] = ((i + 1) % 4 != 0); // TODO: extract the magic 4 from "full_attention_interval"
}
switch (hparams.n_layer) {
case 80: type = LLM_TYPE_80B_A3B; break;
default: type = LLM_TYPE_UNKNOWN;
}
} break;
default: throw std::runtime_error("unsupported model architecture");
}
@ -6133,9 +6155,10 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
case LLM_ARCH_LFM2:
case LLM_ARCH_LFM2MOE:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0);
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), {n_embd}, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), {n_embd, n_vocab}, TENSOR_NOT_REQUIRED);
if (output == NULL) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED);
@ -6414,6 +6437,74 @@ bool llama_model::load_tensors(llama_model_loader & ml) {
layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0);
}
} break;
case LLM_ARCH_QWEN3NEXT:
{
tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0);
// output
output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0);
output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, TENSOR_NOT_REQUIRED);
// if output is NULL, init from the input tok embed
if (output == NULL) {
output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, TENSOR_DUPLICATED);
}
const int64_t n_ff_exp = hparams.n_ff_exp ? hparams.n_ff_exp : n_ff / n_expert_used;
// Calculate dimensions from hyperparameters
const int64_t head_k_dim = hparams.ssm_d_state;
const int64_t head_v_dim = hparams.ssm_d_state;
const int64_t n_k_heads = hparams.ssm_n_group;
const int64_t n_v_heads = hparams.ssm_dt_rank;
const int64_t key_dim = head_k_dim * n_k_heads;
const int64_t value_dim = head_v_dim * n_v_heads;
const int64_t conv_dim = key_dim * 2 + value_dim;
// Calculate projection sizes
const int64_t qkvz_dim = key_dim * 2 + value_dim * 2;
const int64_t ba_dim = n_v_heads * 2;
for (int i = 0; i < n_layer; ++i) {
auto & layer = layers[i];
layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0);
layer.attn_post_norm = create_tensor(tn(LLM_TENSOR_ATTN_POST_NORM, "weight", i), { n_embd }, 0);
if (!hparams.is_recurrent(i)) {
// Attention layers
layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head * 2 }, 0);
layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_k_gqa }, 0);
layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_v_gqa }, 0);
layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0);
// Q/K normalization for attention layers
layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0);
layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0);
} else {
// Linear attention (gated delta net) specific tensors
// Create tensors with calculated dimensions
layer.ssm_in = create_tensor(tn(LLM_TENSOR_SSM_IN, "weight", i), { n_embd, qkvz_dim }, 0);
layer.ssm_conv1d = create_tensor(tn(LLM_TENSOR_SSM_CONV1D, "weight", i), { hparams.ssm_d_conv, conv_dim }, 0);
layer.ssm_dt = create_tensor(tn(LLM_TENSOR_SSM_DT, "bias", i), { hparams.ssm_dt_rank }, 0);
layer.ssm_a = create_tensor(tn(LLM_TENSOR_SSM_A, i), { hparams.ssm_dt_rank }, 0);
layer.ssm_beta_alpha = create_tensor(tn(LLM_TENSOR_SSM_BETA_ALPHA, "weight", i), { n_embd, ba_dim }, 0);
layer.ssm_norm = create_tensor(tn(LLM_TENSOR_SSM_NORM, "weight", i), { head_v_dim }, 0);
layer.ssm_out = create_tensor(tn(LLM_TENSOR_SSM_OUT, "weight", i), { value_dim, n_embd }, 0);
}
layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), { n_embd, n_expert }, 0);
layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), { n_ff_exp, n_embd, n_expert }, 0);
layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), { n_embd, n_ff_exp, n_expert }, 0);
// Shared experts
layer.ffn_gate_inp_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP_SHEXP, "weight", i), { n_embd }, 0);
layer.ffn_gate_shexp = create_tensor(tn(LLM_TENSOR_FFN_GATE_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp }, 0);
layer.ffn_up_shexp = create_tensor(tn(LLM_TENSOR_FFN_UP_SHEXP, "weight", i), { n_embd, hparams.n_ff_shexp }, 0);
layer.ffn_down_shexp = create_tensor(tn(LLM_TENSOR_FFN_DOWN_SHEXP, "weight", i), { hparams.n_ff_shexp, n_embd }, 0);
}
} break;
default:
throw std::runtime_error("unknown architecture");
}
@ -6684,6 +6775,7 @@ void llama_model::print_info() const {
arch == LLM_ARCH_FALCON_H1 ||
arch == LLM_ARCH_PLAMO2 ||
arch == LLM_ARCH_GRANITE_HYBRID ||
arch == LLM_ARCH_QWEN3NEXT ||
arch == LLM_ARCH_NEMOTRON_H) {
LLAMA_LOG_INFO("%s: ssm_d_conv = %u\n", __func__, hparams.ssm_d_conv);
LLAMA_LOG_INFO("%s: ssm_d_inner = %u\n", __func__, hparams.ssm_d_inner);
@ -7425,7 +7517,11 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const {
case LLM_ARCH_PANGU_EMBED:
{
llm = std::make_unique<llm_build_pangu_embedded>(*this, params);
}break;
} break;
case LLM_ARCH_QWEN3NEXT:
{
llm = std::make_unique<llm_build_qwen3next>(*this, params);
} break;
default:
GGML_ABORT("fatal error");
}
@ -7655,6 +7751,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) {
case LLM_ARCH_COGVLM:
case LLM_ARCH_PANGU_EMBED:
case LLM_ARCH_AFMOE:
case LLM_ARCH_QWEN3NEXT:
return LLAMA_ROPE_TYPE_NEOX;
case LLM_ARCH_QWEN2VL:

View File

@ -113,6 +113,7 @@ enum llm_type {
LLM_TYPE_16B_A1B,
LLM_TYPE_21B_A3B, // Ernie MoE small
LLM_TYPE_30B_A3B,
LLM_TYPE_80B_A3B, // Qwen3 Next
LLM_TYPE_100B_A6B,
LLM_TYPE_106B_A12B, // GLM-4.5-Air
LLM_TYPE_230B_A10B, // Minimax M2
@ -309,6 +310,9 @@ struct llama_layer {
struct ggml_tensor * ssm_conv1d_b = nullptr;
struct ggml_tensor * ssm_dt_b = nullptr;
// qwen3next
struct ggml_tensor * ssm_beta_alpha = nullptr;
// rwkv
struct ggml_tensor * time_mix_w1 = nullptr;
struct ggml_tensor * time_mix_w2 = nullptr;

View File

@ -681,7 +681,9 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
}
LLAMA_LOG_DEBUG("%s: pruning tensor %s\n", __func__, it.first.c_str());
continue;
} else if (remapped_name != it.first) {
}
if (remapped_name != it.first) {
ggml_set_name(it.second.tensor, remapped_name.c_str());
LLAMA_LOG_DEBUG("%s: tensor %s remapped to %s\n", __func__, it.first.c_str(), ggml_get_name(it.second.tensor));
}
@ -726,13 +728,19 @@ static void llama_model_quantize_impl(const std::string & fname_inp, const std::
{
const auto & n_head_kv_iter = model.hparams.n_head_kv_arr.begin();
// attention layers have a non-zero number of kv heads
int32_t n_attn_layer = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0);
int32_t n_layer_attn = model.hparams.n_layer - std::count(n_head_kv_iter, n_head_kv_iter + model.hparams.n_layer, 0);
if (llama_model_has_encoder(&model)) {
// now n_attn_layer is the number of attention layers in the encoder
// now n_layer_attn is the number of attention layers in the encoder
// for each decoder block, there are 2 attention layers
n_attn_layer += 2 * model.hparams.dec_n_layer;
n_layer_attn += 2 * model.hparams.dec_n_layer;
}
GGML_ASSERT((qs.n_attention_wv == n_attn_layer - pruned_attention_w) && "n_attention_wv is unexpected");
// note: for linear-attention models (such as Qwen3 Next) this is the number of linear layers
const int32_t n_layer_recr = std::count(model.hparams.recurrent_layer_arr.begin(), model.hparams.recurrent_layer_arr.end(), true);
LLAMA_LOG_INFO("%s: n_layer_attn = %d, n_layer_recr = %d, pruned_attention_w = %d\n", __func__, n_layer_attn, n_layer_recr, pruned_attention_w);
GGML_ASSERT((qs.n_attention_wv == n_layer_attn - pruned_attention_w - n_layer_recr) && "n_attention_wv is unexpected");
}
size_t total_size_org = 0;

View File

@ -9,6 +9,8 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params
ggml_tensor * cur = build_inp_embd(model.tok_embd);
cb(cur, "model.embed_tokens", -1);
ggml_build_forward_expand(gf, cur);
ggml_tensor * inp_pos = build_inp_pos();
auto * inp_hybrid = build_inp_mem_hybrid();
ggml_tensor * inp_out_ids = build_inp_out_ids();
@ -40,12 +42,12 @@ llm_build_lfm2::llm_build_lfm2(const llama_model & model, const llm_graph_params
cur = ggml_add(ctx0, cur, ffn_out);
}
cur = build_norm(cur, model.tok_norm, NULL, LLM_NORM_RMS, -1);
cb(cur, "model.embedding_norm", -1);
cur = build_norm(cur, model.output_norm, NULL, LLM_NORM_RMS, -1);
cb(cur, "result_norm", -1);
res->t_embd = cur;
cur = build_lora_mm(model.output, cur);
cb(cur, "lm_head", -1);
cb(cur, "result_output", -1);
res->t_logits = cur;

View File

@ -2,8 +2,9 @@
#include "../llama-model.h"
#include "../llama-graph.h"
#include "../llama-memory-recurrent.h"
// TODO: remove in follow-up PR - move to .cpp files
#include "../llama-memory-recurrent.h"
#include <cmath>
struct llm_graph_context_mamba : public llm_graph_context {
@ -421,7 +422,56 @@ struct llm_build_qwen3vl : public llm_graph_context {
struct llm_build_qwen3vlmoe : public llm_graph_context {
llm_build_qwen3vlmoe(const llama_model & model, const llm_graph_params & params);
};
struct llm_build_qwen3next : public llm_graph_context_mamba {
llm_build_qwen3next(const llama_model & model, const llm_graph_params & params);
private:
ggml_tensor * build_layer_attn(
llm_graph_input_attn_kv * inp_attn,
ggml_tensor * cur,
ggml_tensor * inp_pos,
int il);
ggml_tensor * build_layer_attn_linear(
llm_graph_input_rs * inp,
ggml_tensor * cur,
ggml_tensor * causal_mask,
ggml_tensor * identity,
int il);
ggml_tensor * build_layer_ffn(
ggml_tensor * cur,
int il);
ggml_tensor * build_delta_net_recurrent(
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * v,
ggml_tensor * g,
ggml_tensor * beta,
ggml_tensor * state,
ggml_tensor * causal_mask,
ggml_tensor * identity,
int il);
ggml_tensor * build_delta_net_chunking(
ggml_tensor * q,
ggml_tensor * k,
ggml_tensor * v,
ggml_tensor * g,
ggml_tensor * beta,
ggml_tensor * state,
ggml_tensor * causal_mask,
ggml_tensor * identity,
int il);
ggml_tensor * build_norm_gated(
ggml_tensor * input,
ggml_tensor * weights,
ggml_tensor * gate,
int layer);
const llama_model & model;
};
struct llm_build_qwen : public llm_graph_context {
llm_build_qwen(const llama_model & model, const llm_graph_params & params);

1042
src/models/qwen3next.cpp Normal file

File diff suppressed because it is too large Load Diff

View File

@ -196,7 +196,7 @@ if (NOT WIN32)
llama_build_and_test(test-arg-parser.cpp)
endif()
if (NOT LLAMA_SANITIZE_ADDRESS)
if (NOT LLAMA_SANITIZE_ADDRESS AND NOT GGML_SCHED_NO_REALLOC)
# TODO: repair known memory leaks
llama_build_and_test(test-opt.cpp)
endif()

View File

@ -1446,14 +1446,14 @@ struct test_case {
const uint64_t target_flops_cpu = 8ULL * GFLOP;
const uint64_t target_flops_gpu = 100ULL * GFLOP;
uint64_t target_flops = is_cpu ? target_flops_cpu : target_flops_gpu;
n_runs = std::min<int>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_flops / op_flops(out)) + 1;
n_runs = (int)std::min<int64_t>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_flops / op_flops(out)) + 1;
} else {
// based on memory size
const size_t GB = 1ULL << 30;
const size_t target_size_cpu = 8 * GB;
const size_t target_size_gpu = 32 * GB;
size_t target_size = is_cpu ? target_size_cpu : target_size_gpu;
n_runs = std::min<int>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1;
n_runs = (int)std::min<int64_t>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1;
}
// duplicate the op
@ -7935,6 +7935,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 16416, 1, 128, {8, 1}, {4, 1}, {0, 2, 1, 3}));
test_cases.emplace_back(new test_mul_mat(GGML_TYPE_F16, GGML_TYPE_F32, 128, 1, 16416, {8, 1}, {4, 1}, {0, 1, 2, 3}, 2*16416));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 64, 64, 4, 2 }, { 6, 64, 4, 2 }));
test_cases.emplace_back(new test_solve_tri(GGML_TYPE_F32, { 128, 128, 4, 1 }, { 8, 128, 4, 1 }));
for (int bs : {1, 2, 3, 4, 5, 8, 512}) {
for (ggml_type type_a : all_types) {
for (ggml_type type_b : {GGML_TYPE_F32}) {
@ -8042,6 +8045,8 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1}));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {200000, 1, 1, 1}));
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {200000, 16, 1, 1}));
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2, 1, 1, 1}, 1));
for (auto k : {1, 10, 40, 400}) {
for (auto nrows : {1, 16}) {
for (auto cols : {k, 1000, 65000, 200000}) {

View File

@ -1339,6 +1339,32 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
test({
SUCCESS,
"literal string with escapes",
R"""({
"properties": {
"code": {
"const": " \r \n \" \\ ",
"description": "Generated code",
"title": "Code",
"type": "string"
}
},
"required": [
"code"
],
"title": "DecoderResponse",
"type": "object"
})""",
R"""(
code ::= "\" \\r \\n \\\" \\\\ \"" space
code-kv ::= "\"code\"" space ":" space code
root ::= "{" space code-kv "}" space
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
}
int main() {

View File

@ -1175,10 +1175,11 @@ struct clip_graph {
cb(K, "resampler_K", -1);
cb(V, "resampler_V", -1);
float resampler_kq_scale = 1.0f/ sqrtf(float(d_head));
embeddings = build_attn(
model.mm_model_attn_o_w,
model.mm_model_attn_o_b,
Q, K, V, nullptr, kq_scale, -1);
Q, K, V, nullptr, resampler_kq_scale, -1);
cb(embeddings, "resampler_attn_out", -1);
}
// layernorm

View File

@ -7,6 +7,7 @@ Set of LLM REST APIs and a simple web front end to interact with llama.cpp.
**Features:**
* LLM inference of F16 and quantized models on GPU and CPU
* [OpenAI API](https://github.com/openai/openai-openapi) compatible chat completions and embeddings routes
* [Anthropic Messages API](https://docs.anthropic.com/en/api/messages) compatible chat completions
* Reranking endpoint (https://github.com/ggml-org/llama.cpp/pull/9510)
* Parallel decoding with multi-user support
* Continuous batching
@ -30,9 +31,10 @@ The project is under active development, and we are [looking for feedback and co
| -------- | ----------- |
| `-h, --help, --usage` | print usage and exit |
| `--version` | show version and build info |
| `-cl, --cache-list` | show list of models in cache |
| `--completion-bash` | print source-able bash completion script for llama.cpp |
| `--verbose-prompt` | print a verbose prompt before generation (default: false) |
| `-t, --threads N` | number of threads to use during generation (default: -1)<br/>(env: LLAMA_ARG_THREADS) |
| `-t, --threads N` | number of CPU threads to use during generation (default: -1)<br/>(env: LLAMA_ARG_THREADS) |
| `-tb, --threads-batch N` | number of threads to use during batch and prompt processing (default: same as --threads) |
| `-C, --cpu-mask M` | CPU affinity mask: arbitrarily long hex. Complements cpu-range (default: "") |
| `-Cr, --cpu-range lo-hi` | range of CPUs for affinity. Complements --cpu-mask |
@ -51,7 +53,7 @@ The project is under active development, and we are [looking for feedback and co
| `--keep N` | number of tokens to keep from the initial prompt (default: 0, -1 = all) |
| `--swa-full` | use full-size SWA cache (default: false)<br/>[(more info)](https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055)<br/>(env: LLAMA_ARG_SWA_FULL) |
| `--kv-unified, -kvu` | use single unified KV buffer for the KV cache of all sequences (default: false)<br/>[(more info)](https://github.com/ggml-org/llama.cpp/pull/14363)<br/>(env: LLAMA_ARG_KV_SPLIT) |
| `-fa, --flash-attn` | enable Flash Attention (default: disabled)<br/>(env: LLAMA_ARG_FLASH_ATTN) |
| `-fa, --flash-attn [on\|off\|auto]` | set Flash Attention use ('on', 'off', or 'auto', default: 'auto')<br/>(env: LLAMA_ARG_FLASH_ATTN) |
| `--no-perf` | disable internal libllama performance timings (default: false)<br/>(env: LLAMA_ARG_NO_PERF) |
| `-e, --escape` | process escapes sequences (\n, \r, \t, \', \", \\) (default: true) |
| `--no-escape` | do not process escape sequences |
@ -61,11 +63,12 @@ The project is under active development, and we are [looking for feedback and co
| `--rope-freq-scale N` | RoPE frequency scaling factor, expands context by a factor of 1/N<br/>(env: LLAMA_ARG_ROPE_FREQ_SCALE) |
| `--yarn-orig-ctx N` | YaRN: original context size of model (default: 0 = model training context size)<br/>(env: LLAMA_ARG_YARN_ORIG_CTX) |
| `--yarn-ext-factor N` | YaRN: extrapolation mix factor (default: -1.0, 0.0 = full interpolation)<br/>(env: LLAMA_ARG_YARN_EXT_FACTOR) |
| `--yarn-attn-factor N` | YaRN: scale sqrt(t) or attention magnitude (default: 1.0)<br/>(env: LLAMA_ARG_YARN_ATTN_FACTOR) |
| `--yarn-beta-slow N` | YaRN: high correction dim or alpha (default: 1.0)<br/>(env: LLAMA_ARG_YARN_BETA_SLOW) |
| `--yarn-beta-fast N` | YaRN: low correction dim or beta (default: 32.0)<br/>(env: LLAMA_ARG_YARN_BETA_FAST) |
| `--yarn-attn-factor N` | YaRN: scale sqrt(t) or attention magnitude (default: -1.0)<br/>(env: LLAMA_ARG_YARN_ATTN_FACTOR) |
| `--yarn-beta-slow N` | YaRN: high correction dim or alpha (default: -1.0)<br/>(env: LLAMA_ARG_YARN_BETA_SLOW) |
| `--yarn-beta-fast N` | YaRN: low correction dim or beta (default: -1.0)<br/>(env: LLAMA_ARG_YARN_BETA_FAST) |
| `-nkvo, --no-kv-offload` | disable KV offload<br/>(env: LLAMA_ARG_NO_KV_OFFLOAD) |
| `-nr, --no-repack` | disable weight repacking<br/>(env: LLAMA_ARG_NO_REPACK) |
| `--no-host` | bypass host buffer allowing extra buffers to be used<br/>(env: LLAMA_ARG_NO_HOST) |
| `-ctk, --cache-type-k TYPE` | KV cache data type for K<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_K) |
| `-ctv, --cache-type-v TYPE` | KV cache data type for V<br/>allowed values: f32, f16, bf16, q8_0, q4_0, q4_1, iq4_nl, q5_0, q5_1<br/>(default: f16)<br/>(env: LLAMA_ARG_CACHE_TYPE_V) |
| `-dt, --defrag-thold N` | KV cache defragmentation threshold (DEPRECATED)<br/>(env: LLAMA_ARG_DEFRAG_THOLD) |
@ -78,7 +81,7 @@ The project is under active development, and we are [looking for feedback and co
| `--override-tensor, -ot <tensor name pattern>=<buffer type>,...` | override tensor buffer type |
| `--cpu-moe, -cmoe` | keep all Mixture of Experts (MoE) weights in the CPU<br/>(env: LLAMA_ARG_CPU_MOE) |
| `--n-cpu-moe, -ncmoe N` | keep the Mixture of Experts (MoE) weights of the first N layers in the CPU<br/>(env: LLAMA_ARG_N_CPU_MOE) |
| `-ngl, --gpu-layers, --n-gpu-layers N` | number of layers to store in VRAM<br/>(env: LLAMA_ARG_N_GPU_LAYERS) |
| `-ngl, --gpu-layers, --n-gpu-layers N` | max. number of layers to store in VRAM (default: -1)<br/>(env: LLAMA_ARG_N_GPU_LAYERS) |
| `-sm, --split-mode {none,layer,row}` | how to split the model across multiple GPUs, one of:<br/>- none: use one GPU only<br/>- layer (default): split layers and KV across GPUs<br/>- row: split rows across GPUs<br/>(env: LLAMA_ARG_SPLIT_MODE) |
| `-ts, --tensor-split N0,N1,N2,...` | fraction of the model to offload to each GPU, comma-separated list of proportions, e.g. 3,1<br/>(env: LLAMA_ARG_TENSOR_SPLIT) |
| `-mg, --main-gpu INDEX` | the GPU to use for the model (with split-mode = none), or for intermediate results and KV (with split-mode = row) (default: 0)<br/>(env: LLAMA_ARG_MAIN_GPU) |
@ -92,6 +95,7 @@ The project is under active development, and we are [looking for feedback and co
| `--control-vector-layer-range START END` | layer range to apply the control vector(s) to, start and end inclusive |
| `-m, --model FNAME` | model path (default: `models/$filename` with filename from `--hf-file` or `--model-url` if set, otherwise models/7B/ggml-model-f16.gguf)<br/>(env: LLAMA_ARG_MODEL) |
| `-mu, --model-url MODEL_URL` | model download url (default: unused)<br/>(env: LLAMA_ARG_MODEL_URL) |
| `-dr, --docker-repo [<repo>/]<model>[:quant]` | Docker Hub model repository. repo is optional, default to ai/. quant is optional, default to :latest.<br/>example: gemma3<br/>(default: unused)<br/>(env: LLAMA_ARG_DOCKER_REPO) |
| `-hf, -hfr, --hf-repo <user>/<model>[:quant]` | Hugging Face model repository; quant is optional, case-insensitive, default to Q4_K_M, or falls back to the first file in the repo if Q4_K_M doesn't exist.<br/>mmproj is also downloaded automatically if available. to disable, add --no-mmproj<br/>example: unsloth/phi-4-GGUF:q4_k_m<br/>(default: unused)<br/>(env: LLAMA_ARG_HF_REPO) |
| `-hfd, -hfrd, --hf-repo-draft <user>/<model>[:quant]` | Same as --hf-repo, but for the draft model (default: unused)<br/>(env: LLAMA_ARG_HFD_REPO) |
| `-hff, --hf-file FILE` | Hugging Face model file. If specified, it will override the quant in --hf-repo (default: unused)<br/>(env: LLAMA_ARG_HF_FILE) |
@ -100,7 +104,7 @@ The project is under active development, and we are [looking for feedback and co
| `-hft, --hf-token TOKEN` | Hugging Face access token (default: value from HF_TOKEN environment variable)<br/>(env: HF_TOKEN) |
| `--log-disable` | Log disable |
| `--log-file FNAME` | Log to file |
| `--log-colors` | Enable colored logging<br/>(env: LLAMA_LOG_COLORS) |
| `--log-colors [on\|off\|auto]` | Set colored logging ('on', 'off', or 'auto', default: 'auto')<br/>'auto' enables colors when output is to a terminal<br/>(env: LLAMA_LOG_COLORS) |
| `-v, --verbose, --log-verbose` | Set verbosity level to infinity (i.e. log all messages, useful for debugging) |
| `--offline` | Offline mode: forces use of cache, prevents network access<br/>(env: LLAMA_OFFLINE) |
| `-lv, --verbosity, --log-verbosity N` | Set the verbosity threshold. Messages with a higher verbosity will be ignored.<br/>(env: LLAMA_LOG_VERBOSITY) |
@ -151,7 +155,8 @@ The project is under active development, and we are [looking for feedback and co
| Argument | Explanation |
| -------- | ----------- |
| `--swa-checkpoints N` | max number of SWA checkpoints per slot to create (default: 3)<br/>[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)<br/>(env: LLAMA_ARG_SWA_CHECKPOINTS) |
| `--ctx-checkpoints, --swa-checkpoints N` | max number of context checkpoints to create per slot (default: 8)<br/>[(more info)](https://github.com/ggml-org/llama.cpp/pull/15293)<br/>(env: LLAMA_ARG_CTX_CHECKPOINTS) |
| `--cache-ram, -cram N` | set the maximum cache size in MiB (default: 8192, -1 - no limit, 0 - disable)<br/>[(more info)](https://github.com/ggml-org/llama.cpp/pull/16391)<br/>(env: LLAMA_ARG_CACHE_RAM) |
| `--no-context-shift` | disables context shift on infinite text generation (default: enabled)<br/>(env: LLAMA_ARG_NO_CONTEXT_SHIFT) |
| `--context-shift` | enables context shift on infinite text generation (default: disabled)<br/>(env: LLAMA_ARG_CONTEXT_SHIFT) |
| `-r, --reverse-prompt PROMPT` | halt generation at PROMPT, return control in interactive mode<br/> |
@ -165,6 +170,8 @@ The project is under active development, and we are [looking for feedback and co
| `--mmproj-url URL` | URL to a multimodal projector file. see tools/mtmd/README.md<br/>(env: LLAMA_ARG_MMPROJ_URL) |
| `--no-mmproj` | explicitly disable multimodal projector, useful when using -hf<br/>(env: LLAMA_ARG_NO_MMPROJ) |
| `--no-mmproj-offload` | do not offload multimodal projector to GPU<br/>(env: LLAMA_ARG_NO_MMPROJ_OFFLOAD) |
| `--image-min-tokens N` | minimum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)<br/>(env: LLAMA_ARG_IMAGE_MIN_TOKENS) |
| `--image-max-tokens N` | maximum number of tokens each image can take, only used by vision models with dynamic resolution (default: read from model)<br/>(env: LLAMA_ARG_IMAGE_MAX_TOKENS) |
| `--override-tensor-draft, -otd <tensor name pattern>=<buffer type>,...` | override tensor buffer type for draft model |
| `--cpu-moe-draft, -cmoed` | keep all Mixture of Experts (MoE) weights in the CPU for the draft model<br/>(env: LLAMA_ARG_CPU_MOE_DRAFT) |
| `--n-cpu-moe-draft, -ncmoed N` | keep the Mixture of Experts (MoE) weights of the first N layers in the CPU for the draft model<br/>(env: LLAMA_ARG_N_CPU_MOE_DRAFT) |
@ -189,13 +196,14 @@ The project is under active development, and we are [looking for feedback and co
| `--slots` | enable slots monitoring endpoint (default: enabled)<br/>(env: LLAMA_ARG_ENDPOINT_SLOTS) |
| `--no-slots` | disables slots monitoring endpoint<br/>(env: LLAMA_ARG_NO_ENDPOINT_SLOTS) |
| `--slot-save-path PATH` | path to save slot kv cache (default: disabled) |
| `--jinja` | use jinja template for chat (default: disabled)<br/>(env: LLAMA_ARG_JINJA) |
| `--reasoning-format FORMAT` | controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:<br/>- none: leaves thoughts unparsed in `message.content`<br/>- deepseek: puts thoughts in `message.reasoning_content`<br/>- deepseek-legacy: keeps `<think>` tags in `message.content` while also populating `message.reasoning_content`<br/>(default: deepseek)<br/>(env: LLAMA_ARG_THINK) |
| `--jinja` | use jinja template for chat (default: enabled)<br/><br/>(env: LLAMA_ARG_JINJA) |
| `--no-jinja` | disable jinja template for chat (default: enabled)<br/><br/>(env: LLAMA_ARG_NO_JINJA) |
| `--reasoning-format FORMAT` | controls whether thought tags are allowed and/or extracted from the response, and in which format they're returned; one of:<br/>- none: leaves thoughts unparsed in `message.content`<br/>- deepseek: puts thoughts in `message.reasoning_content`<br/>- deepseek-legacy: keeps `<think>` tags in `message.content` while also populating `message.reasoning_content`<br/>(default: auto)<br/>(env: LLAMA_ARG_THINK) |
| `--reasoning-budget N` | controls the amount of thinking allowed; currently only one of: -1 for unrestricted thinking budget, or 0 to disable thinking (default: -1)<br/>(env: LLAMA_ARG_THINK_BUDGET) |
| `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)<br/>if suffix/prefix are specified, template will be disabled<br/>only commonly used templates are accepted (unless --jinja is set before this flag):<br/>list of built-in templates:<br/>bailing, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, phi3, phi4, rwkv-world, seed_oss, smolvlm, vicuna, vicuna-orca, yandex, zephyr<br/>(env: LLAMA_ARG_CHAT_TEMPLATE) |
| `--chat-template-file JINJA_TEMPLATE_FILE` | set custom jinja chat template file (default: template taken from model's metadata)<br/>if suffix/prefix are specified, template will be disabled<br/>only commonly used templates are accepted (unless --jinja is set before this flag):<br/>list of built-in templates:<br/>bailing, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, phi3, phi4, rwkv-world, seed_oss, smolvlm, vicuna, vicuna-orca, yandex, zephyr<br/>(env: LLAMA_ARG_CHAT_TEMPLATE_FILE) |
| `--chat-template JINJA_TEMPLATE` | set custom jinja chat template (default: template taken from model's metadata)<br/>if suffix/prefix are specified, template will be disabled<br/>only commonly used templates are accepted (unless --jinja is set before this flag):<br/>list of built-in templates:<br/>bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, grok-2, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, vicuna, vicuna-orca, yandex, zephyr<br/>(env: LLAMA_ARG_CHAT_TEMPLATE) |
| `--chat-template-file JINJA_TEMPLATE_FILE` | set custom jinja chat template file (default: template taken from model's metadata)<br/>if suffix/prefix are specified, template will be disabled<br/>only commonly used templates are accepted (unless --jinja is set before this flag):<br/>list of built-in templates:<br/>bailing, bailing-think, bailing2, chatglm3, chatglm4, chatml, command-r, deepseek, deepseek2, deepseek3, exaone3, exaone4, falcon3, gemma, gigachat, glmedge, gpt-oss, granite, grok-2, hunyuan-dense, hunyuan-moe, kimi-k2, llama2, llama2-sys, llama2-sys-bos, llama2-sys-strip, llama3, llama4, megrez, minicpm, mistral-v1, mistral-v3, mistral-v3-tekken, mistral-v7, mistral-v7-tekken, monarch, openchat, orion, pangu-embedded, phi3, phi4, rwkv-world, seed_oss, smolvlm, vicuna, vicuna-orca, yandex, zephyr<br/>(env: LLAMA_ARG_CHAT_TEMPLATE_FILE) |
| `--no-prefill-assistant` | whether to prefill the assistant's response if the last message is an assistant message (default: prefill enabled)<br/>when this flag is set, if the last message is an assistant message then it will be treated as a full message and not prefilled<br/><br/>(env: LLAMA_ARG_NO_PREFILL_ASSISTANT) |
| `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.50, 0.0 = disabled)<br/> |
| `-sps, --slot-prompt-similarity SIMILARITY` | how much the prompt of a request must match the prompt of a slot in order to use that slot (default: 0.10, 0.0 = disabled)<br/> |
| `--lora-init-without-apply` | load LoRA adapters without applying them (apply later via POST /lora-adapters) (default: disabled) |
| `-td, --threads-draft N` | number of threads to use during generation (default: same as --threads) |
| `-tbd, --threads-batch-draft N` | number of threads to use during batch and prompt processing (default: same as --threads-draft) |
@ -209,15 +217,17 @@ The project is under active development, and we are [looking for feedback and co
| `--spec-replace TARGET DRAFT` | translate the string in TARGET into DRAFT if the draft model and main model are not compatible |
| `-mv, --model-vocoder FNAME` | vocoder model for audio generation (default: unused) |
| `--tts-use-guide-tokens` | Use guide tokens to improve TTS word recall |
| `--embd-bge-small-en-default` | use default bge-small-en-v1.5 model (note: can download weights from the internet) |
| `--embd-e5-small-en-default` | use default e5-small-v2 model (note: can download weights from the internet) |
| `--embd-gte-small-default` | use default gte-small model (note: can download weights from the internet) |
| `--embd-gemma-default` | use default EmbeddingGemma model (note: can download weights from the internet) |
| `--fim-qwen-1.5b-default` | use default Qwen 2.5 Coder 1.5B (note: can download weights from the internet) |
| `--fim-qwen-3b-default` | use default Qwen 2.5 Coder 3B (note: can download weights from the internet) |
| `--fim-qwen-7b-default` | use default Qwen 2.5 Coder 7B (note: can download weights from the internet) |
| `--fim-qwen-7b-spec` | use Qwen 2.5 Coder 7B + 0.5B draft for speculative decoding (note: can download weights from the internet) |
| `--fim-qwen-14b-spec` | use Qwen 2.5 Coder 14B + 0.5B draft for speculative decoding (note: can download weights from the internet) |
| `--fim-qwen-30b-default` | use default Qwen 3 Coder 30B A3B Instruct (note: can download weights from the internet) |
| `--gpt-oss-20b-default` | use gpt-oss-20b (note: can download weights from the internet) |
| `--gpt-oss-120b-default` | use gpt-oss-120b (note: can download weights from the internet) |
| `--vision-gemma-4b-default` | use Gemma 3 4B QAT (note: can download weights from the internet) |
| `--vision-gemma-12b-default` | use Gemma 3 12B QAT (note: can download weights from the internet) |
Note: If both command line argument and environment variable are both set for the same param, the argument will take precedence over env var.
@ -1343,6 +1353,77 @@ See [OpenAI Embeddings API documentation](https://platform.openai.com/docs/api-r
}'
```
### POST `/v1/messages`: Anthropic-compatible Messages API
Given a list of `messages`, returns the assistant's response. Streaming is supported via Server-Sent Events. While no strong claims of compatibility with the Anthropic API spec are made, in our experience it suffices to support many apps.
*Options:*
See [Anthropic Messages API documentation](https://docs.anthropic.com/en/api/messages). Tool use requires `--jinja` flag.
`model`: Model identifier (required)
`messages`: Array of message objects with `role` and `content` (required)
`max_tokens`: Maximum tokens to generate (default: 4096)
`system`: System prompt as string or array of content blocks
`temperature`: Sampling temperature 0-1 (default: 1.0)
`top_p`: Nucleus sampling (default: 1.0)
`top_k`: Top-k sampling
`stop_sequences`: Array of stop sequences
`stream`: Enable streaming (default: false)
`tools`: Array of tool definitions (requires `--jinja`)
`tool_choice`: Tool selection mode (`{"type": "auto"}`, `{"type": "any"}`, or `{"type": "tool", "name": "..."}`)
*Examples:*
```shell
curl http://localhost:8080/v1/messages \
-H "Content-Type: application/json" \
-H "x-api-key: your-api-key" \
-d '{
"model": "gpt-4",
"max_tokens": 1024,
"system": "You are a helpful assistant.",
"messages": [
{"role": "user", "content": "Hello!"}
]
}'
```
### POST `/v1/messages/count_tokens`: Token Counting
Counts the number of tokens in a request without generating a response.
Accepts the same parameters as `/v1/messages`. The `max_tokens` parameter is not required.
*Example:*
```shell
curl http://localhost:8080/v1/messages/count_tokens \
-H "Content-Type: application/json" \
-d '{
"model": "gpt-4",
"messages": [
{"role": "user", "content": "Hello!"}
]
}'
```
*Response:*
```json
{"input_tokens": 10}
```
## More examples
### Interactive mode

View File

@ -257,9 +257,9 @@ const STRING_FORMAT_RULES = {
const RESERVED_NAMES = {'root': true, ...PRIMITIVE_RULES, ...STRING_FORMAT_RULES};
const INVALID_RULE_CHARS_RE = /[^\dA-Za-z-]+/g;
const GRAMMAR_LITERAL_ESCAPE_RE = /[\n\r"]/g;
const GRAMMAR_LITERAL_ESCAPE_RE = /[\n\r"\\]/g;
const GRAMMAR_RANGE_LITERAL_ESCAPE_RE = /[\n\r"\]\-\\]/g;
const GRAMMAR_LITERAL_ESCAPES = { '\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]' };
const GRAMMAR_LITERAL_ESCAPES = { '\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]', '\\': '\\\\' };
const NON_LITERAL_SET = new Set('|.()[]{}*+?');
const ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = new Set('^$.[]()|{}*+?');

View File

@ -725,7 +725,6 @@ std::vector<server_tokens> tokenize_input_prompts(const llama_vocab * vocab, mtm
return result;
}
//
// OAI utils
//
@ -1048,6 +1047,222 @@ json oaicompat_chat_params_parse(
return llama_params;
}
json convert_anthropic_to_oai(const json & body) {
json oai_body;
// Convert system prompt
json oai_messages = json::array();
auto system_param = json_value(body, "system", json());
if (!system_param.is_null()) {
std::string system_content;
if (system_param.is_string()) {
system_content = system_param.get<std::string>();
} else if (system_param.is_array()) {
for (const auto & block : system_param) {
if (json_value(block, "type", std::string()) == "text") {
system_content += json_value(block, "text", std::string());
}
}
}
oai_messages.push_back({
{"role", "system"},
{"content", system_content}
});
}
// Convert messages
if (!body.contains("messages")) {
throw std::runtime_error("'messages' is required");
}
const json & messages = body.at("messages");
if (messages.is_array()) {
for (const auto & msg : messages) {
std::string role = json_value(msg, "role", std::string());
if (!msg.contains("content")) {
if (role == "assistant") {
continue;
}
oai_messages.push_back(msg);
continue;
}
const json & content = msg.at("content");
if (content.is_string()) {
oai_messages.push_back(msg);
continue;
}
if (!content.is_array()) {
oai_messages.push_back(msg);
continue;
}
json tool_calls = json::array();
json converted_content = json::array();
json tool_results = json::array();
bool has_tool_calls = false;
for (const auto & block : content) {
std::string type = json_value(block, "type", std::string());
if (type == "text") {
converted_content.push_back(block);
} else if (type == "image") {
json source = json_value(block, "source", json::object());
std::string source_type = json_value(source, "type", std::string());
if (source_type == "base64") {
std::string media_type = json_value(source, "media_type", std::string("image/jpeg"));
std::string data = json_value(source, "data", std::string());
std::ostringstream ss;
ss << "data:" << media_type << ";base64," << data;
converted_content.push_back({
{"type", "image_url"},
{"image_url", {
{"url", ss.str()}
}}
});
} else if (source_type == "url") {
std::string url = json_value(source, "url", std::string());
converted_content.push_back({
{"type", "image_url"},
{"image_url", {
{"url", url}
}}
});
}
} else if (type == "tool_use") {
tool_calls.push_back({
{"id", json_value(block, "id", std::string())},
{"type", "function"},
{"function", {
{"name", json_value(block, "name", std::string())},
{"arguments", json_value(block, "input", json::object()).dump()}
}}
});
has_tool_calls = true;
} else if (type == "tool_result") {
std::string tool_use_id = json_value(block, "tool_use_id", std::string());
auto result_content = json_value(block, "content", json());
std::string result_text;
if (result_content.is_string()) {
result_text = result_content.get<std::string>();
} else if (result_content.is_array()) {
for (const auto & c : result_content) {
if (json_value(c, "type", std::string()) == "text") {
result_text += json_value(c, "text", std::string());
}
}
}
tool_results.push_back({
{"role", "tool"},
{"tool_call_id", tool_use_id},
{"content", result_text}
});
}
}
if (!converted_content.empty() || has_tool_calls) {
json new_msg = {{"role", role}};
if (!converted_content.empty()) {
new_msg["content"] = converted_content;
} else if (has_tool_calls) {
new_msg["content"] = "";
}
if (!tool_calls.empty()) {
new_msg["tool_calls"] = tool_calls;
}
oai_messages.push_back(new_msg);
}
for (const auto & tool_msg : tool_results) {
oai_messages.push_back(tool_msg);
}
}
}
oai_body["messages"] = oai_messages;
// Convert tools
if (body.contains("tools")) {
const json & tools = body.at("tools");
if (tools.is_array()) {
json oai_tools = json::array();
for (const auto & tool : tools) {
oai_tools.push_back({
{"type", "function"},
{"function", {
{"name", json_value(tool, "name", std::string())},
{"description", json_value(tool, "description", std::string())},
{"parameters", tool.contains("input_schema") ? tool.at("input_schema") : json::object()}
}}
});
}
oai_body["tools"] = oai_tools;
}
}
// Convert tool_choice
if (body.contains("tool_choice")) {
const json & tc = body.at("tool_choice");
if (tc.is_object()) {
std::string type = json_value(tc, "type", std::string());
if (type == "auto") {
oai_body["tool_choice"] = "auto";
} else if (type == "any" || type == "tool") {
oai_body["tool_choice"] = "required";
}
}
}
// Convert stop_sequences to stop
if (body.contains("stop_sequences")) {
oai_body["stop"] = body.at("stop_sequences");
}
// Handle max_tokens (required in Anthropic, but we're permissive)
if (body.contains("max_tokens")) {
oai_body["max_tokens"] = body.at("max_tokens");
} else {
oai_body["max_tokens"] = 4096;
}
// Pass through common params
for (const auto & key : {"temperature", "top_p", "top_k", "stream"}) {
if (body.contains(key)) {
oai_body[key] = body.at(key);
}
}
// Handle Anthropic-specific thinking param
if (body.contains("thinking")) {
json thinking = json_value(body, "thinking", json::object());
std::string thinking_type = json_value(thinking, "type", std::string());
if (thinking_type == "enabled") {
int budget_tokens = json_value(thinking, "budget_tokens", 10000);
oai_body["thinking_budget_tokens"] = budget_tokens;
}
}
// Handle Anthropic-specific metadata param
if (body.contains("metadata")) {
json metadata = json_value(body, "metadata", json::object());
std::string user_id = json_value(metadata, "user_id", std::string());
if (!user_id.empty()) {
oai_body["__metadata_user_id"] = user_id;
}
}
return oai_body;
}
json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64) {
json data = json::array();
int32_t n_tokens = 0;
@ -1211,7 +1426,7 @@ std::string tokens_to_output_formatted_string(const llama_context * ctx, const l
// format server-sent event (SSE), return the formatted string to send
// note: if data is a json array, it will be sent as multiple events, one per item
std::string format_sse(const json & data) {
std::string format_oai_sse(const json & data) {
std::ostringstream ss;
auto send_single = [&ss](const json & data) {
ss << "data: " <<
@ -1230,6 +1445,29 @@ std::string format_sse(const json & data) {
return ss.str();
}
std::string format_anthropic_sse(const json & data) {
std::ostringstream ss;
auto send_event = [&ss](const json & event_obj) {
if (event_obj.contains("event") && event_obj.contains("data")) {
ss << "event: " << event_obj.at("event").get<std::string>() << "\n";
ss << "data: " << safe_json_to_str(event_obj.at("data")) << "\n\n";
} else {
ss << "data: " << safe_json_to_str(event_obj) << "\n\n";
}
};
if (data.is_array()) {
for (const auto & event : data) {
send_event(event);
}
} else {
send_event(data);
}
return ss.str();
}
bool is_valid_utf8(const std::string & str) {
const unsigned char* bytes = reinterpret_cast<const unsigned char*>(str.data());
const unsigned char* end = bytes + str.length();

View File

@ -294,6 +294,9 @@ json oaicompat_chat_params_parse(
const oaicompat_parser_options & opt,
std::vector<raw_buffer> & out_files);
// convert Anthropic Messages API format to OpenAI Chat Completions API format
json convert_anthropic_to_oai(const json & body);
// TODO: move it to server-task.cpp
json format_embeddings_response_oaicompat(const json & request, const json & embeddings, bool use_base64 = false);
@ -320,7 +323,10 @@ std::string tokens_to_output_formatted_string(const llama_context * ctx, const l
// format server-sent event (SSE), return the formatted string to send
// note: if data is a json array, it will be sent as multiple events, one per item
std::string format_sse(const json & data);
std::string format_oai_sse(const json & data);
// format Anthropic-style SSE with event types
std::string format_anthropic_sse(const json & data);
bool is_valid_utf8(const std::string & str);

View File

@ -136,15 +136,22 @@ bool server_http_context::init(const common_params & params) {
return true;
}
// Check for API key in the header
auto auth_header = req.get_header_value("Authorization");
// Check for API key in the Authorization header
std::string req_api_key = req.get_header_value("Authorization");
if (req_api_key.empty()) {
// retry with anthropic header
req_api_key = req.get_header_value("X-Api-Key");
}
// remove the "Bearer " prefix if needed
std::string prefix = "Bearer ";
if (auth_header.substr(0, prefix.size()) == prefix) {
std::string received_api_key = auth_header.substr(prefix.size());
if (std::find(api_keys.begin(), api_keys.end(), received_api_key) != api_keys.end()) {
return true; // API key is valid
}
if (req_api_key.substr(0, prefix.size()) == prefix) {
req_api_key = req_api_key.substr(prefix.size());
}
// validate the API key
if (std::find(api_keys.begin(), api_keys.end(), req_api_key) != api_keys.end()) {
return true; // API key is valid
}
// API key is invalid or not provided

View File

@ -199,7 +199,7 @@ server_task_result_ptr server_response::recv(const std::unordered_set<int> & id_
std::unique_lock<std::mutex> lock(mutex_results);
condition_results.wait(lock, [&]{
if (!running) {
RES_DBG("%s : queue result stop\n", __func__);
RES_DBG("%s : queue result stop\n", "recv");
std::terminate(); // we cannot return here since the caller is HTTP code
}
return !queue_results.empty();

View File

@ -570,15 +570,17 @@ std::vector<unsigned char> completion_token_output::str_to_bytes(const std::stri
// server_task_result_cmpl_final
//
json server_task_result_cmpl_final::to_json() {
switch (oaicompat) {
case OAICOMPAT_TYPE_NONE:
switch (res_type) {
case TASK_RESPONSE_TYPE_NONE:
return to_json_non_oaicompat();
case OAICOMPAT_TYPE_COMPLETION:
case TASK_RESPONSE_TYPE_OAI_CMPL:
return to_json_oaicompat();
case OAICOMPAT_TYPE_CHAT:
case TASK_RESPONSE_TYPE_OAI_CHAT:
return stream ? to_json_oaicompat_chat_stream() : to_json_oaicompat_chat();
case TASK_RESPONSE_TYPE_ANTHROPIC:
return stream ? to_json_anthropic_stream() : to_json_anthropic();
default:
GGML_ASSERT(false && "Invalid oaicompat_type");
GGML_ASSERT(false && "Invalid task_response_type");
}
}
@ -773,19 +775,203 @@ json server_task_result_cmpl_final::to_json_oaicompat_chat_stream() {
return deltas;
}
json server_task_result_cmpl_final::to_json_anthropic() {
std::string stop_reason = "max_tokens";
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use";
}
json content_blocks = json::array();
common_chat_msg msg;
if (!oaicompat_msg.empty()) {
msg = oaicompat_msg;
} else {
msg.role = "assistant";
msg.content = content;
}
if (!msg.content.empty()) {
content_blocks.push_back({
{"type", "text"},
{"text", msg.content}
});
}
for (const auto & tool_call : msg.tool_calls) {
json tool_use_block = {
{"type", "tool_use"},
{"id", tool_call.id},
{"name", tool_call.name}
};
try {
tool_use_block["input"] = json::parse(tool_call.arguments);
} catch (const std::exception &) {
tool_use_block["input"] = json::object();
}
content_blocks.push_back(tool_use_block);
}
json res = {
{"id", oaicompat_cmpl_id},
{"type", "message"},
{"role", "assistant"},
{"content", content_blocks},
{"model", oaicompat_model},
{"stop_reason", stop_reason},
{"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)},
{"usage", {
{"input_tokens", n_prompt_tokens},
{"output_tokens", n_decoded}
}}
};
return res;
}
json server_task_result_cmpl_final::to_json_anthropic_stream() {
json events = json::array();
std::string stop_reason = "max_tokens";
if (stop == STOP_TYPE_WORD || stop == STOP_TYPE_EOS) {
stop_reason = oaicompat_msg.tool_calls.empty() ? "end_turn" : "tool_use";
}
bool has_text = !oaicompat_msg.content.empty();
size_t num_tool_calls = oaicompat_msg.tool_calls.size();
bool text_block_started = false;
std::unordered_set<size_t> tool_calls_started;
for (const auto & diff : oaicompat_msg_diffs) {
if (!diff.content_delta.empty()) {
if (!text_block_started) {
events.push_back({
{"event", "content_block_start"},
{"data", {
{"type", "content_block_start"},
{"index", 0},
{"content_block", {
{"type", "text"},
{"text", ""}
}}
}}
});
text_block_started = true;
}
events.push_back({
{"event", "content_block_delta"},
{"data", {
{"type", "content_block_delta"},
{"index", 0},
{"delta", {
{"type", "text_delta"},
{"text", diff.content_delta}
}}
}}
});
}
if (diff.tool_call_index != std::string::npos) {
size_t content_block_index = (has_text ? 1 : 0) + diff.tool_call_index;
if (tool_calls_started.find(diff.tool_call_index) == tool_calls_started.end()) {
const auto & full_tool_call = oaicompat_msg.tool_calls[diff.tool_call_index];
events.push_back({
{"event", "content_block_start"},
{"data", {
{"type", "content_block_start"},
{"index", content_block_index},
{"content_block", {
{"type", "tool_use"},
{"id", full_tool_call.id},
{"name", full_tool_call.name}
}}
}}
});
tool_calls_started.insert(diff.tool_call_index);
}
if (!diff.tool_call_delta.arguments.empty()) {
events.push_back({
{"event", "content_block_delta"},
{"data", {
{"type", "content_block_delta"},
{"index", content_block_index},
{"delta", {
{"type", "input_json_delta"},
{"partial_json", diff.tool_call_delta.arguments}
}}
}}
});
}
}
}
if (has_text) {
events.push_back({
{"event", "content_block_stop"},
{"data", {
{"type", "content_block_stop"},
{"index", 0}
}}
});
}
for (size_t i = 0; i < num_tool_calls; i++) {
size_t content_block_index = (has_text ? 1 : 0) + i;
events.push_back({
{"event", "content_block_stop"},
{"data", {
{"type", "content_block_stop"},
{"index", content_block_index}
}}
});
}
events.push_back({
{"event", "message_delta"},
{"data", {
{"type", "message_delta"},
{"delta", {
{"stop_reason", stop_reason},
{"stop_sequence", stopping_word.empty() ? nullptr : json(stopping_word)}
}},
{"usage", {
{"output_tokens", n_decoded}
}}
}}
});
events.push_back({
{"event", "message_stop"},
{"data", {
{"type", "message_stop"}
}}
});
return events;
}
//
// server_task_result_cmpl_partial
//
json server_task_result_cmpl_partial::to_json() {
switch (oaicompat) {
case OAICOMPAT_TYPE_NONE:
switch (res_type) {
case TASK_RESPONSE_TYPE_NONE:
return to_json_non_oaicompat();
case OAICOMPAT_TYPE_COMPLETION:
case TASK_RESPONSE_TYPE_OAI_CMPL:
return to_json_oaicompat();
case OAICOMPAT_TYPE_CHAT:
case TASK_RESPONSE_TYPE_OAI_CHAT:
return to_json_oaicompat_chat();
case TASK_RESPONSE_TYPE_ANTHROPIC:
return to_json_anthropic();
default:
GGML_ASSERT(false && "Invalid oaicompat_type");
GGML_ASSERT(false && "Invalid task_response_type");
}
}
@ -910,7 +1096,7 @@ json server_task_result_cmpl_partial::to_json_oaicompat_chat() {
// server_task_result_embd
//
json server_task_result_embd::to_json() {
return oaicompat == OAICOMPAT_TYPE_EMBEDDING
return res_type == TASK_RESPONSE_TYPE_OAI_EMBD
? to_json_oaicompat()
: to_json_non_oaicompat();
}
@ -941,6 +1127,102 @@ json server_task_result_rerank::to_json() {
};
}
json server_task_result_cmpl_partial::to_json_anthropic() {
json events = json::array();
bool first = (n_decoded == 1);
static bool text_block_started = false;
if (first) {
text_block_started = false;
events.push_back({
{"event", "message_start"},
{"data", {
{"type", "message_start"},
{"message", {
{"id", oaicompat_cmpl_id},
{"type", "message"},
{"role", "assistant"},
{"content", json::array()},
{"model", oaicompat_model},
{"stop_reason", nullptr},
{"stop_sequence", nullptr},
{"usage", {
{"input_tokens", n_prompt_tokens},
{"output_tokens", 0}
}}
}}
}}
});
}
for (const auto & diff : oaicompat_msg_diffs) {
if (!diff.content_delta.empty()) {
if (!text_block_started) {
events.push_back({
{"event", "content_block_start"},
{"data", {
{"type", "content_block_start"},
{"index", 0},
{"content_block", {
{"type", "text"},
{"text", ""}
}}
}}
});
text_block_started = true;
}
events.push_back({
{"event", "content_block_delta"},
{"data", {
{"type", "content_block_delta"},
{"index", 0},
{"delta", {
{"type", "text_delta"},
{"text", diff.content_delta}
}}
}}
});
}
if (diff.tool_call_index != std::string::npos) {
size_t content_block_index = (text_block_started ? 1 : 0) + diff.tool_call_index;
if (!diff.tool_call_delta.name.empty()) {
events.push_back({
{"event", "content_block_start"},
{"data", {
{"type", "content_block_start"},
{"index", content_block_index},
{"content_block", {
{"type", "tool_use"},
{"id", diff.tool_call_delta.id},
{"name", diff.tool_call_delta.name}
}}
}}
});
}
if (!diff.tool_call_delta.arguments.empty()) {
events.push_back({
{"event", "content_block_delta"},
{"data", {
{"type", "content_block_delta"},
{"index", content_block_index},
{"delta", {
{"type", "input_json_delta"},
{"partial_json", diff.tool_call_delta.arguments}
}}
}}
});
}
}
}
return events;
}
//
// server_task_result_error
//

View File

@ -27,11 +27,12 @@ enum server_task_type {
};
// TODO: change this to more generic "response_format" to replace the "format_response_*" in server-common
enum oaicompat_type {
OAICOMPAT_TYPE_NONE,
OAICOMPAT_TYPE_CHAT,
OAICOMPAT_TYPE_COMPLETION,
OAICOMPAT_TYPE_EMBEDDING,
enum task_response_type {
TASK_RESPONSE_TYPE_NONE, // llama.cpp native format
TASK_RESPONSE_TYPE_OAI_CHAT,
TASK_RESPONSE_TYPE_OAI_CMPL,
TASK_RESPONSE_TYPE_OAI_EMBD,
TASK_RESPONSE_TYPE_ANTHROPIC,
};
enum stop_type {
@ -66,9 +67,9 @@ struct task_params {
struct common_params_sampling sampling;
struct common_params_speculative speculative;
// OAI-compat fields
// response formatting
bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_syntax oaicompat_chat_syntax;
@ -227,12 +228,12 @@ struct server_task_result_cmpl_final : server_task_result {
task_params generation_params;
// OAI-compat fields
bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_msg oaicompat_msg;
// response formatting
bool verbose = false;
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
common_chat_msg oaicompat_msg;
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
@ -253,6 +254,10 @@ struct server_task_result_cmpl_final : server_task_result {
json to_json_oaicompat_chat();
json to_json_oaicompat_chat_stream();
json to_json_anthropic();
json to_json_anthropic_stream();
};
struct server_task_result_cmpl_partial : server_task_result {
@ -270,11 +275,11 @@ struct server_task_result_cmpl_partial : server_task_result {
result_timings timings;
result_prompt_progress progress;
// OAI-compat fields
bool verbose = false;
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
// response formatting
bool verbose = false;
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
std::string oaicompat_model;
std::string oaicompat_cmpl_id;
std::vector<common_chat_msg_diff> oaicompat_msg_diffs;
virtual int get_index() override {
@ -292,6 +297,8 @@ struct server_task_result_cmpl_partial : server_task_result {
json to_json_oaicompat();
json to_json_oaicompat_chat();
json to_json_anthropic();
};
struct server_task_result_embd : server_task_result {
@ -300,8 +307,8 @@ struct server_task_result_embd : server_task_result {
int32_t n_tokens;
// OAI-compat fields
oaicompat_type oaicompat = OAICOMPAT_TYPE_NONE;
// response formatting
task_response_type res_type = TASK_RESPONSE_TYPE_NONE;
virtual int get_index() override {
return index;

View File

@ -1252,7 +1252,7 @@ struct server_context {
res->post_sampling_probs = slot.task->params.post_sampling_probs;
res->verbose = slot.task->params.verbose;
res->oaicompat = slot.task->params.oaicompat;
res->res_type = slot.task->params.res_type;
res->oaicompat_model = slot.task->params.oaicompat_model;
res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
@ -1294,7 +1294,7 @@ struct server_context {
res->verbose = slot.task->params.verbose;
res->stream = slot.task->params.stream;
res->include_usage = slot.task->params.include_usage;
res->oaicompat = slot.task->params.oaicompat;
res->res_type = slot.task->params.res_type;
res->oaicompat_model = slot.task->params.oaicompat_model;
res->oaicompat_cmpl_id = slot.task->params.oaicompat_cmpl_id;
res->oaicompat_msg = slot.update_chat_msg(res->oaicompat_msg_diffs);
@ -1325,7 +1325,7 @@ struct server_context {
res->id = slot.task->id;
res->index = slot.task->index;
res->n_tokens = slot.task->n_tokens();
res->oaicompat = slot.task->params.oaicompat;
res->res_type = slot.task->params.res_type;
const int n_embd = llama_model_n_embd(model);
@ -2710,7 +2710,8 @@ public:
res->headers["Process-Start-Time-Unix"] = std::to_string(res_task->t_start);
res->content_type = "text/plain; version=0.0.4";
res->ok(prometheus.str());
res->status = 200;
res->data = prometheus.str();
return res;
};
@ -2948,7 +2949,7 @@ public:
data,
files,
req.should_stop,
OAICOMPAT_TYPE_NONE); // infill is not OAI compatible
TASK_RESPONSE_TYPE_NONE); // infill is not OAI compatible
};
server_http_context::handler_t post_completions = [this](const server_http_req & req) {
@ -2959,7 +2960,7 @@ public:
body,
files,
req.should_stop,
OAICOMPAT_TYPE_NONE);
TASK_RESPONSE_TYPE_NONE);
};
server_http_context::handler_t post_completions_oai = [this](const server_http_req & req) {
@ -2970,7 +2971,7 @@ public:
body,
files,
req.should_stop,
OAICOMPAT_TYPE_COMPLETION);
TASK_RESPONSE_TYPE_OAI_CMPL);
};
server_http_context::handler_t post_chat_completions = [this](const server_http_req & req) {
@ -2985,7 +2986,38 @@ public:
body_parsed,
files,
req.should_stop,
OAICOMPAT_TYPE_CHAT);
TASK_RESPONSE_TYPE_OAI_CHAT);
};
server_http_context::handler_t post_anthropic_messages = [this](const server_http_req & req) {
std::vector<raw_buffer> files;
json body = convert_anthropic_to_oai(json::parse(req.body));
json body_parsed = oaicompat_chat_params_parse(
body,
ctx_server.oai_parser_opt,
files);
return handle_completions_impl(
SERVER_TASK_TYPE_COMPLETION,
body_parsed,
files,
req.should_stop,
TASK_RESPONSE_TYPE_ANTHROPIC);
};
server_http_context::handler_t post_anthropic_count_tokens = [this](const server_http_req & req) {
auto res = std::make_unique<server_res_generator>(ctx_server);
std::vector<raw_buffer> files;
json body = convert_anthropic_to_oai(json::parse(req.body));
json body_parsed = oaicompat_chat_params_parse(
body,
ctx_server.oai_parser_opt,
files);
json prompt = body_parsed.at("prompt");
llama_tokens tokens = tokenize_mixed(ctx_server.vocab, prompt, true, true);
res->ok({{"input_tokens", static_cast<int>(tokens.size())}});
return res;
};
// same with handle_chat_completions, but without inference part
@ -3104,11 +3136,11 @@ public:
};
server_http_context::handler_t post_embeddings = [this](const server_http_req & req) {
return handle_embeddings_impl(req, OAICOMPAT_TYPE_NONE);
return handle_embeddings_impl(req, TASK_RESPONSE_TYPE_NONE);
};
server_http_context::handler_t post_embeddings_oai = [this](const server_http_req & req) {
return handle_embeddings_impl(req, OAICOMPAT_TYPE_EMBEDDING);
return handle_embeddings_impl(req, TASK_RESPONSE_TYPE_OAI_EMBD);
};
server_http_context::handler_t post_rerank = [this](const server_http_req & req) {
@ -3259,7 +3291,7 @@ private:
const json & data,
const std::vector<raw_buffer> & files,
const std::function<bool()> & should_stop,
oaicompat_type oaicompat) {
task_response_type res_type) {
GGML_ASSERT(type == SERVER_TASK_TYPE_COMPLETION || type == SERVER_TASK_TYPE_INFILL);
auto res = std::make_unique<server_res_generator>(ctx_server);
@ -3276,7 +3308,7 @@ private:
// process prompt
std::vector<server_tokens> inputs;
if (oaicompat && ctx_server.mctx != nullptr) {
if (res_type != TASK_RESPONSE_TYPE_NONE && ctx_server.mctx != nullptr) {
// This is the case used by OAI compatible chat path with MTMD. TODO It can be moved to the path below.
inputs.push_back(process_mtmd_prompt(ctx_server.mctx, prompt.get<std::string>(), files));
} else {
@ -3298,8 +3330,8 @@ private:
task.id_slot = json_value(data, "id_slot", -1);
// OAI-compat
task.params.oaicompat = oaicompat;
task.params.oaicompat_cmpl_id = completion_id;
task.params.res_type = res_type;
task.params.oaicompat_cmpl_id = completion_id;
// oaicompat_model is already populated by params_from_json_cmpl
tasks.push_back(std::move(task));
@ -3349,10 +3381,14 @@ private:
}
// next responses are streamed
res->data = format_sse(first_result->to_json()); // to be sent immediately
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
res->data = format_anthropic_sse(first_result->to_json());
} else {
res->data = format_oai_sse(first_result->to_json()); // to be sent immediately
}
res->status = 200;
res->content_type = "text/event-stream";
res->next = [res_this = res.get(), oaicompat, &should_stop](std::string & output) -> bool {
res->next = [res_this = res.get(), res_type, &should_stop](std::string & output) -> bool {
if (should_stop()) {
SRV_DBG("%s", "stopping streaming due to should_stop condition\n");
return false; // should_stop condition met
@ -3369,7 +3405,10 @@ private:
// check if there is more data
if (!rd.has_next()) {
if (oaicompat != OAICOMPAT_TYPE_NONE) {
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
// Anthropic doesn't send [DONE], message_stop was already sent
output = "";
} else if (res_type != TASK_RESPONSE_TYPE_NONE) {
output = "data: [DONE]\n\n";
} else {
output = "";
@ -3388,7 +3427,14 @@ private:
// send the results
json res_json = result->to_json();
if (result->is_error()) {
output = format_sse(json {{ "error", res_json }});
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
output = format_anthropic_sse({
{"event", "error"},
{"data", res_json},
});
} else {
output = format_oai_sse(json {{ "error", res_json }});
}
SRV_DBG("%s", "error received during streaming, terminating stream\n");
return false; // terminate on error
} else {
@ -3396,7 +3442,11 @@ private:
dynamic_cast<server_task_result_cmpl_partial*>(result.get()) != nullptr
|| dynamic_cast<server_task_result_cmpl_final*>(result.get()) != nullptr
);
output = format_sse(res_json);
if (res_type == TASK_RESPONSE_TYPE_ANTHROPIC) {
output = format_anthropic_sse(res_json);
} else {
output = format_oai_sse(res_json);
}
}
// has next data, continue
@ -3504,14 +3554,14 @@ private:
return res;
}
std::unique_ptr<server_res_generator> handle_embeddings_impl(const server_http_req & req, oaicompat_type oaicompat) {
std::unique_ptr<server_res_generator> handle_embeddings_impl(const server_http_req & req, task_response_type res_type) {
auto res = std::make_unique<server_res_generator>(ctx_server);
if (!ctx_server.params_base.embedding) {
res->error(format_error_response("This server does not support embeddings. Start it with `--embeddings`", ERROR_TYPE_NOT_SUPPORTED));
return res;
}
if (oaicompat != OAICOMPAT_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
if (res_type != TASK_RESPONSE_TYPE_NONE && llama_pooling_type(ctx_server.ctx) == LLAMA_POOLING_TYPE_NONE) {
res->error(format_error_response("Pooling type 'none' is not OAI compatible. Please use a different pooling type", ERROR_TYPE_INVALID_REQUEST));
return res;
}
@ -3523,7 +3573,7 @@ private:
if (body.count("input") != 0) {
prompt = body.at("input");
} else if (body.contains("content")) {
oaicompat = OAICOMPAT_TYPE_NONE; // "content" field is not OAI compatible
res_type = TASK_RESPONSE_TYPE_NONE; // "content" field is not OAI compatible
prompt = body.at("content");
} else {
res->error(format_error_response("\"input\" or \"content\" must be provided", ERROR_TYPE_INVALID_REQUEST));
@ -3571,7 +3621,7 @@ private:
task.tokens = std::move(tokenized_prompts[i]);
// OAI-compat
task.params.oaicompat = oaicompat;
task.params.res_type = res_type;
task.params.embd_normalize = embd_normalize;
tasks.push_back(std::move(task));
@ -3596,7 +3646,7 @@ private:
}
// write JSON response
json root = oaicompat == OAICOMPAT_TYPE_EMBEDDING
json root = res_type == TASK_RESPONSE_TYPE_OAI_EMBD
? format_embeddings_response_oaicompat(body, responses, use_base64)
: json(responses);
res->ok(root);
@ -3709,6 +3759,8 @@ int main(int argc, char ** argv) {
ctx_http.post("/chat/completions", ex_wrapper(routes.post_chat_completions));
ctx_http.post("/v1/chat/completions", ex_wrapper(routes.post_chat_completions));
ctx_http.post("/api/chat", ex_wrapper(routes.post_chat_completions)); // ollama specific endpoint
ctx_http.post("/v1/messages", ex_wrapper(routes.post_anthropic_messages)); // anthropic messages API
ctx_http.post("/v1/messages/count_tokens", ex_wrapper(routes.post_anthropic_count_tokens)); // anthropic token counting
ctx_http.post("/infill", ex_wrapper(routes.post_infill));
ctx_http.post("/embedding", ex_wrapper(routes.post_embeddings)); // legacy
ctx_http.post("/embeddings", ex_wrapper(routes.post_embeddings));

View File

@ -13,3 +13,9 @@ def stop_server_after_each_test():
) # copy the set to prevent 'Set changed size during iteration'
for server in instances:
server.stop()
@pytest.fixture(scope="module", autouse=True)
def do_something():
# this will be run once per test session, before any tests
ServerPreset.load_all()

View File

@ -5,12 +5,6 @@ from utils import *
server = ServerPreset.tinyllama2()
@pytest.fixture(scope="session", autouse=True)
def do_something():
# this will be run once per test session, before any tests
ServerPreset.load_all()
@pytest.fixture(autouse=True)
def create_server():
global server

View File

@ -0,0 +1,807 @@
#!/usr/bin/env python3
import pytest
import base64
import requests
from utils import *
server: ServerProcess
def get_test_image_base64() -> str:
"""Get a test image in base64 format"""
# Use the same test image as test_vision_api.py
IMG_URL = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
response = requests.get(IMG_URL)
response.raise_for_status()
return base64.b64encode(response.content).decode("utf-8")
@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
server.model_alias = "tinyllama-2-anthropic"
server.server_port = 8082
server.n_slots = 1
server.n_ctx = 8192
server.n_batch = 2048
@pytest.fixture
def vision_server():
"""Separate fixture for vision tests that require multimodal support"""
global server
server = ServerPreset.tinygemma3()
server.offline = False # Allow downloading the model
server.model_alias = "tinygemma3-anthropic"
server.server_port = 8083 # Different port to avoid conflicts
server.n_slots = 1
return server
# Basic message tests
def test_anthropic_messages_basic():
"""Test basic Anthropic messages endpoint"""
server.start()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 50,
"messages": [
{"role": "user", "content": "Say hello"}
]
})
assert res.status_code == 200, f"Expected 200, got {res.status_code}"
assert res.body["type"] == "message", f"Expected type 'message', got {res.body.get('type')}"
assert res.body["role"] == "assistant", f"Expected role 'assistant', got {res.body.get('role')}"
assert "content" in res.body, "Missing 'content' field"
assert isinstance(res.body["content"], list), "Content should be an array"
assert len(res.body["content"]) > 0, "Content array should not be empty"
assert res.body["content"][0]["type"] == "text", "First content block should be text"
assert "text" in res.body["content"][0], "Text content block missing 'text' field"
assert res.body["stop_reason"] in ["end_turn", "max_tokens"], f"Invalid stop_reason: {res.body.get('stop_reason')}"
assert "usage" in res.body, "Missing 'usage' field"
assert "input_tokens" in res.body["usage"], "Missing usage.input_tokens"
assert "output_tokens" in res.body["usage"], "Missing usage.output_tokens"
assert isinstance(res.body["usage"]["input_tokens"], int), "input_tokens should be integer"
assert isinstance(res.body["usage"]["output_tokens"], int), "output_tokens should be integer"
assert res.body["usage"]["output_tokens"] > 0, "Should have generated some tokens"
# Anthropic API should NOT include timings
assert "timings" not in res.body, "Anthropic API should not include timings field"
def test_anthropic_messages_with_system():
"""Test messages with system prompt"""
server.start()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 50,
"system": "You are a helpful assistant.",
"messages": [
{"role": "user", "content": "Hello"}
]
})
assert res.status_code == 200
assert res.body["type"] == "message"
assert len(res.body["content"]) > 0
def test_anthropic_messages_multipart_content():
"""Test messages with multipart content blocks"""
server.start()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 50,
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "What is"},
{"type": "text", "text": " the answer?"}
]
}
]
})
assert res.status_code == 200
assert res.body["type"] == "message"
def test_anthropic_messages_conversation():
"""Test multi-turn conversation"""
server.start()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 50,
"messages": [
{"role": "user", "content": "Hello"},
{"role": "assistant", "content": "Hi there!"},
{"role": "user", "content": "How are you?"}
]
})
assert res.status_code == 200
assert res.body["type"] == "message"
# Streaming tests
def test_anthropic_messages_streaming():
"""Test streaming messages"""
server.start()
res = server.make_stream_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 30,
"messages": [
{"role": "user", "content": "Say hello"}
],
"stream": True
})
events = []
for data in res:
# Each event should have type and other fields
assert "type" in data, f"Missing 'type' in event: {data}"
events.append(data)
# Verify event sequence
event_types = [e["type"] for e in events]
assert "message_start" in event_types, "Missing message_start event"
assert "content_block_start" in event_types, "Missing content_block_start event"
assert "content_block_delta" in event_types, "Missing content_block_delta event"
assert "content_block_stop" in event_types, "Missing content_block_stop event"
assert "message_delta" in event_types, "Missing message_delta event"
assert "message_stop" in event_types, "Missing message_stop event"
# Check message_start structure
message_start = next(e for e in events if e["type"] == "message_start")
assert "message" in message_start, "message_start missing 'message' field"
assert message_start["message"]["type"] == "message"
assert message_start["message"]["role"] == "assistant"
assert message_start["message"]["content"] == []
assert "usage" in message_start["message"]
assert message_start["message"]["usage"]["input_tokens"] > 0
# Check content_block_start
block_start = next(e for e in events if e["type"] == "content_block_start")
assert "index" in block_start, "content_block_start missing 'index'"
assert block_start["index"] == 0, "First content block should be at index 0"
assert "content_block" in block_start
assert block_start["content_block"]["type"] == "text"
# Check content_block_delta
deltas = [e for e in events if e["type"] == "content_block_delta"]
assert len(deltas) > 0, "Should have at least one content_block_delta"
for delta in deltas:
assert "index" in delta
assert "delta" in delta
assert delta["delta"]["type"] == "text_delta"
assert "text" in delta["delta"]
# Check content_block_stop
block_stop = next(e for e in events if e["type"] == "content_block_stop")
assert "index" in block_stop
assert block_stop["index"] == 0
# Check message_delta
message_delta = next(e for e in events if e["type"] == "message_delta")
assert "delta" in message_delta
assert "stop_reason" in message_delta["delta"]
assert message_delta["delta"]["stop_reason"] in ["end_turn", "max_tokens"]
assert "usage" in message_delta
assert message_delta["usage"]["output_tokens"] > 0
# Check message_stop
message_stop = next(e for e in events if e["type"] == "message_stop")
# message_stop should NOT have timings for Anthropic API
assert "timings" not in message_stop, "Anthropic streaming should not include timings"
# Token counting tests
def test_anthropic_count_tokens():
"""Test token counting endpoint"""
server.start()
res = server.make_request("POST", "/v1/messages/count_tokens", data={
"model": "test",
"messages": [
{"role": "user", "content": "Hello world"}
]
})
assert res.status_code == 200
assert "input_tokens" in res.body
assert isinstance(res.body["input_tokens"], int)
assert res.body["input_tokens"] > 0
# Should only have input_tokens, no other fields
assert "output_tokens" not in res.body
def test_anthropic_count_tokens_with_system():
"""Test token counting with system prompt"""
server.start()
res = server.make_request("POST", "/v1/messages/count_tokens", data={
"model": "test",
"system": "You are a helpful assistant.",
"messages": [
{"role": "user", "content": "Hello"}
]
})
assert res.status_code == 200
assert res.body["input_tokens"] > 0
def test_anthropic_count_tokens_no_max_tokens():
"""Test that count_tokens doesn't require max_tokens"""
server.start()
# max_tokens is NOT required for count_tokens
res = server.make_request("POST", "/v1/messages/count_tokens", data={
"model": "test",
"messages": [
{"role": "user", "content": "Hello"}
]
})
assert res.status_code == 200
assert "input_tokens" in res.body
# Tool use tests
def test_anthropic_tool_use_basic():
"""Test basic tool use"""
server.jinja = True
server.start()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 200,
"tools": [{
"name": "get_weather",
"description": "Get the current weather in a location",
"input_schema": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "City name"
}
},
"required": ["location"]
}
}],
"messages": [
{"role": "user", "content": "What's the weather in Paris?"}
]
})
assert res.status_code == 200
assert res.body["type"] == "message"
assert len(res.body["content"]) > 0
# Check if model used the tool (it might not always, depending on the model)
content_types = [block.get("type") for block in res.body["content"]]
if "tool_use" in content_types:
# Model used the tool
assert res.body["stop_reason"] == "tool_use"
# Find the tool_use block
tool_block = next(b for b in res.body["content"] if b.get("type") == "tool_use")
assert "id" in tool_block
assert "name" in tool_block
assert tool_block["name"] == "get_weather"
assert "input" in tool_block
assert isinstance(tool_block["input"], dict)
def test_anthropic_tool_result():
"""Test sending tool results back
This test verifies that tool_result blocks are properly converted to
role="tool" messages internally. Without proper conversion, this would
fail with a 500 error: "unsupported content[].type" because tool_result
blocks would remain in the user message content array.
"""
server.jinja = True
server.start()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 100,
"messages": [
{"role": "user", "content": "What's the weather?"},
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "test123",
"name": "get_weather",
"input": {"location": "Paris"}
}
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "test123",
"content": "The weather is sunny, 25°C"
}
]
}
]
})
# This would be 500 with the old bug where tool_result blocks weren't converted
assert res.status_code == 200
assert res.body["type"] == "message"
# Model should respond to the tool result
assert len(res.body["content"]) > 0
assert res.body["content"][0]["type"] == "text"
def test_anthropic_tool_result_with_text():
"""Test tool result mixed with text content
This tests the edge case where a user message contains both text and
tool_result blocks. The server must properly split these into separate
messages: a user message with text, followed by tool messages.
Without proper handling, this would fail with 500: "unsupported content[].type"
"""
server.jinja = True
server.start()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 100,
"messages": [
{"role": "user", "content": "What's the weather?"},
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "tool_1",
"name": "get_weather",
"input": {"location": "Paris"}
}
]
},
{
"role": "user",
"content": [
{"type": "text", "text": "Here are the results:"},
{
"type": "tool_result",
"tool_use_id": "tool_1",
"content": "Sunny, 25°C"
}
]
}
]
})
assert res.status_code == 200
assert res.body["type"] == "message"
assert len(res.body["content"]) > 0
def test_anthropic_tool_result_error():
"""Test tool result with error flag"""
server.jinja = True
server.start()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 100,
"messages": [
{"role": "user", "content": "Get the weather"},
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "test123",
"name": "get_weather",
"input": {"location": "InvalidCity"}
}
]
},
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "test123",
"is_error": True,
"content": "City not found"
}
]
}
]
})
assert res.status_code == 200
assert res.body["type"] == "message"
def test_anthropic_tool_streaming():
"""Test streaming with tool use"""
server.jinja = True
server.start()
res = server.make_stream_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 200,
"stream": True,
"tools": [{
"name": "calculator",
"description": "Calculate math",
"input_schema": {
"type": "object",
"properties": {
"expression": {"type": "string"}
},
"required": ["expression"]
}
}],
"messages": [
{"role": "user", "content": "Calculate 2+2"}
]
})
events = []
for data in res:
events.append(data)
event_types = [e["type"] for e in events]
# Should have basic events
assert "message_start" in event_types
assert "message_stop" in event_types
# If tool was used, check for proper tool streaming
if any(e.get("type") == "content_block_start" and
e.get("content_block", {}).get("type") == "tool_use"
for e in events):
# Find tool use block start
tool_starts = [e for e in events if
e.get("type") == "content_block_start" and
e.get("content_block", {}).get("type") == "tool_use"]
assert len(tool_starts) > 0, "Should have tool_use content_block_start"
# Check index is correct (should be 0 if no text, 1 if there's text)
tool_start = tool_starts[0]
assert "index" in tool_start
assert tool_start["content_block"]["type"] == "tool_use"
assert "name" in tool_start["content_block"]
# Vision/multimodal tests
def test_anthropic_vision_format_accepted():
"""Test that Anthropic vision format is accepted (format validation only)"""
server.start()
# Small 1x1 red PNG image in base64
red_pixel_png = "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8DwHwAFBQIAX8jx0gAAAABJRU5ErkJggg=="
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 10,
"messages": [
{
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": red_pixel_png
}
},
{
"type": "text",
"text": "What is this?"
}
]
}
]
})
# Server accepts the format but tinyllama doesn't support images
# So it should return 500 with clear error message about missing mmproj
assert res.status_code == 500
assert "image input is not supported" in res.body.get("error", {}).get("message", "").lower()
def test_anthropic_vision_base64_with_multimodal_model(vision_server):
"""Test vision with base64 image using Anthropic format with multimodal model"""
global server
server = vision_server
server.start()
# Get test image in base64 format
image_base64 = get_test_image_base64()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 10,
"messages": [
{
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": image_base64
}
},
{
"type": "text",
"text": "What is this:\n"
}
]
}
]
})
assert res.status_code == 200, f"Expected 200, got {res.status_code}: {res.body}"
assert res.body["type"] == "message"
assert len(res.body["content"]) > 0
assert res.body["content"][0]["type"] == "text"
# The model should generate some response about the image
assert len(res.body["content"][0]["text"]) > 0
# Parameter tests
def test_anthropic_stop_sequences():
"""Test stop_sequences parameter"""
server.start()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 100,
"stop_sequences": ["\n", "END"],
"messages": [
{"role": "user", "content": "Count to 10"}
]
})
assert res.status_code == 200
assert res.body["type"] == "message"
def test_anthropic_temperature():
"""Test temperature parameter"""
server.start()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 50,
"temperature": 0.5,
"messages": [
{"role": "user", "content": "Hello"}
]
})
assert res.status_code == 200
assert res.body["type"] == "message"
def test_anthropic_top_p():
"""Test top_p parameter"""
server.start()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 50,
"top_p": 0.9,
"messages": [
{"role": "user", "content": "Hello"}
]
})
assert res.status_code == 200
assert res.body["type"] == "message"
def test_anthropic_top_k():
"""Test top_k parameter (llama.cpp specific)"""
server.start()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 50,
"top_k": 40,
"messages": [
{"role": "user", "content": "Hello"}
]
})
assert res.status_code == 200
assert res.body["type"] == "message"
# Error handling tests
def test_anthropic_missing_messages():
"""Test error when messages are missing"""
server.start()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 50
# missing "messages" field
})
# Should return an error (400 or 500)
assert res.status_code >= 400
def test_anthropic_empty_messages():
"""Test permissive handling of empty messages array"""
server.start()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 50,
"messages": []
})
# Server is permissive and accepts empty messages (provides defaults)
# This matches the permissive validation design choice
assert res.status_code == 200
assert res.body["type"] == "message"
# Content block index tests
def test_anthropic_streaming_content_block_indices():
"""Test that content block indices are correct in streaming"""
server.jinja = True
server.start()
# Request that might produce both text and tool use
res = server.make_stream_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 200,
"stream": True,
"tools": [{
"name": "test_tool",
"description": "A test tool",
"input_schema": {
"type": "object",
"properties": {
"param": {"type": "string"}
},
"required": ["param"]
}
}],
"messages": [
{"role": "user", "content": "Use the test tool"}
]
})
events = []
for data in res:
events.append(data)
# Check content_block_start events have sequential indices
block_starts = [e for e in events if e.get("type") == "content_block_start"]
if len(block_starts) > 1:
# If there are multiple blocks, indices should be sequential
indices = [e["index"] for e in block_starts]
expected_indices = list(range(len(block_starts)))
assert indices == expected_indices, f"Expected indices {expected_indices}, got {indices}"
# Check content_block_stop events match the starts
block_stops = [e for e in events if e.get("type") == "content_block_stop"]
start_indices = set(e["index"] for e in block_starts)
stop_indices = set(e["index"] for e in block_stops)
assert start_indices == stop_indices, "content_block_stop indices should match content_block_start indices"
# Extended features tests
def test_anthropic_thinking():
"""Test extended thinking parameter"""
server.jinja = True
server.start()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 100,
"thinking": {
"type": "enabled",
"budget_tokens": 50
},
"messages": [
{"role": "user", "content": "What is 2+2?"}
]
})
assert res.status_code == 200
assert res.body["type"] == "message"
def test_anthropic_metadata():
"""Test metadata parameter"""
server.start()
res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 50,
"metadata": {
"user_id": "test_user_123"
},
"messages": [
{"role": "user", "content": "Hello"}
]
})
assert res.status_code == 200
assert res.body["type"] == "message"
# Compatibility tests
def test_anthropic_vs_openai_different_response_format():
"""Verify Anthropic format is different from OpenAI format"""
server.start()
# Make OpenAI request
openai_res = server.make_request("POST", "/v1/chat/completions", data={
"model": "test",
"max_tokens": 50,
"messages": [
{"role": "user", "content": "Hello"}
]
})
# Make Anthropic request
anthropic_res = server.make_request("POST", "/v1/messages", data={
"model": "test",
"max_tokens": 50,
"messages": [
{"role": "user", "content": "Hello"}
]
})
assert openai_res.status_code == 200
assert anthropic_res.status_code == 200
# OpenAI has "object", Anthropic has "type"
assert "object" in openai_res.body
assert "type" in anthropic_res.body
assert openai_res.body["object"] == "chat.completion"
assert anthropic_res.body["type"] == "message"
# OpenAI has "choices", Anthropic has "content"
assert "choices" in openai_res.body
assert "content" in anthropic_res.body
# Different usage field names
assert "prompt_tokens" in openai_res.body["usage"]
assert "input_tokens" in anthropic_res.body["usage"]
assert "completion_tokens" in openai_res.body["usage"]
assert "output_tokens" in anthropic_res.body["usage"]

View File

@ -49,6 +49,19 @@ def test_correct_api_key():
assert "content" in res.body
def test_correct_api_key_anthropic_header():
global server
server.start()
res = server.make_request("POST", "/completions", data={
"prompt": "I believe the meaning of life is",
}, headers={
"X-Api-Key": TEST_API_KEY,
})
assert res.status_code == 200
assert "error" not in res.body
assert "content" in res.body
def test_openai_library_correct_api_key():
global server
server.start()

View File

@ -205,6 +205,8 @@ class ServerProcess:
server_args.append("--no-webui")
if self.jinja:
server_args.append("--jinja")
else:
server_args.append("--no-jinja")
if self.reasoning_format is not None:
server_args.extend(("--reasoning-format", self.reasoning_format))
if self.reasoning_budget is not None: