Merge branch 'master' into xsn/server_model_management_v1_2

This commit is contained in:
Xuan Son Nguyen 2025-11-29 22:34:16 +01:00
commit c1dfccd078
50 changed files with 5652 additions and 5133 deletions

View File

@ -45,7 +45,7 @@ sd=`dirname $0`
cd $sd/../ cd $sd/../
SRC=`pwd` SRC=`pwd`
CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON" CMAKE_EXTRA="-DLLAMA_FATAL_WARNINGS=ON -DLLAMA_CURL=ON -DGGML_SCHED_NO_REALLOC=ON"
if [ ! -z ${GG_BUILD_METAL} ]; then if [ ! -z ${GG_BUILD_METAL} ]; then
CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=ON" CMAKE_EXTRA="${CMAKE_EXTRA} -DGGML_METAL=ON"
@ -428,8 +428,8 @@ function gg_run_qwen3_0_6b {
(time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -ngl 99 -c 1024 -b 512 --chunks 2 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log (time ./bin/llama-imatrix --model ${model_f16} -f ${wiki_test} -ngl 99 -c 1024 -b 512 --chunks 2 ) 2>&1 | tee -a $OUT/${ci}-imatrix.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa off --no-op-offload) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 10 -c 1024 -fa on --no-op-offload) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa off ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
(time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log (time ./bin/llama-save-load-state --model ${model_q4_0} -ngl 99 -c 1024 -fa on ) 2>&1 | tee -a $OUT/${ci}-save-load-state.log
@ -523,8 +523,8 @@ function gg_run_embd_bge_small {
./bin/llama-quantize ${model_f16} ${model_q8_0} q8_0 ./bin/llama-quantize ${model_f16} ${model_q8_0} q8_0
(time ./bin/llama-embedding --model ${model_f16} -p "I believe the meaning of life is" -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-tg-f16.log (time ./bin/llama-embedding --model ${model_f16} -p "I believe the meaning of life is" -ngl 99 -c 0 --no-op-offload) 2>&1 | tee -a $OUT/${ci}-tg-f16.log
(time ./bin/llama-embedding --model ${model_q8_0} -p "I believe the meaning of life is" -ngl 99 -c 0 ) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log (time ./bin/llama-embedding --model ${model_q8_0} -p "I believe the meaning of life is" -ngl 99 -c 0 --no-op-offload) 2>&1 | tee -a $OUT/${ci}-tg-q8_0.log
set +e set +e
} }
@ -564,7 +564,7 @@ function gg_run_rerank_tiny {
model_f16="${path_models}/ggml-model-f16.gguf" model_f16="${path_models}/ggml-model-f16.gguf"
# for this model, the SEP token is "</s>" # for this model, the SEP token is "</s>"
(time ./bin/llama-embedding --model ${model_f16} -p "what is panda?\thi\nwhat is panda?\tit's a bear\nwhat is panda?\tThe giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log (time ./bin/llama-embedding --model ${model_f16} -p "what is panda?\thi\nwhat is panda?\tit's a bear\nwhat is panda?\tThe giant panda (Ailuropoda melanoleuca), sometimes called a panda bear or simply panda, is a bear species endemic to China." -ngl 99 -c 0 --pooling rank --embd-normalize -1 --no-op-offload --verbose-prompt) 2>&1 | tee -a $OUT/${ci}-rk-f16.log
# sample output # sample output
# rerank score 0: 0.029 # rerank score 0: 0.029

View File

@ -13,6 +13,120 @@
using json = nlohmann::ordered_json; using json = nlohmann::ordered_json;
static void parse_prefixed_json_tool_call_array(common_chat_msg_parser & builder,
const common_regex & prefix,
size_t rstrip_prefix = 0) {
static const std::vector<std::vector<std::string>> args_paths = { { "arguments" } };
if (auto res = builder.try_find_regex(prefix)) {
builder.move_back(rstrip_prefix);
auto tool_calls = builder.consume_json_with_dumped_args(args_paths);
if (!builder.add_tool_calls(tool_calls.value) || tool_calls.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call array");
}
} else {
builder.add_content(builder.consume_rest());
}
}
static std::string wrap_code_as_arguments(common_chat_msg_parser & builder, const std::string & code) {
std::string arguments;
if (builder.is_partial()) {
arguments = (json{
{ "code", code + builder.healing_marker() }
})
.dump();
auto idx = arguments.find(builder.healing_marker());
if (idx != std::string::npos) {
arguments.resize(idx);
}
} else {
arguments = (json{
{ "code", code }
})
.dump();
}
return arguments;
}
/**
* Takes a prefix regex that must have 1 group to capture the function name, a closing suffix, and expects json parameters in between.
* Aggregates the prefix, suffix and in-between text into the content.
*/
static void parse_json_tool_calls(
common_chat_msg_parser & builder,
const std::optional<common_regex> & block_open,
const std::optional<common_regex> & function_regex_start_only,
const std::optional<common_regex> & function_regex,
const common_regex & close_regex,
const std::optional<common_regex> & block_close,
bool allow_raw_python = false,
const std::function<std::string(const common_chat_msg_parser::find_regex_result & fres)> & get_function_name =
nullptr) {
auto parse_tool_calls = [&]() {
size_t from = std::string::npos;
auto first = true;
while (true) {
auto start_pos = builder.pos();
auto res = function_regex_start_only && first ? builder.try_consume_regex(*function_regex_start_only) :
function_regex ? builder.try_find_regex(*function_regex, from) :
std::nullopt;
if (res) {
std::string name;
if (get_function_name) {
name = get_function_name(*res);
} else {
GGML_ASSERT(res->groups.size() == 2);
name = builder.str(res->groups[1]);
}
first = false;
if (name.empty()) {
// get_function_name signalled us that we should skip this match and treat it as content.
from = res->groups[0].begin + 1;
continue;
}
from = std::string::npos;
auto maybe_raw_python = name == "python" && allow_raw_python;
if (builder.input()[builder.pos()] == '{' || !maybe_raw_python) {
if (auto arguments = builder.try_consume_json_with_dumped_args({ {} })) {
if (!builder.add_tool_call(name, "", arguments->value) || arguments->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
builder.consume_regex(close_regex);
}
continue;
}
if (maybe_raw_python) {
auto arguments = wrap_code_as_arguments(builder, builder.consume_rest());
if (!builder.add_tool_call(name, "", arguments)) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
return;
}
throw common_chat_msg_partial_exception("incomplete tool call");
} else {
builder.move_to(start_pos);
}
break;
}
if (block_close) {
builder.consume_regex(*block_close);
}
builder.consume_spaces();
builder.add_content(builder.consume_rest());
};
if (block_open) {
if (auto res = builder.try_find_regex(*block_open)) {
parse_tool_calls();
} else {
builder.add_content(builder.consume_rest());
}
} else {
parse_tool_calls();
}
}
common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax) common_chat_msg_parser::common_chat_msg_parser(const std::string & input, bool is_partial, const common_chat_syntax & syntax)
: input_(input), is_partial_(is_partial), syntax_(syntax) : input_(input), is_partial_(is_partial), syntax_(syntax)
{ {
@ -532,3 +646,857 @@ std::optional<common_chat_msg_parser::consume_json_result> common_chat_msg_parse
void common_chat_msg_parser::clear_tools() { void common_chat_msg_parser::clear_tools() {
result_.tool_calls.clear(); result_.tool_calls.clear();
} }
/**
* All common_chat_parse_* moved from chat.cpp to chat-parser.cpp below
* to reduce incremental compile time for parser changes.
*/
static void common_chat_parse_generic(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const std::vector<std::vector<std::string>> content_paths = {
{"response"},
};
static const std::vector<std::vector<std::string>> args_paths = {
{"tool_call", "arguments"},
{"tool_calls", "arguments"},
};
auto data = builder.consume_json_with_dumped_args(args_paths, content_paths);
if (data.value.contains("tool_calls")) {
if (!builder.add_tool_calls(data.value.at("tool_calls")) || data.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool calls");
}
} else if (data.value.contains("tool_call")) {
if (!builder.add_tool_call(data.value.at("tool_call")) || data.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
} else if (data.value.contains("response")) {
const auto & response = data.value.at("response");
builder.add_content(response.is_string() ? response.template get<std::string>() : response.dump(2));
if (data.is_partial) {
throw common_chat_msg_partial_exception("incomplete response");
}
} else {
throw common_chat_msg_partial_exception("Expected 'tool_call', 'tool_calls' or 'response' in JSON");
}
}
static void common_chat_parse_mistral_nemo(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex prefix(regex_escape("[TOOL_CALLS]"));
parse_prefixed_json_tool_call_array(builder, prefix);
}
static void common_chat_parse_magistral(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("[THINK]", "[/THINK]");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex prefix(regex_escape("[TOOL_CALLS]"));
parse_prefixed_json_tool_call_array(builder, prefix);
}
static void common_chat_parse_command_r7b(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<|START_THINKING|>", "<|END_THINKING|>");
static const common_regex start_action_regex("<\\|START_ACTION\\|>");
static const common_regex end_action_regex("<\\|END_ACTION\\|>");
static const common_regex start_response_regex("<\\|START_RESPONSE\\|>");
static const common_regex end_response_regex("<\\|END_RESPONSE\\|>");
if (auto res = builder.try_find_regex(start_action_regex)) {
// If we didn't extract thoughts, prelude includes them.
auto tool_calls = builder.consume_json_with_dumped_args({{"parameters"}});
for (const auto & tool_call : tool_calls.value) {
std::string name = tool_call.contains("tool_name") ? tool_call.at("tool_name") : "";
std::string id = tool_call.contains("tool_call_id") ? tool_call.at("tool_call_id") : "";
std::string arguments = tool_call.contains("parameters") ? tool_call.at("parameters") : "";
if (!builder.add_tool_call(name, id, arguments) || tool_calls.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
}
if (tool_calls.is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
builder.consume_regex(end_action_regex);
} else if (auto res = builder.try_find_regex(start_response_regex)) {
if (!builder.try_find_regex(end_response_regex)) {
builder.add_content(builder.consume_rest());
throw common_chat_msg_partial_exception(end_response_regex.str());
}
} else {
builder.add_content(builder.consume_rest());
}
}
static void common_chat_parse_llama_3_1(common_chat_msg_parser & builder, bool with_builtin_tools = false) {
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex function_regex(
"\\s*\\{\\s*(?:\"type\"\\s*:\\s*\"function\"\\s*,\\s*)?\"name\"\\s*:\\s*\"([^\"]+)\"\\s*,\\s*\"parameters\"\\s*: ");
static const common_regex close_regex("\\}\\s*");
static const common_regex function_name_regex("\\s*(\\w+)\\s*\\.\\s*call\\(");
static const common_regex arg_name_regex("\\s*(\\w+)\\s*=\\s*");
if (with_builtin_tools) {
static const common_regex builtin_call_regex("<\\|python_tag\\|>");
if (auto res = builder.try_find_regex(builtin_call_regex)) {
auto fun_res = builder.consume_regex(function_name_regex);
auto function_name = builder.str(fun_res.groups[1]);
common_healing_marker healing_marker;
json args = json::object();
while (true) {
if (auto arg_res = builder.try_consume_regex(arg_name_regex)) {
auto arg_name = builder.str(arg_res->groups[1]);
auto partial = builder.consume_json();
args[arg_name] = partial.json;
healing_marker.marker = partial.healing_marker.marker;
healing_marker.json_dump_marker = partial.healing_marker.json_dump_marker;
builder.consume_spaces();
if (!builder.try_consume_literal(",")) {
break;
}
} else {
break;
}
}
builder.consume_literal(")");
builder.consume_spaces();
auto arguments = args.dump();
if (!builder.add_tool_call(function_name, "", arguments)) {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
return;
}
}
parse_json_tool_calls(
builder,
/* block_open= */ std::nullopt,
/* function_regex_start_only= */ function_regex,
/* function_regex= */ std::nullopt,
close_regex,
std::nullopt);
}
static void common_chat_parse_deepseek_r1(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex tool_calls_begin("(?:<tool▁calls▁begin>|<tool_calls_begin>|<tool calls begin>|<tool\\\\_calls\\\\_begin>|<tool▁calls>)");
static const common_regex tool_calls_end("<tool▁calls▁end>");
static const common_regex function_regex("(?:<tool▁call▁begin>)?function<toolsep>([^\n]+)\n```json\n");
static const common_regex close_regex("```[\\s\\r\\n]*<tool▁call▁end>");
parse_json_tool_calls(
builder,
/* block_open= */ tool_calls_begin,
/* function_regex_start_only= */ std::nullopt,
function_regex,
close_regex,
tool_calls_end);
}
static void common_chat_parse_deepseek_v3_1_content(common_chat_msg_parser & builder) {
static const common_regex function_regex("(?:<tool▁call▁begin>)?([^\\n<]+)(?:<toolsep>)");
static const common_regex close_regex("(?:[\\s]*)?<toolcallend>");
static const common_regex tool_calls_begin("(?:<tool▁calls▁begin>|<tool_calls_begin>|<tool calls begin>|<tool\\\\_calls\\\\_begin>|<tool▁calls>)");
static const common_regex tool_calls_end("<tool▁calls▁end>");
if (!builder.syntax().parse_tool_calls) {
LOG_DBG("%s: not parse_tool_calls\n", __func__);
builder.add_content(builder.consume_rest());
return;
}
LOG_DBG("%s: parse_tool_calls\n", __func__);
parse_json_tool_calls(
builder,
/* block_open= */ tool_calls_begin,
/* function_regex_start_only= */ std::nullopt,
function_regex,
close_regex,
tool_calls_end);
}
static void common_chat_parse_deepseek_v3_1(common_chat_msg_parser & builder) {
// DeepSeek V3.1 outputs reasoning content between "<think>" and "</think>" tags, followed by regular content
// First try to parse using the standard reasoning parsing method
LOG_DBG("%s: thinking_forced_open: %s\n", __func__, std::to_string(builder.syntax().thinking_forced_open).c_str());
auto start_pos = builder.pos();
auto found_end_think = builder.try_find_literal("</think>");
builder.move_to(start_pos);
if (builder.syntax().thinking_forced_open && !builder.is_partial() && !found_end_think) {
LOG_DBG("%s: no end_think, not partial, adding content\n", __func__);
common_chat_parse_deepseek_v3_1_content(builder);
} else if (builder.try_parse_reasoning("<think>", "</think>")) {
// If reasoning was parsed successfully, the remaining content is regular content
LOG_DBG("%s: parsed reasoning, adding content\n", __func__);
// </think><tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>NAME\n```json\nJSON\n```<tool▁call▁end><tool▁calls▁end>
common_chat_parse_deepseek_v3_1_content(builder);
} else {
if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE) {
LOG_DBG("%s: reasoning_format none, adding content\n", __func__);
common_chat_parse_deepseek_v3_1_content(builder);
return;
}
// If no reasoning tags found, check if we should treat everything as reasoning
if (builder.syntax().thinking_forced_open) {
// If thinking is forced open but no tags found, treat everything as reasoning
LOG_DBG("%s: thinking_forced_open, adding reasoning content\n", __func__);
builder.add_reasoning_content(builder.consume_rest());
} else {
LOG_DBG("%s: no thinking_forced_open, adding content\n", __func__);
// <tool▁call▁begin>NAME<tool▁sep>JSON<tool▁call▁end>
common_chat_parse_deepseek_v3_1_content(builder);
}
}
}
static void common_chat_parse_minimax_m2(common_chat_msg_parser & builder) {
static const xml_tool_call_format form {
/* form.scope_start = */ "<minimax:tool_call>",
/* form.tool_start = */ "<invoke name=\"",
/* form.tool_sep = */ "\">",
/* form.key_start = */ "<parameter name=\"",
/* form.key_val_sep = */ "\">",
/* form.val_end = */ "</parameter>",
/* form.tool_end = */ "</invoke>",
/* form.scope_end = */ "</minimax:tool_call>",
};
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
}
static void common_chat_parse_qwen3_coder_xml(common_chat_msg_parser & builder) {
static const xml_tool_call_format form = ([]() {
xml_tool_call_format form {};
form.scope_start = "<tool_call>";
form.tool_start = "<function=";
form.tool_sep = ">";
form.key_start = "<parameter=";
form.key_val_sep = ">";
form.val_end = "</parameter>";
form.tool_end = "</function>";
form.scope_end = "</tool_call>";
form.trim_raw_argval = true;
return form;
})();
builder.consume_reasoning_with_xml_tool_calls(form);
}
static void common_chat_parse_kimi_k2(common_chat_msg_parser & builder) {
static const xml_tool_call_format form = ([]() {
xml_tool_call_format form {};
form.scope_start = "<|tool_calls_section_begin|>";
form.tool_start = "<|tool_call_begin|>";
form.tool_sep = "<|tool_call_argument_begin|>{";
form.key_start = "\"";
form.key_val_sep = "\": ";
form.val_end = ", ";
form.tool_end = "}<|tool_call_end|>";
form.scope_end = "<|tool_calls_section_end|>";
form.raw_argval = false;
form.last_val_end = "";
return form;
})();
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
}
static void common_chat_parse_apriel_1_5(common_chat_msg_parser & builder) {
static const xml_tool_call_format form = ([]() {
xml_tool_call_format form {};
form.scope_start = "<tool_calls>[";
form.tool_start = "{\"name\": \"";
form.tool_sep = "\", \"arguments\": {";
form.key_start = "\"";
form.key_val_sep = "\": ";
form.val_end = ", ";
form.tool_end = "}, ";
form.scope_end = "]</tool_calls>";
form.raw_argval = false;
form.last_val_end = "";
form.last_tool_end = "}";
return form;
})();
builder.consume_reasoning_with_xml_tool_calls(form, "<thinking>", "</thinking>");
}
static void common_chat_parse_xiaomi_mimo(common_chat_msg_parser & builder) {
static const xml_tool_call_format form = ([]() {
xml_tool_call_format form {};
form.scope_start = "";
form.tool_start = "<tool_call>\n{\"name\": \"";
form.tool_sep = "\", \"arguments\": {";
form.key_start = "\"";
form.key_val_sep = "\": ";
form.val_end = ", ";
form.tool_end = "}\n</tool_call>";
form.scope_end = "";
form.raw_argval = false;
form.last_val_end = "";
return form;
})();
builder.consume_reasoning_with_xml_tool_calls(form);
}
static void common_chat_parse_gpt_oss(common_chat_msg_parser & builder) {
static const std::string constraint = "(?: (<\\|constrain\\|>)?([a-zA-Z0-9_-]+))";
static const std::string recipient("(?: to=functions\\.([^<\\s]+))");
static const common_regex start_regex("<\\|start\\|>assistant");
static const common_regex analysis_regex("<\\|channel\\|>analysis");
static const common_regex final_regex("<\\|channel\\|>final" + constraint + "?");
static const common_regex preamble_regex("<\\|channel\\|>commentary");
static const common_regex tool_call1_regex(recipient + "<\\|channel\\|>(analysis|commentary)" + constraint + "?");
static const common_regex tool_call2_regex("<\\|channel\\|>(analysis|commentary)" + recipient + constraint + "?");
auto consume_end = [&](bool include_end = false) {
if (auto res = builder.try_find_literal("<|end|>")) {
return res->prelude + (include_end ? builder.str(res->groups[0]) : "");
}
return builder.consume_rest();
};
auto handle_tool_call = [&](const std::string & name) {
if (auto args = builder.try_consume_json_with_dumped_args({{}})) {
if (builder.syntax().parse_tool_calls) {
if (!builder.add_tool_call(name, "", args->value) || args->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
} else if (args->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
}
};
auto regex_match = [](const common_regex & regex, const std::string & input) -> std::optional<common_regex_match> {
auto match = regex.search(input, 0, true);
if (match.type == COMMON_REGEX_MATCH_TYPE_FULL) {
return match;
}
return std::nullopt;
};
do {
auto header_start_pos = builder.pos();
auto content_start = builder.try_find_literal("<|message|>");
if (!content_start) {
throw common_chat_msg_partial_exception("incomplete header");
}
auto header = content_start->prelude;
if (auto match = regex_match(tool_call1_regex, header)) {
auto group = match->groups[1];
auto name = header.substr(group.begin, group.end - group.begin);
handle_tool_call(name);
continue;
}
if (auto match = regex_match(tool_call2_regex, header)) {
auto group = match->groups[2];
auto name = header.substr(group.begin, group.end - group.begin);
handle_tool_call(name);
continue;
}
if (regex_match(analysis_regex, header)) {
builder.move_to(header_start_pos);
if (builder.syntax().reasoning_format == COMMON_REASONING_FORMAT_NONE || builder.syntax().reasoning_in_content) {
builder.add_content(consume_end(true));
} else {
builder.try_parse_reasoning("<|channel|>analysis<|message|>", "<|end|>");
}
continue;
}
if(regex_match(final_regex, header) || regex_match(preamble_regex, header)) {
builder.add_content(consume_end());
continue;
}
// Possibly a malformed message, attempt to recover by rolling
// back to pick up the next <|start|>
LOG_DBG("%s: unknown header from message: %s\n", __func__, header.c_str());
builder.move_to(header_start_pos);
} while (builder.try_find_regex(start_regex, std::string::npos, false));
auto remaining = builder.consume_rest();
if (!remaining.empty()) {
LOG_DBG("%s: content after last message: %s\n", __func__, remaining.c_str());
}
}
static void common_chat_parse_glm_4_5(common_chat_msg_parser & builder) {
static const xml_tool_call_format form {
/* form.scope_start = */ "",
/* form.tool_start = */ "<tool_call>",
/* form.tool_sep = */ "",
/* form.key_start = */ "<arg_key>",
/* form.key_val_sep = */ "</arg_key>",
/* form.val_end = */ "</arg_value>",
/* form.tool_end = */ "</tool_call>",
/* form.scope_end = */ "",
/* form.key_val_sep2 = */ "<arg_value>",
};
builder.consume_reasoning_with_xml_tool_calls(form, "<think>", "</think>");
}
static void common_chat_parse_firefunction_v2(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex prefix(regex_escape(" functools["));
parse_prefixed_json_tool_call_array(builder, prefix, /* rstrip_prefix= */ 1);
}
static void common_chat_parse_functionary_v3_2(common_chat_msg_parser & builder) {
static const common_regex function_regex_start_only(R"((\w+\n\{|python\n|all\n))");
static const common_regex function_regex(R"(>>>(\w+\n\{|python\n|all\n))");
static const common_regex close_regex(R"(\s*)");
parse_json_tool_calls(
builder,
std::nullopt,
function_regex_start_only,
function_regex,
close_regex,
std::nullopt,
/* allow_raw_python= */ true,
/* get_function_name= */ [&](const auto & res) -> std::string {
auto at_start = res.groups[0].begin == 0;
auto name = builder.str(res.groups[1]);
if (!name.empty() && name.back() == '{') {
// Unconsume the opening brace '{' to ensure the JSON parsing goes well.
builder.move_back(1);
}
auto idx = name.find_last_not_of("\n{");
name = name.substr(0, idx + 1);
if (at_start && name == "all") {
return "";
}
return name;
});
}
static void common_chat_parse_functionary_v3_1_llama_3_1(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// This version of Functionary still supports the llama 3.1 tool call format for the python tool.
static const common_regex python_tag_regex(regex_escape("<|python_tag|>"));
static const common_regex function_regex(R"(<function=(\w+)>)");
static const common_regex close_regex(R"(</function>)");
parse_json_tool_calls(
builder,
/* block_open= */ std::nullopt,
/* function_regex_start_only= */ std::nullopt,
function_regex,
close_regex,
std::nullopt);
if (auto res = builder.try_find_regex(python_tag_regex)) {
auto arguments = wrap_code_as_arguments(builder, builder.consume_rest());
builder.add_tool_call("python", "", arguments);
return;
}
}
static void common_chat_parse_hermes_2_pro(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
static const common_regex open_regex(
"(?:"
"(```(?:xml|json)?\\n\\s*)?" // match 1 (block_start)
"(" // match 2 (open_tag)
"<tool_call>"
"|<function_call>"
"|<tool>"
"|<tools>"
"|<response>"
"|<json>"
"|<xml>"
"|<JSON>"
")?"
"(\\s*\\{\\s*\"name\")" // match 3 (named tool call)
")"
"|<function=([^>]+)>" // match 4 (function name)
"|<function name=\"([^\"]+)\">" // match 5 (function name again)
);
while (auto res = builder.try_find_regex(open_regex)) {
const auto & block_start = res->groups[1];
std::string block_end = block_start.empty() ? "" : "```";
const auto & open_tag = res->groups[2];
std::string close_tag;
if (!res->groups[3].empty()) {
builder.move_to(res->groups[3].begin);
close_tag = open_tag.empty() ? "" : "</" + builder.str(open_tag).substr(1);
if (auto tool_call = builder.try_consume_json_with_dumped_args({{"arguments"}})) {
if (!builder.add_tool_call(tool_call->value) || tool_call->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
builder.consume_spaces();
builder.consume_literal(close_tag);
builder.consume_spaces();
if (!block_end.empty()) {
builder.consume_literal(block_end);
builder.consume_spaces();
}
} else {
throw common_chat_msg_partial_exception("failed to parse tool call");
}
} else {
auto function_name = builder.str(res->groups[4]);
if (function_name.empty()) {
function_name = builder.str(res->groups[5]);
}
GGML_ASSERT(!function_name.empty());
close_tag = "</function>";
if (auto arguments = builder.try_consume_json_with_dumped_args({{}})) {
if (!builder.add_tool_call(function_name, "", arguments->value) || arguments->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
builder.consume_spaces();
builder.consume_literal(close_tag);
builder.consume_spaces();
if (!block_end.empty()) {
builder.consume_literal(block_end);
builder.consume_spaces();
}
}
}
}
builder.add_content(builder.consume_rest());
}
static void common_chat_parse_granite(common_chat_msg_parser & builder) {
// Parse thinking tags
static const common_regex start_think_regex(regex_escape("<think>"));
static const common_regex end_think_regex(regex_escape("</think>"));
// Granite models output partial tokens such as "<" and "<think".
// By leveraging try_consume_regex()/try_find_regex() throwing
// common_chat_msg_partial_exception for these partial tokens,
// processing is interrupted and the tokens are not passed to add_content().
if (auto res = builder.try_consume_regex(start_think_regex)) {
// Restore position for try_parse_reasoning()
builder.move_to(res->groups[0].begin);
builder.try_find_regex(end_think_regex, std::string::npos, false);
// Restore position for try_parse_reasoning()
builder.move_to(res->groups[0].begin);
}
builder.try_parse_reasoning("<think>", "</think>");
// Parse response tags
static const common_regex start_response_regex(regex_escape("<response>"));
static const common_regex end_response_regex(regex_escape("</response>"));
// Granite models output partial tokens such as "<" and "<response".
// Same hack as reasoning parsing.
if (builder.try_consume_regex(start_response_regex)) {
builder.try_find_regex(end_response_regex);
}
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// Look for tool calls
static const common_regex tool_call_regex(regex_escape("<|tool_call|>"));
if (auto res = builder.try_find_regex(tool_call_regex)) {
builder.move_to(res->groups[0].end);
// Expect JSON array of tool calls
if (auto tool_call = builder.try_consume_json_with_dumped_args({{{"arguments"}}})) {
if (!builder.add_tool_calls(tool_call->value) || tool_call->is_partial) {
throw common_chat_msg_partial_exception("incomplete tool call");
}
}
} else {
builder.add_content(builder.consume_rest());
}
}
static void common_chat_parse_nemotron_v2(common_chat_msg_parser & builder) {
// Parse thinking tags
builder.try_parse_reasoning("<think>", "</think>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// Look for tool calls
static const common_regex tool_call_regex(regex_escape("<TOOLCALL>"));
if (auto res = builder.try_find_regex(tool_call_regex)) {
builder.move_to(res->groups[0].end);
// Expect JSON array of tool calls
auto tool_calls_data = builder.consume_json();
if (tool_calls_data.json.is_array()) {
if (!builder.try_consume_literal("</TOOLCALL>")) {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
builder.add_tool_calls(tool_calls_data.json);
} else {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
}
builder.add_content(builder.consume_rest());
}
static void common_chat_parse_apertus(common_chat_msg_parser & builder) {
// Parse thinking tags
builder.try_parse_reasoning("<|inner_prefix|>", "<|inner_suffix|>");
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// Look for tool calls
static const common_regex tool_call_regex(regex_escape("<|tools_prefix|>"));
if (auto res = builder.try_find_regex(tool_call_regex)) {
builder.move_to(res->groups[0].end);
auto tool_calls_data = builder.consume_json();
if (tool_calls_data.json.is_array()) {
builder.consume_spaces();
if (!builder.try_consume_literal("<|tools_suffix|>")) {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
for (const auto & value : tool_calls_data.json) {
if (value.is_object()) {
builder.add_tool_call_short_form(value);
}
}
} else {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
}
builder.add_content(builder.consume_rest());
}
static void common_chat_parse_lfm2(common_chat_msg_parser & builder) {
if (!builder.syntax().parse_tool_calls) {
builder.add_content(builder.consume_rest());
return;
}
// LFM2 format: <|tool_call_start|>[{"name": "get_current_time", "arguments": {"location": "Paris"}}]<|tool_call_end|>
static const common_regex tool_call_start_regex(regex_escape("<|tool_call_start|>"));
static const common_regex tool_call_end_regex(regex_escape("<|tool_call_end|>"));
// Loop through all tool calls
while (auto res = builder.try_find_regex(tool_call_start_regex, std::string::npos, /* add_prelude_to_content= */ true)) {
builder.move_to(res->groups[0].end);
// Parse JSON array format: [{"name": "...", "arguments": {...}}]
auto tool_calls_data = builder.consume_json();
// Consume end marker
builder.consume_spaces();
if (!builder.try_consume_regex(tool_call_end_regex)) {
throw common_chat_msg_partial_exception("Expected <|tool_call_end|>");
}
// Process each tool call in the array
if (tool_calls_data.json.is_array()) {
for (const auto & tool_call : tool_calls_data.json) {
if (!tool_call.is_object()) {
throw common_chat_msg_partial_exception("Tool call must be an object");
}
if (!tool_call.contains("name")) {
throw common_chat_msg_partial_exception("Tool call missing 'name' field");
}
std::string function_name = tool_call.at("name");
std::string arguments = "{}";
if (tool_call.contains("arguments")) {
if (tool_call.at("arguments").is_object()) {
arguments = tool_call.at("arguments").dump();
} else if (tool_call.at("arguments").is_string()) {
arguments = tool_call.at("arguments");
}
}
if (!builder.add_tool_call(function_name, "", arguments)) {
throw common_chat_msg_partial_exception("Incomplete tool call");
}
}
} else {
throw common_chat_msg_partial_exception("Expected JSON array for tool calls");
}
// Consume any trailing whitespace after this tool call
builder.consume_spaces();
}
// Consume any remaining content after all tool calls
auto remaining = builder.consume_rest();
if (!string_strip(remaining).empty()) {
builder.add_content(remaining);
}
}
static void common_chat_parse_seed_oss(common_chat_msg_parser & builder) {
static const xml_tool_call_format form {
/* form.scope_start = */ "<seed:tool_call>",
/* form.tool_start = */ "<function=",
/* form.tool_sep = */ ">",
/* form.key_start = */ "<parameter=",
/* form.key_val_sep = */ ">",
/* form.val_end = */ "</parameter>",
/* form.tool_end = */ "</function>",
/* form.scope_end = */ "</seed:tool_call>",
};
builder.consume_reasoning_with_xml_tool_calls(form, "<seed:think>", "</seed:think>");
}
static void common_chat_parse_content_only(common_chat_msg_parser & builder) {
builder.try_parse_reasoning("<think>", "</think>");
builder.add_content(builder.consume_rest());
}
static void common_chat_parse(common_chat_msg_parser & builder) {
LOG_DBG("Parsing input with format %s: %s\n", common_chat_format_name(builder.syntax().format), builder.input().c_str());
switch (builder.syntax().format) {
case COMMON_CHAT_FORMAT_CONTENT_ONLY:
common_chat_parse_content_only(builder);
break;
case COMMON_CHAT_FORMAT_GENERIC:
common_chat_parse_generic(builder);
break;
case COMMON_CHAT_FORMAT_MISTRAL_NEMO:
common_chat_parse_mistral_nemo(builder);
break;
case COMMON_CHAT_FORMAT_MAGISTRAL:
common_chat_parse_magistral(builder);
break;
case COMMON_CHAT_FORMAT_LLAMA_3_X:
common_chat_parse_llama_3_1(builder);
break;
case COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS:
common_chat_parse_llama_3_1(builder, /* with_builtin_tools= */ true);
break;
case COMMON_CHAT_FORMAT_DEEPSEEK_R1:
common_chat_parse_deepseek_r1(builder);
break;
case COMMON_CHAT_FORMAT_DEEPSEEK_V3_1:
common_chat_parse_deepseek_v3_1(builder);
break;
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2:
common_chat_parse_functionary_v3_2(builder);
break;
case COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1:
common_chat_parse_functionary_v3_1_llama_3_1(builder);
break;
case COMMON_CHAT_FORMAT_HERMES_2_PRO:
common_chat_parse_hermes_2_pro(builder);
break;
case COMMON_CHAT_FORMAT_FIREFUNCTION_V2:
common_chat_parse_firefunction_v2(builder);
break;
case COMMON_CHAT_FORMAT_COMMAND_R7B:
common_chat_parse_command_r7b(builder);
break;
case COMMON_CHAT_FORMAT_GRANITE:
common_chat_parse_granite(builder);
break;
case COMMON_CHAT_FORMAT_GPT_OSS:
common_chat_parse_gpt_oss(builder);
break;
case COMMON_CHAT_FORMAT_SEED_OSS:
common_chat_parse_seed_oss(builder);
break;
case COMMON_CHAT_FORMAT_NEMOTRON_V2:
common_chat_parse_nemotron_v2(builder);
break;
case COMMON_CHAT_FORMAT_APERTUS:
common_chat_parse_apertus(builder);
break;
case COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS:
common_chat_parse_lfm2(builder);
break;
case COMMON_CHAT_FORMAT_MINIMAX_M2:
common_chat_parse_minimax_m2(builder);
break;
case COMMON_CHAT_FORMAT_GLM_4_5:
common_chat_parse_glm_4_5(builder);
break;
case COMMON_CHAT_FORMAT_KIMI_K2:
common_chat_parse_kimi_k2(builder);
break;
case COMMON_CHAT_FORMAT_QWEN3_CODER_XML:
common_chat_parse_qwen3_coder_xml(builder);
break;
case COMMON_CHAT_FORMAT_APRIEL_1_5:
common_chat_parse_apriel_1_5(builder);
break;
case COMMON_CHAT_FORMAT_XIAOMI_MIMO:
common_chat_parse_xiaomi_mimo(builder);
break;
default:
throw std::runtime_error(std::string("Unsupported format: ") + common_chat_format_name(builder.syntax().format));
}
builder.finish();
}
common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax) {
common_chat_msg_parser builder(input, is_partial, syntax);
try {
common_chat_parse(builder);
} catch (const common_chat_msg_partial_exception & ex) {
LOG_DBG("Partial parse: %s\n", ex.what());
if (!is_partial) {
builder.clear_tools();
builder.move_to(0);
common_chat_parse_content_only(builder);
}
}
auto msg = builder.result();
if (!is_partial) {
LOG_DBG("Parsed message: %s\n", common_chat_msgs_to_json_oaicompat<json>({msg}).at(0).dump().c_str());
}
return msg;
}

File diff suppressed because it is too large Load Diff

View File

@ -268,10 +268,10 @@ static bool is_reserved_name(const std::string & name) {
} }
std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+"); std::regex INVALID_RULE_CHARS_RE("[^a-zA-Z0-9-]+");
std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"]"); std::regex GRAMMAR_LITERAL_ESCAPE_RE("[\r\n\"\\\\]");
std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]"); std::regex GRAMMAR_RANGE_LITERAL_ESCAPE_RE("[\r\n\"\\]\\-\\\\]");
std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = { std::unordered_map<char, std::string> GRAMMAR_LITERAL_ESCAPES = {
{'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"} {'\r', "\\r"}, {'\n', "\\n"}, {'"', "\\\""}, {'-', "\\-"}, {']', "\\]"}, {'\\', "\\\\"}
}; };
std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'}; std::unordered_set<char> NON_LITERAL_SET = {'|', '.', '(', ')', '[', ']', '{', '}', '*', '+', '?'};

View File

@ -42,6 +42,9 @@ The following releases are verified and recommended:
## News ## News
- 2025.11
- Support malloc memory on device more than 4GB.
- 2025.2 - 2025.2
- Optimize MUL_MAT Q4_0 on Intel GPU for all dGPUs and built-in GPUs since MTL. Increase the performance of LLM (llama-2-7b.Q4_0.gguf) 21%-87% on Intel GPUs (MTL, ARL-H, Arc, Flex, PVC). - Optimize MUL_MAT Q4_0 on Intel GPU for all dGPUs and built-in GPUs since MTL. Increase the performance of LLM (llama-2-7b.Q4_0.gguf) 21%-87% on Intel GPUs (MTL, ARL-H, Arc, Flex, PVC).
|GPU|Base tokens/s|Increased tokens/s|Percent| |GPU|Base tokens/s|Increased tokens/s|Percent|
@ -789,6 +792,8 @@ use 1 SYCL GPUs: [0] with Max compute units:512
| GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because graph performance isn't yet better than non-graph performance. | | GGML_SYCL_DISABLE_GRAPH | 0 or 1 (default) | Disable running computations through SYCL Graphs feature. Disabled by default because graph performance isn't yet better than non-graph performance. |
| GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. | | GGML_SYCL_DISABLE_DNN | 0 (default) or 1 | Disable running computations through oneDNN and always use oneMKL. |
| ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer | | ZES_ENABLE_SYSMAN | 0 (default) or 1 | Support to get free memory of GPU by sycl::aspect::ext_intel_free_memory.<br>Recommended to use when --split-mode = layer |
| UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS | 0 (default) or 1 | Support malloc device memory more than 4GB.|
## Known Issues ## Known Issues
@ -835,6 +840,14 @@ use 1 SYCL GPUs: [0] with Max compute units:512
| The default context is too big. It leads to excessive memory usage.|Set `-c 8192` or a smaller value.| | The default context is too big. It leads to excessive memory usage.|Set `-c 8192` or a smaller value.|
| The model is too big and requires more memory than what is available.|Choose a smaller model or change to a smaller quantization, like Q5 -> Q4;<br>Alternatively, use more than one device to load model.| | The model is too big and requires more memory than what is available.|Choose a smaller model or change to a smaller quantization, like Q5 -> Q4;<br>Alternatively, use more than one device to load model.|
- `ggml_backend_sycl_buffer_type_alloc_buffer: can't allocate 5000000000 Bytes of memory on device`
You need to enable to support 4GB memory malloc by:
```
export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
```
### **GitHub contribution**: ### **GitHub contribution**:
Please add the `SYCL :` prefix/tag in issues/PRs titles to help the SYCL contributors to check/address them without delay. Please add the `SYCL :` prefix/tag in issues/PRs titles to help the SYCL contributors to check/address them without delay.

View File

@ -104,12 +104,16 @@ int main(int argc, char ** argv) {
params.embedding = true; params.embedding = true;
// get max number of sequences per batch
const int n_seq_max = llama_max_parallel_sequences();
// if the number of prompts that would be encoded is known in advance, it's more efficient to specify the // if the number of prompts that would be encoded is known in advance, it's more efficient to specify the
// --parallel argument accordingly. for convenience, if not specified, we fallback to unified KV cache // --parallel argument accordingly. for convenience, if not specified, we fallback to unified KV cache
// in order to support any number of prompts // in order to support any number of prompts
if (params.n_parallel == 1) { if (params.n_parallel == 1) {
LOG_INF("%s: n_parallel == 1 -> unified KV cache is enabled\n", __func__); LOG_INF("%s: n_parallel == 1 -> unified KV cache is enabled\n", __func__);
params.kv_unified = true; params.kv_unified = true;
params.n_parallel = n_seq_max;
} }
// utilize the full context // utilize the full context
@ -123,9 +127,6 @@ int main(int argc, char ** argv) {
params.n_ubatch = params.n_batch; params.n_ubatch = params.n_batch;
} }
// get max number of sequences per batch
const int n_seq_max = llama_max_parallel_sequences();
llama_backend_init(); llama_backend_init();
llama_numa_init(params.numa); llama_numa_init(params.numa);

View File

@ -231,9 +231,9 @@ DOT = '[^\\x0A\\x0D]'
RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()]) RESERVED_NAMES = set(["root", "dot", *PRIMITIVE_RULES.keys(), *STRING_FORMAT_RULES.keys()])
INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+') INVALID_RULE_CHARS_RE = re.compile(r'[^a-zA-Z0-9-]+')
GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"]') GRAMMAR_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\\]')
GRAMMAR_RANGE_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\]\-\\]') GRAMMAR_RANGE_LITERAL_ESCAPE_RE = re.compile(r'[\r\n"\]\-\\]')
GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]'} GRAMMAR_LITERAL_ESCAPES = {'\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]', '\\': '\\\\'}
NON_LITERAL_SET = set('|.()[]{}*+?') NON_LITERAL_SET = set('|.()[]{}*+?')
ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('^$.[]()|{}*+?') ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = set('^$.[]()|{}*+?')

View File

@ -15,6 +15,9 @@ MODEL_FILE=models/llama-2-7b.Q4_0.gguf
NGL=99 NGL=99
CONTEXT=4096 CONTEXT=4096
#support malloc device memory more than 4GB.
export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
if [ $# -gt 0 ]; then if [ $# -gt 0 ]; then
GGML_SYCL_DEVICE=$1 GGML_SYCL_DEVICE=$1
echo "use $GGML_SYCL_DEVICE as main GPU" echo "use $GGML_SYCL_DEVICE as main GPU"

View File

@ -6,7 +6,7 @@
# If you want more control, DPC++ Allows selecting a specific device through the # If you want more control, DPC++ Allows selecting a specific device through the
# following environment variable # following environment variable
#export ONEAPI_DEVICE_SELECTOR="level_zero:0" export ONEAPI_DEVICE_SELECTOR="level_zero:0"
source /opt/intel/oneapi/setvars.sh source /opt/intel/oneapi/setvars.sh
#export GGML_SYCL_DEBUG=1 #export GGML_SYCL_DEBUG=1
@ -18,11 +18,14 @@ MODEL_FILE=models/Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf
NGL=99 # Layers offloaded to the GPU. If the device runs out of memory, reduce this value according to the model you are using. NGL=99 # Layers offloaded to the GPU. If the device runs out of memory, reduce this value according to the model you are using.
CONTEXT=4096 CONTEXT=4096
#support malloc device memory more than 4GB.
export UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
if [ $# -gt 0 ]; then if [ $# -gt 0 ]; then
GGML_SYCL_DEVICE=$1 GGML_SYCL_DEVICE=$1
echo "Using $GGML_SYCL_DEVICE as the main GPU" echo "Using $GGML_SYCL_DEVICE as the main GPU"
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT} -mg $GGML_SYCL_DEVICE -sm none
else else
#use multiple GPUs with same max compute units #use multiple GPUs with same max compute units
ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -c ${CONTEXT} ZES_ENABLE_SYSMAN=1 ./build/bin/llama-cli -m ${MODEL_FILE} -p "${INPUT_PROMPT}" -n 400 -e -ngl ${NGL} -s 0 -c ${CONTEXT}
fi fi

View File

@ -5,5 +5,7 @@
set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:" set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:"
@call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force @call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force
:: support malloc device memory more than 4GB.
set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
.\build\bin\llama-cli.exe -m models\llama-2-7b.Q4_0.gguf -p %INPUT2% -n 400 -e -ngl 99 -s 0 .\build\bin\llama-cli.exe -m models\llama-2-7b.Q4_0.gguf -p %INPUT2% -n 400 -e -ngl 99 -s 0

View File

@ -5,5 +5,7 @@
set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:" set INPUT2="Building a website can be done in 10 simple steps:\nStep 1:"
@call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force @call "C:\Program Files (x86)\Intel\oneAPI\setvars.bat" intel64 --force
:: support malloc device memory more than 4GB.
set UR_L0_ENABLE_RELAXED_ALLOCATION_LIMITS=1
.\build\bin\llama-cli.exe -m models\Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -p %INPUT2% -n 400 -e -ngl 99 .\build\bin\llama-cli.exe -m models\Meta-Llama-3.1-8B-Instruct-Q4_K_M.gguf -p %INPUT2% -n 400 -s 0 -e -ngl 99

View File

@ -183,6 +183,7 @@ endif()
# ggml core # ggml core
set(GGML_SCHED_MAX_COPIES "4" CACHE STRING "ggml: max input copies for pipeline parallelism") set(GGML_SCHED_MAX_COPIES "4" CACHE STRING "ggml: max input copies for pipeline parallelism")
option(GGML_CPU "ggml: enable CPU backend" ON) option(GGML_CPU "ggml: enable CPU backend" ON)
option(GGML_SCHED_NO_REALLOC "ggml: disallow reallocations in ggml-alloc (for debugging)" OFF)
# 3rd party libs / backends # 3rd party libs / backends
option(GGML_ACCELERATE "ggml: enable Accelerate framework" ON) option(GGML_ACCELERATE "ggml: enable Accelerate framework" ON)

View File

@ -221,6 +221,10 @@ if (GGML_BACKEND_DL)
target_compile_definitions(ggml-base PUBLIC GGML_BACKEND_DL) target_compile_definitions(ggml-base PUBLIC GGML_BACKEND_DL)
endif() endif()
if (GGML_SCHED_NO_REALLOC)
target_compile_definitions(ggml-base PUBLIC GGML_SCHED_NO_REALLOC)
endif()
add_library(ggml add_library(ggml
ggml-backend-reg.cpp) ggml-backend-reg.cpp)
add_library(ggml::ggml ALIAS ggml) add_library(ggml::ggml ALIAS ggml)

View File

@ -921,10 +921,15 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c
} }
if (realloc) { if (realloc) {
#ifndef NDEBUG #ifndef NDEBUG
{
size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0; size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0;
GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); if (cur_size > 0) {
GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n",
__func__, ggml_backend_buft_name(galloc->bufts[i]),
cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0);
}
}
#endif #endif
ggml_vbuffer_free(galloc->buffers[i]); ggml_vbuffer_free(galloc->buffers[i]);
galloc->buffers[i] = ggml_vbuffer_alloc(galloc->bufts[i], galloc->buf_tallocs[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE); galloc->buffers[i] = ggml_vbuffer_alloc(galloc->bufts[i], galloc->buf_tallocs[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE);
if (galloc->buffers[i] == NULL) { if (galloc->buffers[i] == NULL) {

View File

@ -1395,14 +1395,20 @@ static bool ggml_backend_sched_alloc_splits(ggml_backend_sched_t sched) {
// allocate graph // allocate graph
if (backend_ids_changed || !ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) { if (backend_ids_changed || !ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) {
#ifdef GGML_SCHED_NO_REALLOC
GGML_ABORT("%s: failed to allocate graph, but graph re-allocation is disabled by GGML_SCHED_NO_REALLOC\n", __func__);
#endif
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed);
#endif
// the re-allocation may cause the split inputs to be moved to a different address // the re-allocation may cause the split inputs to be moved to a different address
// synchronize without ggml_backend_sched_synchronize to avoid changing cur_copy // synchronize without ggml_backend_sched_synchronize to avoid changing cur_copy
for (int i = 0; i < sched->n_backends; i++) { for (int i = 0; i < sched->n_backends; i++) {
ggml_backend_synchronize(sched->backends[i]); ggml_backend_synchronize(sched->backends[i]);
} }
#ifndef NDEBUG
GGML_LOG_DEBUG("%s: failed to allocate graph, reserving (backend_ids_changed = %d)\n", __func__, backend_ids_changed);
#endif
ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids); ggml_gallocr_reserve_n(sched->galloc, &sched->graph, sched->node_backend_ids, sched->leaf_backend_ids);
if (!ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) { if (!ggml_gallocr_alloc_graph(sched->galloc, &sched->graph)) {
GGML_LOG_ERROR("%s: failed to allocate graph\n", __func__); GGML_LOG_ERROR("%s: failed to allocate graph\n", __func__);

View File

@ -1,20 +1,23 @@
#include "ggml-backend-impl.h" #include "ggml-backend-impl.h"
#if defined(__riscv) && __riscv_xlen == 64 #if defined(__riscv) && __riscv_xlen == 64
#include <sys/auxv.h> #include <asm/hwprobe.h>
#include <asm/unistd.h>
//https://github.com/torvalds/linux/blob/master/arch/riscv/include/uapi/asm/hwcap.h#L24 #include <unistd.h>
#ifndef COMPAT_HWCAP_ISA_V
#define COMPAT_HWCAP_ISA_V (1 << ('V' - 'A'))
#endif
struct riscv64_features { struct riscv64_features {
bool has_rvv = false; bool has_rvv = false;
riscv64_features() { riscv64_features() {
uint32_t hwcap = getauxval(AT_HWCAP); struct riscv_hwprobe probe;
probe.key = RISCV_HWPROBE_KEY_IMA_EXT_0;
probe.value = 0;
has_rvv = !!(hwcap & COMPAT_HWCAP_ISA_V); int ret = syscall(__NR_riscv_hwprobe, &probe, 1, 0, NULL, 0);
if (0 == ret) {
has_rvv = !!(probe.value & RISCV_HWPROBE_IMA_V);
}
} }
}; };

View File

@ -84,12 +84,12 @@
#define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000 #define GGML_CUDA_CC_QY1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x210) // MTT S80, MTT S3000
#define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000 #define GGML_CUDA_CC_QY2 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x220) // MTT S4000
#define GGML_CUDA_CC_NG (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // TBD #define GGML_CUDA_CC_PH1 (GGML_CUDA_CC_OFFSET_MTHREADS + 0x310) // MTT S5000
#define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD) #define GGML_CUDA_CC_IS_MTHREADS(cc) (cc >= GGML_CUDA_CC_OFFSET_MTHREADS && cc < GGML_CUDA_CC_OFFSET_AMD)
#define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2) #define GGML_CUDA_CC_IS_QY1(cc) (cc >= GGML_CUDA_CC_QY1 && cc < GGML_CUDA_CC_QY2)
#define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_NG) #define GGML_CUDA_CC_IS_QY2(cc) (cc >= GGML_CUDA_CC_QY2 && cc < GGML_CUDA_CC_PH1)
#define GGML_CUDA_CC_IS_NG(cc) (cc >= GGML_CUDA_CC_NG) #define GGML_CUDA_CC_IS_PH1(cc) (cc >= GGML_CUDA_CC_PH1)
#if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070 #if !defined(GGML_USE_HIP) && !defined(GGML_USE_MUSA) && CUDART_VERSION >= 11070
# define GGML_CUDA_USE_CUB # define GGML_CUDA_USE_CUB
@ -212,9 +212,9 @@ static const char * cu_get_error_str(CUresult err) {
#define GGML_USE_VMM #define GGML_USE_VMM
#endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM)) #endif // (!defined(GGML_USE_HIP) && !defined(GGML_CUDA_NO_VMM)) || (defined(GGML_USE_HIP) && !defined(GGML_HIP_NO_VMM))
#if defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL #if defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
#define FP16_AVAILABLE #define FP16_AVAILABLE
#endif // defined(GGML_USE_HIP) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL #endif // defined(GGML_USE_HIP) || defined(GGML_USE_MUSA) || __CUDA_ARCH__ >= GGML_CUDA_CC_PASCAL
#if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610 #if defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610
#define FAST_FP16_AVAILABLE #define FAST_FP16_AVAILABLE
@ -250,12 +250,14 @@ static const char * cu_get_error_str(CUresult err) {
#endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220) #endif // !defined(GGML_CUDA_NO_FA) && !(defined(GGML_USE_MUSA) && __MUSA_ARCH__ < 220)
static bool fp16_available(const int cc) { static bool fp16_available(const int cc) {
return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL; return ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_PASCAL ||
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);
} }
static bool fast_fp16_available(const int cc) { static bool fast_fp16_available(const int cc) {
return GGML_CUDA_CC_IS_AMD(cc) || return GGML_CUDA_CC_IS_AMD(cc) ||
(GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610); (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610) ||
(GGML_CUDA_CC_IS_MTHREADS(cc) && fp16_available(cc));
} }
// To be used for feature selection of external libraries, e.g. cuBLAS. // To be used for feature selection of external libraries, e.g. cuBLAS.
@ -272,7 +274,9 @@ static bool fp16_mma_hardware_available(const int cc) {
} }
static bool bf16_mma_hardware_available(const int cc) { static bool bf16_mma_hardware_available(const int cc) {
return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) || GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3; return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_AMPERE) ||
GGML_CUDA_CC_IS_CDNA(cc) || cc >= GGML_CUDA_CC_RDNA3 ||
(GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_PH1);
} }
static bool fp32_mma_hardware_available(const int cc) { static bool fp32_mma_hardware_available(const int cc) {

View File

@ -86,6 +86,9 @@ static __global__ void cpy_scalar_transpose(const char * cx, char * cdst, const
} }
} }
} }
GGML_UNUSED_VARS(ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11,
nb12, nb13);
} }
static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) { static __device__ void cpy_blck_q8_0_f32(const char * cxi, char * cdsti) {
@ -202,7 +205,7 @@ static void ggml_cpy_scalar_cuda(
ne00n = ne00; ne00n = ne00;
ne01n = ne01; ne01n = ne01;
ne02n = ne02; ne02n = ne02;
} else if (nb00 > nb02) { } else {
ne00n = ne00; ne00n = ne00;
ne01n = ne01*ne02; ne01n = ne01*ne02;
ne02n = 1; ne02n = 1;

View File

@ -609,7 +609,7 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
float KQ_sum_add = 0.0f; float KQ_sum_add = 0.0f;
#pragma unroll #pragma unroll
for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) { for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup ? const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < static_cast<uint32_t>(k_VKQ_sup) ?
expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f; expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f;
KQ_sum_add += val; KQ_sum_add += val;
tmp[i0/(np*warp_size)][jc1] = val; tmp[i0/(np*warp_size)][jc1] = val;

View File

@ -155,7 +155,7 @@ static __global__ void flash_attn_ext_vec(
for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) { for (int i0 = 0; i0 < int(D/sizeof(int)); i0 += WARP_SIZE) {
const int i = i0 + threadIdx.x; const int i = i0 + threadIdx.x;
if (i0 + WARP_SIZE <= D/sizeof(int) || i < D/sizeof(int)) { if (i0 + WARP_SIZE <= int(D/sizeof(int)) || i < int(D/sizeof(int))) {
tmp_q_i32[i] = 0; tmp_q_i32[i] = 0;
} }
} }
@ -272,7 +272,7 @@ static __global__ void flash_attn_ext_vec(
KQ_max_new[j] = fmaxf(KQ_max_new[j], sum); KQ_max_new[j] = fmaxf(KQ_max_new[j], sum);
if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == i_KQ_0) { if ((nthreads_KQ == WARP_SIZE ? threadIdx.x : threadIdx.x % nthreads_KQ) == uint32_t(i_KQ_0)) {
KQ_reg[j] = sum; KQ_reg[j] = sum;
} }
} }

View File

@ -3050,7 +3050,12 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax = std::initializer_list<enum ggml_op> topk_moe_ops_delayed_softmax =
ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true); ggml_cuda_topk_moe_ops(/*with_norm=*/false, /*delayed_softmax=*/true);
if (ops.size() == topk_moe_ops_with_norm.size() && const auto is_equal = [](const std::initializer_list<enum ggml_op> & list1,
const std::initializer_list<enum ggml_op> & list2) {
return std::equal(list1.begin(), list1.end(), list2.begin(), list2.end());
};
if (is_equal(topk_moe_ops_with_norm, ops) &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) { ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 9 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx]; ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx + 9]; ggml_tensor * weights = cgraph->nodes[node_idx + 9];
@ -3060,8 +3065,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
} }
} }
if (ops.size() == topk_moe_ops.size() && if (is_equal(topk_moe_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 3, node_idx + 4 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx]; ggml_tensor * softmax = cgraph->nodes[node_idx];
ggml_tensor * weights = cgraph->nodes[node_idx + 4]; ggml_tensor * weights = cgraph->nodes[node_idx + 4];
if (ggml_cuda_should_use_topk_moe(softmax, weights)) { if (ggml_cuda_should_use_topk_moe(softmax, weights)) {
@ -3069,7 +3073,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
} }
} }
if (ops.size() == topk_moe_ops_delayed_softmax.size() && if (is_equal(topk_moe_ops_delayed_softmax, ops) &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) { ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 1, node_idx + 5 })) {
ggml_tensor * softmax = cgraph->nodes[node_idx + 4]; ggml_tensor * softmax = cgraph->nodes[node_idx + 4];
ggml_tensor * weights = cgraph->nodes[node_idx + 5]; ggml_tensor * weights = cgraph->nodes[node_idx + 5];
@ -3085,9 +3089,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
std::initializer_list<enum ggml_op> mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU }; std::initializer_list<enum ggml_op> mul_mat_id_glu_ops = { GGML_OP_MUL_MAT_ID, GGML_OP_MUL_MAT_ID, GGML_OP_GLU };
std::initializer_list<enum ggml_op> mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU }; std::initializer_list<enum ggml_op> mul_mat_glu_ops = { GGML_OP_MUL_MAT, GGML_OP_MUL_MAT, GGML_OP_GLU };
if (ops.size() == 5 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}) || if ((is_equal(mul_mat_bias_glu_ops, ops) || is_equal(mul_mat_id_bias_glu_ops, ops)) &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 4}))) { ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 4 })) {
const ggml_tensor * ffn_gate = cgraph->nodes[node_idx]; const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
const ggml_tensor * ffn_gate_bias = cgraph->nodes[node_idx + 1]; const ggml_tensor * ffn_gate_bias = cgraph->nodes[node_idx + 1];
const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 2]; const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 2];
@ -3099,9 +3102,8 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
} }
} }
if (ops.size() == 3 && (ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}) || if ((is_equal(mul_mat_id_glu_ops, ops) || is_equal(mul_mat_glu_ops, ops)) &&
ggml_can_fuse_subgraph(cgraph, node_idx, ops, {node_idx + 2}))) { ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
const ggml_tensor * ffn_gate = cgraph->nodes[node_idx]; const ggml_tensor * ffn_gate = cgraph->nodes[node_idx];
const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 1]; const ggml_tensor * ffn_up = cgraph->nodes[node_idx + 1];
const ggml_tensor * glu = cgraph->nodes[node_idx + 2]; const ggml_tensor * glu = cgraph->nodes[node_idx + 2];
@ -3111,7 +3113,9 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx,
} }
} }
if (ops.size() == 3 && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) { std::initializer_list<enum ggml_op> rope_set_rows_ops = { GGML_OP_ROPE, GGML_OP_VIEW, GGML_OP_SET_ROWS };
if (is_equal(rope_set_rows_ops, ops) && ggml_can_fuse_subgraph(cgraph, node_idx, ops, { node_idx + 2 })) {
const ggml_tensor * rope = cgraph->nodes[node_idx]; const ggml_tensor * rope = cgraph->nodes[node_idx];
const ggml_tensor * view = cgraph->nodes[node_idx + 1]; const ggml_tensor * view = cgraph->nodes[node_idx + 1];
const ggml_tensor * set_rows = cgraph->nodes[node_idx + 2]; const ggml_tensor * set_rows = cgraph->nodes[node_idx + 2];

View File

@ -889,8 +889,8 @@ namespace ggml_cuda_mma {
: "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7]) : "+r"(Dxi[0]), "+r"(Dxi[1]), "+r"(Dxi[2]), "+r"(Dxi[3]), "+r"(Dxi[4]), "+r"(Dxi[5]), "+r"(Dxi[6]), "+r"(Dxi[7])
: "r"(Axi[6]), "r"(Axi[7]), "r"(Bxi[6]), "r"(Bxi[7])); : "r"(Axi[6]), "r"(Axi[7]), "r"(Bxi[6]), "r"(Bxi[7]));
#else #else
tile<16, 8, float> * D16 = (tile<16, 8, float> *) &D; tile <16, 8, float> * D16 = reinterpret_cast<tile <16, 8, float> *>(&D);
tile<16, 8, half2> * A16 = (tile<16, 8, half2> *) &A; const tile<16, 8, half2> * A16 = reinterpret_cast<const tile<16, 8, half2> *>(&A);
mma(D16[0], A16[0], B); mma(D16[0], A16[0], B);
mma(D16[1], A16[1], B); mma(D16[1], A16[1], B);
#endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE #endif // __CUDA_ARCH__ >= GGML_CUDA_CC_AMPERE

View File

@ -91,7 +91,10 @@ if (GGML_SYCL_F16)
add_compile_definitions(GGML_SYCL_F16) add_compile_definitions(GGML_SYCL_F16)
endif() endif()
if (GGML_SYCL_TARGET STREQUAL "NVIDIA") if (GGML_SYCL_TARGET STREQUAL "INTEL")
add_compile_definitions(GGML_SYCL_WARP_SIZE=16)
target_link_options(ggml-sycl PRIVATE -Xs -ze-intel-greater-than-4GB-buffer-required)
elseif (GGML_SYCL_TARGET STREQUAL "NVIDIA")
add_compile_definitions(GGML_SYCL_WARP_SIZE=32) add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
elseif (GGML_SYCL_TARGET STREQUAL "AMD") elseif (GGML_SYCL_TARGET STREQUAL "AMD")
# INFO: Allowed Sub_group_sizes are not consistent through all # INFO: Allowed Sub_group_sizes are not consistent through all
@ -100,7 +103,8 @@ elseif (GGML_SYCL_TARGET STREQUAL "AMD")
# Target archs tested working: gfx1030, gfx1031, (Only tested sub_group_size = 32) # Target archs tested working: gfx1030, gfx1031, (Only tested sub_group_size = 32)
add_compile_definitions(GGML_SYCL_WARP_SIZE=32) add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
else() else()
add_compile_definitions(GGML_SYCL_WARP_SIZE=16) # default for other target
add_compile_definitions(GGML_SYCL_WARP_SIZE=32)
endif() endif()
if (GGML_SYCL_GRAPH) if (GGML_SYCL_GRAPH)

View File

@ -515,9 +515,6 @@ void ggml_sycl_cpy(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, co
const int64_t ne = ggml_nelements(src0); const int64_t ne = ggml_nelements(src0);
GGML_ASSERT(ne == ggml_nelements(src1)); GGML_ASSERT(ne == ggml_nelements(src1));
GGML_ASSERT(ggml_nbytes(src0) <= INT_MAX);
GGML_ASSERT(ggml_nbytes(src1) <= INT_MAX);
GGML_TENSOR_BINARY_OP_LOCALS01; GGML_TENSOR_BINARY_OP_LOCALS01;
SYCL_CHECK(ggml_sycl_set_device(ctx.device)); SYCL_CHECK(ggml_sycl_set_device(ctx.device));

View File

@ -613,9 +613,10 @@ struct vk_device_struct {
vk_pipeline pipeline_dequant[GGML_TYPE_COUNT]; vk_pipeline pipeline_dequant[GGML_TYPE_COUNT];
vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols]; vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols]; vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT]; vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT];
vk_pipeline pipeline_dequant_mul_mat_vec_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols]; vk_pipeline pipeline_dequant_mul_mat_vec_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols];
vk_pipeline pipeline_dequant_mul_mat_vec_id_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT];
vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio]; vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio];
vk_pipeline pipeline_mul_mat_vec_nc_f16_f32; vk_pipeline pipeline_mul_mat_vec_nc_f16_f32;
@ -1611,7 +1612,7 @@ class vk_perf_logger {
} }
if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) { if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) {
const uint64_t m = node->src[0]->ne[1]; const uint64_t m = node->src[0]->ne[1];
const uint64_t n = node->ne[1]; const uint64_t n = (node->op == GGML_OP_MUL_MAT) ? node->ne[1] : node->ne[2];
const uint64_t k = node->src[1]->ne[0]; const uint64_t k = node->src[1]->ne[0];
const uint64_t batch = node->src[1]->ne[2] * node->src[1]->ne[3]; const uint64_t batch = node->src[1]->ne[2] * node->src[1]->ne[3];
std::string name = ggml_op_name(node->op); std::string name = ggml_op_name(node->op);
@ -3525,13 +3526,18 @@ static void ggml_vk_load_shaders(vk_device& device) {
// the number of rows computed per shader depends on GPU model and quant // the number of rows computed per shader depends on GPU model and quant
uint32_t rm_stdq = 1; uint32_t rm_stdq = 1;
uint32_t rm_kq = 2; uint32_t rm_kq = 2;
uint32_t rm_stdq_int = 1;
uint32_t rm_kq_int = 1;
if (device->vendor_id == VK_VENDOR_ID_AMD) { if (device->vendor_id == VK_VENDOR_ID_AMD) {
if (device->architecture == AMD_GCN) { if (device->architecture == AMD_GCN) {
rm_stdq = 2; rm_stdq = 2;
rm_kq = 4; rm_kq = 4;
rm_stdq_int = 4;
} }
} else if (device->vendor_id == VK_VENDOR_ID_INTEL) } else if (device->vendor_id == VK_VENDOR_ID_INTEL) {
rm_stdq = 2; rm_stdq = 2;
rm_stdq_int = 2;
}
uint32_t rm_iq = 2 * rm_kq; uint32_t rm_iq = 2 * rm_kq;
const bool use_subgroups = device->subgroup_arithmetic && device->architecture != vk_device_architecture::AMD_GCN; const bool use_subgroups = device->subgroup_arithmetic && device->architecture != vk_device_architecture::AMD_GCN;
@ -3612,39 +3618,73 @@ static void ggml_vk_load_shaders(vk_device& device) {
const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size; const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;
const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4); const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32", arr_dmmv_q4_0_q8_1_f32_len[reduc], arr_dmmv_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32", arr_dmmv_q4_0_q8_1_f32_len[reduc], arr_dmmv_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32", arr_dmmv_q4_1_q8_1_f32_len[reduc], arr_dmmv_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32", arr_dmmv_q4_1_q8_1_f32_len[reduc], arr_dmmv_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_q8_1_f32", arr_dmmv_q5_0_q8_1_f32_len[reduc], arr_dmmv_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_q8_1_f32", arr_dmmv_q5_0_q8_1_f32_len[reduc], arr_dmmv_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_q8_1_f32", arr_dmmv_q5_1_q8_1_f32_len[reduc], arr_dmmv_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_q8_1_f32", arr_dmmv_q5_1_q8_1_f32_len[reduc], arr_dmmv_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32", arr_dmmv_q8_0_q8_1_f32_len[reduc], arr_dmmv_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32", arr_dmmv_q8_0_q8_1_f32_len[reduc], arr_dmmv_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_q8_1_f32", arr_dmmv_mxfp4_q8_1_f32_len[reduc], arr_dmmv_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_q8_1_f32", arr_dmmv_q2_k_q8_1_f32_len[reduc], arr_dmmv_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_q8_1_f32", arr_dmmv_q3_k_q8_1_f32_len[reduc], arr_dmmv_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_q8_1_f32", arr_dmmv_q4_k_q8_1_f32_len[reduc], arr_dmmv_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_q8_1_f32", arr_dmmv_q5_k_q8_1_f32_len[reduc], arr_dmmv_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_q8_1_f32", arr_dmmv_q6_k_q8_1_f32_len[reduc], arr_dmmv_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int, i+1}, 1, true, use_subgroups, subgroup_size_int);
} }
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT #endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
} }
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", arr_dmmv_id_f32_f32_f32_len[reduc], arr_dmmv_id_f32_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", arr_dmmv_id_f16_f32_f32_len[reduc], arr_dmmv_id_f16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", arr_dmmv_id_bf16_f32_f32_len[reduc], arr_dmmv_id_bf16_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {wg_size_subgroup, 2}, 1, false, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", arr_dmmv_id_q4_0_f32_f32_len[reduc], arr_dmmv_id_q4_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", arr_dmmv_id_q4_1_f32_f32_len[reduc], arr_dmmv_id_q4_1_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", arr_dmmv_id_q5_0_f32_f32_len[reduc], arr_dmmv_id_q5_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", arr_dmmv_id_q5_1_f32_f32_len[reduc], arr_dmmv_id_q5_1_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", arr_dmmv_id_q8_0_f32_f32_len[reduc], arr_dmmv_id_q8_0_f32_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq}, 1, true, use_subgroups, force_subgroup_size);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", arr_dmmv_id_q2_k_f32_f32_len[reduc16], arr_dmmv_id_q2_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", arr_dmmv_id_q3_k_f32_f32_len[reduc16], arr_dmmv_id_q3_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", arr_dmmv_id_q4_k_f32_f32_len[reduc16], arr_dmmv_id_q4_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", arr_dmmv_id_q5_k_f32_f32_len[reduc16], arr_dmmv_id_q5_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", arr_dmmv_id_q6_k_f32_f32_len[reduc16], arr_dmmv_id_q6_k_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", arr_dmmv_id_iq1_s_f32_f32_len[reduc16], arr_dmmv_id_iq1_s_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", arr_dmmv_id_iq1_m_f32_f32_len[reduc16], arr_dmmv_id_iq1_m_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", arr_dmmv_id_iq2_xxs_f32_f32_len[reduc16], arr_dmmv_id_iq2_xxs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", arr_dmmv_id_iq2_xs_f32_f32_len[reduc16], arr_dmmv_id_iq2_xs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", arr_dmmv_id_iq2_s_f32_f32_len[reduc16], arr_dmmv_id_iq2_s_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", arr_dmmv_id_iq3_xxs_f32_f32_len[reduc16], arr_dmmv_id_iq3_xxs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", arr_dmmv_id_iq3_s_f32_f32_len[reduc16], arr_dmmv_id_iq3_s_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", arr_dmmv_id_iq4_xs_f32_f32_len[reduc16], arr_dmmv_id_iq4_xs_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", arr_dmmv_id_iq4_nl_f32_f32_len[reduc16], arr_dmmv_id_iq4_nl_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", arr_dmmv_id_mxfp4_f32_f32_len[reduc16], arr_dmmv_id_mxfp4_f32_f32_data[reduc16], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq}, 1, true, use_subgroups16, force_subgroup_size16);
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (device->integer_dot_product) {
const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size;
const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_q8_1_f32", arr_dmmv_id_q4_0_q8_1_f32_len[reduc], arr_dmmv_id_q4_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_q8_1_f32", arr_dmmv_id_q4_1_q8_1_f32_len[reduc], arr_dmmv_id_q4_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_q8_1_f32", arr_dmmv_id_q5_0_q8_1_f32_len[reduc], arr_dmmv_id_q5_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_q8_1_f32", arr_dmmv_id_q5_1_q8_1_f32_len[reduc], arr_dmmv_id_q5_1_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_q8_1_f32", arr_dmmv_id_q8_0_q8_1_f32_len[reduc], arr_dmmv_id_q8_0_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_q8_1_f32", arr_dmmv_id_mxfp4_q8_1_f32_len[reduc], arr_dmmv_id_mxfp4_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_stdq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_q8_1_f32", arr_dmmv_id_q2_k_q8_1_f32_len[reduc], arr_dmmv_id_q2_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {2*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 2*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_q8_1_f32", arr_dmmv_id_q3_k_q8_1_f32_len[reduc], arr_dmmv_id_q3_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_q8_1_f32", arr_dmmv_id_q4_k_q8_1_f32_len[reduc], arr_dmmv_id_q4_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_q8_1_f32", arr_dmmv_id_q5_k_q8_1_f32_len[reduc], arr_dmmv_id_q5_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[w][GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_q8_1_f32", arr_dmmv_id_q6_k_q8_1_f32_len[reduc], arr_dmmv_id_q6_k_q8_1_f32_data[reduc], "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_push_constants), {1*rm_kq_int, 1, 1}, {wg_size_subgroup_int, 1*rm_kq_int}, 1, true, use_subgroups, subgroup_size_int);
}
#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT
} }
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); #if !defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); GGML_UNUSED(rm_stdq_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", mul_mat_vec_id_bf16_f32_len, mul_mat_vec_id_bf16_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); GGML_UNUSED(rm_kq_int);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); #endif
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", mul_mat_vec_id_iq1_s_f32_len, mul_mat_vec_id_iq1_s_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", mul_mat_vec_id_iq1_m_f32_len, mul_mat_vec_id_iq1_m_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", mul_mat_vec_id_mxfp4_f32_len, mul_mat_vec_id_mxfp4_f32_data, "main", mul_mat_vec_id_num_bindings, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true);
// dequant shaders // dequant shaders
ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1);
@ -5453,6 +5493,12 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context *
case GGML_TYPE_Q5_0: case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1: case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
break; break;
default: default:
return nullptr; return nullptr;
@ -5592,9 +5638,28 @@ static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_co
} }
} }
static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) { static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t m, uint32_t k) {
VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec_id()"); VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec_id()");
GGML_ASSERT(b_type == GGML_TYPE_F32); GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_Q8_1);
if (b_type == GGML_TYPE_Q8_1) {
switch (a_type) {
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q4_1:
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_Q8_0:
case GGML_TYPE_MXFP4:
case GGML_TYPE_Q2_K:
case GGML_TYPE_Q3_K:
case GGML_TYPE_Q4_K:
case GGML_TYPE_Q5_K:
case GGML_TYPE_Q6_K:
break;
default:
return nullptr;
}
}
switch (a_type) { switch (a_type) {
case GGML_TYPE_F32: case GGML_TYPE_F32:
@ -5625,7 +5690,31 @@ static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context
return nullptr; return nullptr;
} }
return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type]; // heuristic to choose workgroup size
uint32_t dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
if ((ctx->device->vendor_id == VK_VENDOR_ID_NVIDIA && ctx->device->architecture != vk_device_architecture::NVIDIA_PRE_TURING) || ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {
// Prefer larger workgroups when M is small, to spread the work out more
// and keep more SMs busy.
// q6_k seems to prefer small workgroup size even for "medium" values of M.
if (a_type == GGML_TYPE_Q6_K) {
if (m < 4096 && k >= 1024) {
dmmv_wg = DMMV_WG_SIZE_LARGE;
}
} else {
if (m <= 8192 && k >= 1024) {
dmmv_wg = DMMV_WG_SIZE_LARGE;
}
}
}
if (b_type == GGML_TYPE_Q8_1) {
if (ctx->device->vendor_id == VK_VENDOR_ID_INTEL) {
dmmv_wg = DMMV_WG_SIZE_SUBGROUP;
}
return ctx->device->pipeline_dequant_mul_mat_vec_id_q8_1_f32[dmmv_wg][a_type];
}
return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[dmmv_wg][a_type];
} }
static void * ggml_vk_host_malloc(vk_device& device, size_t size) { static void * ggml_vk_host_malloc(vk_device& device, size_t size) {
@ -6817,20 +6906,35 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
return false; return false;
} }
// General performance issue with q3_k and q6_k due to 2-byte alignment
if (src0_type == GGML_TYPE_Q3_K || src0_type == GGML_TYPE_Q6_K) {
return false;
}
// MMVQ is generally good for batches // MMVQ is generally good for batches
if (n > 1) { if (n > 1) {
return true; return true;
} }
// Quantization overhead is not worth it for small k
switch (device->vendor_id) { switch (device->vendor_id) {
case VK_VENDOR_ID_NVIDIA: case VK_VENDOR_ID_NVIDIA:
if (k <= 4096) {
return false;
}
switch (src0_type) { switch (src0_type) {
case GGML_TYPE_MXFP4:
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
return device->architecture == vk_device_architecture::NVIDIA_PRE_TURING; return device->architecture == vk_device_architecture::NVIDIA_PRE_TURING;
default: default:
return true; return true;
} }
case VK_VENDOR_ID_AMD: case VK_VENDOR_ID_AMD:
if (k < 2048) {
return false;
}
switch (src0_type) { switch (src0_type) {
case GGML_TYPE_Q8_0: case GGML_TYPE_Q8_0:
return device->architecture == vk_device_architecture::AMD_GCN; return device->architecture == vk_device_architecture::AMD_GCN;
@ -6838,6 +6942,10 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
return true; return true;
} }
case VK_VENDOR_ID_INTEL: case VK_VENDOR_ID_INTEL:
if (k < 2048) {
return false;
}
switch (src0_type) { switch (src0_type) {
// From tests on A770 Linux, may need more tuning // From tests on A770 Linux, may need more tuning
case GGML_TYPE_Q4_0: case GGML_TYPE_Q4_0:
@ -6851,7 +6959,6 @@ static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_
} }
GGML_UNUSED(m); GGML_UNUSED(m);
GGML_UNUSED(k);
} }
static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) { static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const struct ggml_cgraph * cgraph, int node_idx) {
@ -7574,7 +7681,7 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
if (x_non_contig || qx_needs_dequant) { if (x_non_contig || qx_needs_dequant) {
ctx->prealloc_x_need_sync = true; ctx->prealloc_x_need_sync = true;
} }
if (y_non_contig) { if (y_non_contig || quantize_y) {
ctx->prealloc_y_need_sync = true; ctx->prealloc_y_need_sync = true;
} }
} }
@ -7600,7 +7707,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
const uint64_t ne10 = src1->ne[0]; const uint64_t ne10 = src1->ne[0];
const uint64_t ne11 = src1->ne[1]; const uint64_t ne11 = src1->ne[1];
// const uint64_t ne12 = src1->ne[2]; const uint64_t ne12 = src1->ne[2];
// const uint64_t ne13 = src1->ne[3]; // const uint64_t ne13 = src1->ne[3];
const uint64_t nei0 = ids->ne[0]; const uint64_t nei0 = ids->ne[0];
@ -7617,19 +7724,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); const bool y_non_contig = !ggml_vk_dim01_contiguous(src1);
const bool f16_f32_kernel = src1->type == GGML_TYPE_F32; const bool f16_f32_kernel = src1->type == GGML_TYPE_F32;
bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && !y_non_contig && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne12, ne10, src0->type);
const bool qx_needs_dequant = x_non_contig;
const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig;
// Not implemented
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
const uint64_t x_ne = ggml_nelements(src0);
const uint64_t y_ne = ggml_nelements(src1);
const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne;
vk_pipeline to_fp16_vk_0 = nullptr; vk_pipeline to_fp16_vk_0 = nullptr;
vk_pipeline to_fp16_vk_1 = nullptr; vk_pipeline to_fp16_vk_1 = nullptr;
@ -7641,11 +7736,38 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
} else { } else {
to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type);
} }
vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type);
// Check for mmq first
vk_pipeline dmmv = quantize_y ? ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, GGML_TYPE_Q8_1, ne20, ne00) : nullptr;
vk_pipeline to_q8_1 = nullptr;
if (dmmv == nullptr) {
// Fall back to f16 dequant mul mat
dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type, ne20, ne00);
quantize_y = false;
}
if (quantize_y) {
to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1);
}
const bool qx_needs_dequant = x_non_contig;
const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig);
// Not implemented
GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT
GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT
GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT
GGML_ASSERT(dmmv != nullptr); GGML_ASSERT(dmmv != nullptr);
const uint64_t x_ne = ggml_nelements(src0);
const uint64_t y_ne = ggml_nelements(src1);
const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment);
const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz;
const uint64_t y_sz = quantize_y ? (ggml_vk_align_size(y_ne, 128) * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) :
(f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne);
{ {
if ( if (
(qx_needs_dequant && x_sz > ctx->device->properties.limits.maxStorageBufferRange) || (qx_needs_dequant && x_sz > ctx->device->properties.limits.maxStorageBufferRange) ||
@ -7656,7 +7778,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
ctx->prealloc_size_x = x_sz; ctx->prealloc_size_x = x_sz;
ggml_vk_preallocate_buffers(ctx, subctx); ggml_vk_preallocate_buffers(ctx, subctx);
} }
if (qy_needs_dequant && ctx->prealloc_size_y < y_sz) { if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz) {
ctx->prealloc_size_y = y_sz; ctx->prealloc_size_y = y_sz;
ggml_vk_preallocate_buffers(ctx, subctx); ggml_vk_preallocate_buffers(ctx, subctx);
} }
@ -7668,6 +7790,9 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
if (qy_needs_dequant) { if (qy_needs_dequant) {
ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1);
} }
if (quantize_y) {
ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1);
}
ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1); ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1);
} }
@ -7683,7 +7808,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
} else { } else {
d_X = d_Qx; d_X = d_Qx;
} }
if (qy_needs_dequant) { if (qy_needs_dequant || quantize_y) {
d_Y = { ctx->prealloc_y, 0, ctx->prealloc_y->size }; d_Y = { ctx->prealloc_y, 0, ctx->prealloc_y->size };
} else { } else {
d_Y = d_Qy; d_Y = d_Qy;
@ -7711,6 +7836,17 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
ctx->prealloc_y_last_tensor_used = src1; ctx->prealloc_y_last_tensor_used = src1;
} }
} }
if (quantize_y) {
if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() ||
ctx->prealloc_y_last_tensor_used != src1) {
if (ctx->prealloc_y_need_sync) {
ggml_vk_sync_buffers(ctx, subctx);
}
ggml_vk_quantize_q8_1(ctx, subctx, d_Qy, d_Y, y_ne);
ctx->prealloc_y_last_pipeline_used = to_q8_1.get();
ctx->prealloc_y_last_tensor_used = src1;
}
}
uint32_t stride_batch_y = ne10*ne11; uint32_t stride_batch_y = ne10*ne11;
@ -7772,7 +7908,7 @@ static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_conte
if (x_non_contig) { if (x_non_contig) {
ctx->prealloc_x_need_sync = true; ctx->prealloc_x_need_sync = true;
} }
if (y_non_contig) { if (y_non_contig || quantize_y) {
ctx->prealloc_y_need_sync = true; ctx->prealloc_y_need_sync = true;
} }
} }
@ -10239,7 +10375,9 @@ static void ggml_vk_topk(ggml_backend_vk_context * ctx, vk_context& subctx, cons
// Prefer going as small as num_topk_pipelines - 3 for perf reasons. // Prefer going as small as num_topk_pipelines - 3 for perf reasons.
// But if K is larger, then we need a larger workgroup // But if K is larger, then we need a larger workgroup
uint32_t max_pipeline = num_topk_pipelines - 3; uint32_t max_pipeline = num_topk_pipelines - 1;
uint32_t preferred_pipeline = std::max(num_topk_pipelines - 3, (uint32_t)log2f(float(k)) + 2);
max_pipeline = std::min(preferred_pipeline, max_pipeline);
uint32_t min_pipeline = (uint32_t)log2f(float(k)) + 1; uint32_t min_pipeline = (uint32_t)log2f(float(k)) + 1;
// require full subgroup // require full subgroup
min_pipeline = std::max(min_pipeline, ctx->device->subgroup_size_log2); min_pipeline = std::max(min_pipeline, ctx->device->subgroup_size_log2);

View File

@ -4,13 +4,6 @@
#include "types.glsl" #include "types.glsl"
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
#if defined(DATA_A_F32) #if defined(DATA_A_F32)
vec2 dequantize(uint ib, uint iqs, uint a_offset) { vec2 dequantize(uint ib, uint iqs, uint a_offset) {
return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]);

View File

@ -22,6 +22,13 @@ layout (push_constant) uniform parameter
#if !RMS_NORM_ROPE_FUSION #if !RMS_NORM_ROPE_FUSION
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#endif #endif

View File

@ -18,6 +18,13 @@ layout (push_constant) uniform parameter
} p; } p;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
uint get_idx() { uint get_idx() {

View File

@ -3,6 +3,7 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.glsl" #include "mul_mat_vec_base.glsl"
#include "dequant_funcs.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;

View File

@ -13,8 +13,6 @@
#include "mul_mat_vec_iface.glsl" #include "mul_mat_vec_iface.glsl"
#include "dequant_funcs.glsl"
layout (push_constant) uniform parameter layout (push_constant) uniform parameter
{ {
uint ncols; uint ncols;

View File

@ -5,13 +5,15 @@
#define MAT_VEC_FUSION_FLAGS_SCALE0 0x4 #define MAT_VEC_FUSION_FLAGS_SCALE0 0x4
#define MAT_VEC_FUSION_FLAGS_SCALE1 0x8 #define MAT_VEC_FUSION_FLAGS_SCALE1 0x8
#ifndef MMQ
layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#if defined(A_TYPE_VEC4) #if defined(A_TYPE_VEC4)
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];}; layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
#endif #endif
#else #if defined(A_TYPE_PACKED16)
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];}; layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];};
#endif
#if defined(A_TYPE_PACKED32)
layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];};
#endif #endif
layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; layout (binding = 1) readonly buffer B {B_TYPE data_b[];};

View File

@ -10,60 +10,56 @@
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
#if defined(DATA_A_QUANT_LEGACY) || defined(DATA_A_MXFP4)
#define K_PER_ITER 8 #define K_PER_ITER 8
#elif defined(DATA_A_QUANT_K)
#include "mul_mmq_funcs.glsl" #define K_PER_ITER 16
#else
#error unimplemented
#endif
uint a_offset, b_offset, d_offset; uint a_offset, b_offset, d_offset;
int32_t cache_b_qs[2]; int32_t cache_b_qs[K_PER_ITER / 4];
vec2 cache_b_ds; vec2 cache_b_ds;
#include "mul_mat_vecq_funcs.glsl"
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) { void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
const uint col = i*BLOCK_SIZE + tid*K_PER_ITER; const uint col = i*BLOCK_SIZE + tid*K_PER_ITER;
// Preload data_b block // Preload data_b block
const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset; const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset;
const uint b_qs_idx = tid % 4; const uint b_qs_idx = tid % (32 / K_PER_ITER);
const uint b_block_idx_outer = b_block_idx / 4; const uint b_block_idx_outer = b_block_idx / 4;
const uint b_block_idx_inner = b_block_idx % 4; const uint b_block_idx_inner = b_block_idx % 4;
cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]); cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]);
#if QUANT_R == 2 #if QUANT_R == 2
// Assumes K_PER_ITER == 8
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx]; cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx];
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx + 4]; cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx + 4];
#else #else
#if K_PER_ITER == 8
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2]; cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2];
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2 + 1]; cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2 + 1];
#elif K_PER_ITER == 16
cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 ];
cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 1];
cache_b_qs[2] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 2];
cache_b_qs[3] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 4 + 3];
#else
#error unimplemented
#endif
#endif #endif
uint ibi = first_row*p.ncols; uint ibi = first_row*p.ncols;
[[unroll]] for (uint n = 0; n < num_rows; ++n) { [[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint a_block_idx = (ibi + col)/QUANT_K + a_offset; const uint a_block_idx = (ibi + col)/QUANT_K_Q8_1 + a_offset;
ibi += p.ncols; ibi += p.ncols;
int32_t q_sum = 0; temp[j][n] += mmvq_dot_product(a_block_idx, b_qs_idx);
#if QUANT_R == 2
const i32vec2 data_a_qs = repack(a_block_idx, b_qs_idx);
q_sum += dotPacked4x8EXT(data_a_qs.x,
cache_b_qs[0]);
q_sum += dotPacked4x8EXT(data_a_qs.y,
cache_b_qs[1]);
#else
int32_t data_a_qs = repack(a_block_idx, b_qs_idx * 2);
q_sum += dotPacked4x8EXT(data_a_qs,
cache_b_qs[0]);
data_a_qs = repack(a_block_idx, b_qs_idx * 2 + 1);
q_sum += dotPacked4x8EXT(data_a_qs,
cache_b_qs[1]);
#endif
#if QUANT_AUXF == 1
temp[j][n] += mul_q8_1(q_sum, get_d(a_block_idx), cache_b_ds, 4);
#else
temp[j][n] += mul_q8_1(q_sum, get_dm(a_block_idx), cache_b_ds, 4);
#endif
} }
} }
} }
@ -72,7 +68,7 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint tid = gl_LocalInvocationID.x; const uint tid = gl_LocalInvocationID.x;
get_offsets(a_offset, b_offset, d_offset); get_offsets(a_offset, b_offset, d_offset);
a_offset /= QUANT_K; a_offset /= QUANT_K_Q8_1;
b_offset /= QUANT_K_Q8_1; b_offset /= QUANT_K_Q8_1;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
@ -102,14 +98,6 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
unroll_count = 2; unroll_count = 2;
unrolled_iters = num_iters & ~(unroll_count - 1); unrolled_iters = num_iters & ~(unroll_count - 1);
#if K_PER_ITER == 2
if ((p.ncols & 1) != 0 &&
unrolled_iters == num_iters &&
unrolled_iters > 0) {
unrolled_iters -= unroll_count;
}
#endif
while (i < unrolled_iters) { while (i < unrolled_iters) {
// Manually partially unroll the loop // Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) { [[unroll]] for (uint k = 0; k < unroll_count; ++k) {
@ -128,6 +116,10 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
void main() { void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
// do NUM_ROWS at a time, unless there aren't enough remaining rows // do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) { if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS); compute_outputs(first_row, NUM_ROWS);

View File

@ -0,0 +1,379 @@
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
#include "types.glsl"
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
FLOAT_TYPE get_dm(uint ib) {
return FLOAT_TYPE(data_a[ib].d);
}
#endif
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1)
FLOAT_TYPE_VEC2 get_dm(uint ib) {
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm);
}
#endif
#if defined(DATA_A_MXFP4)
FLOAT_TYPE get_dm(uint ib) {
return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e));
}
#endif
#if defined(DATA_A_Q2_K)
FLOAT_TYPE_VEC2 get_dm(uint ib) {
const uint ib_k = ib / 8;
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm);
}
#endif
// Each iqs value maps to a 32-bit integer
#if defined(DATA_A_Q4_0)
// 2-byte loads for Q4_0 blocks (18 bytes)
i32vec2 repack(uint ib, uint iqs) {
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
}
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return FLOAT_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
}
#endif
#if defined(DATA_A_Q4_1)
// 4-byte loads for Q4_1 blocks (20 bytes)
i32vec2 repack(uint ib, uint iqs) {
const uint32_t vui = data_a_packed32[ib].qs[iqs];
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
}
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return FLOAT_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
}
#endif
#if defined(DATA_A_Q5_0)
// 2-byte loads for Q5_0 blocks (22 bytes)
i32vec2 repack(uint ib, uint iqs) {
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs));
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
return i32vec2(v0, v1);
}
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return FLOAT_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
}
#endif
#if defined(DATA_A_Q5_1)
// 4-byte loads for Q5_1 blocks (24 bytes)
i32vec2 repack(uint ib, uint iqs) {
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
return i32vec2(v0, v1);
}
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return FLOAT_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
}
#endif
#if defined(DATA_A_Q8_0)
// 2-byte loads for Q8_0 blocks (34 bytes)
int32_t repack(uint ib, uint iqs) {
return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]));
}
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return FLOAT_TYPE(float(q_sum) * da * dsb.x);
}
#endif
#if defined(DATA_A_MXFP4)
// 1-byte loads for mxfp4 blocks (17 bytes)
i32vec2 repack(uint ib, uint iqs) {
const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
data_a[ib].qs[iqs * 4 + 1],
data_a[ib].qs[iqs * 4 + 2],
data_a[ib].qs[iqs * 4 + 3]));
const u8vec4 i_a0 = unpack8( qs & 0x0F0F0F0F);
const u8vec4 i_a1 = unpack8((qs >> 4) & 0x0F0F0F0F);
return i32vec2(pack32(i8vec4(kvalues_mxfp4[i_a0.x], kvalues_mxfp4[i_a0.y], kvalues_mxfp4[i_a0.z], kvalues_mxfp4[i_a0.w])),
pack32(i8vec4(kvalues_mxfp4[i_a1.x], kvalues_mxfp4[i_a1.y], kvalues_mxfp4[i_a1.z], kvalues_mxfp4[i_a1.w])));
}
FLOAT_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return FLOAT_TYPE(da * dsb.x * float(q_sum) * 0.5);
}
#endif
#if defined(DATA_A_QUANT_LEGACY) || defined(DATA_A_MXFP4)
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
int32_t q_sum = 0;
#if QUANT_R == 2
const i32vec2 data_a_qs = repack(ib_a, iqs);
q_sum += dotPacked4x8EXT(data_a_qs.x,
cache_b_qs[0]);
q_sum += dotPacked4x8EXT(data_a_qs.y,
cache_b_qs[1]);
#else
int32_t data_a_qs = repack(ib_a, iqs * 2);
q_sum += dotPacked4x8EXT(data_a_qs,
cache_b_qs[0]);
data_a_qs = repack(ib_a, iqs * 2 + 1);
q_sum += dotPacked4x8EXT(data_a_qs,
cache_b_qs[1]);
#endif
// 2 quants per call => divide sums by 8/2 = 4
return mul_q8_1(q_sum, get_dm(ib_a), cache_b_ds, 4);
}
#endif
#if defined(DATA_A_Q2_K)
// 4-byte loads for Q2_K blocks (84 bytes)
i32vec4 repack4(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
return i32vec4((data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x03030303,
(data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x03030303,
(data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x03030303,
(data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x03030303);
}
uint8_t get_scale(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
return data_a[ib_k].scales[iqs_k / 4];
}
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
int32_t sum_d = 0;
int32_t sum_m = 0;
const i32vec4 qs_a = repack4(ib_a, iqs * 4);
const uint8_t scale = get_scale(ib_a, iqs * 4);
const vec2 dm = vec2(get_dm(ib_a));
const int32_t scale_m = int32_t(scale >> 4) * 0x01010101; // Duplicate 8-bit value across 32-bits.
sum_d += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]) * (scale & 0xF);
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[0]);
sum_d += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]) * (scale & 0xF);
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[1]);
sum_d += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]) * (scale & 0xF);
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[2]);
sum_d += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]) * (scale & 0xF);
sum_m += dotPacked4x8EXT(scale_m, cache_b_qs[3]);
return FLOAT_TYPE(float(cache_b_ds.x) * (float(dm.x) * float(sum_d) - float(dm.y) * float(sum_m)));
}
#endif
#if defined(DATA_A_Q3_K)
// 2-byte loads for Q3_K blocks (110 bytes)
i32vec4 repack4(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
const uint hm_shift = iqs_k / 8;
// bitwise OR to add 4 if hmask is set, subtract later
const i8vec2 vals00 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 ] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 ] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals01 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 1] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 1] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals10 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 2] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 2] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals11 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 3] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 3] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals20 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 4] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 4] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals21 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 5] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 5] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals30 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 6] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 6] >> hm_shift) & uint16_t(0x0101)) << 2));
const i8vec2 vals31 = unpack8(int16_t((data_a_packed16[ib_k].qs[qs_idx * 2 + 7] >> qs_shift) & uint16_t(0x0303))) |
unpack8(int16_t(((data_a_packed16[ib_k].hmask[iqs * 2 + 7] >> hm_shift) & uint16_t(0x0101)) << 2));
return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y) - int8_t(4)),
pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y) - int8_t(4)),
pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y) - int8_t(4)),
pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y) - int8_t(4)));
}
float get_d_scale(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint is = iqs_k / 4;
const int8_t scale = int8_t(((data_a[ib_k].scales[is % 8 ] >> (4 * (is / 8))) & 0x0F0F) |
(((data_a[ib_k].scales[8 + (is % 4)] >> (2 * (is / 4))) & 0x0303) << 4));
return float(data_a[ib_k].d) * float(scale - 32);
}
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
int32_t q_sum = 0;
const i32vec4 qs_a = repack4(ib_a, iqs * 4);
const float d_scale = get_d_scale(ib_a, iqs * 4);
q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]);
q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]);
q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]);
q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]);
return FLOAT_TYPE(float(cache_b_ds.x) * d_scale * float(q_sum));
}
#endif
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
i32vec4 repack4(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint qs_idx = (iqs_k / 16) * 8 + (iqs_k % 8);
const uint qs_shift = ((iqs_k % 16) / 8) * 4;
#if defined(DATA_A_Q4_K)
const uint32_t vals0 = (data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F;
const uint32_t vals1 = (data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F;
const uint32_t vals2 = (data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x0F0F0F0F;
const uint32_t vals3 = (data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x0F0F0F0F;
return i32vec4(vals0, vals1, vals2, vals3);
#else // defined(DATA_A_Q5_K)
const uint qh_idx = iqs;
const uint qh_shift = iqs_k / 8;
return i32vec4(((data_a_packed32[ib_k].qs[qs_idx ] >> qs_shift) & 0x0F0F0F0F) |
(((data_a_packed32[ib_k].qh[qh_idx ] >> qh_shift) & 0x01010101) << 4),
((data_a_packed32[ib_k].qs[qs_idx + 1] >> qs_shift) & 0x0F0F0F0F) |
(((data_a_packed32[ib_k].qh[qh_idx + 1] >> qh_shift) & 0x01010101) << 4),
((data_a_packed32[ib_k].qs[qs_idx + 2] >> qs_shift) & 0x0F0F0F0F) |
(((data_a_packed32[ib_k].qh[qh_idx + 2] >> qh_shift) & 0x01010101) << 4),
((data_a_packed32[ib_k].qs[qs_idx + 3] >> qs_shift) & 0x0F0F0F0F) |
(((data_a_packed32[ib_k].qh[qh_idx + 3] >> qh_shift) & 0x01010101) << 4));
#endif
}
vec2 get_dm_scale(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint is = iqs_k / 8;
u8vec2 scale_dm;
if (is < 4) {
scale_dm = u8vec2(data_a[ib_k].scales[is] & 0x3F, data_a[ib_k].scales[is + 4] & 0x3F);
} else {
scale_dm = u8vec2((data_a[ib_k].scales[is+4] & 0xF) | ((data_a[ib_k].scales[is-4] & 0xC0) >> 2),
(data_a[ib_k].scales[is+4] >> 4) | ((data_a[ib_k].scales[is ] & 0xC0) >> 2));
}
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm) * FLOAT_TYPE_VEC2(scale_dm);
}
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
int32_t q_sum = 0;
const i32vec4 qs_a = repack4(ib_a, iqs * 4);
const vec2 dm_scale = get_dm_scale(ib_a, iqs * 4);
q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]);
q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]);
q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]);
q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]);
return FLOAT_TYPE(float(cache_b_ds.x) * float(dm_scale.x) * float(q_sum) - float(dm_scale.y) * float(cache_b_ds.y / 2));
}
#endif
#if defined(DATA_A_Q6_K)
// 2-byte loads for Q6_K blocks (210 bytes)
i32vec4 repack4(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint ql_idx = (iqs_k / 32) * 16 + iqs_k % 16;
const uint ql_shift = ((iqs_k % 32) / 16) * 4;
const uint qh_idx = (iqs_k / 32) * 8 + iqs;
const uint qh_shift = ((iqs_k % 32) / 8) * 2;
const i8vec2 vals00 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 ] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 ] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
const i8vec2 vals01 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 1] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 1] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
const i8vec2 vals10 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 2] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 2] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
const i8vec2 vals11 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 3] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 3] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
const i8vec2 vals20 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 4] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 4] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
const i8vec2 vals21 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 5] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 5] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
const i8vec2 vals30 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 6] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 6] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
const i8vec2 vals31 = (unpack8(int16_t((data_a_packed16[ib_k].ql[ql_idx * 2 + 7] >> ql_shift) & uint16_t(0x0F0F))) |
unpack8(int16_t(((data_a_packed16[ib_k].qh[qh_idx * 2 + 7] >> qh_shift) & uint16_t(0x0303)) << 4))) - int8_t(32);
return i32vec4(pack32(i8vec4(vals00.x, vals00.y, vals01.x, vals01.y)),
pack32(i8vec4(vals10.x, vals10.y, vals11.x, vals11.y)),
pack32(i8vec4(vals20.x, vals20.y, vals21.x, vals21.y)),
pack32(i8vec4(vals30.x, vals30.y, vals31.x, vals31.y)));
}
float get_d_scale(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
return float(data_a[ib_k].d) * float(data_a[ib_k].scales[iqs_k / 4]);
}
FLOAT_TYPE mmvq_dot_product(const uint ib_a, const uint iqs) {
int32_t q_sum = 0;
const i32vec4 qs_a = repack4(ib_a, iqs * 4);
const float d_scale = get_d_scale(ib_a, iqs * 4);
q_sum += dotPacked4x8EXT(qs_a.x, cache_b_qs[0]);
q_sum += dotPacked4x8EXT(qs_a.y, cache_b_qs[1]);
q_sum += dotPacked4x8EXT(qs_a.z, cache_b_qs[2]);
q_sum += dotPacked4x8EXT(qs_a.w, cache_b_qs[3]);
return FLOAT_TYPE(float(cache_b_ds.x) * float(d_scale) * float(q_sum));
}
#endif

View File

@ -78,8 +78,6 @@ layout (constant_id = 10) const uint WARP = 32;
#define BK 32 #define BK 32
#define MMQ_SHMEM
#include "mul_mmq_shmem_types.glsl" #include "mul_mmq_shmem_types.glsl"
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID

View File

@ -9,31 +9,6 @@
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1) #if defined(DATA_A_Q4_0) || defined(DATA_A_Q4_1)
// 2-byte loads for Q4_0 blocks (18 bytes) // 2-byte loads for Q4_0 blocks (18 bytes)
// 4-byte loads for Q4_1 blocks (20 bytes) // 4-byte loads for Q4_1 blocks (20 bytes)
i32vec2 repack(uint ib, uint iqs) {
#ifdef DATA_A_Q4_0
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
#else // DATA_A_Q4_1
const uint32_t vui = data_a_packed32[ib].qs[iqs];
return i32vec2( vui & 0x0F0F0F0F,
(vui >> 4) & 0x0F0F0F0F);
#endif
}
#ifdef DATA_A_Q4_0
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y));
}
#else // DATA_A_Q4_1
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
}
#endif
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
#ifdef DATA_A_Q4_0 #ifdef DATA_A_Q4_0
buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2], buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
@ -73,42 +48,17 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
q_sum += dotPacked4x8EXT(qs_a.y, qs_b1); q_sum += dotPacked4x8EXT(qs_a.y, qs_b1);
} }
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); #ifdef DATA_A_Q4_0
return ACC_TYPE(float(cache_a[ib_a].dm) * (float(q_sum) * float(cache_b.ds.x) - 8.0 * float(cache_b.ds.y)));
#else // DATA_A_Q4_1
return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm.x) * float(cache_b.ds.x) + float(cache_a[ib_a].dm.y) * float(cache_b.ds.y));
#endif
} }
#endif // MMQ_SHMEM #endif
#elif defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1) #if defined(DATA_A_Q5_0) || defined(DATA_A_Q5_1)
// 2-byte loads for Q5_0 blocks (22 bytes) // 2-byte loads for Q5_0 blocks (22 bytes)
// 4-byte loads for Q5_1 blocks (24 bytes) // 4-byte loads for Q5_1 blocks (24 bytes)
i32vec2 repack(uint ib, uint iqs) {
const u16vec2 quants = u16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]);
const uint32_t vui = pack32(quants);
#ifdef DATA_A_Q5_0
const int32_t qh = int32_t((uint32_t(data_a_packed16[ib].qh[1]) << 16 | data_a_packed16[ib].qh[0]) >> (4 * iqs));
#else // DATA_A_Q5_1
const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs));
#endif
const int32_t v0 = int32_t(vui & 0x0F0F0F0F)
| ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28)
const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F)
| (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28)
return i32vec2(v0, v1);
}
#ifdef DATA_A_Q5_0
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y));
}
#else // DATA_A_Q5_1
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor);
}
#endif
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
#ifdef DATA_A_Q5_0 #ifdef DATA_A_Q5_0
buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2], buf_a[buf_ib].qs[iqs] = pack32(u16vec2(data_a_packed16[ib].qs[iqs * 2],
@ -154,23 +104,16 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
q_sum += dotPacked4x8EXT(qs_a1, qs_b1); q_sum += dotPacked4x8EXT(qs_a1, qs_b1);
} }
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); #ifdef DATA_A_Q5_0
return ACC_TYPE(float(cache_a[ib_a].dm) * (float(q_sum) * float(cache_b.ds.x) - 16.0 * float(cache_b.ds.y)));
#else // DATA_A_Q5_1
return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm.x) * float(cache_b.ds.x) + float(cache_a[ib_a].dm.y) * float(cache_b.ds.y));
#endif
} }
#endif // MMQ_SHMEM
#endif #endif
#if defined(DATA_A_Q8_0) #if defined(DATA_A_Q8_0)
// 2-byte loads for Q8_0 blocks (34 bytes) // 2-byte loads for Q8_0 blocks (34 bytes)
int32_t repack(uint ib, uint iqs) {
return pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2 ],
data_a_packed16[ib].qs[iqs * 2 + 1]));
}
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(float(q_sum) * da * dsb.x);
}
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
buf_a[buf_ib].qs[iqs] = pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2], buf_a[buf_ib].qs[iqs] = pack32(i16vec2(data_a_packed16[ib].qs[iqs * 2],
data_a_packed16[ib].qs[iqs * 2 + 1])); data_a_packed16[ib].qs[iqs * 2 + 1]));
@ -197,28 +140,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
q_sum += dotPacked4x8EXT(qs_a, qs_b); q_sum += dotPacked4x8EXT(qs_a, qs_b);
} }
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); return ACC_TYPE(float(q_sum) * float(cache_a[ib_a].dm) * float(cache_b.ds.x));
} }
#endif // MMQ_SHMEM
#endif #endif
#if defined(DATA_A_MXFP4) #if defined(DATA_A_MXFP4)
// 1-byte loads for mxfp4 blocks (17 bytes) // 1-byte loads for mxfp4 blocks (17 bytes)
i32vec2 repack(uint ib, uint iqs) {
const uint32_t quants = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
data_a[ib].qs[iqs * 4 + 1],
data_a[ib].qs[iqs * 4 + 2],
data_a[ib].qs[iqs * 4 + 3]));
return i32vec2( quants & 0x0F0F0F0F,
(quants >> 4) & 0x0F0F0F0F);
}
ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(da * dsb.x * float(q_sum));
}
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ], const uint32_t qs = pack32(u8vec4(data_a[ib].qs[iqs * 4 ],
data_a[ib].qs[iqs * 4 + 1], data_a[ib].qs[iqs * 4 + 1],
@ -252,37 +179,14 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
} }
return mul_q8_1(q_sum, cache_a[ib_a].d, cache_b.ds, 1); return ACC_TYPE(float(cache_a[ib_a].d) * float(cache_b.ds.x) * float(q_sum));
} }
#endif // MMQ_SHMEM
#endif #endif
// For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide // For k-quants, ib and iqs still assume 32-wide blocks, but k-quants are 256-wide
// iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants // iqs still refers to a 32-bit integer, meaning 0..7 for 32-wide quants
#if defined(DATA_A_Q2_K) #if defined(DATA_A_Q2_K)
// 4-byte loads for Q2_K blocks (84 bytes) // 4-byte loads for Q2_K blocks (84 bytes)
int32_t repack(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
const uint qs_idx = (iqs_k / 32) * 8 + (iqs_k % 8);
const uint qs_shift = ((iqs_k % 32) / 8) * 2;
return int32_t((data_a_packed32[ib_k].qs[qs_idx] >> qs_shift) & 0x03030303);
}
uint8_t get_scale(uint ib, uint iqs) {
const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs;
return data_a[ib_k].scales[iqs_k / 4];
}
ACC_TYPE mul_q8_1(const int32_t sum_d, const int32_t sum_m, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(dsb.x * (dma.x * float(sum_d) - dma.y * float(sum_m)));
}
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_k = ib / 8; const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ; const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
@ -326,14 +230,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]); sum_m += dotPacked4x8EXT(scale_m, cache_b.qs[iqs]);
} }
return mul_q8_1(sum_d, sum_m, cache_a[ib_a].dm, cache_b.ds, 1); return ACC_TYPE(float(cache_b.ds.x) * (float(cache_a[ib_a].dm.x) * float(sum_d) - float(cache_a[ib_a].dm.y) * float(sum_m)));
} }
#endif // MMQ_SHMEM
#endif #endif
#if defined(DATA_A_Q3_K) #if defined(DATA_A_Q3_K)
// 2-byte loads for Q3_K blocks (110 bytes) // 2-byte loads for Q3_K blocks (110 bytes)
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_k = ib / 8; const uint ib_k = ib / 8;
const uint hm_idx = iqs * QUANT_R_MMQ; const uint hm_idx = iqs * QUANT_R_MMQ;
@ -394,18 +296,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
} }
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum); result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
return ACC_TYPE(cache_b.ds.x * result); return ACC_TYPE(float(cache_b.ds.x) * result);
} }
#endif // MMQ_SHMEM
#endif #endif
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K) #if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
// 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes) // 4-byte loads for Q4_K blocks (144 bytes) and Q5_K blocks (176 bytes)
ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) {
return ACC_TYPE(dsb.x * dma.x * float(q_sum) - dma.y * dsb.y);
}
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_k = ib / 8; const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ; const uint iqs_k = (ib % 8) * 8 + iqs * QUANT_R_MMQ;
@ -427,7 +323,6 @@ void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
(((data_a_packed32[ib_k].qh[qh_idx] >> qh_shift) & 0x01010101) << 4)); (((data_a_packed32[ib_k].qh[qh_idx] >> qh_shift) & 0x01010101) << 4));
#endif #endif
if (iqs == 0) { if (iqs == 0) {
// Scale index // Scale index
const uint is = iqs_k / 8; const uint is = iqs_k / 8;
@ -464,49 +359,12 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]); q_sum += dotPacked4x8EXT(qs_a, cache_b.qs[iqs]);
} }
return mul_q8_1(q_sum, cache_a[ib_a].dm, cache_b.ds, 1); return ACC_TYPE(float(cache_b.ds.x) * float(cache_a[ib_a].dm.x) * float(q_sum) - float(cache_a[ib_a].dm.y) * float(cache_b.ds.y));
}
#endif // MMQ_SHMEM
#endif
#ifdef MMQ_SHMEM
void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bool is_in_bounds) {
if (is_in_bounds) {
const uint ib_outer = ib / 4;
const uint ib_inner = ib % 4;
if (iqs == 0) {
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
}
const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
buf_b[buf_ib].qs[iqs * 4 ] = values.x;
buf_b[buf_ib].qs[iqs * 4 + 1] = values.y;
buf_b[buf_ib].qs[iqs * 4 + 2] = values.z;
buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
} else {
if (iqs == 0) {
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(0.0f);
}
buf_b[buf_ib].qs[iqs * 4 ] = 0;
buf_b[buf_ib].qs[iqs * 4 + 1] = 0;
buf_b[buf_ib].qs[iqs * 4 + 2] = 0;
buf_b[buf_ib].qs[iqs * 4 + 3] = 0;
}
}
void block_b_to_registers(const uint ib) {
cache_b.ds = buf_b[ib].ds;
[[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {
cache_b.qs[iqs] = buf_b[ib].qs[iqs];
}
} }
#endif #endif
#if defined(DATA_A_Q6_K) #if defined(DATA_A_Q6_K)
// 2-byte loads for Q6_K blocks (210 bytes) // 2-byte loads for Q6_K blocks (210 bytes)
#ifdef MMQ_SHMEM
void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) { void block_a_to_shmem(const uint buf_ib, const uint ib, const uint iqs) {
const uint ib_k = ib / 8; const uint ib_k = ib / 8;
const uint iqs_k = (ib % 8) * 8 + iqs; const uint iqs_k = (ib % 8) * 8 + iqs;
@ -558,32 +416,39 @@ ACC_TYPE mmq_dot_product(const uint ib_a) {
} }
result += float(cache_a[ib_a].d_scales[1]) * float(q_sum); result += float(cache_a[ib_a].d_scales[1]) * float(q_sum);
return ACC_TYPE(cache_b.ds.x * result); return ACC_TYPE(float(cache_b.ds.x) * result);
}
#endif // MMQ_SHMEM
#endif
#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL)
FLOAT_TYPE get_d(uint ib) {
return FLOAT_TYPE(data_a[ib].d);
} }
#endif #endif
#if defined(DATA_A_MXFP4) void block_b_to_shmem(const uint buf_ib, const uint ib, const uint iqs, const bool is_in_bounds) {
FLOAT_TYPE get_d(uint ib) { if (is_in_bounds) {
return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e)); const uint ib_outer = ib / 4;
} const uint ib_inner = ib % 4;
#endif
#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) if (iqs == 0) {
FLOAT_TYPE_VEC2 get_dm(uint ib) { buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]);
return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); }
}
#endif
#if defined(DATA_A_Q2_K) const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs];
FLOAT_TYPE_VEC2 get_dm(uint ib) { buf_b[buf_ib].qs[iqs * 4 ] = values.x;
const uint ib_k = ib / 8; buf_b[buf_ib].qs[iqs * 4 + 1] = values.y;
return FLOAT_TYPE_VEC2(data_a_packed32[ib_k].dm); buf_b[buf_ib].qs[iqs * 4 + 2] = values.z;
buf_b[buf_ib].qs[iqs * 4 + 3] = values.w;
} else {
if (iqs == 0) {
buf_b[buf_ib].ds = FLOAT_TYPE_VEC2(0.0f);
}
buf_b[buf_ib].qs[iqs * 4 ] = 0;
buf_b[buf_ib].qs[iqs * 4 + 1] = 0;
buf_b[buf_ib].qs[iqs * 4 + 2] = 0;
buf_b[buf_ib].qs[iqs * 4 + 3] = 0;
}
}
void block_b_to_registers(const uint ib) {
cache_b.ds = buf_b[ib].ds;
[[unroll]] for (uint iqs = 0; iqs < BK / 4; iqs++) {
cache_b.qs[iqs] = buf_b[ib].qs[iqs];
}
} }
#endif

View File

@ -679,14 +679,20 @@ void process_shaders() {
string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
// mul mat vec with integer dot product // mul mat vec with integer dot product
#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) #if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT)
if (is_legacy_quant(tname)) { if (is_legacy_quant(tname) || tname == "mxfp4" || is_k_quant(tname)) {
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}})); string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}));
string_to_spv("mul_mat_vec_id_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}}));
} }
#endif #endif
@ -1100,7 +1106,7 @@ void write_output_files() {
for (const std::string& btype : btypes) { for (const std::string& btype : btypes) {
for (const auto& tname : type_names) { for (const auto& tname : type_names) {
if (btype == "q8_1" && !is_legacy_quant(tname)) { if (btype == "q8_1" && !is_legacy_quant(tname) && tname != "mxfp4" && !is_k_quant(tname)) {
continue; continue;
} }
hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n"; hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n";
@ -1109,6 +1115,16 @@ void write_output_files() {
src << "const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n"; src << "const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n";
src << "const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n"; src << "const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n";
} }
if (btype == "f16") {
continue;
}
hdr << "extern const void * arr_dmmv_id_" << tname << "_" << btype << "_f32_data[3];\n";
hdr << "extern const uint64_t arr_dmmv_id_" << tname << "_" << btype << "_f32_len[3];\n";
if (basename(input_filepath) == "mul_mat_vec.comp") {
src << "const void * arr_dmmv_id_" << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_id_" << tname << "_" << btype << "_f32_data, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n";
src << "const uint64_t arr_dmmv_id_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_id_" << tname << "_" << btype << "_f32_len, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_id_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n";
}
} }
} }

View File

@ -371,10 +371,13 @@ class GGUFWriter:
def add_tensor( def add_tensor(
self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None, self, name: str, tensor: np.ndarray[Any, Any], raw_shape: Sequence[int] | None = None,
raw_dtype: GGMLQuantizationType | None = None, raw_dtype: GGMLQuantizationType | None = None, tensor_endianess: GGUFEndian | None = None
) -> None: ) -> None:
if (self.endianess == GGUFEndian.BIG and sys.byteorder != 'big') or \ # if tensor endianness is not passed, assume it's native to system
(self.endianess == GGUFEndian.LITTLE and sys.byteorder != 'little'): if tensor_endianess is None:
tensor_endianess = GGUFEndian.BIG if sys.byteorder == 'big' else GGUFEndian.LITTLE
if tensor_endianess != self.endianess:
# Don't byteswap inplace since lazy copies cannot handle it # Don't byteswap inplace since lazy copies cannot handle it
tensor = tensor.byteswap(inplace=False) tensor = tensor.byteswap(inplace=False)
if self.use_temp_file and self.temp_file is None: if self.use_temp_file and self.temp_file is None:
@ -397,13 +400,16 @@ class GGUFWriter:
if pad != 0: if pad != 0:
fp.write(bytes([0] * pad)) fp.write(bytes([0] * pad))
def write_tensor_data(self, tensor: np.ndarray[Any, Any]) -> None: def write_tensor_data(self, tensor: np.ndarray[Any, Any], tensor_endianess: GGUFEndian | None = None) -> None:
if self.state is not WriterState.TI_DATA and self.state is not WriterState.WEIGHTS: if self.state is not WriterState.TI_DATA and self.state is not WriterState.WEIGHTS:
raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}') raise ValueError(f'Expected output file to contain tensor info or weights, got {self.state}')
assert self.fout is not None assert self.fout is not None
if (self.endianess == GGUFEndian.BIG and sys.byteorder != 'big') or \ # if tensor endianness is not passed, assume it's native to system
(self.endianess == GGUFEndian.LITTLE and sys.byteorder != 'little'): if tensor_endianess is None:
tensor_endianess = GGUFEndian.BIG if sys.byteorder == 'big' else GGUFEndian.LITTLE
if tensor_endianess != self.endianess:
# Don't byteswap inplace since lazy copies cannot handle it # Don't byteswap inplace since lazy copies cannot handle it
tensor = tensor.byteswap(inplace=False) tensor = tensor.byteswap(inplace=False)

View File

@ -1552,7 +1552,7 @@ class GGUFEditorWindow(QMainWindow):
# Add tensors (including data) # Add tensors (including data)
for tensor in self.reader.tensors: for tensor in self.reader.tensors:
writer.add_tensor(tensor.name, tensor.data, raw_shape=tensor.data.shape, raw_dtype=tensor.tensor_type) writer.add_tensor(tensor.name, tensor.data, raw_shape=tensor.data.shape, raw_dtype=tensor.tensor_type, tensor_endianess=self.reader.endianess)
# Write header and metadata # Write header and metadata
writer.open_output_file(Path(file_path)) writer.open_output_file(Path(file_path))

View File

@ -94,7 +94,7 @@ def copy_with_new_metadata(reader: gguf.GGUFReader, writer: gguf.GGUFWriter, new
writer.write_ti_data_to_file() writer.write_ti_data_to_file()
for tensor in reader.tensors: for tensor in reader.tensors:
writer.write_tensor_data(tensor.data) writer.write_tensor_data(tensor.data, tensor_endianess=reader.endianess)
bar.update(tensor.n_bytes) bar.update(tensor.n_bytes)
writer.close() writer.close()

View File

@ -300,7 +300,7 @@ llama_context::llama_context(
cross.v_embd.clear(); cross.v_embd.clear();
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max; const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
// avoid reserving graphs with zero outputs - assume one output per sequence // avoid reserving graphs with zero outputs - assume one output per sequence
@ -543,7 +543,7 @@ bool llama_context::memory_update(bool optimize) {
throw std::runtime_error("failed to initialize memory context"); throw std::runtime_error("failed to initialize memory context");
} }
const uint32_t n_seqs = cparams.kv_unified ? 1 : cparams.n_seq_max; const uint32_t n_seqs = cparams.n_seq_max;
const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch); const uint32_t n_tokens = std::min(cparams.n_ctx, cparams.n_ubatch);
auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get()); auto * gf = graph_reserve(n_tokens, n_seqs, n_tokens, mctx.get());

View File

@ -196,7 +196,7 @@ if (NOT WIN32)
llama_build_and_test(test-arg-parser.cpp) llama_build_and_test(test-arg-parser.cpp)
endif() endif()
if (NOT LLAMA_SANITIZE_ADDRESS) if (NOT LLAMA_SANITIZE_ADDRESS AND NOT GGML_SCHED_NO_REALLOC)
# TODO: repair known memory leaks # TODO: repair known memory leaks
llama_build_and_test(test-opt.cpp) llama_build_and_test(test-opt.cpp)
endif() endif()

View File

@ -1446,14 +1446,14 @@ struct test_case {
const uint64_t target_flops_cpu = 8ULL * GFLOP; const uint64_t target_flops_cpu = 8ULL * GFLOP;
const uint64_t target_flops_gpu = 100ULL * GFLOP; const uint64_t target_flops_gpu = 100ULL * GFLOP;
uint64_t target_flops = is_cpu ? target_flops_cpu : target_flops_gpu; uint64_t target_flops = is_cpu ? target_flops_cpu : target_flops_gpu;
n_runs = std::min<int>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_flops / op_flops(out)) + 1; n_runs = (int)std::min<int64_t>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_flops / op_flops(out)) + 1;
} else { } else {
// based on memory size // based on memory size
const size_t GB = 1ULL << 30; const size_t GB = 1ULL << 30;
const size_t target_size_cpu = 8 * GB; const size_t target_size_cpu = 8 * GB;
const size_t target_size_gpu = 32 * GB; const size_t target_size_gpu = 32 * GB;
size_t target_size = is_cpu ? target_size_cpu : target_size_gpu; size_t target_size = is_cpu ? target_size_cpu : target_size_gpu;
n_runs = std::min<int>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1; n_runs = (int)std::min<int64_t>(ggml_graph_size(gf) - ggml_graph_n_nodes(gf), target_size / op_size(out)) + 1;
} }
// duplicate the op // duplicate the op
@ -8043,7 +8043,9 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_perf() {
} }
test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1})); test_cases.emplace_back(new test_argsort(GGML_TYPE_F32, {65000, 16, 1, 1}));
for (auto k : {1, 10, 40}) {
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {2, 1, 1, 1}, 1));
for (auto k : {1, 10, 40, 400}) {
for (auto nrows : {1, 16}) { for (auto nrows : {1, 16}) {
for (auto cols : {k, 1000, 65000, 200000}) { for (auto cols : {k, 1000, 65000, 200000}) {
test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {cols, nrows, 1, 1}, k)); test_cases.emplace_back(new test_top_k(GGML_TYPE_F32, {cols, nrows, 1, 1}, k));

View File

@ -1339,6 +1339,32 @@ static void test_all(const std::string & lang, std::function<void(const TestCase
space ::= | " " | "\n"{1,2} [ \t]{0,20} space ::= | " " | "\n"{1,2} [ \t]{0,20}
)""" )"""
}); });
test({
SUCCESS,
"literal string with escapes",
R"""({
"properties": {
"code": {
"const": " \r \n \" \\ ",
"description": "Generated code",
"title": "Code",
"type": "string"
}
},
"required": [
"code"
],
"title": "DecoderResponse",
"type": "object"
})""",
R"""(
code ::= "\" \\r \\n \\\" \\\\ \"" space
code-kv ::= "\"code\"" space ":" space code
root ::= "{" space code-kv "}" space
space ::= | " " | "\n"{1,2} [ \t]{0,20}
)"""
});
} }
int main() { int main() {

View File

@ -23,6 +23,8 @@ set(TARGET_SRCS
server-queue.h server-queue.h
server-common.cpp server-common.cpp
server-common.h server-common.h
server-context.cpp
server-context.h
) )
set(PUBLIC_ASSETS set(PUBLIC_ASSETS
index.html.gz index.html.gz

View File

@ -257,9 +257,9 @@ const STRING_FORMAT_RULES = {
const RESERVED_NAMES = {'root': true, ...PRIMITIVE_RULES, ...STRING_FORMAT_RULES}; const RESERVED_NAMES = {'root': true, ...PRIMITIVE_RULES, ...STRING_FORMAT_RULES};
const INVALID_RULE_CHARS_RE = /[^\dA-Za-z-]+/g; const INVALID_RULE_CHARS_RE = /[^\dA-Za-z-]+/g;
const GRAMMAR_LITERAL_ESCAPE_RE = /[\n\r"]/g; const GRAMMAR_LITERAL_ESCAPE_RE = /[\n\r"\\]/g;
const GRAMMAR_RANGE_LITERAL_ESCAPE_RE = /[\n\r"\]\-\\]/g; const GRAMMAR_RANGE_LITERAL_ESCAPE_RE = /[\n\r"\]\-\\]/g;
const GRAMMAR_LITERAL_ESCAPES = { '\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]' }; const GRAMMAR_LITERAL_ESCAPES = { '\r': '\\r', '\n': '\\n', '"': '\\"', '-': '\\-', ']': '\\]', '\\': '\\\\' };
const NON_LITERAL_SET = new Set('|.()[]{}*+?'); const NON_LITERAL_SET = new Set('|.()[]{}*+?');
const ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = new Set('^$.[]()|{}*+?'); const ESCAPED_IN_REGEXPS_BUT_NOT_IN_LITERALS = new Set('^$.[]()|{}*+?');

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,83 @@
#include "server-http.h"
#include "server-task.h"
#include "server-queue.h"
#include <nlohmann/json_fwd.hpp>
#include <cstddef>
#include <memory>
struct server_context_impl; // private implementation
struct server_context {
std::unique_ptr<server_context_impl> impl;
server_context();
~server_context();
// initialize slots and server-related data
void init();
// load the model and initialize llama_context
// returns true on success
bool load_model(const common_params & params);
// this function will block main thread until termination
void start_loop();
// terminate main loop (will unblock start_loop)
void terminate();
// get the underlaying llama_context
llama_context * get_llama_context() const;
// get the underlaying queue_tasks and queue_results
// used by CLI application
std::pair<server_queue &, server_response &> get_queues();
};
// forward declarations
struct server_res_generator;
struct server_routes {
server_routes(const common_params & params, server_context & ctx_server, std::function<bool()> is_ready = []() { return true; })
: params(params), ctx_server(*ctx_server.impl), is_ready(is_ready) {
init_routes();
}
void init_routes();
// handlers using lambda function, so that they can capture `this` without `std::bind`
server_http_context::handler_t get_health;
server_http_context::handler_t get_metrics;
server_http_context::handler_t get_slots;
server_http_context::handler_t post_slots;
server_http_context::handler_t get_props;
server_http_context::handler_t post_props;
server_http_context::handler_t get_api_show;
server_http_context::handler_t post_infill;
server_http_context::handler_t post_completions;
server_http_context::handler_t post_completions_oai;
server_http_context::handler_t post_chat_completions;
server_http_context::handler_t post_anthropic_messages;
server_http_context::handler_t post_anthropic_count_tokens;
server_http_context::handler_t post_apply_template;
server_http_context::handler_t get_models;
server_http_context::handler_t post_tokenize;
server_http_context::handler_t post_detokenize;
server_http_context::handler_t post_embeddings;
server_http_context::handler_t post_embeddings_oai;
server_http_context::handler_t post_rerank;
server_http_context::handler_t get_lora_adapters;
server_http_context::handler_t post_lora_adapters;
private:
// TODO: move these outside of server_routes?
std::unique_ptr<server_res_generator> handle_slots_save(const server_http_req & req, int id_slot);
std::unique_ptr<server_res_generator> handle_slots_restore(const server_http_req & req, int id_slot);
std::unique_ptr<server_res_generator> handle_slots_erase(const server_http_req &, int id_slot);
std::unique_ptr<server_res_generator> handle_embeddings_impl(const server_http_req & req, task_response_type res_type);
const common_params & params;
server_context_impl & ctx_server;
std::function<bool()> is_ready;
};

View File

@ -199,7 +199,7 @@ server_task_result_ptr server_response::recv(const std::unordered_set<int> & id_
std::unique_lock<std::mutex> lock(mutex_results); std::unique_lock<std::mutex> lock(mutex_results);
condition_results.wait(lock, [&]{ condition_results.wait(lock, [&]{
if (!running) { if (!running) {
RES_DBG("%s : queue result stop\n", __func__); RES_DBG("%s : queue result stop\n", "recv");
std::terminate(); // we cannot return here since the caller is HTTP code std::terminate(); // we cannot return here since the caller is HTTP code
} }
return !queue_results.empty(); return !queue_results.empty();
@ -266,3 +266,86 @@ void server_response::terminate() {
running = false; running = false;
condition_results.notify_all(); condition_results.notify_all();
} }
//
// server_response_reader
//
void server_response_reader::post_tasks(std::vector<server_task> && tasks) {
id_tasks = server_task::get_list_id(tasks);
queue_results.add_waiting_tasks(tasks);
queue_tasks.post(std::move(tasks));
}
bool server_response_reader::has_next() const {
return !cancelled && received_count < id_tasks.size();
}
// return nullptr if should_stop() is true before receiving a result
// note: if one error is received, it will stop further processing and return error result
server_task_result_ptr server_response_reader::next(const std::function<bool()> & should_stop) {
while (true) {
server_task_result_ptr result = queue_results.recv_with_timeout(id_tasks, polling_interval_seconds);
if (result == nullptr) {
// timeout, check stop condition
if (should_stop()) {
SRV_DBG("%s", "stopping wait for next result due to should_stop condition\n");
return nullptr;
}
} else {
if (result->is_error()) {
stop(); // cancel remaining tasks
SRV_DBG("%s", "received error result, stopping further processing\n");
return result;
}
if (result->is_stop()) {
received_count++;
}
return result;
}
}
// should not reach here
}
server_response_reader::batch_response server_response_reader::wait_for_all(const std::function<bool()> & should_stop) {
batch_response batch_res;
batch_res.results.resize(id_tasks.size());
while (has_next()) {
auto res = next(should_stop);
if (res == nullptr) {
batch_res.is_terminated = true;
return batch_res;
}
if (res->is_error()) {
batch_res.error = std::move(res);
return batch_res;
}
const size_t idx = res->get_index();
GGML_ASSERT(idx < batch_res.results.size() && "index out of range");
GGML_ASSERT(batch_res.results[idx] == nullptr && "duplicate result received");
batch_res.results[idx] = std::move(res);
}
return batch_res;
}
void server_response_reader::stop() {
queue_results.remove_waiting_task_ids(id_tasks);
if (has_next() && !cancelled) {
// if tasks is not finished yet, cancel them
cancelled = true;
std::vector<server_task> cancel_tasks;
cancel_tasks.reserve(id_tasks.size());
for (const auto & id_task : id_tasks) {
SRV_WRN("cancel task, id_task = %d\n", id_task);
server_task task(SERVER_TASK_TYPE_CANCEL);
task.id_target = id_task;
queue_results.remove_waiting_task_id(id_task);
cancel_tasks.push_back(std::move(task));
}
// push to beginning of the queue, so it has highest priority
queue_tasks.post(std::move(cancel_tasks), true);
} else {
SRV_DBG("%s", "all tasks already finished, no need to cancel\n");
}
}

View File

@ -108,3 +108,39 @@ public:
// terminate the waiting loop // terminate the waiting loop
void terminate(); void terminate();
}; };
// utility class to make working with server_queue and server_response easier
// it provides a generator-like API for server responses
// support pooling connection state and aggregating multiple results
struct server_response_reader {
std::unordered_set<int> id_tasks;
server_queue & queue_tasks;
server_response & queue_results;
size_t received_count = 0;
bool cancelled = false;
int polling_interval_seconds;
// should_stop function will be called each polling_interval_seconds
server_response_reader(std::pair<server_queue &, server_response &> server_queues, int polling_interval_seconds)
: queue_tasks(server_queues.first), queue_results(server_queues.second), polling_interval_seconds(polling_interval_seconds) {}
~server_response_reader() {
stop();
}
void post_tasks(std::vector<server_task> && tasks);
bool has_next() const;
// return nullptr if should_stop() is true before receiving a result
// note: if one error is received, it will stop further processing and return error result
server_task_result_ptr next(const std::function<bool()> & should_stop);
struct batch_response {
bool is_terminated = false; // if true, indicates that processing was stopped before all results were received
std::vector<server_task_result_ptr> results;
server_task_result_ptr error; // nullptr if no error
};
// aggregate multiple results
batch_response wait_for_all(const std::function<bool()> & should_stop);
void stop();
};

File diff suppressed because it is too large Load Diff